Coverage for amqtt/contexts.py: 79%
234 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1from dataclasses import dataclass, field, fields, replace
2import logging
3import warnings
5try:
6 from enum import Enum, StrEnum
7except ImportError:
8 # support for python 3.10
9 from enum import Enum
10 class StrEnum(str, Enum): # type: ignore[no-redef]
11 pass
13from collections.abc import Iterator
14from pathlib import Path
15from typing import TYPE_CHECKING, Any, Literal
17from dacite import Config as DaciteConfig, from_dict as dict_to_dataclass
19from amqtt.mqtt.constants import QOS_0, QOS_2
21if TYPE_CHECKING:
22 import asyncio
24logger = logging.getLogger(__name__)
27class BaseContext:
28 def __init__(self) -> None:
29 self.loop: asyncio.AbstractEventLoop | None = None
30 self.logger: logging.Logger = logging.getLogger(__name__)
31 # cleanup with a `Generic` type
32 self.config: ClientConfig | BrokerConfig | dict[str, Any] | None = None
35class Action(StrEnum):
36 """Actions issued by the broker."""
38 SUBSCRIBE = "subscribe"
39 PUBLISH = "publish"
40 RECEIVE = "receive"
43class ListenerType(StrEnum):
44 """Types of mqtt listeners."""
46 TCP = "tcp"
47 WS = "ws"
48 EXTERNAL = "external"
50 def __repr__(self) -> str:
51 """Display the string value, instead of the enum member."""
52 return f'"{self.value!s}"'
55class Dictable:
56 """Add dictionary methods to a dataclass."""
58 def __getitem__(self, key: str) -> Any:
59 """Allow dict-style `[]` access to a dataclass."""
60 return self.get(key)
62 def get(self, name: str, default: Any = None) -> Any:
63 """Allow dict-style access to a dataclass."""
64 name = name.replace("-", "_")
65 if hasattr(self, name):
66 return getattr(self, name)
67 if default is not None: 67 ↛ 69line 67 didn't jump to line 69 because the condition on line 67 was always true
68 return default
69 msg = f"'{name}' is not defined"
70 raise ValueError(msg)
72 def __contains__(self, name: str) -> bool:
73 """Provide dict-style 'in' check."""
74 return getattr(self, name.replace("-", "_"), None) is not None
76 def __iter__(self) -> Iterator[Any]:
77 """Provide dict-style iteration."""
78 for f in fields(self): # type: ignore[arg-type]
79 yield getattr(self, f.name)
81 def copy(self) -> dataclass: # type: ignore[valid-type]
82 """Return a copy of the dataclass."""
83 return replace(self) # type: ignore[type-var]
85 @staticmethod
86 def _coerce_lists(value: list[Any] | dict[str, Any] | Any) -> list[dict[str, Any]]:
87 if isinstance(value, list):
88 return value # It's already a list of dicts
89 if isinstance(value, dict):
90 return [value] # Promote single dict to a list
91 msg = "Could not convert 'list' to 'list[dict[str, Any]]'"
92 raise ValueError(msg)
95@dataclass
96class ListenerConfig(Dictable):
97 """Structured configuration for a broker's listeners."""
99 type: ListenerType = ListenerType.TCP
100 """Type of listener: `tcp` for 'mqtt' or `ws` for 'websocket' when specified in dictionary or yaml.'"""
101 bind: str | None = "0.0.0.0:1883"
102 """address and port for the listener to bind to"""
103 max_connections: int = 0
104 """max number of connections allowed for this listener"""
105 ssl: bool = False
106 """secured by ssl"""
107 cafile: str | Path | None = None
108 """Path to a file of concatenated CA certificates in PEM format. See
109 [Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info."""
110 capath: str | Path | None = None
111 """Path to a directory containing one or more CA certificates in PEM format, following the
112 [OpenSSL-specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/)."""
113 cadata: str | Path | None = None
114 """Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates."""
115 certfile: str | Path | None = None
116 """Full path to file in PEM format containing the server's certificate (as well as any number of CA
117 certificates needed to establish the certificate's authenticity.)"""
118 keyfile: str | Path | None = None
119 """Full path to file in PEM format containing the server's private key."""
120 reader: str | None = None
121 writer: str | None = None
123 def __post_init__(self) -> None:
124 """Check config for errors and transform fields for easier use."""
125 if (self.certfile is None) ^ (self.keyfile is None): 125 ↛ 126line 125 didn't jump to line 126 because the condition on line 125 was never true
126 msg = "If specifying the 'certfile' or 'keyfile', both are required."
127 raise ValueError(msg)
129 for fn in ("cafile", "capath", "certfile", "keyfile"):
130 if isinstance(getattr(self, fn), str): 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true
131 setattr(self, fn, Path(getattr(self, fn)))
132 if getattr(self, fn) and not getattr(self, fn).exists(): 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true
133 msg = f"'{fn}' does not exist : {getattr(self, fn)}"
134 raise FileNotFoundError(msg)
136 def apply(self, other: "ListenerConfig") -> None:
137 """Apply the field from 'other', if 'self' field is default."""
138 for f in fields(self):
139 if getattr(self, f.name) == f.default:
140 setattr(self, f.name, other[f.name])
143def default_listeners() -> dict[str, Any]:
144 """Create defaults for BrokerConfig.listeners."""
145 return {
146 "default": ListenerConfig()
147 }
150def default_broker_plugins() -> dict[str, Any]:
151 """Create defaults for BrokerConfig.plugins."""
152 return {
153 "amqtt.plugins.logging_amqtt.EventLoggerPlugin": {},
154 "amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {},
155 "amqtt.plugins.authentication.AnonymousAuthPlugin": {"allow_anonymous": True},
156 "amqtt.plugins.sys.broker.BrokerSysPlugin": {"sys_interval": 20}
157 }
160@dataclass
161class BrokerConfig(Dictable):
162 """Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
164 listeners: dict[Literal["default"] | str, ListenerConfig] = field(default_factory=default_listeners) # noqa: PYI051
165 """Network of listeners used by the services. a 'default' named listener is required; if another listener
166 does not set a value, the 'default' settings are applied. See
167 [`ListenerConfig`](broker_config.md#amqtt.contexts.ListenerConfig) for more information."""
168 sys_interval: int | None = None
169 """*Deprecated field to configure the `BrokerSysPlugin`. See [`BrokerSysPlugin`](../plugins/packaged_plugins.md#sys-topics)
170 for recommended configuration.*"""
171 timeout_disconnect_delay: int | None = 0
172 """Client disconnect timeout without a keep-alive."""
173 session_expiry_interval: int | None = None
174 """Seconds for an inactive session to be retained."""
175 auth: dict[str, Any] | None = None
176 """*Deprecated field used to config EntryPoint-loaded plugins. See
177 [`AnonymousAuthPlugin`](../plugins/packaged_plugins.md#anonymous-auth-plugin) and
178 [`FileAuthPlugin`](../plugins/packaged_plugins.md#password-file-auth-plugin) for recommended configuration.*"""
179 topic_check: dict[str, Any] | None = None
180 """*Deprecated field used to config EntryPoint-loaded plugins. See
181 [`TopicTabooPlugin`](../plugins/packaged_plugins.md#taboo-topic-plugin) and
182 [`TopicACLPlugin`](../plugins/packaged_plugins.md#acl-topic-plugin) for recommended configuration method.*"""
183 plugins: dict[str, Any] | list[str | dict[str, Any]] | None = field(default_factory=default_broker_plugins)
184 """The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin`
185 or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
186 [custom plugins](../plugins/custom_plugins.md) for more information. `list[str | dict[str,Any]]` is deprecated but available
187 to support legacy use cases."""
189 def __post_init__(self) -> None:
190 """Check config for errors and transform fields for easier use."""
191 if self.sys_interval is not None:
192 logger.warning("sys_interval is deprecated, use 'plugins' to define configuration")
194 if self.auth is not None or self.topic_check is not None:
195 logger.warning("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration")
197 default_listener = self.listeners["default"]
198 for listener_name, listener in self.listeners.items():
199 if listener_name == "default":
200 continue
201 listener.apply(default_listener)
203 if isinstance(self.plugins, list):
204 _plugins: dict[str, Any] = {}
205 for plugin in self.plugins:
206 # in case a plugin in a yaml file is listed without config map
207 if isinstance(plugin, str):
208 _plugins |= {plugin: {}}
209 continue
210 _plugins |= plugin
211 self.plugins = _plugins
213 @classmethod
214 def from_dict(cls, d: dict[str, Any] | None) -> "BrokerConfig":
215 """Create a broker config from a dictionary."""
216 if d is None: 216 ↛ 217line 216 didn't jump to line 217 because the condition on line 216 was never true
217 return BrokerConfig()
219 # patch the incoming dictionary so it can be loaded correctly
220 if "topic-check" in d:
221 d["topic_check"] = d["topic-check"]
222 del d["topic-check"]
224 # identify EntryPoint plugin loading and prevent 'plugins' from getting defaults
225 if ("auth" in d or "topic-check" in d) and "plugins" not in d:
226 d["plugins"] = None
228 return dict_to_dataclass(data_class=BrokerConfig,
229 data=d,
230 config=DaciteConfig(
231 cast=[StrEnum, ListenerType],
232 strict=True,
233 type_hooks={list[dict[str, Any]]: cls._coerce_lists}
234 ))
237@dataclass
238class ConnectionConfig(Dictable):
239 """Properties for connecting to the broker."""
241 uri: str | None = "mqtt://127.0.0.1:1883"
242 """URI of the broker"""
243 cafile: str | Path | None = None
244 """Path to a file of concatenated CA certificates in PEM format to verify broker's authenticity. See
245 [Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info."""
246 capath: str | Path | None = None
247 """Path to a directory containing one or more CA certificates in PEM format, following the
248 [OpenSSL-specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/)."""
249 cadata: str | None = None
250 """The certificate to verify the broker's authenticity in an ASCII string format of one or more PEM-encoded
251 certificates or a bytes-like object of DER-encoded certificates."""
252 certfile: str | Path | None = None
253 """Full path to file in PEM format containing the client's certificate (as well as any number of CA
254 certificates needed to establish the certificate's authenticity.)"""
255 keyfile: str | Path | None = None
256 """Full path to file in PEM format containing the client's private key associated with the certfile."""
258 def __post__init__(self) -> None:
259 """Check config for errors and transform fields for easier use."""
260 if (self.certfile is None) ^ (self.keyfile is None):
261 msg = "If specifying the 'certfile' or 'keyfile', both are required."
262 raise ValueError(msg)
264 for fn in ("cafile", "capath", "certfile", "keyfile"):
265 if isinstance(getattr(self, fn), str):
266 setattr(self, fn, Path(getattr(self, fn)))
269@dataclass
270class TopicConfig(Dictable):
271 """Configuration of how messages to specific topics are published.
273 The topic name is specified as the key in the dictionary of the `ClientConfig.topics.
274 """
276 qos: int = 0
277 """The quality of service associated with the publishing to this topic."""
278 retain: bool = False
279 """Determines if the message should be retained by the topic it was published."""
281 def __post__init__(self) -> None:
282 """Check config for errors and transform fields for easier use."""
283 if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
284 msg = "Topic config: default QoS must be 0, 1 or 2."
285 raise ValueError(msg)
288@dataclass
289class WillConfig(Dictable):
290 """Configuration of the 'last will & testament' of the client upon improper disconnection."""
292 topic: str
293 """The will message will be published to this topic."""
294 message: str
295 """The contents of the message to be published."""
296 qos: int | None = QOS_0
297 """The quality of service associated with sending this message."""
298 retain: bool | None = False
299 """Determines if the message should be retained by the topic it was published."""
301 def __post__init__(self) -> None:
302 """Check config for errors and transform fields for easier use."""
303 if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
304 msg = "Will config: default QoS must be 0, 1 or 2."
305 raise ValueError(msg)
308def default_client_plugins() -> dict[str, Any]:
309 """Create defaults for `ClientConfig.plugins`."""
310 return {
311 "amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {}
312 }
315@dataclass
316class ClientConfig(Dictable):
317 """Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
319 keep_alive: int | None = 10
320 """Keep-alive timeout sent to the broker."""
321 ping_delay: int | None = 1
322 """Auto-ping delay before keep-alive timeout. Setting to 0 will disable which may lead to broker disconnection."""
323 default_qos: int | None = QOS_0
324 """Default QoS for messages published."""
325 default_retain: bool | None = False
326 """Default retain value to messages published."""
327 auto_reconnect: bool | None = True
328 """Enable or disable auto-reconnect if connection with the broker is interrupted."""
329 connection_timeout: int | None = 60
330 """The number of seconds before a connection times out"""
331 reconnect_retries: int | None = 2
332 """Number of reconnection retry attempts. Negative value will cause client to reconnect indefinitely."""
333 reconnect_max_interval: int | None = 10
334 """Maximum seconds to wait before retrying a connection."""
335 cleansession: bool | None = True
336 """Upon reconnect, should subscriptions be cleared. Can be overridden by `MQTTClient.connect`"""
337 topics: dict[str, TopicConfig] | None = field(default_factory=dict)
338 """Specify the topics and what flags should be set for messages published to them."""
339 broker: ConnectionConfig | None = None
340 """*Deprecated* Configuration for connecting to the broker. Use `connection` field instead."""
341 connection: ConnectionConfig = field(default_factory=ConnectionConfig)
342 """Configuration for connecting to the broker. See
343 [`ConnectionConfig`](client_config.md#amqtt.contexts.ConnectionConfig) for more information."""
344 plugins: dict[str, Any] | list[dict[str, Any]] | None = field(default_factory=default_client_plugins)
345 """The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`; the value is
346 a dictionary of configuration options for that plugin. See [custom plugins](../plugins/custom_plugins.md) for
347 more information. `list[str | dict[str,Any]]` is deprecated but available to support legacy use cases."""
348 check_hostname: bool | None = True
349 """If establishing a secure connection, should the hostname of the certificate be verified."""
350 will: WillConfig | None = None
351 """Message, topic and flags that should be sent to if the client disconnects. See
352 [`WillConfig`](client_config.md#amqtt.contexts.WillConfig) for more information."""
354 def __post_init__(self) -> None:
355 """Check config for errors and transform fields for easier use."""
356 if self.default_qos is not None and (self.default_qos < QOS_0 or self.default_qos > QOS_2): 356 ↛ 357line 356 didn't jump to line 357 because the condition on line 356 was never true
357 msg = "Client config: default QoS must be 0, 1 or 2."
358 raise ValueError(msg)
360 if self.broker is not None:
361 warnings.warn("The 'broker' option is deprecated, please use 'connection' instead.", stacklevel=2)
362 self.connection = self.broker
364 if bool(not self.connection.keyfile) ^ bool(not self.connection.certfile): 364 ↛ 365line 364 didn't jump to line 365 because the condition on line 364 was never true
365 msg = "Connection key and certificate files are _both_ required."
366 raise ValueError(msg)
368 @classmethod
369 def from_dict(cls, d: dict[str, Any] | None) -> "ClientConfig":
370 """Create a client config from a dictionary."""
371 if d is None: 371 ↛ 372line 371 didn't jump to line 372 because the condition on line 371 was never true
372 return ClientConfig()
374 return dict_to_dataclass(data_class=ClientConfig,
375 data=d,
376 config=DaciteConfig(
377 cast=[StrEnum],
378 strict=True)
379 )