Coverage for amqtt/broker.py: 88%
649 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1import asyncio
2from asyncio import CancelledError, futures
3from collections import deque
4from collections.abc import Generator
5from functools import partial
6import logging
7from math import floor
8import re
9import ssl
10import time
11from typing import Any, ClassVar, TypeAlias
13from transitions import Machine, MachineError
14import websockets.asyncio.server
15from websockets.asyncio.server import ServerConnection
17from amqtt.adapters import (
18 ReaderAdapter,
19 StreamReaderAdapter,
20 StreamWriterAdapter,
21 WebSocketsReader,
22 WebSocketsWriter,
23 WriterAdapter,
24)
25from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig, ListenerType
26from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError
27from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
28from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
29from amqtt.utils import format_client_message, gen_client_id
31from .events import BrokerEvents
32from .mqtt.constants import QOS_0, QOS_1, QOS_2
33from .mqtt.disconnect import DisconnectPacket
34from .plugins.manager import PluginManager
36_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
38# Default port numbers
39DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}
40AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80
43class RetainedApplicationMessage(ApplicationMessage):
44 __slots__ = ("data", "qos", "source_session", "topic")
46 def __init__(self, source_session: Session | None, topic: str, data: bytes | bytearray, qos: int | None = None) -> None:
47 super().__init__(None, topic, qos, data, retain=True)
48 self.source_session = source_session
49 self.topic = topic
50 self.data = data
51 self.qos = qos
54class Server:
55 """Used to encapsulate the server associated with a listener. Allows broker to interact with the connection lifecycle."""
57 def __init__(
58 self,
59 listener_name: str,
60 server_instance: asyncio.Server | websockets.asyncio.server.Server,
61 max_connections: int = -1,
62 ) -> None:
63 self.logger = logging.getLogger(__name__)
64 self.instance = server_instance
65 self.conn_count = 0
66 self.listener_name = listener_name
67 self.max_connections = max_connections
68 self.semaphore = asyncio.Semaphore(max_connections) if max_connections > 0 else None
70 async def acquire_connection(self) -> None:
71 if self.semaphore:
72 await self.semaphore.acquire()
73 self.conn_count += 1
74 self.logger.info(
75 f"Listener '{self.listener_name}': {self.conn_count}/"
76 f"{self.max_connections if self.max_connections > 0 else '∞'} connections acquired",
77 )
79 def release_connection(self) -> None:
80 if self.semaphore:
81 self.semaphore.release()
82 self.conn_count -= 1
83 self.logger.info(
84 f"Listener '{self.listener_name}': {self.conn_count}/"
85 f"{self.max_connections if self.max_connections > 0 else '∞'} connections acquired",
86 )
88 async def close_instance(self) -> None:
89 if self.instance: 89 ↛ exitline 89 didn't return from function 'close_instance' because the condition on line 89 was always true
90 self.instance.close()
91 await self.instance.wait_closed()
94class ExternalServer(Server):
95 """For external listeners, the connection lifecycle is handled by that implementation so these are no-ops."""
97 def __init__(self) -> None:
98 super().__init__("aiohttp", None) # type: ignore[arg-type]
100 async def acquire_connection(self) -> None:
101 pass
103 def release_connection(self) -> None:
104 pass
106 async def close_instance(self) -> None:
107 pass
110class BrokerContext(BaseContext):
111 """Used to provide the server's context as well as public methods for accessing internal state."""
113 def __init__(self, broker: "Broker") -> None:
114 super().__init__()
115 self.config: BrokerConfig | None = None
116 self._broker_instance = broker
118 async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None:
119 """Send message to all client sessions subscribing to `topic`."""
120 await self._broker_instance.internal_message_broadcast(topic, data, qos)
122 async def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None:
123 await self._broker_instance.retain_message(None, topic_name, data, qos)
125 @property
126 def sessions(self) -> Generator[Session]:
127 for session in self._broker_instance.sessions.values():
128 yield session[0]
130 def get_session(self, client_id: str) -> Session | None:
131 """Return the session associated with `client_id`, if it exists."""
132 return self._broker_instance.sessions.get(client_id, (None, None))[0]
134 @property
135 def retained_messages(self) -> dict[str, RetainedApplicationMessage]:
136 return self._broker_instance.retained_messages
138 @property
139 def subscriptions(self) -> dict[str, list[tuple[Session, int]]]:
140 return self._broker_instance.subscriptions
142 async def add_subscription(self, client_id: str, topic: str | None, qos: int | None) -> None:
143 """Create a topic subscription for the given `client_id`.
145 If a client session doesn't exist for `client_id`, create a disconnected session.
146 If `topic` and `qos` are both `None`, only create the client session.
147 """
148 if client_id not in self._broker_instance.sessions:
149 broker_handler, session = self._broker_instance.create_offline_session(client_id)
150 self._broker_instance._sessions[client_id] = (session, broker_handler) # noqa: SLF001
152 if topic is not None and qos is not None:
153 session, _ = self._broker_instance.sessions[client_id]
154 await self._broker_instance.add_subscription((topic, qos), session)
157class Broker:
158 """MQTT 3.1.1 compliant broker implementation.
160 Args:
161 config: `BrokerConfig` or dictionary of equivalent structure options (see [broker configuration](broker_config.md)).
162 loop: asyncio loop. defaults to `asyncio.new_event_loop()`.
163 plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`.
165 Raises:
166 BrokerError: problem with broker configuration
167 PluginImportError: if importing a plugin from configuration
168 PluginInitError: if initialization plugin fails
170 """
172 states: ClassVar[list[str]] = [
173 "new",
174 "starting",
175 "started",
176 "not_started",
177 "stopping",
178 "stopped",
179 "not_stopped",
180 ]
182 def __init__(
183 self,
184 config: BrokerConfig | dict[str, Any] | None = None,
185 loop: asyncio.AbstractEventLoop | None = None,
186 plugin_namespace: str | None = None,
187 ) -> None:
188 """Initialize the broker."""
189 self.logger = logging.getLogger(__name__)
191 if isinstance(config, dict):
192 self.config = BrokerConfig.from_dict(config)
193 else:
194 self.config = config or BrokerConfig()
196 # listeners are populated from default within BrokerConfig
197 self.listeners_config = self.config.listeners
199 self._loop = loop or asyncio.get_running_loop()
200 self._servers: dict[str, Server] = {}
201 self._init_states()
202 self._sessions: dict[str, tuple[Session, BrokerProtocolHandler]] = {}
203 self._subscriptions: dict[str, list[tuple[Session, int]]] = {}
204 self._retained_messages: dict[str, RetainedApplicationMessage] = {}
206 self._topic_filter_matchers: dict[str, re.Pattern[str]] = {}
208 # Broadcast queue for outgoing messages
209 self._broadcast_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
210 self._broadcast_task: asyncio.Task[Any] | None = None
211 self._broadcast_shutdown_waiter: asyncio.Future[Any] = futures.Future()
213 # Tasks queue for managing broadcasting tasks
214 self._tasks_queue: deque[asyncio.Task[OutgoingApplicationMessage]] = deque()
216 # Task for session monitor
217 self._session_monitor_task: asyncio.Task[Any] | None = None
219 # Initialize plugins manager
221 context = BrokerContext(self)
222 context.config = self.config
223 namespace = plugin_namespace or "amqtt.broker.plugins"
224 self.plugins_manager = PluginManager(namespace, context, self._loop)
226 def _init_states(self) -> None:
227 self.transitions = Machine(states=Broker.states, initial="new")
228 self.transitions.add_transition(trigger="start", source="new", dest="starting", before=self._log_state_change)
229 self.transitions.add_transition(trigger="starting_fail", source="starting", dest="not_started")
230 self.transitions.add_transition(trigger="starting_success", source="starting", dest="started")
231 self.transitions.add_transition(trigger="shutdown", source="started", dest="stopping")
232 self.transitions.add_transition(trigger="stopping_success", source="stopping", dest="stopped")
233 self.transitions.add_transition(trigger="stopping_failure", source="stopping", dest="not_stopped")
234 self.transitions.add_transition(trigger="start", source="stopped", dest="starting")
236 def _log_state_change(self) -> None:
237 self.logger.debug(f"State transition: {self.transitions.state}")
239 async def start(self) -> None:
240 """Start the broker to serve with the given configuration.
242 Start method opens network sockets and will start listening for incoming connections.
243 """
244 try:
245 self._sessions.clear()
246 self._subscriptions.clear()
247 self._retained_messages.clear()
248 self.transitions.start()
249 self.logger.debug("Broker starting")
250 except (MachineError, ValueError) as exc:
251 # Backwards compat: MachineError is raised by transitions < 0.5.0.
252 self.logger.warning(f"[WARN-0001] Invalid method call at this moment: {exc}")
253 msg = f"Broker instance can't be started: {exc}"
254 raise BrokerError(msg) from exc
256 await self.plugins_manager.fire_event(BrokerEvents.PRE_START)
257 try:
258 await self._start_listeners()
259 self.transitions.starting_success()
260 await self.plugins_manager.fire_event(BrokerEvents.POST_START)
261 self._broadcast_task = asyncio.ensure_future(self._broadcast_loop())
262 self._session_monitor_task = asyncio.create_task(self._session_monitor())
263 self.logger.debug("Broker started")
264 except Exception as e:
265 self.logger.exception("Broker startup failed")
266 self.transitions.starting_fail()
267 msg = f"Broker instance can't be started: {e}"
268 raise BrokerError(msg) from e
270 async def _start_listeners(self) -> None:
271 """Start network listeners based on the configuration."""
272 for listener_name, listener in self.listeners_config.items():
273 if "bind" not in listener: 273 ↛ 274line 273 didn't jump to line 274 because the condition on line 273 was never true
274 self.logger.debug(f"Listener configuration '{listener_name}' is not bound")
275 continue
277 max_connections = listener.get("max_connections", -1)
278 ssl_context = self._create_ssl_context(listener) if listener.get("ssl", False) else None
280 # for listeners which are external, don't need to create a server
281 if listener.type == ListenerType.EXTERNAL:
283 # broker still needs to associate a new connection to the listener
284 self.logger.info(f"External listener exists for '{listener_name}' ")
285 self._servers[listener_name] = ExternalServer()
286 else:
287 # for tcp and websockets, start servers to listen for inbound connections
288 try:
289 address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
290 except ValueError as e:
291 msg = f"Invalid port value in bind value: {listener['bind']}"
292 raise BrokerError(msg) from e
294 instance = await self._create_server_instance(listener_name, listener.type, address, port, ssl_context)
295 self._servers[listener_name] = Server(listener_name, instance, max_connections)
297 self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
299 @staticmethod
300 def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext:
301 """Create an SSL context for a listener."""
302 try:
303 ssl_context = ssl.create_default_context(
304 ssl.Purpose.CLIENT_AUTH,
305 cafile=listener.get("cafile"),
306 capath=listener.get("capath"),
307 cadata=listener.get("cadata"),
308 )
309 ssl_context.load_cert_chain(listener["certfile"], listener["keyfile"])
310 ssl_context.verify_mode = ssl.CERT_OPTIONAL
311 except KeyError as ke:
312 msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}"
313 raise BrokerError(msg) from ke
314 except FileNotFoundError as fnfe:
315 msg = f"Can't read cert files '{listener['certfile']}' or '{listener['keyfile']}' : {fnfe}"
316 raise BrokerError(msg) from fnfe
317 return ssl_context
319 async def _create_server_instance(
320 self,
321 listener_name: str,
322 listener_type: ListenerType,
323 address: str | None,
324 port: int,
325 ssl_context: ssl.SSLContext | None,
326 ) -> asyncio.Server | websockets.asyncio.server.Server:
327 """Create a server instance for a listener."""
328 match listener_type:
329 case ListenerType.TCP:
330 return await asyncio.start_server(
331 partial(self.stream_connected, listener_name=listener_name),
332 address,
333 port,
334 reuse_address=True,
335 ssl=ssl_context,
336 )
337 case ListenerType.WS: 337 ↛ 345line 337 didn't jump to line 345 because the pattern on line 337 always matched
338 return await websockets.serve(
339 partial(self.ws_connected, listener_name=listener_name),
340 address,
341 port,
342 ssl=ssl_context,
343 subprotocols=[websockets.Subprotocol("mqtt")],
344 )
345 case _:
346 msg = f"Unsupported listener type: {listener_type}"
347 raise BrokerError(msg)
349 async def _session_monitor(self) -> None:
351 self.logger.info("Starting session expiration monitor.")
353 while True:
355 session_count_before = len(self._sessions)
357 # clean or anonymous sessions don't retain messages (or subscriptions); the session can be filtered out
358 sessions_to_remove = [client_id for client_id, (session, _) in self._sessions.items()
359 if session.transitions.state == "disconnected" and (session.is_anonymous or session.clean_session)]
361 # if session expiration is enabled, check to see if any of the sessions are disconnected and past expiration
362 if self.config.session_expiry_interval is not None:
363 retain_after = floor(time.time() - self.config.session_expiry_interval)
365 sessions_to_remove += [client_id for client_id, (session, _) in self._sessions.items()
366 if session.transitions.state == "disconnected" and
367 session.last_disconnect_time and
368 session.last_disconnect_time < retain_after]
370 for client_id in sessions_to_remove:
371 await self._cleanup_session(client_id)
373 if session_count_before > (session_count_after := len(self._sessions)):
374 self.logger.debug(f"Expired {session_count_before - session_count_after} sessions")
376 await asyncio.sleep(1)
378 async def shutdown(self) -> None:
379 """Stop broker instance."""
380 self.logger.info("Shutting down broker...")
381 # Fire broker_shutdown event to plugins
382 await self.plugins_manager.fire_event(BrokerEvents.PRE_SHUTDOWN)
384 # Cleanup all sessions
385 for client_id in list(self._sessions.keys()):
386 await self._cleanup_session(client_id)
388 # Clear retained messages
389 self.logger.debug(f"Clearing {len(self._retained_messages)} retained messages")
390 self._retained_messages.clear()
392 self.transitions.shutdown()
394 await self._shutdown_broadcast_loop()
395 if self._session_monitor_task: 395 ↛ 398line 395 didn't jump to line 398 because the condition on line 395 was always true
396 self._session_monitor_task.cancel()
398 for server in self._servers.values():
399 await server.close_instance()
401 if not self._broadcast_queue.empty():
402 self.logger.warning(f"{self._broadcast_queue.qsize()} messages not broadcasted")
403 # Clear the broadcast queue
404 while not self._broadcast_queue.empty():
405 self._broadcast_queue.get_nowait()
407 self.logger.info("Broker closed")
408 await self.plugins_manager.fire_event(BrokerEvents.POST_SHUTDOWN)
409 self.transitions.stopping_success()
411 async def _cleanup_session(self, client_id: str) -> None:
412 """Centralized cleanup logic for a session."""
413 session, handler = self._sessions.pop(client_id, (None, None))
415 if handler: 415 ↛ 418line 415 didn't jump to line 418 because the condition on line 415 was always true
416 self.logger.debug(f"Stopping handler for session {client_id}")
417 await self._stop_handler(handler)
418 if session: 418 ↛ exitline 418 didn't return from function '_cleanup_session' because the condition on line 418 was always true
419 self.logger.debug(f"Clearing all subscriptions for session {client_id}")
420 await self._del_all_subscriptions(session)
421 session.clear_queues()
423 async def internal_message_broadcast(self, topic: str, data: bytes, qos: int | None = None) -> None:
424 return await self._broadcast_message(None, topic, data, qos)
426 async def ws_connected(self, websocket: ServerConnection, listener_name: str) -> None:
427 await self._client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))
429 async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
430 await self._client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))
432 async def external_connected(self, reader: ReaderAdapter, writer: WriterAdapter, listener_name: str) -> None:
433 """Engage the broker in handling the data stream to/from an established connection."""
434 await self._client_connected(listener_name, reader, writer)
436 async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
437 """Handle a new client connection."""
438 server = self._servers.get(listener_name)
439 if not server: 439 ↛ 440line 439 didn't jump to line 440 because the condition on line 439 was never true
440 msg = f"Invalid listener name '{listener_name}'"
441 raise BrokerError(msg)
443 await server.acquire_connection()
444 remote_info = writer.get_peer_info()
445 if remote_info is None: 445 ↛ 446line 445 didn't jump to line 446 because the condition on line 445 was never true
446 self.logger.warning("Remote info could not be retrieved from peer info")
447 return
449 remote_address, remote_port = remote_info
450 self.logger.info(f"Connection from {remote_address}:{remote_port} on listener '{listener_name}'")
452 try:
453 handler, client_session = await self._initialize_client_session(reader, writer, remote_address, remote_port)
454 except (AMQTTError, MQTTError, NoDataError) as exc:
455 self.logger.warning(f"Failed to initialize client session: {exc}")
456 server.release_connection()
457 return
459 try:
460 await self._handle_client_session(reader, writer, client_session, handler, server, listener_name)
461 except (AMQTTError, MQTTError, NoDataError) as exc:
462 self.logger.warning(f"Error while handling client session: {exc}")
463 finally:
464 self.logger.debug(f"{client_session.client_id} Client disconnected")
465 server.release_connection()
467 async def _initialize_client_session(
468 self,
469 reader: ReaderAdapter,
470 writer: WriterAdapter,
471 remote_address: str,
472 remote_port: int,
473 ) -> tuple[BrokerProtocolHandler, Session]:
474 """Initialize a client session and protocol handler."""
475 # Wait for first packet and expect a CONNECT
476 try:
477 handler, client_session = await BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
478 except AMQTTError as exc:
479 self.logger.warning(
480 f"[MQTT-3.1.0-1] {format_client_message(address=remote_address, port=remote_port)}:"
481 f" Can't read first packet as CONNECT: {exc}",
482 )
483 raise AMQTTError(exc) from exc
484 except MQTTError as exc:
485 self.logger.exception(
486 f"Invalid connection from {format_client_message(address=remote_address, port=remote_port)}",
487 )
488 await writer.close()
489 raise MQTTError(exc) from exc
490 except NoDataError as exc:
491 self.logger.error( # noqa: TRY400
492 f"No data from {format_client_message(address=remote_address, port=remote_port)} : {exc}",
493 )
494 raise AMQTTError(exc) from exc
496 if client_session.clean_session:
497 # Delete existing session and create a new one
498 if client_session.client_id is not None and client_session.client_id != "": 498 ↛ 501line 498 didn't jump to line 501 because the condition on line 498 was always true
499 await self._delete_session(client_session.client_id)
500 else:
501 client_session.client_id = gen_client_id()
503 client_session.parent = 0
504 # Get session from cache
505 elif client_session.client_id in self._sessions:
506 self.logger.debug(f"Found old session {self._sessions[client_session.client_id]!r}")
508 # even though the session previously existed, the new connection can bring updated configuration and credentials
509 existing_client_session, _ = self._sessions[client_session.client_id]
510 existing_client_session.will_flag = client_session.will_flag
511 existing_client_session.will_message = client_session.will_message
512 existing_client_session.will_topic = client_session.will_topic
513 existing_client_session.will_qos = client_session.will_qos
514 existing_client_session.keep_alive = client_session.keep_alive
515 existing_client_session.username = client_session.username
516 existing_client_session.password = client_session.password
517 client_session = existing_client_session
518 client_session.parent = 1
519 else:
520 client_session.parent = 0
522 timeout_disconnect_delay = self.config.get("timeout-disconnect-delay", 0)
523 if client_session.keep_alive > 0 and isinstance(timeout_disconnect_delay, int):
524 client_session.keep_alive += timeout_disconnect_delay
526 self.logger.debug(f"Keep-alive timeout={client_session.keep_alive}")
527 return handler, client_session
529 def create_offline_session(self, client_id: str) -> tuple[BrokerProtocolHandler, Session]:
530 session = Session()
531 session.client_id = client_id
533 bph = BrokerProtocolHandler(self.plugins_manager, session)
534 session.transitions.disconnect()
535 return bph, session
537 async def _handle_client_session(
538 self,
539 reader: ReaderAdapter,
540 writer: WriterAdapter,
541 client_session: Session,
542 handler: BrokerProtocolHandler,
543 server: Server,
544 listener_name: str,
545 ) -> None:
546 """Handle the lifecycle of a client session."""
547 authenticated = await self._authenticate(client_session, self.listeners_config[listener_name])
548 if not authenticated:
549 await writer.close()
550 return
552 if client_session.client_id is None: 552 ↛ 553line 552 didn't jump to line 553 because the condition on line 552 was never true
553 msg = "Client ID was not correctly created/set."
554 raise BrokerError(msg)
556 while True:
557 try:
558 client_session.transitions.connect()
559 break
560 except (MachineError, ValueError):
561 if client_session.transitions.is_connected():
562 self.logger.warning(f"Client {client_session.client_id} is already connected, performing take-over.")
563 old_session = self._sessions[client_session.client_id]
564 await old_session[1].handle_connection_closed()
565 await old_session[1].stop()
566 break
567 self.logger.warning(f"Client {client_session.client_id} is reconnecting too quickly, make it wait")
568 await asyncio.sleep(1)
570 handler.attach(client_session, reader, writer)
571 self._sessions[client_session.client_id] = (client_session, handler)
573 await handler.mqtt_connack_authorize(authenticated)
574 await self.plugins_manager.fire_event(BrokerEvents.CLIENT_CONNECTED,
575 client_id=client_session.client_id,
576 client_session=client_session)
578 self.logger.debug(f"{client_session.client_id} Start messages handling")
579 await handler.start()
581 # publish messages that were retained because the client session was disconnected
582 self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}")
583 await self._publish_session_retained_messages(client_session)
585 # if this is not a new session, there are subscriptions associated with them; publish any topic retained messages
586 self.logger.debug("Publish retained messages to a pre-existing session's subscriptions.")
587 for topic in self._subscriptions:
588 await self._publish_retained_messages_for_subscription((topic, QOS_0), client_session)
590 await self._client_message_loop(client_session, handler)
592 async def _client_message_loop(self, client_session: Session, handler: BrokerProtocolHandler) -> None:
593 """Run the main loop to handle client messages."""
594 # Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
595 disconnect_waiter = asyncio.ensure_future(handler.wait_disconnect())
596 subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
597 unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription())
598 wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message())
599 connected = True
601 while connected:
602 try:
603 done, _ = await asyncio.wait(
604 [
605 disconnect_waiter,
606 subscribe_waiter,
607 unsubscribe_waiter,
608 wait_deliver,
609 ],
610 return_when=asyncio.FIRST_COMPLETED,
611 )
613 if disconnect_waiter in done:
614 # handle the disconnection: normal or abnormal result, either way, the client is no longer connected
615 await self._handle_disconnect(client_session, handler, disconnect_waiter)
616 connected = False
618 # no need to reschedule the `disconnect_waiter` since we're exiting the message loop
620 if subscribe_waiter in done:
621 await self._handle_subscription(client_session, handler, subscribe_waiter)
622 subscribe_waiter = asyncio.ensure_future(handler.get_next_pending_subscription())
623 self.logger.debug(repr(self._subscriptions))
625 if unsubscribe_waiter in done:
626 await self._handle_unsubscription(client_session, handler, unsubscribe_waiter)
627 unsubscribe_waiter = asyncio.ensure_future(handler.get_next_pending_unsubscription())
629 if wait_deliver in done:
630 if not await self._handle_message_delivery(client_session, handler, wait_deliver):
631 break
632 wait_deliver = asyncio.ensure_future(handler.mqtt_deliver_next_message())
634 except asyncio.CancelledError:
635 self.logger.debug("Client loop cancelled")
636 break
638 disconnect_waiter.cancel()
639 subscribe_waiter.cancel()
640 unsubscribe_waiter.cancel()
641 wait_deliver.cancel()
643 async def _handle_disconnect(
644 self,
645 client_session: Session,
646 handler: BrokerProtocolHandler,
647 disconnect_waiter: asyncio.Future[Any],
648 ) -> None:
649 """Handle client disconnection.
651 Args:
652 client_session (Session): client session
653 handler (BrokerProtocolHandler): broker protocol handler
654 disconnect_waiter (asyncio.Future[Any]): future to wait for disconnection
656 """
657 # check the disconnected waiter result
658 result = disconnect_waiter.result()
659 self.logger.debug(f"{client_session.client_id} Result from wait_disconnect: {result}")
660 # if the client disconnects abruptly by sending no message or the message isn't a disconnect packet
661 if result is None or not isinstance(result, DisconnectPacket):
662 self.logger.debug(f"Will flag: {client_session.will_flag}")
663 if client_session.will_flag:
664 self.logger.debug(
665 f"Client {format_client_message(client_session)} disconnected abnormally, sending will message",
666 )
667 await self._broadcast_message(
668 client_session,
669 client_session.will_topic,
670 client_session.will_message,
671 client_session.will_qos,
672 )
673 if client_session.will_retain:
674 await self.retain_message(
675 client_session,
676 client_session.will_topic,
677 client_session.will_message,
678 client_session.will_qos,
679 )
681 # normal or not, let's end the client's session
682 self.logger.debug(f"{client_session.client_id} Disconnecting session")
683 await self._stop_handler(handler)
684 client_session.transitions.disconnect()
685 await self.plugins_manager.fire_event(BrokerEvents.CLIENT_DISCONNECTED,
686 client_id=client_session.client_id,
687 client_session=client_session)
689 async def _handle_subscription(
690 self,
691 client_session: Session,
692 handler: BrokerProtocolHandler,
693 subscribe_waiter: asyncio.Future[Any],
694 ) -> None:
695 """Handle client subscription."""
696 self.logger.debug(f"{client_session.client_id} handling subscription")
697 subscriptions = subscribe_waiter.result()
698 return_codes = [await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics]
699 await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes)
700 for index, subscription in enumerate(subscriptions.topics):
701 if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
702 await self.plugins_manager.fire_event(
703 BrokerEvents.CLIENT_SUBSCRIBED,
704 client_id=client_session.client_id,
705 topic=subscription[0],
706 qos=subscription[1],
707 )
708 await self._publish_retained_messages_for_subscription(subscription, client_session)
710 async def _handle_unsubscription(
711 self,
712 client_session: Session,
713 handler: BrokerProtocolHandler,
714 unsubscribe_waiter: asyncio.Future[Any],
715 ) -> None:
716 """Handle client unsubscription."""
717 self.logger.debug(f"{client_session.client_id} handling unsubscription")
718 unsubscription = unsubscribe_waiter.result()
719 for topic in unsubscription.topics:
720 self._del_subscription(topic, client_session)
721 await self.plugins_manager.fire_event(
722 BrokerEvents.CLIENT_UNSUBSCRIBED,
723 client_id=client_session.client_id,
724 topic=topic,
725 )
726 await handler.mqtt_acknowledge_unsubscription(unsubscription.packet_id)
728 async def _handle_message_delivery(
729 self,
730 client_session: Session,
731 handler: BrokerProtocolHandler,
732 wait_deliver: asyncio.Future[Any],
733 ) -> bool:
734 """Handle message delivery to the client."""
735 self.logger.debug(f"{client_session.client_id} handling message delivery")
736 app_message = wait_deliver.result()
738 # notify of a message's receipt, even if a client isn't necessarily allowed to send it
739 await self.plugins_manager.fire_event(
740 BrokerEvents.MESSAGE_RECEIVED,
741 client_id=client_session.client_id,
742 message=app_message,
743 )
745 if app_message is None: 745 ↛ 746line 745 didn't jump to line 746 because the condition on line 745 was never true
746 self.logger.debug("app_message was empty!")
747 return True
748 if not app_message.topic: 748 ↛ 749line 748 didn't jump to line 749 because the condition on line 748 was never true
749 self.logger.warning(
750 f"[MQTT-4.7.3-1] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection",
751 )
752 return False
753 if "#" in app_message.topic or "+" in app_message.topic: 753 ↛ 754line 753 didn't jump to line 754 because the condition on line 753 was never true
754 self.logger.warning(
755 f"[MQTT-3.3.2-2] - {client_session.client_id} invalid TOPIC sent in PUBLISH message, closing connection",
756 )
757 return False
758 if app_message.topic.startswith("$"):
759 self.logger.warning(
760 f"[MQTT-4.7.2-1] - {client_session.client_id} cannot use a topic with a leading $ character."
761 )
762 return False
764 permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)
765 if not permitted:
766 self.logger.info(f"{client_session.client_id} not allowed to publish to TOPIC {app_message.topic}.")
767 else:
768 # notify that a received message is valid and is allowed to be distributed to other clients
769 await self.plugins_manager.fire_event(
770 BrokerEvents.MESSAGE_BROADCAST,
771 client_id=client_session.client_id,
772 message=app_message,
773 )
774 await self._broadcast_message(client_session, app_message.topic, app_message.data)
775 if app_message.publish_packet and app_message.publish_packet.retain_flag:
776 await self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
777 return True
779 async def _init_handler(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> BrokerProtocolHandler:
780 """Create a BrokerProtocolHandler and attach to a session."""
781 handler = BrokerProtocolHandler(self.plugins_manager, loop=self._loop)
782 handler.attach(session, reader, writer)
783 return handler
785 async def _stop_handler(self, handler: BrokerProtocolHandler) -> None:
786 """Stop a running handler and detach if from the session."""
787 try:
788 await handler.stop()
789 # a failure in stopping a handler shouldn't cause the broker to fail
790 except asyncio.QueueEmpty:
791 self.logger.exception("Failed to stop handler")
793 async def _authenticate(self, session: Session, _: ListenerConfig) -> bool:
794 """Call the authenticate method on registered plugins to test user authentication.
796 User is considered authenticated if all plugins called returns True.
797 Plugins authenticate() method are supposed to return :
798 - True if user is authentication succeed
799 - False if user authentication fails
800 - None if authentication can't be achieved (then plugin result is then ignored)
801 :param session:
802 :return:
803 """
804 returns = await self.plugins_manager.map_plugin_auth(session=session)
806 results = [result for _, result in returns.items() if result is not None] if returns else []
807 if len(results) < 1:
808 self.logger.debug("Authentication failed: no plugin responded with a boolean")
809 return False
811 if all(results):
812 self.logger.debug("Authentication succeeded")
813 return True
815 for plugin, result in returns.items():
816 self.logger.debug(f"Authentication '{plugin.__class__.__name__}' result: {result}")
818 return False
820 async def retain_message(
821 self,
822 source_session: Session | None,
823 topic_name: str | None,
824 data: bytes | bytearray | None,
825 qos: int | None = None,
826 ) -> None:
827 if data and topic_name is not None:
828 # If retained flag set, store the message for further subscriptions
829 self.logger.debug(f"Retaining message on topic {topic_name}")
830 self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
832 await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
833 client_id=None,
834 retained_message=self._retained_messages[topic_name])
836 # [MQTT-3.3.1-10]
837 elif topic_name in self._retained_messages:
838 self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
840 cleared_message = self._retained_messages[topic_name]
841 cleared_message.data = b""
843 await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
844 client_id=None,
845 retained_message=cleared_message)
847 del self._retained_messages[topic_name]
849 async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
850 topic_filter, qos = subscription
851 if "#" in topic_filter and not topic_filter.endswith("#"): 851 ↛ 853line 851 didn't jump to line 853 because the condition on line 851 was never true
852 # [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
853 return 0x80
854 if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter):
855 # [MQTT-4.7.1-3] + wildcard character must occupy entire level
856 return 0x80
857 # Check if the client is authorised to connect to the topic
858 if not await self._topic_filtering(session, topic_filter, Action.SUBSCRIBE):
859 return 0x80
861 # Ensure "max-qos" is an integer before using it
862 max_qos = self.config.get("max-qos", qos)
863 if not isinstance(max_qos, int): 863 ↛ 864line 863 didn't jump to line 864 because the condition on line 863 was never true
864 max_qos = qos
866 qos = min(qos, max_qos)
867 if topic_filter not in self._subscriptions:
868 self._subscriptions[topic_filter] = []
869 if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]):
870 self._subscriptions[topic_filter].append((session, qos))
871 else:
872 self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}")
873 return qos
875 async def _topic_filtering(self, session: Session, topic: str, action: Action) -> bool:
876 """Call the topic_filtering method on registered plugins to check that the subscription is allowed.
878 User is considered allowed if all plugins called return True.
879 Plugins topic_filtering() method are supposed to return :
880 - True if MQTT client can be subscribed to the topic
881 - False if MQTT client is not allowed to subscribe to the topic
882 - None if topic filtering can't be achieved (then plugin result is then ignored)
883 :param session:
884 :param topic: Topic in which the client wants to subscribe / publish
885 :param action: What is being done with the topic? subscribe or publish
886 :return:
887 """
888 if not self.plugins_manager.is_topic_filtering_enabled():
889 return True
891 results = await self.plugins_manager.map_plugin_topic(session=session, topic=topic, action=action)
892 return all(result for result in results.values())
894 async def _delete_session(self, client_id: str) -> None:
895 """Delete an existing session data, for example due to clean session set in CONNECT."""
896 session = self._sessions.pop(client_id, (None, None))[0]
898 if session is None:
899 self.logger.debug(f"Delete session : session {client_id} doesn't exist")
900 return
901 self.logger.debug(f"Deleted existing session {session!r}")
903 # Delete subscriptions
904 self.logger.debug(f"Deleting session {session!r} subscriptions")
905 await self._del_all_subscriptions(session)
906 session.clear_queues()
908 async def _del_all_subscriptions(self, session: Session) -> None:
909 """Delete all topic subscriptions for a given session."""
910 filter_queue: deque[str] = deque()
911 for topic in self._subscriptions:
912 if self._del_subscription(topic, session):
913 filter_queue.append(topic)
914 for topic in filter_queue:
915 if not self._subscriptions[topic]:
916 del self._subscriptions[topic]
918 def _del_subscription(self, a_filter: str, session: Session) -> int:
919 """Delete a session subscription on a given topic.
921 :param a_filter: The topic filter for the subscription.
922 :param session: The session to be unsubscribed.
923 :return: The number of deleted subscriptions (0 or 1).
924 """
925 deleted = 0
926 try:
927 subscriptions = self._subscriptions[a_filter]
928 for index, (sub_session, _qos) in enumerate(subscriptions):
929 if sub_session.client_id == session.client_id:
930 self.logger.debug(
931 f"Removing subscription on topic '{a_filter}' for client {format_client_message(session=session)}",
932 )
933 subscriptions.pop(index)
934 deleted += 1
935 break
936 except KeyError:
937 self.logger.debug(f"Unsubscription on topic '{a_filter}' for client {format_client_message(session=session)}")
939 return deleted
941 async def _broadcast_loop(self) -> None:
942 """Run the main loop to broadcast messages."""
943 running_tasks: deque[asyncio.Task[OutgoingApplicationMessage]] = self._tasks_queue
945 try:
946 while True:
947 while running_tasks and running_tasks[0].done():
948 task = running_tasks.popleft()
949 try:
950 task.result()
951 except CancelledError:
952 self.logger.info(f"Task has been cancelled: {task}")
953 # if a task fails, don't want it to cause the broker to fail
954 except Exception: # pylint: disable=W0718
955 self.logger.exception(f"Task failed and will be skipped: {task}")
957 run_broadcast_task = asyncio.ensure_future(self._run_broadcast(running_tasks))
959 completed, _ = await asyncio.wait(
960 [run_broadcast_task, self._broadcast_shutdown_waiter],
961 return_when=asyncio.FIRST_COMPLETED,
962 )
964 # Shutdown has been triggered by the broker, so stop the loop execution
965 if self._broadcast_shutdown_waiter in completed:
966 run_broadcast_task.cancel()
967 break
969 except BaseException:
970 self.logger.exception("Broadcast loop stopped by exception")
971 raise
972 finally:
973 # Wait until current broadcasting tasks end
974 if running_tasks:
975 await asyncio.gather(*running_tasks)
977 async def _run_broadcast(self, running_tasks: deque[asyncio.Task[OutgoingApplicationMessage]]) -> None:
978 """Process a single broadcast message."""
979 broadcast = await self._broadcast_queue.get()
981 self.logger.debug(f"Processing broadcast message: {broadcast}")
983 for k_filter, subscriptions in self._subscriptions.items():
985 # Skip all subscriptions which do not match the topic
986 if not self._matches(broadcast["topic"], k_filter):
987 self.logger.debug(f"Topic '{broadcast['topic']}' does not match filter '{k_filter}'")
988 continue
990 for target_session, sub_qos in subscriptions:
991 qos = broadcast.get("qos", sub_qos)
993 sendable = await self._topic_filtering(target_session, topic=broadcast["topic"], action=Action.RECEIVE)
994 if not sendable: 994 ↛ 995line 994 didn't jump to line 995 because the condition on line 994 was never true
995 self.logger.info(
996 f"{target_session.client_id} not allowed to receive messages from TOPIC {broadcast['topic']}.")
997 continue
999 # Retain all messages which cannot be broadcasted, due to the session not being connected
1000 # but only when clean session is false and qos is 1 or 2 [MQTT 3.1.2.4]
1001 # and, if a client used anonymous authentication, there is no expectation that messages should be retained
1002 if (target_session.transitions.state != "connected"
1003 and not target_session.clean_session
1004 and qos in (QOS_1, QOS_2)
1005 and not target_session.is_anonymous):
1006 self.logger.debug(f"Session {target_session.client_id} is not connected, retaining message.")
1007 await self._retain_broadcast_message(broadcast, qos, target_session)
1008 continue
1010 # Only broadcast the message to connected clients
1011 if target_session.transitions.state != "connected":
1012 continue
1014 self.logger.debug(
1015 f"Broadcasting message from {format_client_message(session=broadcast['session'])}"
1016 f" on topic '{broadcast['topic']}' to {format_client_message(session=target_session)}",
1017 )
1019 handler = self._get_handler(target_session)
1020 if handler: 1020 ↛ 990line 1020 didn't jump to line 990 because the condition on line 1020 was always true
1021 task = asyncio.ensure_future(
1022 handler.mqtt_publish(
1023 broadcast["topic"],
1024 broadcast["data"],
1025 qos,
1026 retain=False,
1027 ),
1028 )
1029 running_tasks.append(task)
1031 async def _retain_broadcast_message(self, broadcast: dict[str, Any], qos: int, target_session: Session) -> None:
1032 if self.logger.isEnabledFor(logging.DEBUG): 1032 ↛ 1038line 1032 didn't jump to line 1038 because the condition on line 1032 was always true
1033 self.logger.debug(
1034 f"retaining application message from {format_client_message(session=broadcast['session'])}"
1035 f" on topic '{broadcast['topic']}' to client '{format_client_message(session=target_session)}'",
1036 )
1038 retained_message = RetainedApplicationMessage(broadcast["session"], broadcast["topic"], broadcast["data"], qos)
1039 await target_session.retained_messages.put(retained_message)
1041 await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
1042 client_id=target_session.client_id,
1043 retained_message=retained_message)
1045 if self.logger.isEnabledFor(logging.DEBUG): 1045 ↛ exitline 1045 didn't return from function '_retain_broadcast_message' because the condition on line 1045 was always true
1046 self.logger.debug(f"target_session.retained_messages={target_session.retained_messages.qsize()}")
1048 async def _shutdown_broadcast_loop(self) -> None:
1049 if self._broadcast_task and not self._broadcast_shutdown_waiter.done(): 1049 ↛ 1056line 1049 didn't jump to line 1056 because the condition on line 1049 was always true
1050 self._broadcast_shutdown_waiter.set_result(True)
1051 try:
1052 await asyncio.wait_for(self._broadcast_task, timeout=30)
1053 except TimeoutError as e:
1054 self.logger.warning(f"Failed to cleanly shutdown broadcast loop: {e}")
1056 if not self._broadcast_queue.empty():
1057 self.logger.warning(f"{self._broadcast_queue.qsize()} messages not broadcasted")
1059 self._broadcast_shutdown_waiter = asyncio.Future()
1061 async def _broadcast_message(
1062 self,
1063 session: Session | None,
1064 topic: str | None,
1065 data: bytes | bytearray | None,
1066 force_qos: int | None = None,
1067 ) -> None:
1068 broadcast: _BROADCAST = {"session": session, "topic": topic, "data": data}
1069 if force_qos is not None:
1070 broadcast["qos"] = force_qos
1071 await self._broadcast_queue.put(broadcast)
1073 async def _publish_session_retained_messages(self, session: Session) -> None:
1074 self.logger.debug(
1075 f"Publishing {session.retained_messages.qsize()}"
1076 f" messages retained for session {format_client_message(session=session)}",
1077 )
1078 publish_tasks = []
1079 handler = self._get_handler(session)
1080 if handler: 1080 ↛ 1088line 1080 didn't jump to line 1088 because the condition on line 1080 was always true
1081 while not session.retained_messages.empty():
1082 retained = await session.retained_messages.get()
1083 publish_tasks.append(
1084 asyncio.ensure_future(
1085 handler.mqtt_publish(retained.topic, retained.data, retained.qos, retain=True),
1086 ),
1087 )
1088 if publish_tasks:
1089 await asyncio.wait(publish_tasks)
1091 async def _publish_retained_messages_for_subscription(self, subscription: tuple[str, int], session: Session) -> None:
1092 self.logger.debug(
1093 f"Begin broadcasting messages retained due to subscription on '{subscription[0]}'"
1094 f" from {format_client_message(session=session)}",
1095 )
1096 publish_tasks = []
1098 topic_filter, qos = subscription
1099 for topic, retained in self._retained_messages.items():
1100 self.logger.debug(f"matching : {topic} {topic_filter}")
1101 if self._matches(topic, topic_filter):
1102 self.logger.debug(f"{topic} and {topic_filter} match")
1103 handler = self._get_handler(session)
1104 if handler: 1104 ↛ 1099line 1104 didn't jump to line 1099 because the condition on line 1104 was always true
1105 publish_tasks.append(
1106 asyncio.Task(
1107 handler.mqtt_publish(retained.topic, retained.data, min(qos, retained.qos or qos), retain=True),
1108 ),
1109 )
1110 if publish_tasks:
1111 await asyncio.wait(publish_tasks)
1112 self.logger.debug(
1113 f"End broadcasting messages retained due to subscription on '{subscription[0]}'"
1114 f" from {format_client_message(session=session)}",
1115 )
1117 def _matches(self, topic: str, a_filter: str) -> bool:
1118 if topic.startswith("$") and (a_filter.startswith(("+", "#"))):
1119 self.logger.debug("[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #")
1120 return False
1122 if "#" not in a_filter and "+" not in a_filter:
1123 # if filter doesn't contain wildcard, return exact match
1124 return a_filter == topic
1126 # else use regex (re.compile is an expensive operation, store the matcher for future use)
1127 if a_filter not in self._topic_filter_matchers:
1128 self._topic_filter_matchers[a_filter] = re.compile(re.escape(a_filter)
1129 .replace("\\#", "?.*")
1130 .replace("\\+", "[^/]*")
1131 .lstrip("?"))
1132 match_pattern = self._topic_filter_matchers[a_filter]
1133 return bool(match_pattern.fullmatch(topic))
1135 def _get_handler(self, session: Session) -> BrokerProtocolHandler | None:
1136 client_id = session.client_id
1137 if client_id: 1137 ↛ 1139line 1137 didn't jump to line 1139 because the condition on line 1137 was always true
1138 return self._sessions.get(client_id, (None, None))[1]
1139 return None
1141 @classmethod
1142 def _split_bindaddr_port(cls, port_str: str, default_port: int) -> tuple[str | None, int]:
1143 """Split an address:port pair into separate IP address and port. with IPv6 special-case handling.
1145 - Address can be specified using one of the following methods:
1146 - empty string - all interfaces default port
1147 - 1883 - Port number only (listen all interfaces)
1148 - :1883 - Port number only (listen all interfaces)
1149 - 0.0.0.0:1883 - IPv4 address
1150 - [::]:1883 - IPv6 address
1151 """
1153 def _parse_port(port_str: str) -> int:
1154 port_str = port_str.removeprefix(":")
1156 if not port_str:
1157 return default_port
1159 return int(port_str)
1161 if port_str.startswith("["): # IPv6 literal
1162 try:
1163 addr_end = port_str.index("]")
1164 except ValueError as e:
1165 msg = "Expecting '[' to be followed by ']'"
1166 raise ValueError(msg) from e
1168 return (port_str[0 : addr_end + 1], _parse_port(port_str[addr_end + 1 :]))
1170 if ":" in port_str:
1171 address, port_str = port_str.rsplit(":", 1)
1172 return (address or None, _parse_port(port_str))
1174 try:
1175 return (None, _parse_port(port_str))
1176 except ValueError:
1177 return (port_str, default_port)
1179 @property
1180 def subscriptions(self) -> dict[str, list[tuple[Session, int]]]:
1181 return self._subscriptions
1183 @property
1184 def retained_messages(self) -> dict[str, RetainedApplicationMessage]:
1185 return self._retained_messages
1187 @property
1188 def sessions(self) -> dict[str, tuple[Session, BrokerProtocolHandler]]:
1189 return self._sessions