Coverage for amqtt/contrib/auth_db/__init__.py: 86%
27 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1"""Plugin to determine authentication of clients with DB storage."""
2from dataclasses import dataclass
4import click
6try:
7 from enum import StrEnum
8except ImportError:
9 # support for python 3.10
10 from enum import Enum
11 class StrEnum(str, Enum): # type: ignore[no-redef]
12 pass
14from .plugin import TopicAuthDBPlugin, UserAuthDBPlugin
17class DBType(StrEnum):
18 """Enumeration for supported relational databases."""
20 MARIA = "mariadb"
21 MYSQL = "mysql"
22 POSTGRESQL = "postgresql"
23 SQLITE = "sqlite"
26@dataclass
27class DBInfo:
28 """SQLAlchemy database information."""
30 connect_str: str
31 connect_port: int | None
34_db_map = {
35 DBType.MARIA: DBInfo("mysql+aiomysql", 3306),
36 DBType.MYSQL: DBInfo("mysql+aiomysql", 3306),
37 DBType.POSTGRESQL: DBInfo("postgresql+asyncpg", 5432),
38 DBType.SQLITE: DBInfo("sqlite+aiosqlite", None)
39}
42def db_connection_str(db_type: DBType, db_username: str, db_host: str, db_port: int | None, db_filename: str) -> str:
43 """Create sqlalchemy database connection string."""
44 db_info = _db_map[db_type]
45 if db_type == DBType.SQLITE:
46 return f"{db_info.connect_str}:///{db_filename}"
47 db_password = click.prompt("Enter the db password (press enter for none)", hide_input=True)
48 pwd = f":{db_password}" if db_password else ""
49 return f"{db_info.connect_str}://{db_username}:{pwd}@{db_host}:{db_port or db_info.connect_port}"
52__all__ = ["DBType", "TopicAuthDBPlugin", "UserAuthDBPlugin", "db_connection_str"]