Coverage for amqtt/mqtt/protocol/broker_handler.py: 72%

168 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1import asyncio 

2from asyncio import AbstractEventLoop, Queue 

3from typing import TYPE_CHECKING 

4 

5from amqtt.adapters import ReaderAdapter, WriterAdapter 

6from amqtt.errors import MQTTError 

7from amqtt.events import MQTTEvents 

8from amqtt.mqtt.connack import ( 

9 BAD_USERNAME_PASSWORD, 

10 CONNECTION_ACCEPTED, 

11 IDENTIFIER_REJECTED, 

12 NOT_AUTHORIZED, 

13 UNACCEPTABLE_PROTOCOL_VERSION, 

14 ConnackPacket, 

15) 

16from amqtt.mqtt.connect import ConnectPacket 

17from amqtt.mqtt.disconnect import DisconnectPacket 

18from amqtt.mqtt.pingreq import PingReqPacket 

19from amqtt.mqtt.pingresp import PingRespPacket 

20from amqtt.mqtt.protocol.handler import ProtocolHandler 

21from amqtt.mqtt.suback import SubackPacket 

22from amqtt.mqtt.subscribe import SubscribePacket 

23from amqtt.mqtt.unsuback import UnsubackPacket 

24from amqtt.mqtt.unsubscribe import UnsubscribePacket 

25from amqtt.plugins.manager import PluginManager 

26from amqtt.session import Session 

27from amqtt.utils import format_client_message 

28 

29_MQTT_PROTOCOL_LEVEL_SUPPORTED = 4 

30 

31if TYPE_CHECKING: 

32 from amqtt.broker import BrokerContext 

33 

34 

35class Subscription: 

36 def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None: 

37 self.packet_id = packet_id 

38 self.topics = topics 

39 

40 

41class UnSubscription: 

42 def __init__(self, packet_id: int, topics: list[str]) -> None: 

43 self.packet_id = packet_id 

44 self.topics = topics 

45 

46 

47class BrokerProtocolHandler(ProtocolHandler["BrokerContext"]): 

48 def __init__( 

49 self, 

50 plugins_manager: PluginManager["BrokerContext"], 

51 session: Session | None = None, 

52 loop: AbstractEventLoop | None = None, 

53 ) -> None: 

54 super().__init__(plugins_manager, session, loop) 

55 self._disconnect_waiter: asyncio.Future[DisconnectPacket | None] | None = None 

56 self._pending_subscriptions: Queue[Subscription] = Queue() 

57 self._pending_unsubscriptions: Queue[UnSubscription] = Queue() 

58 

59 async def start(self) -> None: 

60 await super().start() 

61 # Ensure the disconnect waiter is reset 

62 if self._disconnect_waiter is None or self._disconnect_waiter.done(): 62 ↛ exitline 62 didn't return from function 'start' because the condition on line 62 was always true

63 self._disconnect_waiter = asyncio.Future() 

64 

65 async def stop(self) -> None: 

66 """Stop the protocol handler and reset the disconnect waiter.""" 

67 await super().stop() 

68 if self._disconnect_waiter is not None and not self._disconnect_waiter.done(): 68 ↛ 69line 68 didn't jump to line 69 because the condition on line 68 was never true

69 self._disconnect_waiter.set_result(None) 

70 self._disconnect_waiter = None # Reset the disconnect waiter 

71 # Clear pending subscriptions and unsubscriptions 

72 while not self._pending_subscriptions.empty(): 72 ↛ 73line 72 didn't jump to line 73 because the condition on line 72 was never true

73 self._pending_subscriptions.get_nowait() 

74 while not self._pending_unsubscriptions.empty(): 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true

75 self._pending_unsubscriptions.get_nowait() 

76 

77 async def wait_disconnect(self) -> DisconnectPacket | None: 

78 """Wait for a disconnect packet or connection closure.""" 

79 if self._disconnect_waiter is not None: 79 ↛ 81line 79 didn't jump to line 81 because the condition on line 79 was always true

80 return await self._disconnect_waiter 

81 return None 

82 

83 def handle_write_timeout(self) -> None: 

84 pass 

85 

86 def handle_read_timeout(self) -> None: 

87 pass 

88 

89 async def handle_disconnect(self, disconnect: DisconnectPacket | None) -> None: 

90 """Handle a disconnect packet and notify the disconnect waiter.""" 

91 self.logger.debug("Client disconnecting") 

92 if self._disconnect_waiter and not self._disconnect_waiter.done(): 

93 self.logger.debug(f"Setting disconnect waiter result to {disconnect!r}") 

94 self._disconnect_waiter.set_result(disconnect) 

95 self._disconnect_waiter = None # Reset the disconnect waiter to avoid reuse 

96 

97 async def handle_connection_closed(self) -> None: 

98 """Handle connection closure and notify the disconnect waiter.""" 

99 await self.handle_disconnect(None) 

100 

101 async def handle_connect(self, connect: ConnectPacket) -> None: 

102 # Broker handler shouldn't receive CONNECT message during messages handling 

103 # as CONNECT messages are managed by the broker on client connection 

104 self.logger.error( 

105 f"{self.session.client_id if self.session else None} [MQTT-3.1.0-2] {format_client_message(self.session)} :" 

106 f" CONNECT message received during messages handling", 

107 ) 

108 if self._disconnect_waiter is not None and not self._disconnect_waiter.done(): 

109 self._disconnect_waiter.set_result(None) 

110 

111 async def handle_pingreq(self, pingreq: PingReqPacket) -> None: 

112 await self._send_packet(PingRespPacket.build()) 

113 

114 async def handle_subscribe(self, subscribe: SubscribePacket) -> None: 

115 if subscribe.variable_header is None: 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true

116 msg = "SUBSCRIBE packet: variable header not initialized." 

117 raise MQTTError(msg) 

118 if subscribe.payload is None: 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true

119 msg = "SUBSCRIBE packet: payload not initialized." 

120 raise MQTTError(msg) 

121 

122 subscription: Subscription = Subscription(subscribe.variable_header.packet_id, subscribe.payload.topics) 

123 await self._pending_subscriptions.put(subscription) 

124 

125 async def handle_unsubscribe(self, unsubscribe: UnsubscribePacket) -> None: 

126 if unsubscribe.variable_header is None: 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true

127 msg = "UNSUBSCRIBE packet: variable header not initialized." 

128 raise MQTTError(msg) 

129 if unsubscribe.payload is None: 129 ↛ 130line 129 didn't jump to line 130 because the condition on line 129 was never true

130 msg = "UNSUBSCRIBE packet: payload not initialized." 

131 raise MQTTError(msg) 

132 unsubscription: UnSubscription = UnSubscription(unsubscribe.variable_header.packet_id, unsubscribe.payload.topics) 

133 await self._pending_unsubscriptions.put(unsubscription) 

134 

135 async def get_next_pending_subscription(self) -> Subscription: 

136 return await self._pending_subscriptions.get() 

137 

138 async def get_next_pending_unsubscription(self) -> UnSubscription: 

139 return await self._pending_unsubscriptions.get() 

140 

141 async def mqtt_acknowledge_subscription(self, packet_id: int, return_codes: list[int]) -> None: 

142 suback = SubackPacket.build(packet_id, return_codes) 

143 await self._send_packet(suback) 

144 

145 async def mqtt_acknowledge_unsubscription(self, packet_id: int) -> None: 

146 unsuback = UnsubackPacket.build(packet_id) 

147 await self._send_packet(unsuback) 

148 

149 async def mqtt_connack_authorize(self, authorize: bool) -> None: 

150 if self.session is None: 150 ↛ 151line 150 didn't jump to line 151 because the condition on line 150 was never true

151 msg = "Session is not initialized!" 

152 raise MQTTError(msg) 

153 

154 connack = ConnackPacket.build(self.session.parent, CONNECTION_ACCEPTED if authorize else NOT_AUTHORIZED) 

155 await self._send_packet(connack) 

156 

157 @classmethod 

158 async def init_from_connect( 

159 cls, 

160 reader: ReaderAdapter, 

161 writer: WriterAdapter, 

162 plugins_manager: PluginManager["BrokerContext"], 

163 loop: asyncio.AbstractEventLoop | None = None, 

164 ) -> tuple["BrokerProtocolHandler", Session]: 

165 """Initialize from a CONNECT packet and validates the connection.""" 

166 connect = await ConnectPacket.from_stream(reader) 

167 await plugins_manager.fire_event(MQTTEvents.PACKET_RECEIVED, packet=connect) 

168 

169 if connect.variable_header is None: 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true

170 msg = "CONNECT packet: variable header not initialized." 

171 raise MQTTError(msg) 

172 if connect.payload is None: 172 ↛ 173line 172 didn't jump to line 173 because the condition on line 172 was never true

173 msg = "CONNECT packet: payload not initialized." 

174 raise MQTTError(msg) 

175 

176 # this shouldn't be required anymore since broker generates for each client a random client_id if not provided 

177 # [MQTT-3.1.3-6] 

178 if connect.payload.client_id is None: 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true

179 msg = "[[MQTT-3.1.3-3]] : Client identifier must be present" 

180 raise MQTTError(msg) 

181 

182 if connect.variable_header.will_flag and (connect.payload.will_topic is None or connect.payload.will_message is None): 182 ↛ 183line 182 didn't jump to line 183 because the condition on line 182 was never true

183 msg = "Will flag set, but will topic/message not present in payload" 

184 raise MQTTError(msg) 

185 

186 if connect.variable_header.reserved_flag: 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true

187 msg = "[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0" 

188 raise MQTTError(msg) 

189 

190 if connect.proto_name != "MQTT": 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true

191 msg = f'[MQTT-3.1.2-1] Incorrect protocol name: "{connect.proto_name}"' 

192 raise MQTTError(msg) 

193 

194 remote_info = writer.get_peer_info() 

195 if remote_info is not None: 195 ↛ 227line 195 didn't jump to line 227 because the condition on line 195 was always true

196 remote_address, remote_port = remote_info 

197 connack = None 

198 error_msg = None 

199 if connect.proto_level != _MQTT_PROTOCOL_LEVEL_SUPPORTED: 199 ↛ 201line 199 didn't jump to line 201 because the condition on line 199 was never true

200 # only MQTT 3.1.1 supported 

201 error_msg = ( 

202 f"Invalid protocol from {format_client_message(address=remote_address, port=remote_port)}:" 

203 f" {connect.proto_level}" 

204 ) 

205 connack = ConnackPacket.build(0, UNACCEPTABLE_PROTOCOL_VERSION) # [MQTT-3.2.2-4] session_parent=0 

206 elif not connect.username_flag and connect.password_flag: 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true

207 connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.1.2-22] 

208 elif connect.username_flag and connect.username is None: 208 ↛ 209line 208 didn't jump to line 209 because the condition on line 208 was never true

209 error_msg = f"Invalid username from {format_client_message(address=remote_address, port=remote_port)}" 

210 connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0 

211 elif connect.password_flag and connect.password is None: 211 ↛ 212line 211 didn't jump to line 212 because the condition on line 211 was never true

212 error_msg = f"Invalid password from {format_client_message(address=remote_address, port=remote_port)}" 

213 connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD) # [MQTT-3.2.2-4] session_parent=0 

214 elif connect.clean_session_flag is False and connect.payload.client_id_is_random: 

215 error_msg = ( 

216 f"[MQTT-3.1.3-8] [MQTT-3.1.3-9] {format_client_message(address=remote_address, port=remote_port)}:" 

217 " No client Id provided (cleansession=0)" 

218 ) 

219 connack = ConnackPacket.build(0, IDENTIFIER_REJECTED) 

220 

221 if connack is not None: 

222 await plugins_manager.fire_event(MQTTEvents.PACKET_SENT, packet=connack) 

223 await connack.to_stream(writer) 

224 await writer.close() 

225 raise MQTTError(error_msg) from None 

226 

227 incoming_session = Session() 

228 incoming_session.client_id = connect.client_id 

229 incoming_session.clean_session = connect.clean_session_flag 

230 incoming_session.will_flag = connect.will_flag 

231 incoming_session.will_retain = connect.will_retain_flag 

232 incoming_session.will_qos = connect.will_qos 

233 incoming_session.will_topic = connect.will_topic 

234 incoming_session.will_message = connect.will_message 

235 incoming_session.username = connect.username 

236 incoming_session.password = connect.password 

237 incoming_session.remote_address = remote_address 

238 incoming_session.remote_port = remote_port 

239 incoming_session.ssl_object = writer.get_ssl_info() 

240 

241 incoming_session.keep_alive = max(connect.keep_alive, 0) 

242 

243 if connect.keep_alive > 0: 

244 incoming_session.keep_alive = connect.keep_alive 

245 else: 

246 incoming_session.keep_alive = 0 

247 

248 handler = cls(plugins_manager, loop=loop) 

249 return handler, incoming_session