Coverage for amqtt/contrib/persistence.py: 88%

168 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1from dataclasses import dataclass 

2import logging 

3from pathlib import Path 

4 

5from sqlalchemy import Boolean, Integer, LargeBinary, Result, String, select 

6from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine 

7from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 

8 

9from amqtt.broker import BrokerContext, RetainedApplicationMessage 

10from amqtt.contrib import DataClassListJSON 

11from amqtt.errors import PluginError 

12from amqtt.plugins.base import BasePlugin 

13from amqtt.session import Session 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18class Base(DeclarativeBase): 

19 pass 

20 

21 

22@dataclass 

23class RetainedMessage: 

24 topic: str 

25 data: str 

26 qos: int 

27 

28 

29@dataclass 

30class Subscription: 

31 topic: str 

32 qos: int 

33 

34 

35class StoredSession(Base): 

36 __tablename__ = "stored_sessions" 

37 

38 id: Mapped[int] = mapped_column(primary_key=True) 

39 client_id: Mapped[str] = mapped_column(String) 

40 

41 clean_session: Mapped[bool | None] = mapped_column(Boolean, nullable=True) 

42 

43 will_flag: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false") 

44 

45 will_message: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None) 

46 will_qos: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None) 

47 will_retain: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=None) 

48 will_topic: Mapped[str | None] = mapped_column(String, nullable=True, default=None) 

49 

50 keep_alive: Mapped[int] = mapped_column(Integer, default=0) 

51 retained: Mapped[list[RetainedMessage]] = mapped_column(DataClassListJSON(RetainedMessage), default=list) 

52 subscriptions: Mapped[list[Subscription]] = mapped_column(DataClassListJSON(Subscription), default=list) 

53 

54 

55class StoredMessage(Base): 

56 __tablename__ = "stored_messages" 

57 

58 id: Mapped[int] = mapped_column(primary_key=True) 

59 topic: Mapped[str] = mapped_column(String) 

60 data: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None) 

61 qos: Mapped[int] = mapped_column(Integer, default=0) 

62 

63 

64class SessionDBPlugin(BasePlugin[BrokerContext]): 

65 """Plugin to store session information and retained topic messages in the event that the broker terminates abnormally. 

66 

67 Configuration: 

68 - file *(string)* path & filename to store the session db. default: `amqtt.db` 

69 - clear_on_shutdown *(bool)* if the broker shutdowns down normally, don't retain any information. default: `True` 

70 

71 """ 

72 

73 def __init__(self, context: BrokerContext) -> None: 

74 super().__init__(context) 

75 

76 # bypass the `test_plugins_correct_has_attr` until it can be updated 

77 if not hasattr(self.config, "file"): 77 ↛ 78line 77 didn't jump to line 78 because the condition on line 77 was never true

78 logger.warning("`Config` is missing a `file` attribute") 

79 return 

80 

81 self._engine = create_async_engine(f"sqlite+aiosqlite:///{self.config.file}") 

82 self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False) 

83 

84 @staticmethod 

85 async def _get_or_create_session(db_session: AsyncSession, client_id: str) -> StoredSession: 

86 

87 stmt = select(StoredSession).filter(StoredSession.client_id == client_id) 

88 stored_session = await db_session.scalar(stmt) 

89 if stored_session is None: 

90 stored_session = StoredSession(client_id=client_id) 

91 db_session.add(stored_session) 

92 await db_session.flush() 

93 return stored_session 

94 

95 @staticmethod 

96 async def _get_or_create_message(db_session: AsyncSession, topic: str) -> StoredMessage: 

97 

98 stmt = select(StoredMessage).filter(StoredMessage.topic == topic) 

99 stored_message = await db_session.scalar(stmt) 

100 if stored_message is None: 100 ↛ 103line 100 didn't jump to line 103 because the condition on line 100 was always true

101 stored_message = StoredMessage(topic=topic) 

102 db_session.add(stored_message) 

103 await db_session.flush() 

104 return stored_message 

105 

106 async def on_broker_client_connected(self, client_id: str, client_session: Session) -> None: 

107 """Search to see if session already exists.""" 

108 # if client id doesn't exist, create (can ignore if session is anonymous) 

109 # update session information (will, clean_session, etc) 

110 

111 # don't store session information for clean or anonymous sessions 

112 if client_session.clean_session in (None, True) or client_session.is_anonymous: 

113 return 

114 async with self._db_session_maker() as db_session, db_session.begin(): 

115 stored_session = await self._get_or_create_session(db_session, client_id) 

116 

117 stored_session.clean_session = client_session.clean_session 

118 stored_session.will_flag = client_session.will_flag 

119 stored_session.will_message = client_session.will_message # type: ignore[assignment] 

120 stored_session.will_qos = client_session.will_qos 

121 stored_session.will_retain = client_session.will_retain 

122 stored_session.will_topic = client_session.will_topic 

123 stored_session.keep_alive = client_session.keep_alive 

124 

125 await db_session.flush() 

126 

127 async def on_broker_client_subscribed(self, client_id: str, topic: str, qos: int) -> None: 

128 """Create/update subscription if clean session = false.""" 

129 session = self.context.get_session(client_id) 

130 if not session: 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true

131 logger.warning(f"'{client_id}' is subscribing but doesn't have a session") 

132 return 

133 

134 if session.clean_session: 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true

135 return 

136 

137 async with self._db_session_maker() as db_session, db_session.begin(): 

138 # stored sessions shouldn't need to be created here, but we'll use the same helper... 

139 stored_session = await self._get_or_create_session(db_session, client_id) 

140 stored_session.subscriptions = [*stored_session.subscriptions, Subscription(topic, qos)] 

141 await db_session.flush() 

142 

143 async def on_broker_client_unsubscribed(self, client_id: str, topic: str) -> None: 

144 """Remove subscription if clean session = false.""" 

145 

146 async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedApplicationMessage) -> None: 

147 """Update to retained messages. 

148 

149 if retained_message.data is None or '', the message is being cleared 

150 """ 

151 # if client_id is valid, the retained message is for a disconnected client 

152 if client_id is not None: 

153 async with self._db_session_maker() as db_session, db_session.begin(): 

154 # stored sessions shouldn't need to be created here, but we'll use the same helper... 

155 stored_session = await self._get_or_create_session(db_session, client_id) 

156 stored_session.retained = [*stored_session.retained, RetainedMessage(retained_message.topic, 

157 retained_message.data.decode(), 

158 retained_message.qos or 0)] 

159 await db_session.flush() 

160 return 

161 

162 async with self._db_session_maker() as db_session, db_session.begin(): 

163 # if the retained message has data, we need to store/update for the topic 

164 if retained_message.data: 

165 client_message = await self._get_or_create_message(db_session, retained_message.topic) 

166 client_message.data = retained_message.data # type: ignore[assignment] 

167 client_message.qos = retained_message.qos or 0 

168 await db_session.flush() 

169 return 

170 

171 # if there is no data, clear the stored message (if exists) for the topic 

172 stmt = select(StoredMessage).filter(StoredMessage.topic == retained_message.topic) 

173 topic_message = await db_session.scalar(stmt) 

174 if topic_message is not None: 174 ↛ 175line 174 didn't jump to line 175 because the condition on line 174 was never true

175 await db_session.delete(topic_message) 

176 await db_session.flush() 

177 return 

178 

179 async def on_broker_pre_start(self) -> None: 

180 """Initialize the database and db connection.""" 

181 async with self._engine.begin() as conn: 

182 await conn.run_sync(Base.metadata.create_all) 

183 

184 async def on_broker_post_start(self) -> None: 

185 """Load subscriptions.""" 

186 if len(self.context.subscriptions) > 0: 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true

187 msg = "SessionDBPlugin : broker shouldn't have any subscriptions yet" 

188 raise PluginError(msg) 

189 

190 if len(list(self.context.sessions)) > 0: 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true

191 msg = "SessionDBPlugin : broker shouldn't have any sessions yet" 

192 raise PluginError(msg) 

193 

194 async with self._db_session_maker() as db_session, db_session.begin(): 

195 stmt = select(StoredSession) 

196 stored_sessions = await db_session.execute(stmt) 

197 

198 restored_sessions = 0 

199 for stored_session in stored_sessions.scalars(): 

200 await self.context.add_subscription(stored_session.client_id, None, None) 

201 for subscription in stored_session.subscriptions: 

202 await self.context.add_subscription(stored_session.client_id, 

203 subscription.topic, 

204 subscription.qos) 

205 session = self.context.get_session(stored_session.client_id) 

206 if not session: 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true

207 continue 

208 session.clean_session = stored_session.clean_session 

209 session.will_flag = stored_session.will_flag 

210 session.will_message = stored_session.will_message 

211 session.will_qos = stored_session.will_qos 

212 session.will_retain = stored_session.will_retain 

213 session.will_topic = stored_session.will_topic 

214 session.keep_alive = stored_session.keep_alive 

215 

216 for message in stored_session.retained: 

217 retained_message = RetainedApplicationMessage( 

218 source_session=None, 

219 topic=message.topic, 

220 data=message.data.encode(), 

221 qos=message.qos 

222 ) 

223 await session.retained_messages.put(retained_message) 

224 restored_sessions += 1 

225 

226 stmt = select(StoredMessage) 

227 stored_messages: Result[tuple[StoredMessage]] = await db_session.execute(stmt) 

228 

229 restored_messages = 0 

230 retained_messages = self.context.retained_messages 

231 for stored_message in stored_messages.scalars(): 

232 retained_messages[stored_message.topic] = (RetainedApplicationMessage( 

233 source_session=None, 

234 topic=stored_message.topic, 

235 data=stored_message.data or b"", 

236 qos=stored_message.qos 

237 )) 

238 restored_messages += 1 

239 logger.info(f"Retained messages restored: {restored_messages}") 

240 

241 logger.info(f"Restored {restored_sessions} sessions.") 

242 

243 async def on_broker_pre_shutdown(self) -> None: 

244 """Clean up the db connection.""" 

245 await self._engine.dispose() 

246 

247 async def on_broker_post_shutdown(self) -> None: 

248 

249 if self.config.clear_on_shutdown and self.config.file.exists(): 249 ↛ 250line 249 didn't jump to line 250 because the condition on line 249 was never true

250 self.config.file.unlink() 

251 

252 @dataclass 

253 class Config: 

254 """Configuration variables.""" 

255 

256 file: str | Path = "amqtt.db" 

257 """path & filename to store the sqlite session db.""" 

258 clear_on_shutdown: bool = True 

259 """if the broker shutdowns down normally, don't retain any information.""" 

260 

261 def __post_init__(self) -> None: 

262 """Create `Path` from string path.""" 

263 if isinstance(self.file, str): 263 ↛ 264line 263 didn't jump to line 264 because the condition on line 263 was never true

264 self.file = Path(self.file)