Coverage for amqtt/mqtt/protocol/client_handler.py: 81%

147 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1import asyncio 

2from typing import TYPE_CHECKING, Any 

3 

4from amqtt.errors import AMQTTError, NoDataError 

5from amqtt.events import MQTTEvents 

6from amqtt.mqtt.connack import ConnackPacket 

7from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader 

8from amqtt.mqtt.disconnect import DisconnectPacket 

9from amqtt.mqtt.pingreq import PingReqPacket 

10from amqtt.mqtt.pingresp import PingRespPacket 

11from amqtt.mqtt.protocol.handler import ProtocolHandler 

12from amqtt.mqtt.suback import SubackPacket 

13from amqtt.mqtt.subscribe import SubscribePacket 

14from amqtt.mqtt.unsuback import UnsubackPacket 

15from amqtt.mqtt.unsubscribe import UnsubscribePacket 

16from amqtt.plugins.manager import PluginManager 

17from amqtt.session import Session 

18 

19if TYPE_CHECKING: 

20 from amqtt.client import ClientContext 

21 

22 

23class ClientProtocolHandler(ProtocolHandler["ClientContext"]): 

24 def __init__( 

25 self, 

26 plugins_manager: PluginManager["ClientContext"], 

27 session: Session | None = None, 

28 loop: asyncio.AbstractEventLoop | None = None, 

29 ) -> None: 

30 super().__init__(plugins_manager, session, loop=loop) 

31 self._ping_task: asyncio.Task[Any] | None = None 

32 self._pingresp_queue: asyncio.Queue[PingRespPacket] = asyncio.Queue() 

33 self._subscriptions_waiter: dict[int, asyncio.Future[list[int]]] = {} 

34 self._unsubscriptions_waiter: dict[int, asyncio.Future[Any]] = {} 

35 self._disconnect_waiter: asyncio.Future[Any] | None = asyncio.Future() 

36 

37 async def start(self) -> None: 

38 await super().start() 

39 if self._disconnect_waiter and self._disconnect_waiter.cancelled(): 39 ↛ 40line 39 didn't jump to line 40 because the condition on line 39 was never true

40 self._disconnect_waiter = asyncio.Future() 

41 

42 async def stop(self) -> None: 

43 await super().stop() 

44 if self._ping_task and not self._ping_task.cancelled(): 

45 self.logger.debug("Cancel ping task") 

46 self._ping_task.cancel() 

47 

48 if self._disconnect_waiter and not self._disconnect_waiter.done(): 48 ↛ 49line 48 didn't jump to line 49 because the condition on line 48 was never true

49 self._disconnect_waiter.cancel() 

50 

51 def _build_connect_packet(self) -> ConnectPacket: 

52 vh = ConnectVariableHeader() 

53 payload = ConnectPayload() 

54 

55 if self.session is None: 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true

56 msg = "Session is not initialized." 

57 raise AMQTTError(msg) 

58 

59 vh.keep_alive = self.session.keep_alive 

60 vh.clean_session_flag = self.session.clean_session if self.session.clean_session is not None else False 

61 vh.will_retain_flag = self.session.will_retain if self.session.will_retain is not None else False 

62 payload.client_id = self.session.client_id 

63 

64 if self.session.username: 

65 vh.username_flag = True 

66 payload.username = self.session.username 

67 else: 

68 vh.username_flag = False 

69 

70 if self.session.password: 

71 vh.password_flag = True 

72 payload.password = self.session.password 

73 else: 

74 vh.password_flag = False 

75 

76 if self.session.will_flag: 

77 vh.will_flag = True 

78 if self.session.will_qos is not None: 78 ↛ 80line 78 didn't jump to line 80 because the condition on line 78 was always true

79 vh.will_qos = self.session.will_qos 

80 payload.will_message = self.session.will_message 

81 payload.will_topic = self.session.will_topic 

82 else: 

83 vh.will_flag = False 

84 

85 return ConnectPacket(variable_header=vh, payload=payload) 

86 

87 async def mqtt_connect(self) -> int | None: 

88 connect_packet = self._build_connect_packet() 

89 await self._send_packet(connect_packet) 

90 

91 if self.reader is None: 91 ↛ 92line 91 didn't jump to line 92 because the condition on line 91 was never true

92 msg = "Reader is not initialized." 

93 raise AMQTTError(msg) 

94 try: 

95 connack = await ConnackPacket.from_stream(self.reader) 

96 except NoDataError as e: 

97 raise ConnectionError from e 

98 await self.plugins_manager.fire_event(MQTTEvents.PACKET_RECEIVED, packet=connack, session=self.session) 

99 return connack.return_code 

100 

101 def handle_write_timeout(self) -> None: 

102 try: 

103 if not self._ping_task: 

104 self.logger.debug("Scheduling Ping") 

105 self._ping_task = asyncio.create_task(self.mqtt_ping()) 

106 except asyncio.InvalidStateError as e: 

107 self.logger.warning(f"Invalid state while scheduling ping task: {e!r}") 

108 except asyncio.CancelledError as e: 

109 self.logger.info(f"Ping task was cancelled: {e!r}") 

110 

111 def handle_read_timeout(self) -> None: 

112 pass 

113 

114 async def mqtt_subscribe(self, topics: list[tuple[str, int]], packet_id: int) -> list[int]: 

115 """Subscribe to the given topics. 

116 

117 :param topics: List of tuples, e.g. [('filter', '/a/b', 'qos': 0x00)]. 

118 :return: Return codes for the subscription. 

119 """ 

120 subscribe = SubscribePacket.build(topics, packet_id) 

121 await self._send_packet(subscribe) 

122 

123 if subscribe.variable_header is None: 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true

124 msg = f"Invalid variable header in SUBSCRIBE packet: {subscribe.variable_header}" 

125 raise AMQTTError(msg) 

126 

127 waiter: asyncio.Future[list[int]] = asyncio.Future() 

128 self._subscriptions_waiter[subscribe.variable_header.packet_id] = waiter 

129 try: 

130 return_codes = await waiter 

131 finally: 

132 del self._subscriptions_waiter[subscribe.variable_header.packet_id] 

133 return return_codes 

134 

135 async def handle_suback(self, suback: SubackPacket) -> None: 

136 if suback.variable_header is None: 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true

137 msg = "SUBACK packet: variable header not initialized." 

138 raise AMQTTError(msg) 

139 if suback.payload is None: 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true

140 msg = "SUBACK packet: payload not initialized." 

141 raise AMQTTError(msg) 

142 

143 packet_id = suback.variable_header.packet_id 

144 

145 waiter = self._subscriptions_waiter.get(packet_id) 

146 if waiter is not None: 146 ↛ 149line 146 didn't jump to line 149 because the condition on line 146 was always true

147 waiter.set_result(suback.payload.return_codes) 

148 else: 

149 self.logger.warning(f"Received SUBACK for unknown pending subscription with Id: {packet_id}") 

150 

151 async def mqtt_unsubscribe(self, topics: list[str], packet_id: int) -> None: 

152 """Unsubscribe from the given topics. 

153 

154 :param topics: List of topics ['/a/b', ...]. 

155 """ 

156 unsubscribe = UnsubscribePacket.build(topics, packet_id) 

157 

158 if unsubscribe.variable_header is None: 158 ↛ 159line 158 didn't jump to line 159 because the condition on line 158 was never true

159 msg = "UNSUBSCRIBE packet: variable header not initialized." 

160 raise AMQTTError(msg) 

161 

162 await self._send_packet(unsubscribe) 

163 waiter: asyncio.Future[Any] = asyncio.Future() 

164 self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] = waiter 

165 try: 

166 await waiter 

167 finally: 

168 del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] 

169 

170 async def handle_unsuback(self, unsuback: UnsubackPacket) -> None: 

171 if unsuback.variable_header is None: 171 ↛ 172line 171 didn't jump to line 172 because the condition on line 171 was never true

172 msg = "UNSUBACK packet: variable header not initialized." 

173 raise AMQTTError(msg) 

174 

175 packet_id = unsuback.variable_header.packet_id 

176 waiter = self._unsubscriptions_waiter.get(packet_id) 

177 if waiter is not None: 177 ↛ 180line 177 didn't jump to line 180 because the condition on line 177 was always true

178 waiter.set_result(None) 

179 else: 

180 self.logger.warning(f"Received UNSUBACK for unknown pending unsubscription with Id: {packet_id}") 

181 

182 async def mqtt_disconnect(self) -> None: 

183 disconnect_packet = DisconnectPacket() 

184 await self._send_packet(disconnect_packet) 

185 

186 async def mqtt_ping(self) -> PingRespPacket: 

187 ping_packet = PingReqPacket() 

188 try: 

189 await self._send_packet(ping_packet) 

190 resp = await self._pingresp_queue.get() 

191 finally: 

192 self._ping_task = None # Ensure the task is cleaned up 

193 return resp 

194 

195 async def handle_pingresp(self, pingresp: PingRespPacket) -> None: 

196 await self._pingresp_queue.put(pingresp) 

197 

198 async def handle_connection_closed(self) -> None: 

199 self.logger.debug("Broker closed connection") 

200 if self._disconnect_waiter is not None and not self._disconnect_waiter.done(): 

201 self._disconnect_waiter.set_result(None) 

202 

203 async def wait_disconnect(self) -> None: 

204 if self._disconnect_waiter is not None: 204 ↛ exitline 204 didn't return from function 'wait_disconnect' because the condition on line 204 was always true

205 await self._disconnect_waiter