Coverage for amqtt/session.py: 87%

141 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1from asyncio import Queue 

2from collections import OrderedDict 

3import logging 

4from math import floor 

5import time 

6from typing import TYPE_CHECKING, Any, ClassVar 

7 

8from transitions import Machine 

9 

10from amqtt.errors import AMQTTError 

11from amqtt.mqtt.publish import PublishPacket 

12 

13OUTGOING = 0 

14INCOMING = 1 

15 

16if TYPE_CHECKING: 

17 import ssl 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22class ApplicationMessage: 

23 """ApplicationMessage and subclasses are used to store published message information flow. 

24 

25 These objects can contain different information depending on the way they were created (incoming or outgoing) 

26 and the quality of service used between peers. 

27 """ 

28 

29 __slots__ = ( 

30 "data", 

31 "packet_id", 

32 "puback_packet", 

33 "pubcomp_packet", 

34 "publish_packet", 

35 "pubrec_packet", 

36 "pubrel_packet", 

37 "qos", 

38 "retain", 

39 "topic", 

40 ) 

41 

42 def __init__(self, packet_id: int | None, topic: str, qos: int | None, data: bytes | bytearray, retain: bool) -> None: 

43 self.packet_id: int | None = packet_id 

44 """ Publish message packet identifier 

45 <http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718025>_ 

46 """ 

47 

48 self.topic: str = topic 

49 """ Publish message topic""" 

50 

51 self.qos: int | None = qos 

52 """ Publish message Quality of Service""" 

53 

54 self.data: bytes | bytearray = data 

55 """ Publish message payload data""" 

56 

57 self.retain: bool = retain 

58 """ Publish message retain flag""" 

59 

60 self.publish_packet: PublishPacket | None = None 

61 """ :class:`amqtt.mqtt.publish.PublishPacket` instance corresponding to the 

62 `PUBLISH <http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718037>`_ packet in the messages flow. 

63 ``None`` if the PUBLISH packet has not already been received or sent.""" 

64 

65 self.puback_packet: Any | None = None 

66 """ :class:`amqtt.mqtt.puback.PubackPacket` instance corresponding to the 

67 `PUBACK <http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718043>`_ packet in the messages flow. 

68 ``None`` if QoS != QOS_1 or if the PUBACK packet has not already been received or sent.""" 

69 

70 self.pubrec_packet: Any | None = None 

71 """ :class:`amqtt.mqtt.puback.PubrecPacket` instance corresponding to the 

72 `PUBREC <http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718048>`_ packet in the messages flow. 

73 ``None`` if QoS != QOS_2 or if the PUBREC packet has not already been received or sent.""" 

74 

75 self.pubrel_packet: Any | None = None 

76 """ :class:`amqtt.mqtt.puback.PubrelPacket` instance corresponding to the 

77 `PUBREL <http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718053>`_ packet in the messages flow. 

78 ``None`` if QoS != QOS_2 or if the PUBREL packet has not already been received or sent.""" 

79 

80 self.pubcomp_packet: Any | None = None 

81 """ :class:`amqtt.mqtt.puback.PubrelPacket` instance corresponding to the 

82 `PUBCOMP <http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718058>`_ packet in the messages flow. 

83 ``None`` if QoS != QOS_2 or if the PUBCOMP packet has not already been received or sent.""" 

84 

85 def build_publish_packet(self, dup: bool = False) -> PublishPacket: 

86 """Build :class:`amqtt.mqtt.publish.PublishPacket` from attributes. 

87 

88 :param dup: force dup flag 

89 :return: :class:`amqtt.mqtt.publish.PublishPacket` built from ApplicationMessage instance attributes 

90 """ 

91 return PublishPacket.build(self.topic, bytes(self.data), self.packet_id, dup, self.qos, self.retain) 

92 

93 def __eq__(self, other: object) -> bool: 

94 """Compare two ApplicationMessage instances based on their packet_id. 

95 

96 This method is used to check if two messages are the same based on their packet_id. 

97 :param other: The other ApplicationMessage instance to compare with. 

98 :return: True if the packet_id of both messages are equal, False otherwise. 

99 """ 

100 if not isinstance(other, ApplicationMessage): 100 ↛ 101line 100 didn't jump to line 101 because the condition on line 100 was never true

101 return False 

102 return self.packet_id == other.packet_id 

103 

104 

105class IncomingApplicationMessage(ApplicationMessage): 

106 """Incoming :class:~amqtt.session.ApplicationMessage.""" 

107 

108 __slots__ = ("direction",) 

109 

110 def __init__(self, packet_id: int | None, topic: str, qos: int | None, data: bytes, retain: bool) -> None: 

111 super().__init__(packet_id, topic, qos, data, retain) 

112 self.direction: int = INCOMING 

113 

114 

115class OutgoingApplicationMessage(ApplicationMessage): 

116 """Outgoing :class:~amqtt.session.ApplicationMessage.""" 

117 

118 __slots__ = ("direction",) 

119 

120 def __init__(self, packet_id: int | None, topic: str, qos: int | None, data: bytes | bytearray, retain: bool) -> None: 

121 super().__init__(packet_id, topic, qos, data, retain) 

122 self.direction: int = OUTGOING 

123 

124 

125class Session: 

126 states: ClassVar[list[str]] = ["new", "connected", "disconnected"] 

127 

128 def __init__(self) -> None: 

129 self._init_states() 

130 self.remote_address: str | None = None 

131 self.remote_port: int | None = None 

132 self.client_id: str | None = None 

133 self.clean_session: bool | None = None 

134 self.will_flag: bool = False 

135 self.will_message: bytes | bytearray | None = None 

136 self.will_qos: int | None = None 

137 self.will_retain: bool | None = None 

138 self.will_topic: str | None = None 

139 self.keep_alive: int = 0 

140 self.publish_retry_delay: int = 0 

141 self.broker_uri: str | None = None 

142 self.username: str | None = None 

143 self.password: str | None = None 

144 self.cafile: str | None = None 

145 self.capath: str | None = None 

146 self.cadata: bytes | None = None 

147 self._packet_id: int = 0 

148 self.parent: int = 0 

149 self.last_connect_time: int | None = None 

150 self.ssl_object: ssl.SSLObject | None = None 

151 self.last_disconnect_time: int | None = None 

152 

153 # Used to store outgoing ApplicationMessage while publish protocol flows 

154 self.inflight_out: OrderedDict[int, OutgoingApplicationMessage] = OrderedDict() 

155 

156 # Used to store incoming ApplicationMessage while publish protocol flows 

157 self.inflight_in: OrderedDict[int, IncomingApplicationMessage] = OrderedDict() 

158 

159 # Stores messages retained for this session (specifically when the client is disconnected) 

160 self.retained_messages: Queue[ApplicationMessage] = Queue() 

161 

162 # Stores PUBLISH messages ID received in order and ready for application process 

163 self.delivered_message_queue: Queue[ApplicationMessage] = Queue() 

164 

165 # identify anonymous client sessions or clients which didn't identify themselves 

166 self.is_anonymous: bool = False 

167 

168 def _init_states(self) -> None: 

169 self.transitions = Machine(states=Session.states, initial="new") 

170 self.transitions.add_transition( 

171 trigger="connect", 

172 source="new", 

173 dest="connected", 

174 ) 

175 self.transitions.on_enter_connected(self._on_enter_connected) 

176 self.transitions.add_transition( 

177 trigger="connect", 

178 source="disconnected", 

179 dest="connected", 

180 ) 

181 self.transitions.add_transition( 

182 trigger="disconnect", 

183 source="connected", 

184 dest="disconnected", 

185 ) 

186 self.transitions.on_enter_disconnected(self._on_enter_disconnected) 

187 self.transitions.add_transition( 

188 trigger="disconnect", 

189 source="new", 

190 dest="disconnected", 

191 ) 

192 self.transitions.add_transition( 

193 trigger="disconnect", 

194 source="disconnected", 

195 dest="disconnected", 

196 ) 

197 

198 def _on_enter_connected(self) -> None: 

199 cur_time = floor(time.time()) 

200 if self.last_disconnect_time is not None: 

201 logger.debug(f"Session reconnected after {cur_time - self.last_disconnect_time} seconds.") 

202 

203 self.last_connect_time = cur_time 

204 self.last_disconnect_time = None 

205 

206 def _on_enter_disconnected(self) -> None: 

207 cur_time = floor(time.time()) 

208 if self.last_connect_time is not None: 

209 logger.debug(f"Session disconnected after {cur_time - self.last_connect_time} seconds.") 

210 self.last_disconnect_time = cur_time 

211 

212 @property 

213 def next_packet_id(self) -> int: 

214 self._packet_id = (self._packet_id % 65535) + 1 

215 limit = self._packet_id 

216 while self._packet_id in self.inflight_in or self._packet_id in self.inflight_out: 216 ↛ 217line 216 didn't jump to line 217 because the condition on line 216 was never true

217 self._packet_id = (self._packet_id % 65535) + 1 

218 if self._packet_id == limit: 

219 msg = "More than 65535 messages pending. No free packet ID" 

220 raise AMQTTError(msg) 

221 

222 return self._packet_id 

223 

224 @property 

225 def inflight_in_count(self) -> int: 

226 return len(self.inflight_in) 

227 

228 @property 

229 def inflight_out_count(self) -> int: 

230 return len(self.inflight_out) 

231 

232 @property 

233 def retained_messages_count(self) -> int: 

234 return self.retained_messages.qsize() 

235 

236 def __repr__(self) -> str: 

237 """Return a string representation of the session. 

238 

239 This method is used for debugging and logging purposes. 

240 It includes the client ID and the current state of the session. 

241 """ 

242 return type(self).__name__ + f"(clientId={self.client_id}, state={self.transitions.state})" 

243 

244 def __getstate__(self) -> dict[str, Any]: 

245 """Return the state of the session for pickling. 

246 

247 This method is called when pickling the session object. 

248 It returns a dictionary containing the session's state, excluding 

249 unpicklable entries. 

250 """ 

251 state = self.__dict__.copy() 

252 # Remove the unpicklable entries. 

253 del state["retained_messages"] 

254 del state["delivered_message_queue"] 

255 return state 

256 

257 def __setstate__(self, state: dict[str, Any]) -> None: 

258 """Restore the session from its state. 

259 

260 This method is called when unpickling the session object. 

261 It restores the session's state and reinitializes the queues. 

262 """ 

263 self.__dict__.update(state) 

264 self.retained_messages = Queue() 

265 self.delivered_message_queue = Queue() 

266 

267 def clear_queues(self) -> None: 

268 """Clear all message queues associated with the session.""" 

269 while not self.retained_messages.empty(): 

270 self.retained_messages.get_nowait() 

271 while not self.delivered_message_queue.empty(): 271 ↛ 272line 271 didn't jump to line 272 because the condition on line 271 was never true

272 self.delivered_message_queue.get_nowait() 

273 

274 def __eq__(self, other: object) -> bool: 

275 """Compare two Session instances based on their client_id.""" 

276 if not isinstance(other, Session): 276 ↛ 277line 276 didn't jump to line 277 because the condition on line 276 was never true

277 return False 

278 return self.client_id == other.client_id