Coverage for amqtt/mqtt/protocol/client_handler.py: 81%
147 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1import asyncio
2from typing import TYPE_CHECKING, Any
4from amqtt.errors import AMQTTError, NoDataError
5from amqtt.events import MQTTEvents
6from amqtt.mqtt.connack import ConnackPacket
7from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader
8from amqtt.mqtt.disconnect import DisconnectPacket
9from amqtt.mqtt.pingreq import PingReqPacket
10from amqtt.mqtt.pingresp import PingRespPacket
11from amqtt.mqtt.protocol.handler import ProtocolHandler
12from amqtt.mqtt.suback import SubackPacket
13from amqtt.mqtt.subscribe import SubscribePacket
14from amqtt.mqtt.unsuback import UnsubackPacket
15from amqtt.mqtt.unsubscribe import UnsubscribePacket
16from amqtt.plugins.manager import PluginManager
17from amqtt.session import Session
19if TYPE_CHECKING:
20 from amqtt.client import ClientContext
23class ClientProtocolHandler(ProtocolHandler["ClientContext"]):
24 def __init__(
25 self,
26 plugins_manager: PluginManager["ClientContext"],
27 session: Session | None = None,
28 loop: asyncio.AbstractEventLoop | None = None,
29 ) -> None:
30 super().__init__(plugins_manager, session, loop=loop)
31 self._ping_task: asyncio.Task[Any] | None = None
32 self._pingresp_queue: asyncio.Queue[PingRespPacket] = asyncio.Queue()
33 self._subscriptions_waiter: dict[int, asyncio.Future[list[int]]] = {}
34 self._unsubscriptions_waiter: dict[int, asyncio.Future[Any]] = {}
35 self._disconnect_waiter: asyncio.Future[Any] | None = asyncio.Future()
37 async def start(self) -> None:
38 await super().start()
39 if self._disconnect_waiter and self._disconnect_waiter.cancelled(): 39 ↛ 40line 39 didn't jump to line 40 because the condition on line 39 was never true
40 self._disconnect_waiter = asyncio.Future()
42 async def stop(self) -> None:
43 await super().stop()
44 if self._ping_task and not self._ping_task.cancelled():
45 self.logger.debug("Cancel ping task")
46 self._ping_task.cancel()
48 if self._disconnect_waiter and not self._disconnect_waiter.done(): 48 ↛ 49line 48 didn't jump to line 49 because the condition on line 48 was never true
49 self._disconnect_waiter.cancel()
51 def _build_connect_packet(self) -> ConnectPacket:
52 vh = ConnectVariableHeader()
53 payload = ConnectPayload()
55 if self.session is None: 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true
56 msg = "Session is not initialized."
57 raise AMQTTError(msg)
59 vh.keep_alive = self.session.keep_alive
60 vh.clean_session_flag = self.session.clean_session if self.session.clean_session is not None else False
61 vh.will_retain_flag = self.session.will_retain if self.session.will_retain is not None else False
62 payload.client_id = self.session.client_id
64 if self.session.username:
65 vh.username_flag = True
66 payload.username = self.session.username
67 else:
68 vh.username_flag = False
70 if self.session.password:
71 vh.password_flag = True
72 payload.password = self.session.password
73 else:
74 vh.password_flag = False
76 if self.session.will_flag:
77 vh.will_flag = True
78 if self.session.will_qos is not None: 78 ↛ 80line 78 didn't jump to line 80 because the condition on line 78 was always true
79 vh.will_qos = self.session.will_qos
80 payload.will_message = self.session.will_message
81 payload.will_topic = self.session.will_topic
82 else:
83 vh.will_flag = False
85 return ConnectPacket(variable_header=vh, payload=payload)
87 async def mqtt_connect(self) -> int | None:
88 connect_packet = self._build_connect_packet()
89 await self._send_packet(connect_packet)
91 if self.reader is None: 91 ↛ 92line 91 didn't jump to line 92 because the condition on line 91 was never true
92 msg = "Reader is not initialized."
93 raise AMQTTError(msg)
94 try:
95 connack = await ConnackPacket.from_stream(self.reader)
96 except NoDataError as e:
97 raise ConnectionError from e
98 await self.plugins_manager.fire_event(MQTTEvents.PACKET_RECEIVED, packet=connack, session=self.session)
99 return connack.return_code
101 def handle_write_timeout(self) -> None:
102 try:
103 if not self._ping_task:
104 self.logger.debug("Scheduling Ping")
105 self._ping_task = asyncio.create_task(self.mqtt_ping())
106 except asyncio.InvalidStateError as e:
107 self.logger.warning(f"Invalid state while scheduling ping task: {e!r}")
108 except asyncio.CancelledError as e:
109 self.logger.info(f"Ping task was cancelled: {e!r}")
111 def handle_read_timeout(self) -> None:
112 pass
114 async def mqtt_subscribe(self, topics: list[tuple[str, int]], packet_id: int) -> list[int]:
115 """Subscribe to the given topics.
117 :param topics: List of tuples, e.g. [('filter', '/a/b', 'qos': 0x00)].
118 :return: Return codes for the subscription.
119 """
120 subscribe = SubscribePacket.build(topics, packet_id)
121 await self._send_packet(subscribe)
123 if subscribe.variable_header is None: 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true
124 msg = f"Invalid variable header in SUBSCRIBE packet: {subscribe.variable_header}"
125 raise AMQTTError(msg)
127 waiter: asyncio.Future[list[int]] = asyncio.Future()
128 self._subscriptions_waiter[subscribe.variable_header.packet_id] = waiter
129 try:
130 return_codes = await waiter
131 finally:
132 del self._subscriptions_waiter[subscribe.variable_header.packet_id]
133 return return_codes
135 async def handle_suback(self, suback: SubackPacket) -> None:
136 if suback.variable_header is None: 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true
137 msg = "SUBACK packet: variable header not initialized."
138 raise AMQTTError(msg)
139 if suback.payload is None: 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true
140 msg = "SUBACK packet: payload not initialized."
141 raise AMQTTError(msg)
143 packet_id = suback.variable_header.packet_id
145 waiter = self._subscriptions_waiter.get(packet_id)
146 if waiter is not None: 146 ↛ 149line 146 didn't jump to line 149 because the condition on line 146 was always true
147 waiter.set_result(suback.payload.return_codes)
148 else:
149 self.logger.warning(f"Received SUBACK for unknown pending subscription with Id: {packet_id}")
151 async def mqtt_unsubscribe(self, topics: list[str], packet_id: int) -> None:
152 """Unsubscribe from the given topics.
154 :param topics: List of topics ['/a/b', ...].
155 """
156 unsubscribe = UnsubscribePacket.build(topics, packet_id)
158 if unsubscribe.variable_header is None: 158 ↛ 159line 158 didn't jump to line 159 because the condition on line 158 was never true
159 msg = "UNSUBSCRIBE packet: variable header not initialized."
160 raise AMQTTError(msg)
162 await self._send_packet(unsubscribe)
163 waiter: asyncio.Future[Any] = asyncio.Future()
164 self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] = waiter
165 try:
166 await waiter
167 finally:
168 del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id]
170 async def handle_unsuback(self, unsuback: UnsubackPacket) -> None:
171 if unsuback.variable_header is None: 171 ↛ 172line 171 didn't jump to line 172 because the condition on line 171 was never true
172 msg = "UNSUBACK packet: variable header not initialized."
173 raise AMQTTError(msg)
175 packet_id = unsuback.variable_header.packet_id
176 waiter = self._unsubscriptions_waiter.get(packet_id)
177 if waiter is not None: 177 ↛ 180line 177 didn't jump to line 180 because the condition on line 177 was always true
178 waiter.set_result(None)
179 else:
180 self.logger.warning(f"Received UNSUBACK for unknown pending unsubscription with Id: {packet_id}")
182 async def mqtt_disconnect(self) -> None:
183 disconnect_packet = DisconnectPacket()
184 await self._send_packet(disconnect_packet)
186 async def mqtt_ping(self) -> PingRespPacket:
187 ping_packet = PingReqPacket()
188 try:
189 await self._send_packet(ping_packet)
190 resp = await self._pingresp_queue.get()
191 finally:
192 self._ping_task = None # Ensure the task is cleaned up
193 return resp
195 async def handle_pingresp(self, pingresp: PingRespPacket) -> None:
196 await self._pingresp_queue.put(pingresp)
198 async def handle_connection_closed(self) -> None:
199 self.logger.debug("Broker closed connection")
200 if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
201 self._disconnect_waiter.set_result(None)
203 async def wait_disconnect(self) -> None:
204 if self._disconnect_waiter is not None: 204 ↛ exitline 204 didn't return from function 'wait_disconnect' because the condition on line 204 was always true
205 await self._disconnect_waiter