Coverage for amqtt/contrib/auth_db/models.py: 90%
77 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1from dataclasses import dataclass
2import logging
3from typing import TYPE_CHECKING, Any, Optional, Union, cast
5from sqlalchemy import String
6from sqlalchemy.ext.hybrid import hybrid_property
7from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
9from amqtt.contexts import Action
10from amqtt.contrib import DataClassListJSON
11from amqtt.plugins import TopicMatcher
13if TYPE_CHECKING:
14 from passlib.context import CryptContext
17logger = logging.getLogger(__name__)
19matcher = TopicMatcher()
22@dataclass
23class AllowedTopic:
24 topic: str
26 def __contains__(self, item: Union[str, "AllowedTopic"]) -> bool:
27 """Determine `in`."""
28 return self.__eq__(item)
30 def __eq__(self, item: object) -> bool:
31 """Determine `==` or `!=`."""
32 if isinstance(item, str):
33 return matcher.is_topic_allowed(item, self.topic)
34 if isinstance(item, AllowedTopic): 34 ↛ 36line 34 didn't jump to line 36 because the condition on line 34 was always true
35 return item.topic == self.topic
36 msg = "AllowedTopic can only be compared to another AllowedTopic or string."
37 raise AttributeError(msg)
39 def __str__(self) -> str:
40 """Display topic."""
41 return self.topic
43 def __repr__(self) -> str:
44 """Display topic."""
45 return self.topic
48class PasswordHasher:
49 """singleton to initialize the CryptContext and then use it elsewhere in the code."""
51 _instance: Optional["PasswordHasher"] = None
53 def __init__(self) -> None:
55 if not hasattr(self, "_crypt_context"):
56 self._crypt_context: CryptContext | None = None
58 def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]) -> "PasswordHasher":
59 if cls._instance is None:
60 cls._instance = super().__new__(cls, *args, **kwargs)
61 return cls._instance
63 @property
64 def crypt_context(self) -> "CryptContext":
65 if not self._crypt_context: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true
66 msg = "CryptContext is empty"
67 raise ValueError(msg)
68 return self._crypt_context
70 @crypt_context.setter
71 def crypt_context(self, value: "CryptContext") -> None:
72 self._crypt_context = value
75class Base(DeclarativeBase):
76 pass
79class UserAuth(Base):
80 __tablename__ = "user_auth"
82 id: Mapped[int] = mapped_column(primary_key=True)
83 username: Mapped[str] = mapped_column(String, unique=True)
84 _password_hash: Mapped[str] = mapped_column("password_hash", String(128))
86 publish_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
87 subscribe_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
88 receive_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
90 @hybrid_property
91 def password(self) -> None:
92 msg = "Password is write-only"
93 raise AttributeError(msg)
95 @password.inplace.setter # type: ignore[arg-type]
96 def _password_setter(self, plain_password: str) -> None:
97 self._password_hash = PasswordHasher().crypt_context.hash(plain_password)
99 def verify_password(self, plain_password: str) -> bool:
100 return bool(PasswordHasher().crypt_context.verify(plain_password, self._password_hash))
102 def __str__(self) -> str:
103 """Display client id and password hash."""
104 return f"'{self.username}' with password hash: {self._password_hash}"
107class TopicAuth(Base):
108 __tablename__ = "topic_auth"
110 id: Mapped[int] = mapped_column(primary_key=True)
111 username: Mapped[str] = mapped_column(String, unique=True)
113 publish_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
114 subscribe_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
115 receive_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
117 def get_topic_list(self, action: Action) -> list[AllowedTopic]:
118 return cast("list[AllowedTopic]", getattr(self, f"{action}_acl"))
120 def __str__(self) -> str:
121 """Display client id and password hash."""
122 return f"""'{self.username}':
123\tpublish: {self.publish_acl}, subscribe: {self.subscribe_acl}, receive: {self.receive_acl}
124"""