Coverage for amqtt/contrib/auth_db/managers.py: 85%
141 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1from collections.abc import Iterator
2import logging
4from sqlalchemy import select
5from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
7from amqtt.contexts import Action
8from amqtt.contrib.auth_db.models import AllowedTopic, Base, TopicAuth, UserAuth
9from amqtt.errors import MQTTError
11logger = logging.getLogger(__name__)
14class UserManager:
16 def __init__(self, connection: str) -> None:
17 self._engine = create_async_engine(connection)
18 self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
20 async def db_sync(self) -> None:
21 """Sync the database schema."""
22 async with self._engine.begin() as conn:
23 await conn.run_sync(Base.metadata.create_all)
25 @staticmethod
26 async def _get_auth_or_raise(db_session: AsyncSession, username: str) -> UserAuth:
27 stmt = select(UserAuth).filter(UserAuth.username == username)
28 user_auth = await db_session.scalar(stmt)
29 if not user_auth:
30 msg = f"Username '{username}' doesn't exist."
31 logger.debug(msg)
32 raise MQTTError(msg)
34 return user_auth
36 async def get_user_auth(self, username: str) -> UserAuth | None:
37 """Retrieve a user by username."""
38 async with self._db_session_maker() as db_session, db_session.begin():
39 try:
40 return await self._get_auth_or_raise(db_session, username)
41 except MQTTError:
42 return None
44 async def list_user_auths(self) -> Iterator[UserAuth]:
45 """Return list of all clients."""
46 async with self._db_session_maker() as db_session, db_session.begin():
47 stmt = select(UserAuth).order_by(UserAuth.username)
48 users = await db_session.scalars(stmt)
49 if not users: 49 ↛ 50line 49 didn't jump to line 50 because the condition on line 49 was never true
50 msg = "No users exist."
51 logger.info(msg)
52 raise MQTTError(msg)
53 return users
55 async def create_user_auth(self, username: str, plain_password: str) -> UserAuth | None:
56 """Create a new user."""
57 async with self._db_session_maker() as db_session, db_session.begin():
58 stmt = select(UserAuth).filter(UserAuth.username == username)
59 user_auth = await db_session.scalar(stmt)
60 if user_auth: 60 ↛ 61line 60 didn't jump to line 61 because the condition on line 60 was never true
61 msg = f"Username '{username}' already exists."
62 logger.info(msg)
63 raise MQTTError(msg)
65 user_auth = UserAuth(username=username)
66 user_auth.password = plain_password
68 db_session.add(user_auth)
69 await db_session.commit()
70 await db_session.flush()
71 return user_auth
73 async def delete_user_auth(self, username: str) -> UserAuth | None:
74 """Delete a user."""
75 async with self._db_session_maker() as db_session, db_session.begin():
77 try:
78 user_auth = await self._get_auth_or_raise(db_session, username)
79 except MQTTError:
80 return None
82 await db_session.delete(user_auth)
83 await db_session.commit()
84 await db_session.flush()
85 return user_auth
87 async def update_user_auth_password(self, username: str, plain_password: str) -> UserAuth | None:
88 """Change a user's password."""
89 async with self._db_session_maker() as db_session, db_session.begin():
90 user_auth = await self._get_auth_or_raise(db_session, username)
91 user_auth.password = plain_password
92 await db_session.commit()
93 await db_session.flush()
94 return user_auth
97class TopicManager:
99 def __init__(self, connection: str) -> None:
100 self._engine = create_async_engine(connection)
101 self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
103 async def db_sync(self) -> None:
104 """Sync the database schema."""
105 async with self._engine.begin() as conn:
106 await conn.run_sync(Base.metadata.create_all)
108 @staticmethod
109 async def _get_auth_or_raise(db_session: AsyncSession, username: str) -> TopicAuth:
110 stmt = select(TopicAuth).filter(TopicAuth.username == username)
111 topic_auth = await db_session.scalar(stmt)
112 if not topic_auth: 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true
113 msg = f"Username '{username}' doesn't exist."
114 logger.debug(msg)
115 raise MQTTError(msg)
117 return topic_auth
119 @staticmethod
120 def _field_name(action: Action) -> str:
121 return f"{action}_acl"
123 async def create_topic_auth(self, username: str) -> TopicAuth | None:
124 """Create a new user."""
125 async with self._db_session_maker() as db_session, db_session.begin():
126 stmt = select(TopicAuth).filter(TopicAuth.username == username)
127 topic_auth = await db_session.scalar(stmt)
128 if topic_auth: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true
129 msg = f"Username '{username}' already exists."
130 raise MQTTError(msg)
132 topic_auth = TopicAuth(username=username)
134 db_session.add(topic_auth)
135 await db_session.commit()
136 await db_session.flush()
137 return topic_auth
139 async def get_topic_auth(self, username: str) -> TopicAuth | None:
140 """Retrieve a allowed topics by username."""
141 async with self._db_session_maker() as db_session, db_session.begin():
142 try:
143 return await self._get_auth_or_raise(db_session, username)
144 except MQTTError:
145 return None
147 async def list_topic_auths(self) -> Iterator[TopicAuth]:
148 """Return list of all authorized clients."""
149 async with self._db_session_maker() as db_session, db_session.begin():
150 stmt = select(TopicAuth).order_by(TopicAuth.username)
151 topics = await db_session.scalars(stmt)
152 if not topics: 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true
153 msg = "No topics exist."
154 logger.info(msg)
155 raise MQTTError(msg)
156 return topics
158 async def add_allowed_topic(self, username: str, topic: str, action: Action) -> list[AllowedTopic] | None:
159 """Add allowed topic from action for user."""
160 if action == Action.PUBLISH and topic.startswith("$"):
161 msg = "MQTT does not allow clients to publish to $ topics."
162 raise MQTTError(msg)
164 async with self._db_session_maker() as db_session, db_session.begin():
165 user_auth = await self._get_auth_or_raise(db_session, username)
166 topic_list = getattr(user_auth, self._field_name(action))
168 updated_list = [*topic_list, AllowedTopic(topic)]
169 setattr(user_auth, self._field_name(action), updated_list)
170 await db_session.commit()
171 await db_session.flush()
172 return updated_list
174 async def remove_allowed_topic(self, username: str, topic: str, action: Action) -> list[AllowedTopic] | None:
175 """Remove topic from action for user."""
176 async with self._db_session_maker() as db_session, db_session.begin():
177 topic_auth = await self._get_auth_or_raise(db_session, username)
178 topic_list = topic_auth.get_topic_list(action)
180 if AllowedTopic(topic) not in topic_list:
181 msg = f"Client '{username}' doesn't have topic '{topic}' for action '{action}'."
182 logger.debug(msg)
183 raise MQTTError(msg)
185 updated_list = [allowed_topic for allowed_topic in topic_list if allowed_topic != AllowedTopic(topic)]
187 setattr(topic_auth, f"{action}_acl", updated_list)
188 await db_session.commit()
189 await db_session.flush()
190 return updated_list