Coverage for amqtt/contrib/persistence.py: 88%
168 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1from dataclasses import dataclass
2import logging
3from pathlib import Path
5from sqlalchemy import Boolean, Integer, LargeBinary, Result, String, select
6from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
7from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
9from amqtt.broker import BrokerContext, RetainedApplicationMessage
10from amqtt.contrib import DataClassListJSON
11from amqtt.errors import PluginError
12from amqtt.plugins.base import BasePlugin
13from amqtt.session import Session
15logger = logging.getLogger(__name__)
18class Base(DeclarativeBase):
19 pass
22@dataclass
23class RetainedMessage:
24 topic: str
25 data: str
26 qos: int
29@dataclass
30class Subscription:
31 topic: str
32 qos: int
35class StoredSession(Base):
36 __tablename__ = "stored_sessions"
38 id: Mapped[int] = mapped_column(primary_key=True)
39 client_id: Mapped[str] = mapped_column(String)
41 clean_session: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
43 will_flag: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false")
45 will_message: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None)
46 will_qos: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None)
47 will_retain: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=None)
48 will_topic: Mapped[str | None] = mapped_column(String, nullable=True, default=None)
50 keep_alive: Mapped[int] = mapped_column(Integer, default=0)
51 retained: Mapped[list[RetainedMessage]] = mapped_column(DataClassListJSON(RetainedMessage), default=list)
52 subscriptions: Mapped[list[Subscription]] = mapped_column(DataClassListJSON(Subscription), default=list)
55class StoredMessage(Base):
56 __tablename__ = "stored_messages"
58 id: Mapped[int] = mapped_column(primary_key=True)
59 topic: Mapped[str] = mapped_column(String)
60 data: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None)
61 qos: Mapped[int] = mapped_column(Integer, default=0)
64class SessionDBPlugin(BasePlugin[BrokerContext]):
65 """Plugin to store session information and retained topic messages in the event that the broker terminates abnormally.
67 Configuration:
68 - file *(string)* path & filename to store the session db. default: `amqtt.db`
69 - clear_on_shutdown *(bool)* if the broker shutdowns down normally, don't retain any information. default: `True`
71 """
73 def __init__(self, context: BrokerContext) -> None:
74 super().__init__(context)
76 # bypass the `test_plugins_correct_has_attr` until it can be updated
77 if not hasattr(self.config, "file"): 77 ↛ 78line 77 didn't jump to line 78 because the condition on line 77 was never true
78 logger.warning("`Config` is missing a `file` attribute")
79 return
81 self._engine = create_async_engine(f"sqlite+aiosqlite:///{self.config.file}")
82 self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
84 @staticmethod
85 async def _get_or_create_session(db_session: AsyncSession, client_id: str) -> StoredSession:
87 stmt = select(StoredSession).filter(StoredSession.client_id == client_id)
88 stored_session = await db_session.scalar(stmt)
89 if stored_session is None:
90 stored_session = StoredSession(client_id=client_id)
91 db_session.add(stored_session)
92 await db_session.flush()
93 return stored_session
95 @staticmethod
96 async def _get_or_create_message(db_session: AsyncSession, topic: str) -> StoredMessage:
98 stmt = select(StoredMessage).filter(StoredMessage.topic == topic)
99 stored_message = await db_session.scalar(stmt)
100 if stored_message is None: 100 ↛ 103line 100 didn't jump to line 103 because the condition on line 100 was always true
101 stored_message = StoredMessage(topic=topic)
102 db_session.add(stored_message)
103 await db_session.flush()
104 return stored_message
106 async def on_broker_client_connected(self, client_id: str, client_session: Session) -> None:
107 """Search to see if session already exists."""
108 # if client id doesn't exist, create (can ignore if session is anonymous)
109 # update session information (will, clean_session, etc)
111 # don't store session information for clean or anonymous sessions
112 if client_session.clean_session in (None, True) or client_session.is_anonymous:
113 return
114 async with self._db_session_maker() as db_session, db_session.begin():
115 stored_session = await self._get_or_create_session(db_session, client_id)
117 stored_session.clean_session = client_session.clean_session
118 stored_session.will_flag = client_session.will_flag
119 stored_session.will_message = client_session.will_message # type: ignore[assignment]
120 stored_session.will_qos = client_session.will_qos
121 stored_session.will_retain = client_session.will_retain
122 stored_session.will_topic = client_session.will_topic
123 stored_session.keep_alive = client_session.keep_alive
125 await db_session.flush()
127 async def on_broker_client_subscribed(self, client_id: str, topic: str, qos: int) -> None:
128 """Create/update subscription if clean session = false."""
129 session = self.context.get_session(client_id)
130 if not session: 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true
131 logger.warning(f"'{client_id}' is subscribing but doesn't have a session")
132 return
134 if session.clean_session: 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true
135 return
137 async with self._db_session_maker() as db_session, db_session.begin():
138 # stored sessions shouldn't need to be created here, but we'll use the same helper...
139 stored_session = await self._get_or_create_session(db_session, client_id)
140 stored_session.subscriptions = [*stored_session.subscriptions, Subscription(topic, qos)]
141 await db_session.flush()
143 async def on_broker_client_unsubscribed(self, client_id: str, topic: str) -> None:
144 """Remove subscription if clean session = false."""
146 async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedApplicationMessage) -> None:
147 """Update to retained messages.
149 if retained_message.data is None or '', the message is being cleared
150 """
151 # if client_id is valid, the retained message is for a disconnected client
152 if client_id is not None:
153 async with self._db_session_maker() as db_session, db_session.begin():
154 # stored sessions shouldn't need to be created here, but we'll use the same helper...
155 stored_session = await self._get_or_create_session(db_session, client_id)
156 stored_session.retained = [*stored_session.retained, RetainedMessage(retained_message.topic,
157 retained_message.data.decode(),
158 retained_message.qos or 0)]
159 await db_session.flush()
160 return
162 async with self._db_session_maker() as db_session, db_session.begin():
163 # if the retained message has data, we need to store/update for the topic
164 if retained_message.data:
165 client_message = await self._get_or_create_message(db_session, retained_message.topic)
166 client_message.data = retained_message.data # type: ignore[assignment]
167 client_message.qos = retained_message.qos or 0
168 await db_session.flush()
169 return
171 # if there is no data, clear the stored message (if exists) for the topic
172 stmt = select(StoredMessage).filter(StoredMessage.topic == retained_message.topic)
173 topic_message = await db_session.scalar(stmt)
174 if topic_message is not None: 174 ↛ 175line 174 didn't jump to line 175 because the condition on line 174 was never true
175 await db_session.delete(topic_message)
176 await db_session.flush()
177 return
179 async def on_broker_pre_start(self) -> None:
180 """Initialize the database and db connection."""
181 async with self._engine.begin() as conn:
182 await conn.run_sync(Base.metadata.create_all)
184 async def on_broker_post_start(self) -> None:
185 """Load subscriptions."""
186 if len(self.context.subscriptions) > 0: 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true
187 msg = "SessionDBPlugin : broker shouldn't have any subscriptions yet"
188 raise PluginError(msg)
190 if len(list(self.context.sessions)) > 0: 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true
191 msg = "SessionDBPlugin : broker shouldn't have any sessions yet"
192 raise PluginError(msg)
194 async with self._db_session_maker() as db_session, db_session.begin():
195 stmt = select(StoredSession)
196 stored_sessions = await db_session.execute(stmt)
198 restored_sessions = 0
199 for stored_session in stored_sessions.scalars():
200 await self.context.add_subscription(stored_session.client_id, None, None)
201 for subscription in stored_session.subscriptions:
202 await self.context.add_subscription(stored_session.client_id,
203 subscription.topic,
204 subscription.qos)
205 session = self.context.get_session(stored_session.client_id)
206 if not session: 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true
207 continue
208 session.clean_session = stored_session.clean_session
209 session.will_flag = stored_session.will_flag
210 session.will_message = stored_session.will_message
211 session.will_qos = stored_session.will_qos
212 session.will_retain = stored_session.will_retain
213 session.will_topic = stored_session.will_topic
214 session.keep_alive = stored_session.keep_alive
216 for message in stored_session.retained:
217 retained_message = RetainedApplicationMessage(
218 source_session=None,
219 topic=message.topic,
220 data=message.data.encode(),
221 qos=message.qos
222 )
223 await session.retained_messages.put(retained_message)
224 restored_sessions += 1
226 stmt = select(StoredMessage)
227 stored_messages: Result[tuple[StoredMessage]] = await db_session.execute(stmt)
229 restored_messages = 0
230 retained_messages = self.context.retained_messages
231 for stored_message in stored_messages.scalars():
232 retained_messages[stored_message.topic] = (RetainedApplicationMessage(
233 source_session=None,
234 topic=stored_message.topic,
235 data=stored_message.data or b"",
236 qos=stored_message.qos
237 ))
238 restored_messages += 1
239 logger.info(f"Retained messages restored: {restored_messages}")
241 logger.info(f"Restored {restored_sessions} sessions.")
243 async def on_broker_pre_shutdown(self) -> None:
244 """Clean up the db connection."""
245 await self._engine.dispose()
247 async def on_broker_post_shutdown(self) -> None:
249 if self.config.clear_on_shutdown and self.config.file.exists(): 249 ↛ 250line 249 didn't jump to line 250 because the condition on line 249 was never true
250 self.config.file.unlink()
252 @dataclass
253 class Config:
254 """Configuration variables."""
256 file: str | Path = "amqtt.db"
257 """path & filename to store the sqlite session db."""
258 clear_on_shutdown: bool = True
259 """if the broker shutdowns down normally, don't retain any information."""
261 def __post_init__(self) -> None:
262 """Create `Path` from string path."""
263 if isinstance(self.file, str): 263 ↛ 264line 263 didn't jump to line 264 because the condition on line 263 was never true
264 self.file = Path(self.file)