Coverage for amqtt/mqtt/packet.py: 92%

170 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1from abc import ABC, abstractmethod 

2import asyncio 

3 

4try: 

5 from datetime import UTC, datetime 

6except ImportError: 

7 from datetime import datetime, timezone 

8 

9 UTC = timezone.utc 

10 

11from struct import unpack 

12from typing import Generic 

13from typing_extensions import Self, TypeVar 

14 

15from amqtt.adapters import ReaderAdapter, WriterAdapter 

16from amqtt.codecs_amqtt import bytes_to_hex_str, decode_packet_id, int_to_bytes, read_or_raise 

17from amqtt.errors import CodecError, MQTTError, NoDataError 

18 

19RESERVED_0 = 0x00 

20CONNECT = 0x01 

21CONNACK = 0x02 

22PUBLISH = 0x03 

23PUBACK = 0x04 

24PUBREC = 0x05 

25PUBREL = 0x06 

26PUBCOMP = 0x07 

27SUBSCRIBE = 0x08 

28SUBACK = 0x09 

29UNSUBSCRIBE = 0x0A 

30UNSUBACK = 0x0B 

31PINGREQ = 0x0C 

32PINGRESP = 0x0D 

33DISCONNECT = 0x0E 

34RESERVED_15 = 0x0F 

35 

36 

37class MQTTFixedHeader: 

38 """Represents the fixed header of an MQTT packet.""" 

39 

40 __slots__ = ("flags", "packet_type", "remaining_length") 

41 

42 def __init__(self, packet_type: int, flags: int = 0, length: int = 0) -> None: 

43 self.packet_type = packet_type 

44 self.flags = flags 

45 self.remaining_length = length 

46 

47 def to_bytes(self) -> bytes: 

48 """Encode the fixed header to bytes.""" 

49 

50 def encode_remaining_length(length: int) -> bytes: 

51 """Encode the remaining length as per MQTT protocol.""" 

52 encoded = bytearray() 

53 while True: 

54 length_byte = length % 0x80 

55 length //= 0x80 

56 if length > 0: 

57 length_byte |= 0x80 

58 encoded.append(length_byte) 

59 if length <= 0: 

60 break 

61 return bytes(encoded) 

62 

63 try: 

64 packet_type_flags = (self.packet_type << 4) | self.flags 

65 encoded_length = encode_remaining_length(self.remaining_length) 

66 return bytes([packet_type_flags]) + encoded_length 

67 except OverflowError as exc: 

68 msg = f"Fixed header encoding failed: {exc}" 

69 raise CodecError(msg) from exc 

70 

71 async def to_stream(self, writer: WriterAdapter) -> None: 

72 """Write the fixed header to the stream.""" 

73 writer.write(self.to_bytes()) 

74 

75 @property 

76 def bytes_length(self) -> int: 

77 return len(self.to_bytes()) 

78 

79 @classmethod 

80 async def from_stream(cls: type[Self], reader: ReaderAdapter) -> "Self | None": 

81 """Decode a fixed header from the stream.""" 

82 

83 async def decode_remaining_length() -> int: 

84 """Decode the remaining length from the stream.""" 

85 multiplier: int 

86 value: int 

87 multiplier, value = 1, 0 

88 buffer = bytearray() 

89 while True: 

90 encoded_byte = await reader.read(1) 

91 byte_value = unpack("!B", encoded_byte)[0] 

92 buffer.append(byte_value) 

93 value += (byte_value & 0x7F) * multiplier 

94 if (byte_value & 0x80) == 0: 

95 break 

96 multiplier *= 128 

97 if multiplier > 128**3: 

98 msg = f"Invalid remaining length bytes:{bytes_to_hex_str(buffer)}, packet_type={packet_type}" 

99 raise MQTTError(msg) 

100 return value 

101 

102 try: 

103 byte1 = await read_or_raise(reader, 1) 

104 int1 = unpack("!B", byte1)[0] 

105 packet_type = (int1 & 0xF0) >> 4 

106 flags = int1 & 0x0F 

107 remaining_length = await decode_remaining_length() 

108 return cls(packet_type, flags, remaining_length) 

109 except NoDataError: 

110 return None 

111 

112 def __repr__(self) -> str: 

113 """Return a string representation of the MQTTFixedHeader object.""" 

114 return f"{self.__class__.__name__}(packet_type={self.packet_type}, flags={self.flags}, length={self.remaining_length})" 

115 

116 

117_FH = TypeVar("_FH", bound=MQTTFixedHeader) 

118 

119 

120class MQTTVariableHeader(ABC): 

121 """Abstract base class for MQTT variable headers.""" 

122 

123 async def to_stream(self, writer: asyncio.StreamWriter) -> None: 

124 writer.write(self.to_bytes()) 

125 await writer.drain() 

126 

127 @abstractmethod 

128 def to_bytes(self) -> bytes | bytearray: 

129 """Serialize the variable header to bytes.""" 

130 

131 @property 

132 def bytes_length(self) -> int: 

133 return len(self.to_bytes()) 

134 

135 @classmethod 

136 @abstractmethod 

137 async def from_stream(cls: type[Self], reader: ReaderAdapter, fixed_header: MQTTFixedHeader) -> Self: 

138 pass 

139 

140 

141class PacketIdVariableHeader(MQTTVariableHeader): 

142 """Represents a variable header containing a packet ID.""" 

143 

144 __slots__ = ("packet_id",) 

145 

146 def __init__(self, packet_id: int) -> None: 

147 super().__init__() 

148 self.packet_id = packet_id 

149 

150 def to_bytes(self) -> bytes: 

151 return int_to_bytes(self.packet_id, 2) 

152 

153 @classmethod 

154 async def from_stream( 

155 cls: type[Self], 

156 reader: ReaderAdapter, 

157 _: MQTTFixedHeader | None = None, 

158 ) -> Self: 

159 packet_id = await decode_packet_id(reader) 

160 return cls(packet_id) 

161 

162 def __repr__(self) -> str: 

163 """Return a string representation of the PacketIdVariableHeader object.""" 

164 return f"{self.__class__.__name__}(packet_id={self.packet_id})" 

165 

166 

167_VH = TypeVar("_VH", bound=MQTTVariableHeader | None) 

168 

169 

170class MQTTPayload(ABC, Generic[_VH]): 

171 """Abstract base class for MQTT payloads.""" 

172 

173 async def to_stream(self, writer: asyncio.StreamWriter) -> None: 

174 writer.write(self.to_bytes()) 

175 await writer.drain() 

176 

177 @abstractmethod 

178 def to_bytes(self, fixed_header: MQTTFixedHeader | None = None, variable_header: _VH | None = None) -> bytes | bytearray: 

179 pass 

180 

181 @classmethod 

182 @abstractmethod 

183 async def from_stream( 

184 cls: type[Self], 

185 reader: asyncio.StreamReader | ReaderAdapter, 

186 fixed_header: MQTTFixedHeader | None, 

187 variable_header: _VH | None, 

188 ) -> Self: 

189 pass 

190 

191 

192_P = TypeVar("_P", bound=MQTTPayload[MQTTVariableHeader] | None) 

193 

194 

195class MQTTPacket(Generic[_VH, _P, _FH]): 

196 """Represents an MQTT packet.""" 

197 

198 __slots__ = ("fixed_header", "payload", "protocol_ts", "variable_header") 

199 

200 VARIABLE_HEADER: type[_VH] | None = None 

201 PAYLOAD: type[_P] | None = None 

202 FIXED_HEADER: type[_FH] = MQTTFixedHeader # type: ignore [assignment] 

203 

204 def __init__(self, fixed: _FH, variable_header: _VH | None = None, payload: _P | None = None) -> None: 

205 self.fixed_header = fixed 

206 self.variable_header = variable_header 

207 self.payload = payload 

208 self.protocol_ts: datetime | None = None 

209 

210 async def to_stream(self, writer: WriterAdapter) -> None: 

211 """Write the entire packet to the stream.""" 

212 writer.write(self.to_bytes()) 

213 await writer.drain() 

214 self.protocol_ts = datetime.now(UTC) 

215 

216 def to_bytes(self) -> bytes: 

217 """Serialize the packet into bytes.""" 

218 variable_header_bytes = self.variable_header.to_bytes() if self.variable_header is not None else b"" 

219 payload_bytes = self.payload.to_bytes(self.fixed_header, self.variable_header) if self.payload is not None else b"" 

220 

221 fixed_header_bytes = b"" 

222 if self.fixed_header: 222 ↛ 226line 222 didn't jump to line 226 because the condition on line 222 was always true

223 self.fixed_header.remaining_length = len(variable_header_bytes) + len(payload_bytes) 

224 fixed_header_bytes = self.fixed_header.to_bytes() 

225 

226 return fixed_header_bytes + variable_header_bytes + payload_bytes 

227 

228 @classmethod 

229 async def from_stream( 

230 cls: type[Self], 

231 reader: ReaderAdapter, 

232 fixed_header: _FH | None = None, 

233 variable_header: _VH | None = None, 

234 ) -> Self: 

235 """Decode an MQTT packet from the stream.""" 

236 if fixed_header is None: 

237 fixed_header = await cls.FIXED_HEADER.from_stream(reader) 

238 

239 if cls.VARIABLE_HEADER and variable_header is None: 

240 variable_header = await cls.VARIABLE_HEADER.from_stream(reader, fixed_header) 

241 

242 if cls.PAYLOAD and fixed_header: 

243 payload = await cls.PAYLOAD.from_stream(reader, fixed_header, variable_header) 

244 else: 

245 payload = None 

246 

247 if fixed_header and not variable_header and not payload: 

248 instance = cls(fixed_header) 

249 elif fixed_header and not payload: 

250 instance = cls(fixed_header, variable_header) 

251 else: 

252 instance = cls(fixed_header, variable_header, payload) 

253 instance.protocol_ts = datetime.now(UTC) 

254 return instance 

255 

256 @property 

257 def bytes_length(self) -> int: 

258 return len(self.to_bytes()) 

259 

260 def __repr__(self) -> str: 

261 """Return a string representation of the packet.""" 

262 return ( 

263 f"{self.__class__.__name__}(ts={self.protocol_ts}, " 

264 f"fixed={self.fixed_header}, variable={self.variable_header}, payload={self.payload})" 

265 )