Coverage for amqtt/contrib/auth_db/models.py: 90%

77 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1from dataclasses import dataclass 

2import logging 

3from typing import TYPE_CHECKING, Any, Optional, Union, cast 

4 

5from sqlalchemy import String 

6from sqlalchemy.ext.hybrid import hybrid_property 

7from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 

8 

9from amqtt.contexts import Action 

10from amqtt.contrib import DataClassListJSON 

11from amqtt.plugins import TopicMatcher 

12 

13if TYPE_CHECKING: 

14 from passlib.context import CryptContext 

15 

16 

17logger = logging.getLogger(__name__) 

18 

19matcher = TopicMatcher() 

20 

21 

22@dataclass 

23class AllowedTopic: 

24 topic: str 

25 

26 def __contains__(self, item: Union[str, "AllowedTopic"]) -> bool: 

27 """Determine `in`.""" 

28 return self.__eq__(item) 

29 

30 def __eq__(self, item: object) -> bool: 

31 """Determine `==` or `!=`.""" 

32 if isinstance(item, str): 

33 return matcher.is_topic_allowed(item, self.topic) 

34 if isinstance(item, AllowedTopic): 34 ↛ 36line 34 didn't jump to line 36 because the condition on line 34 was always true

35 return item.topic == self.topic 

36 msg = "AllowedTopic can only be compared to another AllowedTopic or string." 

37 raise AttributeError(msg) 

38 

39 def __str__(self) -> str: 

40 """Display topic.""" 

41 return self.topic 

42 

43 def __repr__(self) -> str: 

44 """Display topic.""" 

45 return self.topic 

46 

47 

48class PasswordHasher: 

49 """singleton to initialize the CryptContext and then use it elsewhere in the code.""" 

50 

51 _instance: Optional["PasswordHasher"] = None 

52 

53 def __init__(self) -> None: 

54 

55 if not hasattr(self, "_crypt_context"): 

56 self._crypt_context: CryptContext | None = None 

57 

58 def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]) -> "PasswordHasher": 

59 if cls._instance is None: 

60 cls._instance = super().__new__(cls, *args, **kwargs) 

61 return cls._instance 

62 

63 @property 

64 def crypt_context(self) -> "CryptContext": 

65 if not self._crypt_context: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true

66 msg = "CryptContext is empty" 

67 raise ValueError(msg) 

68 return self._crypt_context 

69 

70 @crypt_context.setter 

71 def crypt_context(self, value: "CryptContext") -> None: 

72 self._crypt_context = value 

73 

74 

75class Base(DeclarativeBase): 

76 pass 

77 

78 

79class UserAuth(Base): 

80 __tablename__ = "user_auth" 

81 

82 id: Mapped[int] = mapped_column(primary_key=True) 

83 username: Mapped[str] = mapped_column(String, unique=True) 

84 _password_hash: Mapped[str] = mapped_column("password_hash", String(128)) 

85 

86 publish_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list) 

87 subscribe_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list) 

88 receive_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list) 

89 

90 @hybrid_property 

91 def password(self) -> None: 

92 msg = "Password is write-only" 

93 raise AttributeError(msg) 

94 

95 @password.inplace.setter # type: ignore[arg-type] 

96 def _password_setter(self, plain_password: str) -> None: 

97 self._password_hash = PasswordHasher().crypt_context.hash(plain_password) 

98 

99 def verify_password(self, plain_password: str) -> bool: 

100 return bool(PasswordHasher().crypt_context.verify(plain_password, self._password_hash)) 

101 

102 def __str__(self) -> str: 

103 """Display client id and password hash.""" 

104 return f"'{self.username}' with password hash: {self._password_hash}" 

105 

106 

107class TopicAuth(Base): 

108 __tablename__ = "topic_auth" 

109 

110 id: Mapped[int] = mapped_column(primary_key=True) 

111 username: Mapped[str] = mapped_column(String, unique=True) 

112 

113 publish_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list) 

114 subscribe_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list) 

115 receive_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list) 

116 

117 def get_topic_list(self, action: Action) -> list[AllowedTopic]: 

118 return cast("list[AllowedTopic]", getattr(self, f"{action}_acl")) 

119 

120 def __str__(self) -> str: 

121 """Display client id and password hash.""" 

122 return f"""'{self.username}': 

123\tpublish: {self.publish_acl}, subscribe: {self.subscribe_acl}, receive: {self.receive_acl} 

124"""