Coverage for amqtt/contrib/auth_db/managers.py: 85%

141 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1from collections.abc import Iterator 

2import logging 

3 

4from sqlalchemy import select 

5from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine 

6 

7from amqtt.contexts import Action 

8from amqtt.contrib.auth_db.models import AllowedTopic, Base, TopicAuth, UserAuth 

9from amqtt.errors import MQTTError 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14class UserManager: 

15 

16 def __init__(self, connection: str) -> None: 

17 self._engine = create_async_engine(connection) 

18 self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False) 

19 

20 async def db_sync(self) -> None: 

21 """Sync the database schema.""" 

22 async with self._engine.begin() as conn: 

23 await conn.run_sync(Base.metadata.create_all) 

24 

25 @staticmethod 

26 async def _get_auth_or_raise(db_session: AsyncSession, username: str) -> UserAuth: 

27 stmt = select(UserAuth).filter(UserAuth.username == username) 

28 user_auth = await db_session.scalar(stmt) 

29 if not user_auth: 

30 msg = f"Username '{username}' doesn't exist." 

31 logger.debug(msg) 

32 raise MQTTError(msg) 

33 

34 return user_auth 

35 

36 async def get_user_auth(self, username: str) -> UserAuth | None: 

37 """Retrieve a user by username.""" 

38 async with self._db_session_maker() as db_session, db_session.begin(): 

39 try: 

40 return await self._get_auth_or_raise(db_session, username) 

41 except MQTTError: 

42 return None 

43 

44 async def list_user_auths(self) -> Iterator[UserAuth]: 

45 """Return list of all clients.""" 

46 async with self._db_session_maker() as db_session, db_session.begin(): 

47 stmt = select(UserAuth).order_by(UserAuth.username) 

48 users = await db_session.scalars(stmt) 

49 if not users: 49 ↛ 50line 49 didn't jump to line 50 because the condition on line 49 was never true

50 msg = "No users exist." 

51 logger.info(msg) 

52 raise MQTTError(msg) 

53 return users 

54 

55 async def create_user_auth(self, username: str, plain_password: str) -> UserAuth | None: 

56 """Create a new user.""" 

57 async with self._db_session_maker() as db_session, db_session.begin(): 

58 stmt = select(UserAuth).filter(UserAuth.username == username) 

59 user_auth = await db_session.scalar(stmt) 

60 if user_auth: 60 ↛ 61line 60 didn't jump to line 61 because the condition on line 60 was never true

61 msg = f"Username '{username}' already exists." 

62 logger.info(msg) 

63 raise MQTTError(msg) 

64 

65 user_auth = UserAuth(username=username) 

66 user_auth.password = plain_password 

67 

68 db_session.add(user_auth) 

69 await db_session.commit() 

70 await db_session.flush() 

71 return user_auth 

72 

73 async def delete_user_auth(self, username: str) -> UserAuth | None: 

74 """Delete a user.""" 

75 async with self._db_session_maker() as db_session, db_session.begin(): 

76 

77 try: 

78 user_auth = await self._get_auth_or_raise(db_session, username) 

79 except MQTTError: 

80 return None 

81 

82 await db_session.delete(user_auth) 

83 await db_session.commit() 

84 await db_session.flush() 

85 return user_auth 

86 

87 async def update_user_auth_password(self, username: str, plain_password: str) -> UserAuth | None: 

88 """Change a user's password.""" 

89 async with self._db_session_maker() as db_session, db_session.begin(): 

90 user_auth = await self._get_auth_or_raise(db_session, username) 

91 user_auth.password = plain_password 

92 await db_session.commit() 

93 await db_session.flush() 

94 return user_auth 

95 

96 

97class TopicManager: 

98 

99 def __init__(self, connection: str) -> None: 

100 self._engine = create_async_engine(connection) 

101 self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False) 

102 

103 async def db_sync(self) -> None: 

104 """Sync the database schema.""" 

105 async with self._engine.begin() as conn: 

106 await conn.run_sync(Base.metadata.create_all) 

107 

108 @staticmethod 

109 async def _get_auth_or_raise(db_session: AsyncSession, username: str) -> TopicAuth: 

110 stmt = select(TopicAuth).filter(TopicAuth.username == username) 

111 topic_auth = await db_session.scalar(stmt) 

112 if not topic_auth: 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true

113 msg = f"Username '{username}' doesn't exist." 

114 logger.debug(msg) 

115 raise MQTTError(msg) 

116 

117 return topic_auth 

118 

119 @staticmethod 

120 def _field_name(action: Action) -> str: 

121 return f"{action}_acl" 

122 

123 async def create_topic_auth(self, username: str) -> TopicAuth | None: 

124 """Create a new user.""" 

125 async with self._db_session_maker() as db_session, db_session.begin(): 

126 stmt = select(TopicAuth).filter(TopicAuth.username == username) 

127 topic_auth = await db_session.scalar(stmt) 

128 if topic_auth: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true

129 msg = f"Username '{username}' already exists." 

130 raise MQTTError(msg) 

131 

132 topic_auth = TopicAuth(username=username) 

133 

134 db_session.add(topic_auth) 

135 await db_session.commit() 

136 await db_session.flush() 

137 return topic_auth 

138 

139 async def get_topic_auth(self, username: str) -> TopicAuth | None: 

140 """Retrieve a allowed topics by username.""" 

141 async with self._db_session_maker() as db_session, db_session.begin(): 

142 try: 

143 return await self._get_auth_or_raise(db_session, username) 

144 except MQTTError: 

145 return None 

146 

147 async def list_topic_auths(self) -> Iterator[TopicAuth]: 

148 """Return list of all authorized clients.""" 

149 async with self._db_session_maker() as db_session, db_session.begin(): 

150 stmt = select(TopicAuth).order_by(TopicAuth.username) 

151 topics = await db_session.scalars(stmt) 

152 if not topics: 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true

153 msg = "No topics exist." 

154 logger.info(msg) 

155 raise MQTTError(msg) 

156 return topics 

157 

158 async def add_allowed_topic(self, username: str, topic: str, action: Action) -> list[AllowedTopic] | None: 

159 """Add allowed topic from action for user.""" 

160 if action == Action.PUBLISH and topic.startswith("$"): 

161 msg = "MQTT does not allow clients to publish to $ topics." 

162 raise MQTTError(msg) 

163 

164 async with self._db_session_maker() as db_session, db_session.begin(): 

165 user_auth = await self._get_auth_or_raise(db_session, username) 

166 topic_list = getattr(user_auth, self._field_name(action)) 

167 

168 updated_list = [*topic_list, AllowedTopic(topic)] 

169 setattr(user_auth, self._field_name(action), updated_list) 

170 await db_session.commit() 

171 await db_session.flush() 

172 return updated_list 

173 

174 async def remove_allowed_topic(self, username: str, topic: str, action: Action) -> list[AllowedTopic] | None: 

175 """Remove topic from action for user.""" 

176 async with self._db_session_maker() as db_session, db_session.begin(): 

177 topic_auth = await self._get_auth_or_raise(db_session, username) 

178 topic_list = topic_auth.get_topic_list(action) 

179 

180 if AllowedTopic(topic) not in topic_list: 

181 msg = f"Client '{username}' doesn't have topic '{topic}' for action '{action}'." 

182 logger.debug(msg) 

183 raise MQTTError(msg) 

184 

185 updated_list = [allowed_topic for allowed_topic in topic_list if allowed_topic != AllowedTopic(topic)] 

186 

187 setattr(topic_auth, f"{action}_acl", updated_list) 

188 await db_session.commit() 

189 await db_session.flush() 

190 return updated_list