Coverage for amqtt/mqtt/protocol/broker_handler.py: 72%
168 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1import asyncio
2from asyncio import AbstractEventLoop, Queue
3from typing import TYPE_CHECKING
5from amqtt.adapters import ReaderAdapter, WriterAdapter
6from amqtt.errors import MQTTError
7from amqtt.events import MQTTEvents
8from amqtt.mqtt.connack import (
9 BAD_USERNAME_PASSWORD,
10 CONNECTION_ACCEPTED,
11 IDENTIFIER_REJECTED,
12 NOT_AUTHORIZED,
13 UNACCEPTABLE_PROTOCOL_VERSION,
14 ConnackPacket,
15)
16from amqtt.mqtt.connect import ConnectPacket
17from amqtt.mqtt.disconnect import DisconnectPacket
18from amqtt.mqtt.pingreq import PingReqPacket
19from amqtt.mqtt.pingresp import PingRespPacket
20from amqtt.mqtt.protocol.handler import ProtocolHandler
21from amqtt.mqtt.suback import SubackPacket
22from amqtt.mqtt.subscribe import SubscribePacket
23from amqtt.mqtt.unsuback import UnsubackPacket
24from amqtt.mqtt.unsubscribe import UnsubscribePacket
25from amqtt.plugins.manager import PluginManager
26from amqtt.session import Session
27from amqtt.utils import format_client_message
29_MQTT_PROTOCOL_LEVEL_SUPPORTED = 4
31if TYPE_CHECKING:
32 from amqtt.broker import BrokerContext
35class Subscription:
36 def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None:
37 self.packet_id = packet_id
38 self.topics = topics
41class UnSubscription:
42 def __init__(self, packet_id: int, topics: list[str]) -> None:
43 self.packet_id = packet_id
44 self.topics = topics
47class BrokerProtocolHandler(ProtocolHandler["BrokerContext"]):
48 def __init__(
49 self,
50 plugins_manager: PluginManager["BrokerContext"],
51 session: Session | None = None,
52 loop: AbstractEventLoop | None = None,
53 ) -> None:
54 super().__init__(plugins_manager, session, loop)
55 self._disconnect_waiter: asyncio.Future[DisconnectPacket | None] | None = None
56 self._pending_subscriptions: Queue[Subscription] = Queue()
57 self._pending_unsubscriptions: Queue[UnSubscription] = Queue()
59 async def start(self) -> None:
60 await super().start()
61 # Ensure the disconnect waiter is reset
62 if self._disconnect_waiter is None or self._disconnect_waiter.done(): 62 ↛ exitline 62 didn't return from function 'start' because the condition on line 62 was always true
63 self._disconnect_waiter = asyncio.Future()
65 async def stop(self) -> None:
66 """Stop the protocol handler and reset the disconnect waiter."""
67 await super().stop()
68 if self._disconnect_waiter is not None and not self._disconnect_waiter.done(): 68 ↛ 69line 68 didn't jump to line 69 because the condition on line 68 was never true
69 self._disconnect_waiter.set_result(None)
70 self._disconnect_waiter = None # Reset the disconnect waiter
71 # Clear pending subscriptions and unsubscriptions
72 while not self._pending_subscriptions.empty(): 72 ↛ 73line 72 didn't jump to line 73 because the condition on line 72 was never true
73 self._pending_subscriptions.get_nowait()
74 while not self._pending_unsubscriptions.empty(): 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true
75 self._pending_unsubscriptions.get_nowait()
77 async def wait_disconnect(self) -> DisconnectPacket | None:
78 """Wait for a disconnect packet or connection closure."""
79 if self._disconnect_waiter is not None: 79 ↛ 81line 79 didn't jump to line 81 because the condition on line 79 was always true
80 return await self._disconnect_waiter
81 return None
83 def handle_write_timeout(self) -> None:
84 pass
86 def handle_read_timeout(self) -> None:
87 pass
89 async def handle_disconnect(self, disconnect: DisconnectPacket | None) -> None:
90 """Handle a disconnect packet and notify the disconnect waiter."""
91 self.logger.debug("Client disconnecting")
92 if self._disconnect_waiter and not self._disconnect_waiter.done():
93 self.logger.debug(f"Setting disconnect waiter result to {disconnect!r}")
94 self._disconnect_waiter.set_result(disconnect)
95 self._disconnect_waiter = None # Reset the disconnect waiter to avoid reuse
97 async def handle_connection_closed(self) -> None:
98 """Handle connection closure and notify the disconnect waiter."""
99 await self.handle_disconnect(None)
101 async def handle_connect(self, connect: ConnectPacket) -> None:
102 # Broker handler shouldn't receive CONNECT message during messages handling
103 # as CONNECT messages are managed by the broker on client connection
104 self.logger.error(
105 f"{self.session.client_id if self.session else None} [MQTT-3.1.0-2] {format_client_message(self.session)} :"
106 f" CONNECT message received during messages handling",
107 )
108 if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
109 self._disconnect_waiter.set_result(None)
111 async def handle_pingreq(self, pingreq: PingReqPacket) -> None:
112 await self._send_packet(PingRespPacket.build())
114 async def handle_subscribe(self, subscribe: SubscribePacket) -> None:
115 if subscribe.variable_header is None: 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true
116 msg = "SUBSCRIBE packet: variable header not initialized."
117 raise MQTTError(msg)
118 if subscribe.payload is None: 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true
119 msg = "SUBSCRIBE packet: payload not initialized."
120 raise MQTTError(msg)
122 subscription: Subscription = Subscription(subscribe.variable_header.packet_id, subscribe.payload.topics)
123 await self._pending_subscriptions.put(subscription)
125 async def handle_unsubscribe(self, unsubscribe: UnsubscribePacket) -> None:
126 if unsubscribe.variable_header is None: 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true
127 msg = "UNSUBSCRIBE packet: variable header not initialized."
128 raise MQTTError(msg)
129 if unsubscribe.payload is None: 129 ↛ 130line 129 didn't jump to line 130 because the condition on line 129 was never true
130 msg = "UNSUBSCRIBE packet: payload not initialized."
131 raise MQTTError(msg)
132 unsubscription: UnSubscription = UnSubscription(unsubscribe.variable_header.packet_id, unsubscribe.payload.topics)
133 await self._pending_unsubscriptions.put(unsubscription)
135 async def get_next_pending_subscription(self) -> Subscription:
136 return await self._pending_subscriptions.get()
138 async def get_next_pending_unsubscription(self) -> UnSubscription:
139 return await self._pending_unsubscriptions.get()
141 async def mqtt_acknowledge_subscription(self, packet_id: int, return_codes: list[int]) -> None:
142 suback = SubackPacket.build(packet_id, return_codes)
143 await self._send_packet(suback)
145 async def mqtt_acknowledge_unsubscription(self, packet_id: int) -> None:
146 unsuback = UnsubackPacket.build(packet_id)
147 await self._send_packet(unsuback)
149 async def mqtt_connack_authorize(self, authorize: bool) -> None:
150 if self.session is None: 150 ↛ 151line 150 didn't jump to line 151 because the condition on line 150 was never true
151 msg = "Session is not initialized!"
152 raise MQTTError(msg)
154 connack = ConnackPacket.build(self.session.parent, CONNECTION_ACCEPTED if authorize else NOT_AUTHORIZED)
155 await self._send_packet(connack)
157 @classmethod
158 async def init_from_connect(
159 cls,
160 reader: ReaderAdapter,
161 writer: WriterAdapter,
162 plugins_manager: PluginManager["BrokerContext"],
163 loop: asyncio.AbstractEventLoop | None = None,
164 ) -> tuple["BrokerProtocolHandler", Session]:
165 """Initialize from a CONNECT packet and validates the connection."""
166 connect = await ConnectPacket.from_stream(reader)
167 await plugins_manager.fire_event(MQTTEvents.PACKET_RECEIVED, packet=connect)
169 if connect.variable_header is None: 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true
170 msg = "CONNECT packet: variable header not initialized."
171 raise MQTTError(msg)
172 if connect.payload is None: 172 ↛ 173line 172 didn't jump to line 173 because the condition on line 172 was never true
173 msg = "CONNECT packet: payload not initialized."
174 raise MQTTError(msg)
176 # this shouldn't be required anymore since broker generates for each client a random client_id if not provided
177 # [MQTT-3.1.3-6]
178 if connect.payload.client_id is None: 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true
179 msg = "[[MQTT-3.1.3-3]] : Client identifier must be present"
180 raise MQTTError(msg)
182 if connect.variable_header.will_flag and (connect.payload.will_topic is None or connect.payload.will_message is None): 182 ↛ 183line 182 didn't jump to line 183 because the condition on line 182 was never true
183 msg = "Will flag set, but will topic/message not present in payload"
184 raise MQTTError(msg)
186 if connect.variable_header.reserved_flag: 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true
187 msg = "[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0"
188 raise MQTTError(msg)
190 if connect.proto_name != "MQTT": 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true
191 msg = f'[MQTT-3.1.2-1] Incorrect protocol name: "{connect.proto_name}"'
192 raise MQTTError(msg)
194 remote_info = writer.get_peer_info()
195 if remote_info is not None: 195 ↛ 227line 195 didn't jump to line 227 because the condition on line 195 was always true
196 remote_address, remote_port = remote_info
197 connack = None
198 error_msg = None
199 if connect.proto_level != _MQTT_PROTOCOL_LEVEL_SUPPORTED: 199 ↛ 201line 199 didn't jump to line 201 because the condition on line 199 was never true
200 # only MQTT 3.1.1 supported
201 error_msg = (
202 f"Invalid protocol from {format_client_message(address=remote_address, port=remote_port)}:"
203 f" {connect.proto_level}"
204 )
205 connack = ConnackPacket.build(0, UNACCEPTABLE_PROTOCOL_VERSION) # [MQTT-3.2.2-4] session_parent=0
206 elif not connect.username_flag and connect.password_flag: 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true
207 connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.1.2-22]
208 elif connect.username_flag and connect.username is None: 208 ↛ 209line 208 didn't jump to line 209 because the condition on line 208 was never true
209 error_msg = f"Invalid username from {format_client_message(address=remote_address, port=remote_port)}"
210 connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0
211 elif connect.password_flag and connect.password is None: 211 ↛ 212line 211 didn't jump to line 212 because the condition on line 211 was never true
212 error_msg = f"Invalid password from {format_client_message(address=remote_address, port=remote_port)}"
213 connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0
214 elif connect.clean_session_flag is False and connect.payload.client_id_is_random:
215 error_msg = (
216 f"[MQTT-3.1.3-8] [MQTT-3.1.3-9] {format_client_message(address=remote_address, port=remote_port)}:"
217 " No client Id provided (cleansession=0)"
218 )
219 connack = ConnackPacket.build(0, IDENTIFIER_REJECTED)
221 if connack is not None:
222 await plugins_manager.fire_event(MQTTEvents.PACKET_SENT, packet=connack)
223 await connack.to_stream(writer)
224 await writer.close()
225 raise MQTTError(error_msg) from None
227 incoming_session = Session()
228 incoming_session.client_id = connect.client_id
229 incoming_session.clean_session = connect.clean_session_flag
230 incoming_session.will_flag = connect.will_flag
231 incoming_session.will_retain = connect.will_retain_flag
232 incoming_session.will_qos = connect.will_qos
233 incoming_session.will_topic = connect.will_topic
234 incoming_session.will_message = connect.will_message
235 incoming_session.username = connect.username
236 incoming_session.password = connect.password
237 incoming_session.remote_address = remote_address
238 incoming_session.remote_port = remote_port
239 incoming_session.ssl_object = writer.get_ssl_info()
241 incoming_session.keep_alive = max(connect.keep_alive, 0)
243 if connect.keep_alive > 0:
244 incoming_session.keep_alive = connect.keep_alive
245 else:
246 incoming_session.keep_alive = 0
248 handler = cls(plugins_manager, loop=loop)
249 return handler, incoming_session