Coverage for amqtt/contrib/auth_db/__init__.py: 86%

27 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1"""Plugin to determine authentication of clients with DB storage.""" 

2from dataclasses import dataclass 

3 

4import click 

5 

6try: 

7 from enum import StrEnum 

8except ImportError: 

9 # support for python 3.10 

10 from enum import Enum 

11 class StrEnum(str, Enum): # type: ignore[no-redef] 

12 pass 

13 

14from .plugin import TopicAuthDBPlugin, UserAuthDBPlugin 

15 

16 

17class DBType(StrEnum): 

18 """Enumeration for supported relational databases.""" 

19 

20 MARIA = "mariadb" 

21 MYSQL = "mysql" 

22 POSTGRESQL = "postgresql" 

23 SQLITE = "sqlite" 

24 

25 

26@dataclass 

27class DBInfo: 

28 """SQLAlchemy database information.""" 

29 

30 connect_str: str 

31 connect_port: int | None 

32 

33 

34_db_map = { 

35 DBType.MARIA: DBInfo("mysql+aiomysql", 3306), 

36 DBType.MYSQL: DBInfo("mysql+aiomysql", 3306), 

37 DBType.POSTGRESQL: DBInfo("postgresql+asyncpg", 5432), 

38 DBType.SQLITE: DBInfo("sqlite+aiosqlite", None) 

39} 

40 

41 

42def db_connection_str(db_type: DBType, db_username: str, db_host: str, db_port: int | None, db_filename: str) -> str: 

43 """Create sqlalchemy database connection string.""" 

44 db_info = _db_map[db_type] 

45 if db_type == DBType.SQLITE: 

46 return f"{db_info.connect_str}:///{db_filename}" 

47 db_password = click.prompt("Enter the db password (press enter for none)", hide_input=True) 

48 pwd = f":{db_password}" if db_password else "" 

49 return f"{db_info.connect_str}://{db_username}:{pwd}@{db_host}:{db_port or db_info.connect_port}" 

50 

51 

52__all__ = ["DBType", "TopicAuthDBPlugin", "UserAuthDBPlugin", "db_connection_str"]