Coverage for amqtt/mqtt/packet.py: 92%
170 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1from abc import ABC, abstractmethod
2import asyncio
4try:
5 from datetime import UTC, datetime
6except ImportError:
7 from datetime import datetime, timezone
9 UTC = timezone.utc
11from struct import unpack
12from typing import Generic
13from typing_extensions import Self, TypeVar
15from amqtt.adapters import ReaderAdapter, WriterAdapter
16from amqtt.codecs_amqtt import bytes_to_hex_str, decode_packet_id, int_to_bytes, read_or_raise
17from amqtt.errors import CodecError, MQTTError, NoDataError
19RESERVED_0 = 0x00
20CONNECT = 0x01
21CONNACK = 0x02
22PUBLISH = 0x03
23PUBACK = 0x04
24PUBREC = 0x05
25PUBREL = 0x06
26PUBCOMP = 0x07
27SUBSCRIBE = 0x08
28SUBACK = 0x09
29UNSUBSCRIBE = 0x0A
30UNSUBACK = 0x0B
31PINGREQ = 0x0C
32PINGRESP = 0x0D
33DISCONNECT = 0x0E
34RESERVED_15 = 0x0F
37class MQTTFixedHeader:
38 """Represents the fixed header of an MQTT packet."""
40 __slots__ = ("flags", "packet_type", "remaining_length")
42 def __init__(self, packet_type: int, flags: int = 0, length: int = 0) -> None:
43 self.packet_type = packet_type
44 self.flags = flags
45 self.remaining_length = length
47 def to_bytes(self) -> bytes:
48 """Encode the fixed header to bytes."""
50 def encode_remaining_length(length: int) -> bytes:
51 """Encode the remaining length as per MQTT protocol."""
52 encoded = bytearray()
53 while True:
54 length_byte = length % 0x80
55 length //= 0x80
56 if length > 0:
57 length_byte |= 0x80
58 encoded.append(length_byte)
59 if length <= 0:
60 break
61 return bytes(encoded)
63 try:
64 packet_type_flags = (self.packet_type << 4) | self.flags
65 encoded_length = encode_remaining_length(self.remaining_length)
66 return bytes([packet_type_flags]) + encoded_length
67 except OverflowError as exc:
68 msg = f"Fixed header encoding failed: {exc}"
69 raise CodecError(msg) from exc
71 async def to_stream(self, writer: WriterAdapter) -> None:
72 """Write the fixed header to the stream."""
73 writer.write(self.to_bytes())
75 @property
76 def bytes_length(self) -> int:
77 return len(self.to_bytes())
79 @classmethod
80 async def from_stream(cls: type[Self], reader: ReaderAdapter) -> "Self | None":
81 """Decode a fixed header from the stream."""
83 async def decode_remaining_length() -> int:
84 """Decode the remaining length from the stream."""
85 multiplier: int
86 value: int
87 multiplier, value = 1, 0
88 buffer = bytearray()
89 while True:
90 encoded_byte = await reader.read(1)
91 byte_value = unpack("!B", encoded_byte)[0]
92 buffer.append(byte_value)
93 value += (byte_value & 0x7F) * multiplier
94 if (byte_value & 0x80) == 0:
95 break
96 multiplier *= 128
97 if multiplier > 128**3:
98 msg = f"Invalid remaining length bytes:{bytes_to_hex_str(buffer)}, packet_type={packet_type}"
99 raise MQTTError(msg)
100 return value
102 try:
103 byte1 = await read_or_raise(reader, 1)
104 int1 = unpack("!B", byte1)[0]
105 packet_type = (int1 & 0xF0) >> 4
106 flags = int1 & 0x0F
107 remaining_length = await decode_remaining_length()
108 return cls(packet_type, flags, remaining_length)
109 except NoDataError:
110 return None
112 def __repr__(self) -> str:
113 """Return a string representation of the MQTTFixedHeader object."""
114 return f"{self.__class__.__name__}(packet_type={self.packet_type}, flags={self.flags}, length={self.remaining_length})"
117_FH = TypeVar("_FH", bound=MQTTFixedHeader)
120class MQTTVariableHeader(ABC):
121 """Abstract base class for MQTT variable headers."""
123 async def to_stream(self, writer: asyncio.StreamWriter) -> None:
124 writer.write(self.to_bytes())
125 await writer.drain()
127 @abstractmethod
128 def to_bytes(self) -> bytes | bytearray:
129 """Serialize the variable header to bytes."""
131 @property
132 def bytes_length(self) -> int:
133 return len(self.to_bytes())
135 @classmethod
136 @abstractmethod
137 async def from_stream(cls: type[Self], reader: ReaderAdapter, fixed_header: MQTTFixedHeader) -> Self:
138 pass
141class PacketIdVariableHeader(MQTTVariableHeader):
142 """Represents a variable header containing a packet ID."""
144 __slots__ = ("packet_id",)
146 def __init__(self, packet_id: int) -> None:
147 super().__init__()
148 self.packet_id = packet_id
150 def to_bytes(self) -> bytes:
151 return int_to_bytes(self.packet_id, 2)
153 @classmethod
154 async def from_stream(
155 cls: type[Self],
156 reader: ReaderAdapter,
157 _: MQTTFixedHeader | None = None,
158 ) -> Self:
159 packet_id = await decode_packet_id(reader)
160 return cls(packet_id)
162 def __repr__(self) -> str:
163 """Return a string representation of the PacketIdVariableHeader object."""
164 return f"{self.__class__.__name__}(packet_id={self.packet_id})"
167_VH = TypeVar("_VH", bound=MQTTVariableHeader | None)
170class MQTTPayload(ABC, Generic[_VH]):
171 """Abstract base class for MQTT payloads."""
173 async def to_stream(self, writer: asyncio.StreamWriter) -> None:
174 writer.write(self.to_bytes())
175 await writer.drain()
177 @abstractmethod
178 def to_bytes(self, fixed_header: MQTTFixedHeader | None = None, variable_header: _VH | None = None) -> bytes | bytearray:
179 pass
181 @classmethod
182 @abstractmethod
183 async def from_stream(
184 cls: type[Self],
185 reader: asyncio.StreamReader | ReaderAdapter,
186 fixed_header: MQTTFixedHeader | None,
187 variable_header: _VH | None,
188 ) -> Self:
189 pass
192_P = TypeVar("_P", bound=MQTTPayload[MQTTVariableHeader] | None)
195class MQTTPacket(Generic[_VH, _P, _FH]):
196 """Represents an MQTT packet."""
198 __slots__ = ("fixed_header", "payload", "protocol_ts", "variable_header")
200 VARIABLE_HEADER: type[_VH] | None = None
201 PAYLOAD: type[_P] | None = None
202 FIXED_HEADER: type[_FH] = MQTTFixedHeader # type: ignore [assignment]
204 def __init__(self, fixed: _FH, variable_header: _VH | None = None, payload: _P | None = None) -> None:
205 self.fixed_header = fixed
206 self.variable_header = variable_header
207 self.payload = payload
208 self.protocol_ts: datetime | None = None
210 async def to_stream(self, writer: WriterAdapter) -> None:
211 """Write the entire packet to the stream."""
212 writer.write(self.to_bytes())
213 await writer.drain()
214 self.protocol_ts = datetime.now(UTC)
216 def to_bytes(self) -> bytes:
217 """Serialize the packet into bytes."""
218 variable_header_bytes = self.variable_header.to_bytes() if self.variable_header is not None else b""
219 payload_bytes = self.payload.to_bytes(self.fixed_header, self.variable_header) if self.payload is not None else b""
221 fixed_header_bytes = b""
222 if self.fixed_header: 222 ↛ 226line 222 didn't jump to line 226 because the condition on line 222 was always true
223 self.fixed_header.remaining_length = len(variable_header_bytes) + len(payload_bytes)
224 fixed_header_bytes = self.fixed_header.to_bytes()
226 return fixed_header_bytes + variable_header_bytes + payload_bytes
228 @classmethod
229 async def from_stream(
230 cls: type[Self],
231 reader: ReaderAdapter,
232 fixed_header: _FH | None = None,
233 variable_header: _VH | None = None,
234 ) -> Self:
235 """Decode an MQTT packet from the stream."""
236 if fixed_header is None:
237 fixed_header = await cls.FIXED_HEADER.from_stream(reader)
239 if cls.VARIABLE_HEADER and variable_header is None:
240 variable_header = await cls.VARIABLE_HEADER.from_stream(reader, fixed_header)
242 if cls.PAYLOAD and fixed_header:
243 payload = await cls.PAYLOAD.from_stream(reader, fixed_header, variable_header)
244 else:
245 payload = None
247 if fixed_header and not variable_header and not payload:
248 instance = cls(fixed_header)
249 elif fixed_header and not payload:
250 instance = cls(fixed_header, variable_header)
251 else:
252 instance = cls(fixed_header, variable_header, payload)
253 instance.protocol_ts = datetime.now(UTC)
254 return instance
256 @property
257 def bytes_length(self) -> int:
258 return len(self.to_bytes())
260 def __repr__(self) -> str:
261 """Return a string representation of the packet."""
262 return (
263 f"{self.__class__.__name__}(ts={self.protocol_ts}, "
264 f"fixed={self.fixed_header}, variable={self.variable_header}, payload={self.payload})"
265 )