Coverage for amqtt/session.py: 87%
141 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1from asyncio import Queue
2from collections import OrderedDict
3import logging
4from math import floor
5import time
6from typing import TYPE_CHECKING, Any, ClassVar
8from transitions import Machine
10from amqtt.errors import AMQTTError
11from amqtt.mqtt.publish import PublishPacket
13OUTGOING = 0
14INCOMING = 1
16if TYPE_CHECKING:
17 import ssl
19logger = logging.getLogger(__name__)
22class ApplicationMessage:
23 """ApplicationMessage and subclasses are used to store published message information flow.
25 These objects can contain different information depending on the way they were created (incoming or outgoing)
26 and the quality of service used between peers.
27 """
29 __slots__ = (
30 "data",
31 "packet_id",
32 "puback_packet",
33 "pubcomp_packet",
34 "publish_packet",
35 "pubrec_packet",
36 "pubrel_packet",
37 "qos",
38 "retain",
39 "topic",
40 )
42 def __init__(self, packet_id: int | None, topic: str, qos: int | None, data: bytes | bytearray, retain: bool) -> None:
43 self.packet_id: int | None = packet_id
44 """ Publish message packet identifier
45 <http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718025>_
46 """
48 self.topic: str = topic
49 """ Publish message topic"""
51 self.qos: int | None = qos
52 """ Publish message Quality of Service"""
54 self.data: bytes | bytearray = data
55 """ Publish message payload data"""
57 self.retain: bool = retain
58 """ Publish message retain flag"""
60 self.publish_packet: PublishPacket | None = None
61 """ :class:`amqtt.mqtt.publish.PublishPacket` instance corresponding to the
62 `PUBLISH <http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718037>`_ packet in the messages flow.
63 ``None`` if the PUBLISH packet has not already been received or sent."""
65 self.puback_packet: Any | None = None
66 """ :class:`amqtt.mqtt.puback.PubackPacket` instance corresponding to the
67 `PUBACK <http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718043>`_ packet in the messages flow.
68 ``None`` if QoS != QOS_1 or if the PUBACK packet has not already been received or sent."""
70 self.pubrec_packet: Any | None = None
71 """ :class:`amqtt.mqtt.puback.PubrecPacket` instance corresponding to the
72 `PUBREC <http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718048>`_ packet in the messages flow.
73 ``None`` if QoS != QOS_2 or if the PUBREC packet has not already been received or sent."""
75 self.pubrel_packet: Any | None = None
76 """ :class:`amqtt.mqtt.puback.PubrelPacket` instance corresponding to the
77 `PUBREL <http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718053>`_ packet in the messages flow.
78 ``None`` if QoS != QOS_2 or if the PUBREL packet has not already been received or sent."""
80 self.pubcomp_packet: Any | None = None
81 """ :class:`amqtt.mqtt.puback.PubrelPacket` instance corresponding to the
82 `PUBCOMP <http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718058>`_ packet in the messages flow.
83 ``None`` if QoS != QOS_2 or if the PUBCOMP packet has not already been received or sent."""
85 def build_publish_packet(self, dup: bool = False) -> PublishPacket:
86 """Build :class:`amqtt.mqtt.publish.PublishPacket` from attributes.
88 :param dup: force dup flag
89 :return: :class:`amqtt.mqtt.publish.PublishPacket` built from ApplicationMessage instance attributes
90 """
91 return PublishPacket.build(self.topic, bytes(self.data), self.packet_id, dup, self.qos, self.retain)
93 def __eq__(self, other: object) -> bool:
94 """Compare two ApplicationMessage instances based on their packet_id.
96 This method is used to check if two messages are the same based on their packet_id.
97 :param other: The other ApplicationMessage instance to compare with.
98 :return: True if the packet_id of both messages are equal, False otherwise.
99 """
100 if not isinstance(other, ApplicationMessage): 100 ↛ 101line 100 didn't jump to line 101 because the condition on line 100 was never true
101 return False
102 return self.packet_id == other.packet_id
105class IncomingApplicationMessage(ApplicationMessage):
106 """Incoming :class:~amqtt.session.ApplicationMessage."""
108 __slots__ = ("direction",)
110 def __init__(self, packet_id: int | None, topic: str, qos: int | None, data: bytes, retain: bool) -> None:
111 super().__init__(packet_id, topic, qos, data, retain)
112 self.direction: int = INCOMING
115class OutgoingApplicationMessage(ApplicationMessage):
116 """Outgoing :class:~amqtt.session.ApplicationMessage."""
118 __slots__ = ("direction",)
120 def __init__(self, packet_id: int | None, topic: str, qos: int | None, data: bytes | bytearray, retain: bool) -> None:
121 super().__init__(packet_id, topic, qos, data, retain)
122 self.direction: int = OUTGOING
125class Session:
126 states: ClassVar[list[str]] = ["new", "connected", "disconnected"]
128 def __init__(self) -> None:
129 self._init_states()
130 self.remote_address: str | None = None
131 self.remote_port: int | None = None
132 self.client_id: str | None = None
133 self.clean_session: bool | None = None
134 self.will_flag: bool = False
135 self.will_message: bytes | bytearray | None = None
136 self.will_qos: int | None = None
137 self.will_retain: bool | None = None
138 self.will_topic: str | None = None
139 self.keep_alive: int = 0
140 self.publish_retry_delay: int = 0
141 self.broker_uri: str | None = None
142 self.username: str | None = None
143 self.password: str | None = None
144 self.cafile: str | None = None
145 self.capath: str | None = None
146 self.cadata: bytes | None = None
147 self._packet_id: int = 0
148 self.parent: int = 0
149 self.last_connect_time: int | None = None
150 self.ssl_object: ssl.SSLObject | None = None
151 self.last_disconnect_time: int | None = None
153 # Used to store outgoing ApplicationMessage while publish protocol flows
154 self.inflight_out: OrderedDict[int, OutgoingApplicationMessage] = OrderedDict()
156 # Used to store incoming ApplicationMessage while publish protocol flows
157 self.inflight_in: OrderedDict[int, IncomingApplicationMessage] = OrderedDict()
159 # Stores messages retained for this session (specifically when the client is disconnected)
160 self.retained_messages: Queue[ApplicationMessage] = Queue()
162 # Stores PUBLISH messages ID received in order and ready for application process
163 self.delivered_message_queue: Queue[ApplicationMessage] = Queue()
165 # identify anonymous client sessions or clients which didn't identify themselves
166 self.is_anonymous: bool = False
168 def _init_states(self) -> None:
169 self.transitions = Machine(states=Session.states, initial="new")
170 self.transitions.add_transition(
171 trigger="connect",
172 source="new",
173 dest="connected",
174 )
175 self.transitions.on_enter_connected(self._on_enter_connected)
176 self.transitions.add_transition(
177 trigger="connect",
178 source="disconnected",
179 dest="connected",
180 )
181 self.transitions.add_transition(
182 trigger="disconnect",
183 source="connected",
184 dest="disconnected",
185 )
186 self.transitions.on_enter_disconnected(self._on_enter_disconnected)
187 self.transitions.add_transition(
188 trigger="disconnect",
189 source="new",
190 dest="disconnected",
191 )
192 self.transitions.add_transition(
193 trigger="disconnect",
194 source="disconnected",
195 dest="disconnected",
196 )
198 def _on_enter_connected(self) -> None:
199 cur_time = floor(time.time())
200 if self.last_disconnect_time is not None:
201 logger.debug(f"Session reconnected after {cur_time - self.last_disconnect_time} seconds.")
203 self.last_connect_time = cur_time
204 self.last_disconnect_time = None
206 def _on_enter_disconnected(self) -> None:
207 cur_time = floor(time.time())
208 if self.last_connect_time is not None:
209 logger.debug(f"Session disconnected after {cur_time - self.last_connect_time} seconds.")
210 self.last_disconnect_time = cur_time
212 @property
213 def next_packet_id(self) -> int:
214 self._packet_id = (self._packet_id % 65535) + 1
215 limit = self._packet_id
216 while self._packet_id in self.inflight_in or self._packet_id in self.inflight_out: 216 ↛ 217line 216 didn't jump to line 217 because the condition on line 216 was never true
217 self._packet_id = (self._packet_id % 65535) + 1
218 if self._packet_id == limit:
219 msg = "More than 65535 messages pending. No free packet ID"
220 raise AMQTTError(msg)
222 return self._packet_id
224 @property
225 def inflight_in_count(self) -> int:
226 return len(self.inflight_in)
228 @property
229 def inflight_out_count(self) -> int:
230 return len(self.inflight_out)
232 @property
233 def retained_messages_count(self) -> int:
234 return self.retained_messages.qsize()
236 def __repr__(self) -> str:
237 """Return a string representation of the session.
239 This method is used for debugging and logging purposes.
240 It includes the client ID and the current state of the session.
241 """
242 return type(self).__name__ + f"(clientId={self.client_id}, state={self.transitions.state})"
244 def __getstate__(self) -> dict[str, Any]:
245 """Return the state of the session for pickling.
247 This method is called when pickling the session object.
248 It returns a dictionary containing the session's state, excluding
249 unpicklable entries.
250 """
251 state = self.__dict__.copy()
252 # Remove the unpicklable entries.
253 del state["retained_messages"]
254 del state["delivered_message_queue"]
255 return state
257 def __setstate__(self, state: dict[str, Any]) -> None:
258 """Restore the session from its state.
260 This method is called when unpickling the session object.
261 It restores the session's state and reinitializes the queues.
262 """
263 self.__dict__.update(state)
264 self.retained_messages = Queue()
265 self.delivered_message_queue = Queue()
267 def clear_queues(self) -> None:
268 """Clear all message queues associated with the session."""
269 while not self.retained_messages.empty():
270 self.retained_messages.get_nowait()
271 while not self.delivered_message_queue.empty(): 271 ↛ 272line 271 didn't jump to line 272 because the condition on line 271 was never true
272 self.delivered_message_queue.get_nowait()
274 def __eq__(self, other: object) -> bool:
275 """Compare two Session instances based on their client_id."""
276 if not isinstance(other, Session): 276 ↛ 277line 276 didn't jump to line 277 because the condition on line 276 was never true
277 return False
278 return self.client_id == other.client_id