Coverage for amqtt/contexts.py: 79%

234 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1from dataclasses import dataclass, field, fields, replace 

2import logging 

3import warnings 

4 

5try: 

6 from enum import Enum, StrEnum 

7except ImportError: 

8 # support for python 3.10 

9 from enum import Enum 

10 class StrEnum(str, Enum): # type: ignore[no-redef] 

11 pass 

12 

13from collections.abc import Iterator 

14from pathlib import Path 

15from typing import TYPE_CHECKING, Any, Literal 

16 

17from dacite import Config as DaciteConfig, from_dict as dict_to_dataclass 

18 

19from amqtt.mqtt.constants import QOS_0, QOS_2 

20 

21if TYPE_CHECKING: 

22 import asyncio 

23 

24logger = logging.getLogger(__name__) 

25 

26 

27class BaseContext: 

28 def __init__(self) -> None: 

29 self.loop: asyncio.AbstractEventLoop | None = None 

30 self.logger: logging.Logger = logging.getLogger(__name__) 

31 # cleanup with a `Generic` type 

32 self.config: ClientConfig | BrokerConfig | dict[str, Any] | None = None 

33 

34 

35class Action(StrEnum): 

36 """Actions issued by the broker.""" 

37 

38 SUBSCRIBE = "subscribe" 

39 PUBLISH = "publish" 

40 RECEIVE = "receive" 

41 

42 

43class ListenerType(StrEnum): 

44 """Types of mqtt listeners.""" 

45 

46 TCP = "tcp" 

47 WS = "ws" 

48 EXTERNAL = "external" 

49 

50 def __repr__(self) -> str: 

51 """Display the string value, instead of the enum member.""" 

52 return f'"{self.value!s}"' 

53 

54 

55class Dictable: 

56 """Add dictionary methods to a dataclass.""" 

57 

58 def __getitem__(self, key: str) -> Any: 

59 """Allow dict-style `[]` access to a dataclass.""" 

60 return self.get(key) 

61 

62 def get(self, name: str, default: Any = None) -> Any: 

63 """Allow dict-style access to a dataclass.""" 

64 name = name.replace("-", "_") 

65 if hasattr(self, name): 

66 return getattr(self, name) 

67 if default is not None: 67 ↛ 69line 67 didn't jump to line 69 because the condition on line 67 was always true

68 return default 

69 msg = f"'{name}' is not defined" 

70 raise ValueError(msg) 

71 

72 def __contains__(self, name: str) -> bool: 

73 """Provide dict-style 'in' check.""" 

74 return getattr(self, name.replace("-", "_"), None) is not None 

75 

76 def __iter__(self) -> Iterator[Any]: 

77 """Provide dict-style iteration.""" 

78 for f in fields(self): # type: ignore[arg-type] 

79 yield getattr(self, f.name) 

80 

81 def copy(self) -> dataclass: # type: ignore[valid-type] 

82 """Return a copy of the dataclass.""" 

83 return replace(self) # type: ignore[type-var] 

84 

85 @staticmethod 

86 def _coerce_lists(value: list[Any] | dict[str, Any] | Any) -> list[dict[str, Any]]: 

87 if isinstance(value, list): 

88 return value # It's already a list of dicts 

89 if isinstance(value, dict): 

90 return [value] # Promote single dict to a list 

91 msg = "Could not convert 'list' to 'list[dict[str, Any]]'" 

92 raise ValueError(msg) 

93 

94 

95@dataclass 

96class ListenerConfig(Dictable): 

97 """Structured configuration for a broker's listeners.""" 

98 

99 type: ListenerType = ListenerType.TCP 

100 """Type of listener: `tcp` for 'mqtt' or `ws` for 'websocket' when specified in dictionary or yaml.'""" 

101 bind: str | None = "0.0.0.0:1883" 

102 """address and port for the listener to bind to""" 

103 max_connections: int = 0 

104 """max number of connections allowed for this listener""" 

105 ssl: bool = False 

106 """secured by ssl""" 

107 cafile: str | Path | None = None 

108 """Path to a file of concatenated CA certificates in PEM format. See 

109 [Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info.""" 

110 capath: str | Path | None = None 

111 """Path to a directory containing one or more CA certificates in PEM format, following the 

112 [OpenSSL-specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/).""" 

113 cadata: str | Path | None = None 

114 """Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.""" 

115 certfile: str | Path | None = None 

116 """Full path to file in PEM format containing the server's certificate (as well as any number of CA 

117 certificates needed to establish the certificate's authenticity.)""" 

118 keyfile: str | Path | None = None 

119 """Full path to file in PEM format containing the server's private key.""" 

120 reader: str | None = None 

121 writer: str | None = None 

122 

123 def __post_init__(self) -> None: 

124 """Check config for errors and transform fields for easier use.""" 

125 if (self.certfile is None) ^ (self.keyfile is None): 125 ↛ 126line 125 didn't jump to line 126 because the condition on line 125 was never true

126 msg = "If specifying the 'certfile' or 'keyfile', both are required." 

127 raise ValueError(msg) 

128 

129 for fn in ("cafile", "capath", "certfile", "keyfile"): 

130 if isinstance(getattr(self, fn), str): 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true

131 setattr(self, fn, Path(getattr(self, fn))) 

132 if getattr(self, fn) and not getattr(self, fn).exists(): 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true

133 msg = f"'{fn}' does not exist : {getattr(self, fn)}" 

134 raise FileNotFoundError(msg) 

135 

136 def apply(self, other: "ListenerConfig") -> None: 

137 """Apply the field from 'other', if 'self' field is default.""" 

138 for f in fields(self): 

139 if getattr(self, f.name) == f.default: 

140 setattr(self, f.name, other[f.name]) 

141 

142 

143def default_listeners() -> dict[str, Any]: 

144 """Create defaults for BrokerConfig.listeners.""" 

145 return { 

146 "default": ListenerConfig() 

147 } 

148 

149 

150def default_broker_plugins() -> dict[str, Any]: 

151 """Create defaults for BrokerConfig.plugins.""" 

152 return { 

153 "amqtt.plugins.logging_amqtt.EventLoggerPlugin": {}, 

154 "amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {}, 

155 "amqtt.plugins.authentication.AnonymousAuthPlugin": {"allow_anonymous": True}, 

156 "amqtt.plugins.sys.broker.BrokerSysPlugin": {"sys_interval": 20} 

157 } 

158 

159 

160@dataclass 

161class BrokerConfig(Dictable): 

162 """Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary.""" 

163 

164 listeners: dict[Literal["default"] | str, ListenerConfig] = field(default_factory=default_listeners) # noqa: PYI051 

165 """Network of listeners used by the services. a 'default' named listener is required; if another listener 

166 does not set a value, the 'default' settings are applied. See 

167 [`ListenerConfig`](broker_config.md#amqtt.contexts.ListenerConfig) for more information.""" 

168 sys_interval: int | None = None 

169 """*Deprecated field to configure the `BrokerSysPlugin`. See [`BrokerSysPlugin`](../plugins/packaged_plugins.md#sys-topics) 

170 for recommended configuration.*""" 

171 timeout_disconnect_delay: int | None = 0 

172 """Client disconnect timeout without a keep-alive.""" 

173 session_expiry_interval: int | None = None 

174 """Seconds for an inactive session to be retained.""" 

175 auth: dict[str, Any] | None = None 

176 """*Deprecated field used to config EntryPoint-loaded plugins. See 

177 [`AnonymousAuthPlugin`](../plugins/packaged_plugins.md#anonymous-auth-plugin) and 

178 [`FileAuthPlugin`](../plugins/packaged_plugins.md#password-file-auth-plugin) for recommended configuration.*""" 

179 topic_check: dict[str, Any] | None = None 

180 """*Deprecated field used to config EntryPoint-loaded plugins. See 

181 [`TopicTabooPlugin`](../plugins/packaged_plugins.md#taboo-topic-plugin) and 

182 [`TopicACLPlugin`](../plugins/packaged_plugins.md#acl-topic-plugin) for recommended configuration method.*""" 

183 plugins: dict[str, Any] | list[str | dict[str, Any]] | None = field(default_factory=default_broker_plugins) 

184 """The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin` 

185 or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See 

186 [custom plugins](../plugins/custom_plugins.md) for more information. `list[str | dict[str,Any]]` is deprecated but available 

187 to support legacy use cases.""" 

188 

189 def __post_init__(self) -> None: 

190 """Check config for errors and transform fields for easier use.""" 

191 if self.sys_interval is not None: 

192 logger.warning("sys_interval is deprecated, use 'plugins' to define configuration") 

193 

194 if self.auth is not None or self.topic_check is not None: 

195 logger.warning("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration") 

196 

197 default_listener = self.listeners["default"] 

198 for listener_name, listener in self.listeners.items(): 

199 if listener_name == "default": 

200 continue 

201 listener.apply(default_listener) 

202 

203 if isinstance(self.plugins, list): 

204 _plugins: dict[str, Any] = {} 

205 for plugin in self.plugins: 

206 # in case a plugin in a yaml file is listed without config map 

207 if isinstance(plugin, str): 

208 _plugins |= {plugin: {}} 

209 continue 

210 _plugins |= plugin 

211 self.plugins = _plugins 

212 

213 @classmethod 

214 def from_dict(cls, d: dict[str, Any] | None) -> "BrokerConfig": 

215 """Create a broker config from a dictionary.""" 

216 if d is None: 216 ↛ 217line 216 didn't jump to line 217 because the condition on line 216 was never true

217 return BrokerConfig() 

218 

219 # patch the incoming dictionary so it can be loaded correctly 

220 if "topic-check" in d: 

221 d["topic_check"] = d["topic-check"] 

222 del d["topic-check"] 

223 

224 # identify EntryPoint plugin loading and prevent 'plugins' from getting defaults 

225 if ("auth" in d or "topic-check" in d) and "plugins" not in d: 

226 d["plugins"] = None 

227 

228 return dict_to_dataclass(data_class=BrokerConfig, 

229 data=d, 

230 config=DaciteConfig( 

231 cast=[StrEnum, ListenerType], 

232 strict=True, 

233 type_hooks={list[dict[str, Any]]: cls._coerce_lists} 

234 )) 

235 

236 

237@dataclass 

238class ConnectionConfig(Dictable): 

239 """Properties for connecting to the broker.""" 

240 

241 uri: str | None = "mqtt://127.0.0.1:1883" 

242 """URI of the broker""" 

243 cafile: str | Path | None = None 

244 """Path to a file of concatenated CA certificates in PEM format to verify broker's authenticity. See 

245 [Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info.""" 

246 capath: str | Path | None = None 

247 """Path to a directory containing one or more CA certificates in PEM format, following the 

248 [OpenSSL-specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/).""" 

249 cadata: str | None = None 

250 """The certificate to verify the broker's authenticity in an ASCII string format of one or more PEM-encoded 

251 certificates or a bytes-like object of DER-encoded certificates.""" 

252 certfile: str | Path | None = None 

253 """Full path to file in PEM format containing the client's certificate (as well as any number of CA 

254 certificates needed to establish the certificate's authenticity.)""" 

255 keyfile: str | Path | None = None 

256 """Full path to file in PEM format containing the client's private key associated with the certfile.""" 

257 

258 def __post__init__(self) -> None: 

259 """Check config for errors and transform fields for easier use.""" 

260 if (self.certfile is None) ^ (self.keyfile is None): 

261 msg = "If specifying the 'certfile' or 'keyfile', both are required." 

262 raise ValueError(msg) 

263 

264 for fn in ("cafile", "capath", "certfile", "keyfile"): 

265 if isinstance(getattr(self, fn), str): 

266 setattr(self, fn, Path(getattr(self, fn))) 

267 

268 

269@dataclass 

270class TopicConfig(Dictable): 

271 """Configuration of how messages to specific topics are published. 

272 

273 The topic name is specified as the key in the dictionary of the `ClientConfig.topics. 

274 """ 

275 

276 qos: int = 0 

277 """The quality of service associated with the publishing to this topic.""" 

278 retain: bool = False 

279 """Determines if the message should be retained by the topic it was published.""" 

280 

281 def __post__init__(self) -> None: 

282 """Check config for errors and transform fields for easier use.""" 

283 if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2): 

284 msg = "Topic config: default QoS must be 0, 1 or 2." 

285 raise ValueError(msg) 

286 

287 

288@dataclass 

289class WillConfig(Dictable): 

290 """Configuration of the 'last will & testament' of the client upon improper disconnection.""" 

291 

292 topic: str 

293 """The will message will be published to this topic.""" 

294 message: str 

295 """The contents of the message to be published.""" 

296 qos: int | None = QOS_0 

297 """The quality of service associated with sending this message.""" 

298 retain: bool | None = False 

299 """Determines if the message should be retained by the topic it was published.""" 

300 

301 def __post__init__(self) -> None: 

302 """Check config for errors and transform fields for easier use.""" 

303 if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2): 

304 msg = "Will config: default QoS must be 0, 1 or 2." 

305 raise ValueError(msg) 

306 

307 

308def default_client_plugins() -> dict[str, Any]: 

309 """Create defaults for `ClientConfig.plugins`.""" 

310 return { 

311 "amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {} 

312 } 

313 

314 

315@dataclass 

316class ClientConfig(Dictable): 

317 """Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary.""" 

318 

319 keep_alive: int | None = 10 

320 """Keep-alive timeout sent to the broker.""" 

321 ping_delay: int | None = 1 

322 """Auto-ping delay before keep-alive timeout. Setting to 0 will disable which may lead to broker disconnection.""" 

323 default_qos: int | None = QOS_0 

324 """Default QoS for messages published.""" 

325 default_retain: bool | None = False 

326 """Default retain value to messages published.""" 

327 auto_reconnect: bool | None = True 

328 """Enable or disable auto-reconnect if connection with the broker is interrupted.""" 

329 connection_timeout: int | None = 60 

330 """The number of seconds before a connection times out""" 

331 reconnect_retries: int | None = 2 

332 """Number of reconnection retry attempts. Negative value will cause client to reconnect indefinitely.""" 

333 reconnect_max_interval: int | None = 10 

334 """Maximum seconds to wait before retrying a connection.""" 

335 cleansession: bool | None = True 

336 """Upon reconnect, should subscriptions be cleared. Can be overridden by `MQTTClient.connect`""" 

337 topics: dict[str, TopicConfig] | None = field(default_factory=dict) 

338 """Specify the topics and what flags should be set for messages published to them.""" 

339 broker: ConnectionConfig | None = None 

340 """*Deprecated* Configuration for connecting to the broker. Use `connection` field instead.""" 

341 connection: ConnectionConfig = field(default_factory=ConnectionConfig) 

342 """Configuration for connecting to the broker. See 

343 [`ConnectionConfig`](client_config.md#amqtt.contexts.ConnectionConfig) for more information.""" 

344 plugins: dict[str, Any] | list[dict[str, Any]] | None = field(default_factory=default_client_plugins) 

345 """The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`; the value is 

346 a dictionary of configuration options for that plugin. See [custom plugins](../plugins/custom_plugins.md) for 

347 more information. `list[str | dict[str,Any]]` is deprecated but available to support legacy use cases.""" 

348 check_hostname: bool | None = True 

349 """If establishing a secure connection, should the hostname of the certificate be verified.""" 

350 will: WillConfig | None = None 

351 """Message, topic and flags that should be sent to if the client disconnects. See 

352 [`WillConfig`](client_config.md#amqtt.contexts.WillConfig) for more information.""" 

353 

354 def __post_init__(self) -> None: 

355 """Check config for errors and transform fields for easier use.""" 

356 if self.default_qos is not None and (self.default_qos < QOS_0 or self.default_qos > QOS_2): 356 ↛ 357line 356 didn't jump to line 357 because the condition on line 356 was never true

357 msg = "Client config: default QoS must be 0, 1 or 2." 

358 raise ValueError(msg) 

359 

360 if self.broker is not None: 

361 warnings.warn("The 'broker' option is deprecated, please use 'connection' instead.", stacklevel=2) 

362 self.connection = self.broker 

363 

364 if bool(not self.connection.keyfile) ^ bool(not self.connection.certfile): 364 ↛ 365line 364 didn't jump to line 365 because the condition on line 364 was never true

365 msg = "Connection key and certificate files are _both_ required." 

366 raise ValueError(msg) 

367 

368 @classmethod 

369 def from_dict(cls, d: dict[str, Any] | None) -> "ClientConfig": 

370 """Create a client config from a dictionary.""" 

371 if d is None: 371 ↛ 372line 371 didn't jump to line 372 because the condition on line 371 was never true

372 return ClientConfig() 

373 

374 return dict_to_dataclass(data_class=ClientConfig, 

375 data=d, 

376 config=DaciteConfig( 

377 cast=[StrEnum], 

378 strict=True) 

379 )