Coverage for amqtt/client.py: 78%
303 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1import asyncio
2from collections import deque
3from collections.abc import Callable, Coroutine
4import contextlib
5from functools import wraps
6import logging
7import ssl
8from typing import TYPE_CHECKING, Any, TypeAlias, cast
9from urllib.parse import urlparse, urlunparse
11import websockets
12from websockets import HeadersLike, InvalidHandshake, InvalidURI
14from amqtt.adapters import (
15 StreamReaderAdapter,
16 StreamWriterAdapter,
17 WebSocketsReader,
18 WebSocketsWriter,
19)
20from amqtt.contexts import BaseContext, ClientConfig
21from amqtt.errors import ClientError, ConnectError, ProtocolHandlerError
22from amqtt.mqtt.connack import CONNECTION_ACCEPTED
23from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2
24from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler
25from amqtt.plugins.manager import PluginManager
26from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
27from amqtt.utils import gen_client_id
29if TYPE_CHECKING:
30 from websockets.asyncio.client import ClientConnection
33class ClientContext(BaseContext):
34 """ClientContext is used as the context passed to plugins interacting with the client.
36 It acts as an adapter to client services from plugins.
37 """
39 def __init__(self) -> None:
40 super().__init__()
41 self.config: ClientConfig | None = None
44base_logger = logging.getLogger(__name__)
46_F: TypeAlias = Callable[..., Coroutine[Any, Any, Any]]
49def mqtt_connected(func: _F) -> _F:
50 """MQTTClient coroutines decorator which will wait until connection before calling the decorated method.
52 :param func: coroutine to be called once connected
53 :return: coroutine result.
54 """
56 @wraps(func)
57 async def wrapper(self: "MQTTClient", *args: Any, **kwargs: Any) -> Any:
58 if not self._connected_state.is_set(): 58 ↛ 59line 58 didn't jump to line 59 because the condition on line 58 was never true
59 base_logger.warning("Client not connected, waiting for it")
60 _, pending = await asyncio.wait(
61 [
62 asyncio.create_task(self._connected_state.wait()),
63 asyncio.create_task(self._no_more_connections.wait()),
64 ],
65 return_when=asyncio.FIRST_COMPLETED,
66 )
67 for t in pending:
68 t.cancel()
69 if self._no_more_connections.is_set():
70 msg = "Will not reconnect"
71 raise ClientError(msg)
72 return await func(self, *args, **kwargs)
74 return cast("_F", wrapper)
77class MQTTClient:
78 """MQTT client implementation, providing an API for connecting to a broker and send/receive messages using the MQTT protocol.
80 Args:
81 client_id: MQTT client ID to use when connecting to the broker. If none,
82 it will be generated randomly by `amqtt.utils.gen_client_id`
83 config: `ClientConfig` or dictionary of equivalent structure options (see [client configuration](client_config.md)).
85 Raises:
86 PluginImportError: if importing a plugin from configuration fails
87 PluginInitError: if initialization plugin fails
89 """
91 def __init__(self, client_id: str | None = None, config: ClientConfig | dict[str, Any] | None = None) -> None:
92 self.logger = logging.getLogger(__name__)
94 if isinstance(config, dict):
95 self.config = ClientConfig.from_dict(config)
96 else:
97 self.config = config or ClientConfig()
99 self.client_id = client_id if client_id is not None else gen_client_id()
101 self.session: Session | None = None
102 self._handler: ClientProtocolHandler | None = None
103 self._disconnect_task: asyncio.Task[Any] | None = None
104 self._connected_state = asyncio.Event()
105 self._no_more_connections = asyncio.Event()
106 self.additional_headers: dict[str, Any] | HeadersLike = {}
108 # Init plugins manager
109 context = ClientContext()
110 context.config = self.config
111 self.plugins_manager: PluginManager[ClientContext] = PluginManager("amqtt.client.plugins", context)
112 self.client_tasks: deque[asyncio.Task[Any]] = deque()
114 async def connect(
115 self,
116 uri: str | None = None,
117 cleansession: bool | None = None,
118 cafile: str | None = None,
119 capath: str | None = None,
120 cadata: str | None = None,
121 additional_headers: dict[str, Any] | HeadersLike | None = None,
122 ) -> int:
123 """Connect to a remote broker.
125 At first, a network connection is established with the server
126 using the given protocol (``mqtt``, ``mqtts``, ``ws`` or ``wss``).
127 Once the socket is connected, a
128 [CONNECT](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718028>)
129 message is sent with the requested information.
131 Args:
132 uri: Broker URI connection, conforming to
133 [MQTT URI scheme](https://github.com/mqtt/mqtt.github.io/wiki/URI-Scheme). default,
134 will be taken from the ``uri`` config attribute.
135 cleansession: MQTT CONNECT clean session flag
136 cafile: server certificate authority file (optional, used for secured connection)
137 capath: server certificate authority path (optional, used for secured connection)
138 cadata: server certificate authority data (optional, used for secured connection)
139 additional_headers: a dictionary with additional http headers that should be sent on the
140 initial connection (optional, used only with websocket connections)
142 Returns:
143 [CONNACK](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033)'s return code
145 Raises:
146 ConnectError: could not connect to broker
148 """
149 additional_headers = additional_headers if additional_headers is not None else {}
150 self.session = self._init_session(uri, cleansession, cafile, capath, cadata)
151 self.additional_headers = additional_headers
152 self.logger.debug(f"Connecting to: {self.session.broker_uri}")
154 try:
155 return await self._do_connect()
156 except asyncio.CancelledError as e:
157 msg = "Future or Task was cancelled"
158 raise ConnectError(msg) from e
159 # no matter the failure mode, still try to reconnect
160 except Exception as e: # pylint: disable=W0718
161 self.logger.warning(f"Connection failed: {e!r}")
162 if not self.config.get("auto_reconnect", False): 162 ↛ 164line 162 didn't jump to line 164 because the condition on line 162 was always true
163 raise
164 return await self.reconnect()
166 async def disconnect(self) -> None:
167 """Disconnect from the connected broker.
169 This method sends a [DISCONNECT](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718090)
170 message and closes the network socket.
172 """
173 await self.cancel_tasks()
175 if not (self.session and self._handler): 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true
176 self.logger.warning("Session or handler not initialized, ignoring disconnect.")
177 return
179 if not self.session.transitions.is_connected():
180 self.logger.warning("Client session not connected, ignoring call.")
181 return
183 if self._disconnect_task and not self._disconnect_task.done(): 183 ↛ 186line 183 didn't jump to line 186 because the condition on line 183 was always true
184 self._disconnect_task.cancel()
186 await self._handler.mqtt_disconnect()
187 self._connected_state.clear()
188 await self._handler.stop()
189 self.session.transitions.disconnect()
191 async def cancel_tasks(self) -> None:
192 """Cancel all pending tasks."""
193 while self.client_tasks: 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true
194 task = self.client_tasks.pop()
195 task.cancel()
197 async def reconnect(self, cleansession: bool | None = None) -> int:
198 """Reconnect a previously connected broker.
200 Reconnection tries to establish a network connection
201 and send a [CONNECT](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718028) message.
202 Retries interval and attempts can be controlled with the ``reconnect_max_interval``
203 and ``reconnect_retries`` configuration parameters.
205 Args:
206 cleansession: clean session flag used in MQTT CONNECT messages sent for reconnections.
208 Returns:
209 [CONNACK](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033) return code
211 Raises:
212 amqtt.client.ConnectException: if re-connection fails after max retries.
214 """
215 if self.session and self.session.transitions.is_connected(): 215 ↛ 216line 215 didn't jump to line 216 because the condition on line 215 was never true
216 self.logger.warning("Client already connected")
217 return CONNECTION_ACCEPTED
219 if self.session and cleansession: 219 ↛ 220line 219 didn't jump to line 220 because the condition on line 219 was never true
220 self.session.clean_session = cleansession
221 self.logger.debug(f"Reconnecting with session parameters: {self.session}")
223 reconnect_max_interval = self.config.get("reconnect_max_interval", 10)
224 reconnect_retries = self.config.get("reconnect_retries", 2)
225 nb_attempt = 1
227 while True:
228 try:
229 self.logger.debug(f"Reconnect attempt {nb_attempt}...")
230 return await self._do_connect()
231 except asyncio.CancelledError as e:
232 msg = "Future or Task was cancelled"
233 raise ConnectError(msg) from e
234 # no matter the failure mode, still try to reconnect
235 except Exception as e: # pylint: disable=W0718
236 self.logger.warning(f"Reconnection attempt failed: {e!r}")
237 self.logger.debug("", exc_info=True)
238 if 0 <= reconnect_retries < nb_attempt: 238 ↛ 239line 238 didn't jump to line 239 because the condition on line 238 was never true
239 self.logger.exception("Maximum connection attempts reached. Reconnection aborted.")
240 self.logger.debug("", exc_info=True)
241 msg = "Too many failed attempts"
242 raise ConnectError(msg) from e
243 delay = min(reconnect_max_interval, 2**nb_attempt)
244 self.logger.debug(f"Waiting {delay} seconds before next attempt")
245 await asyncio.sleep(delay)
246 nb_attempt += 1
248 async def _do_connect(self) -> int:
249 return_code = await self._connect_coro()
250 self._disconnect_task = asyncio.create_task(self.handle_connection_close())
251 return return_code
253 @mqtt_connected
254 async def ping(self) -> None:
255 """Ping the broker.
257 Send a MQTT [PINGREQ](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718081)
258 message for response.
260 """
261 if self.session and self._handler and self.session.transitions.is_connected(): 261 ↛ 263line 261 didn't jump to line 263 because the condition on line 261 was always true
262 await self._handler.mqtt_ping()
263 elif not self.session:
264 self.logger.warning("Session is not initialized.")
265 elif not self._handler:
266 self.logger.warning("Handler is not initialized.")
267 else:
268 self.logger.warning(f"PING incompatible with state '{self.session.transitions.state}'")
270 @mqtt_connected
271 async def publish(
272 self,
273 topic: str,
274 message: bytes,
275 qos: int | None = None,
276 retain: bool | None = None,
277 ack_timeout: int | None = None,
278 ) -> OutgoingApplicationMessage:
279 """Publish a message to the broker.
281 Send a MQTT [PUBLISH](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718037)
282 message and wait for acknowledgment depending on Quality Of Service
284 Args:
285 topic: topic name to which message data is published
286 message: payload message (as bytes) to send.
287 qos: requested publish quality of service : QOS_0, QOS_1 or QOS_2. Defaults to `default_qos`
288 config parameter or QOS_0.
289 retain: retain flag. Defaults to ``default_retain`` config parameter or False.
290 ack_timeout: duration to wait for connection acknowledgment from broker.
292 Returns:
293 the message that was sent
295 """
296 if self._handler is None: 296 ↛ 297line 296 didn't jump to line 297 because the condition on line 296 was never true
297 msg = "Handler is not initialized."
298 raise ClientError(msg)
300 def get_retain_and_qos() -> tuple[int, bool]:
301 if qos is not None:
302 if qos not in (QOS_0, QOS_1, QOS_2): 302 ↛ 303line 302 didn't jump to line 303 because the condition on line 302 was never true
303 msg = f"QOS '{qos}' is not one of QOS_0, QOS_1, QOS_2."
304 raise ClientError(msg)
305 _qos = qos
306 else:
307 _qos = self.config["default_qos"]
308 with contextlib.suppress(KeyError):
309 _qos = self.config["topics"][topic]["qos"]
310 if retain:
311 _retain = retain
312 else:
313 _retain = self.config["default_retain"]
314 with contextlib.suppress(KeyError):
315 _retain = self.config["topics"][topic]["retain"]
316 return _qos, _retain
318 (app_qos, app_retain) = get_retain_and_qos()
319 return await self._handler.mqtt_publish(
320 topic,
321 message,
322 app_qos,
323 app_retain,
324 ack_timeout,
325 )
327 @mqtt_connected
328 async def subscribe(self, topics: list[tuple[str, int]]) -> list[int]:
329 """Subscribe to topics.
331 Send a MQTT [SUBSCRIBE](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718063)
332 message and wait for broker acknowledgment.
334 Args:
335 topics: array of tuples containing topic pattern and QOS from `amqtt.mqtt.constants` to subscribe. For example:
336 ```python
337 [
338 ("$SYS/broker/uptime", QOS_1),
339 ("$SYS/broker/load/#", QOS_2),
340 ]
341 ```
343 Returns:
344 [SUBACK](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718068) message return code.
346 """
347 if self._handler and self.session: 347 ↛ 349line 347 didn't jump to line 349 because the condition on line 347 was always true
348 return await self._handler.mqtt_subscribe(topics, self.session.next_packet_id)
349 return [0x80]
351 @mqtt_connected
352 async def unsubscribe(self, topics: list[str]) -> None:
353 """Unsubscribe from topics.
355 Send a MQTT [UNSUBSCRIBE](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718072)
356 message and wait for broker [UNSUBACK](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718077)
357 message.
359 Args:
360 topics: array of topics to unsubscribe from.
361 ```
362 ["$SYS/broker/uptime", "$SYS/broker/load/#"]
363 ```
365 """
366 if self._handler and self.session: 366 ↛ exitline 366 didn't return from function 'unsubscribe' because the condition on line 366 was always true
367 await self._handler.mqtt_unsubscribe(topics, self.session.next_packet_id)
369 async def deliver_message(self, timeout_duration: float | None = None) -> ApplicationMessage | None:
370 """Deliver the next received message.
372 Deliver next message received from the broker. If no message is available, this methods waits until next message arrives
373 or ``timeout_duration`` occurs.
375 Args:
376 timeout_duration: maximum number of seconds to wait before returning. If not specified or None, there is no limit.
378 Returns:
379 instance of `amqtt.session.ApplicationMessage` containing received message information flow.
381 Raises:
382 asyncio.TimeoutError: if timeout occurs before a message is delivered
383 ClientError: if client is not connected
385 """
386 if self._handler is None: 386 ↛ 387line 386 didn't jump to line 387 because the condition on line 386 was never true
387 msg = "Handler is not initialized."
388 raise ClientError(msg)
390 deliver_task = asyncio.create_task(self._handler.mqtt_deliver_next_message())
391 self.client_tasks.append(deliver_task)
392 self.logger.debug("Waiting for message delivery")
394 done, _ = await asyncio.wait(
395 [deliver_task],
396 return_when=asyncio.FIRST_EXCEPTION,
397 timeout=timeout_duration,
398 )
400 if self.client_tasks: 400 ↛ 403line 400 didn't jump to line 403 because the condition on line 400 was always true
401 self.client_tasks.pop()
403 if deliver_task in done:
404 exception = deliver_task.exception()
405 if exception is not None: 405 ↛ 407line 405 didn't jump to line 407 because the condition on line 405 was never true
406 # deliver_task raised an exception, pass it on to our caller
407 raise exception
408 return deliver_task.result()
409 # timeout occurred before message received
410 deliver_task.cancel()
411 msg = "Timeout waiting for message"
412 raise asyncio.TimeoutError(msg)
414 async def _connect_coro(self) -> int:
415 """Perform the core connection logic."""
416 if self.session is None: 416 ↛ 417line 416 didn't jump to line 417 because the condition on line 416 was never true
417 msg = "Session is not initialized."
418 raise ClientError(msg)
420 kwargs: dict[str, Any] = {}
422 # Decode URI attributes
423 uri_attributes = urlparse(self.session.broker_uri)
424 scheme = uri_attributes.scheme
425 secure = scheme in ("mqtts", "wss")
426 self.session.username = (
427 self.session.username or (str(uri_attributes.username) if uri_attributes.username else None)
428 )
429 self.session.password = (
430 self.session.password or (str(uri_attributes.password) if uri_attributes.password else None)
431 )
432 self.session.remote_address = str(uri_attributes.hostname) if uri_attributes.hostname else None
433 self.session.remote_port = uri_attributes.port
435 if scheme in ("mqtt", "mqtts") and not self.session.remote_port:
436 self.session.remote_port = 8883 if scheme == "mqtts" else 1883
438 if scheme in ("ws", "wss") and not self.session.remote_port: 438 ↛ 439line 438 didn't jump to line 439 because the condition on line 438 was never true
439 self.session.remote_port = 443 if scheme == "wss" else 80
441 if scheme in ("ws", "wss"):
442 # Rewrite URI to conform to https://tools.ietf.org/html/rfc6455#section-3
443 uri = (
444 str(scheme),
445 f"{self.session.remote_address}:{self.session.remote_port}",
446 str(uri_attributes.path),
447 str(uri_attributes.params),
448 str(uri_attributes.query),
449 str(uri_attributes.fragment),
450 )
451 self.session.broker_uri = str(urlunparse(uri))
452 # Init protocol handler
453 # if not self._handler:
454 self._handler = ClientProtocolHandler(self.plugins_manager)
456 connection_timeout = self.config.get("connection_timeout", None)
458 if secure:
459 sc = ssl.create_default_context(
460 ssl.Purpose.SERVER_AUTH,
461 cafile=self.session.cafile
462 )
464 if self.config.connection.certfile and self.config.connection.keyfile:
465 sc.load_cert_chain(certfile=self.config.connection.certfile, keyfile=self.config.connection.keyfile)
466 if self.config.connection.cafile:
467 sc.load_verify_locations(cafile=self.config.connection.cafile)
468 if self.config.check_hostname is not None: 468 ↛ 471line 468 didn't jump to line 471 because the condition on line 468 was always true
469 sc.check_hostname = self.config.check_hostname
470 sc.verify_mode = ssl.CERT_REQUIRED
471 kwargs["ssl"] = sc
473 try:
474 reader: StreamReaderAdapter | WebSocketsReader | None = None
475 writer: StreamWriterAdapter | WebSocketsWriter | None = None
476 self._connected_state.clear()
478 # Open connection
479 if scheme in ("mqtt", "mqtts"):
480 conn_reader, conn_writer = await asyncio.wait_for(
481 asyncio.open_connection(
482 self.session.remote_address,
483 self.session.remote_port,
484 **kwargs,
485 ), timeout=connection_timeout)
487 reader = StreamReaderAdapter(conn_reader)
488 writer = StreamWriterAdapter(conn_writer)
489 elif scheme in ("ws", "wss") and self.session.broker_uri:
490 websocket: ClientConnection = await asyncio.wait_for(
491 websockets.connect(
492 self.session.broker_uri,
493 subprotocols=[websockets.Subprotocol("mqtt")],
494 additional_headers=self.additional_headers,
495 **kwargs,
496 ), timeout=connection_timeout)
498 reader = WebSocketsReader(websocket)
499 writer = WebSocketsWriter(websocket)
500 elif not self.session.broker_uri: 500 ↛ 501line 500 didn't jump to line 501 because the condition on line 500 was never true
501 msg = "missing broker uri"
502 raise ClientError(msg)
503 else:
504 msg = f"incorrect scheme defined in uri: '{scheme!r}'"
505 raise ClientError(msg)
507 # Start MQTT protocol
508 self._handler.attach(self.session, reader, writer)
509 return_code: int | None = await self._handler.mqtt_connect()
511 if return_code is not CONNECTION_ACCEPTED:
512 self.session.transitions.disconnect()
513 self.logger.warning(f"Connection rejected with code '{return_code}'")
514 msg = "Connection rejected by broker"
515 exc = ConnectError(msg)
516 exc.return_code = return_code
517 raise exc
518 # Handle MQTT protocol
519 await self._handler.start()
520 self.session.transitions.connect()
521 self._connected_state.set()
522 self.logger.debug(f"Connected to {self.session.remote_address}:{self.session.remote_port}")
524 except (InvalidURI, InvalidHandshake, ProtocolHandlerError, ConnectionError, OSError, asyncio.TimeoutError) as e:
525 self.logger.debug(f"Connection failed : {self.session.broker_uri} [{e!r}]")
526 self.session.transitions.disconnect()
527 raise ConnectError(e) from e
528 return return_code
530 async def handle_connection_close(self) -> None:
531 """Handle disconnection from the broker."""
532 if self.session is None: 532 ↛ 533line 532 didn't jump to line 533 because the condition on line 532 was never true
533 msg = "Session is not initialized."
534 raise ClientError(msg)
535 if self._handler is None: 535 ↛ 536line 535 didn't jump to line 536 because the condition on line 535 was never true
536 msg = "Handler is not initialized."
537 raise ClientError(msg)
539 def cancel_tasks() -> None:
540 self._no_more_connections.set()
541 while self.client_tasks: 541 ↛ 542line 541 didn't jump to line 542 because the condition on line 541 was never true
542 task = self.client_tasks.popleft()
543 if not task.done():
544 task.cancel(msg="Connection closed.")
546 self.logger.debug("Monitoring broker disconnection")
547 # Wait for disconnection from broker (like connection lost)
548 await self._handler.wait_disconnect()
549 self.logger.warning("Disconnected from broker")
551 # Block client API
552 self._connected_state.clear()
554 # stop an clean handler
555 await self._handler.stop()
556 self._handler.detach()
557 self.session.transitions.disconnect()
559 if self.config.get("auto_reconnect", False):
560 # Try reconnection
561 self.logger.debug("Auto-reconnecting")
562 try:
563 await self.reconnect()
564 except ConnectError:
565 # Cancel client pending tasks
566 cancel_tasks()
567 else:
568 # Cancel client pending tasks
569 cancel_tasks()
571 def _init_session(
572 self,
573 uri: str | None = None,
574 cleansession: bool | None = None,
575 cafile: str | None = None,
576 capath: str | None = None,
577 cadata: str | None = None,
578 ) -> Session:
579 """Initialize the MQTT session."""
580 broker_conf = self.config.get("connection", {}).copy()
582 if uri is not None:
583 broker_conf.uri = uri
584 if cleansession is not None:
585 self.config.cleansession = cleansession
586 if cafile is not None:
587 broker_conf.cafile = cafile
588 if capath is not None: 588 ↛ 589line 588 didn't jump to line 589 because the condition on line 588 was never true
589 broker_conf.capath = capath
590 if cadata is not None: 590 ↛ 591line 590 didn't jump to line 591 because the condition on line 590 was never true
591 broker_conf.cadata = cadata
593 if not broker_conf.get("uri"): 593 ↛ 594line 593 didn't jump to line 594 because the condition on line 593 was never true
594 msg = "Missing connection parameter 'uri'"
595 raise ClientError(msg)
597 session = Session()
598 session.broker_uri = broker_conf["uri"]
599 session.client_id = self.client_id
601 session.cafile = broker_conf.get("cafile")
602 session.capath = broker_conf.get("capath")
603 session.cadata = broker_conf.get("cadata")
605 session.clean_session = self.config.get("cleansession", True)
607 session.keep_alive = self.config["keep_alive"] - self.config["ping_delay"]
609 if "will" in self.config:
610 session.will_flag = True
611 session.will_retain = self.config["will"]["retain"]
612 session.will_topic = self.config["will"]["topic"]
613 session.will_message = self.config["will"]["message"].encode()
614 session.will_qos = self.config["will"]["qos"]
616 return session