Coverage for amqtt/mqtt/publish.py: 96%
141 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1import asyncio
2from typing_extensions import Self
4from amqtt.adapters import ReaderAdapter
5from amqtt.codecs_amqtt import decode_packet_id, decode_string, encode_string, int_to_bytes
6from amqtt.errors import AMQTTError, MQTTError
7from amqtt.mqtt.packet import PUBLISH, MQTTFixedHeader, MQTTPacket, MQTTPayload, MQTTVariableHeader
10class PublishVariableHeader(MQTTVariableHeader):
11 __slots__ = ("packet_id", "topic_name")
13 def __init__(self, topic_name: str, packet_id: int | None = None) -> None:
14 super().__init__()
15 if "#" in topic_name or "+" in topic_name:
16 msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters."
17 raise MQTTError(msg)
18 self.topic_name = topic_name
19 self.packet_id = packet_id
21 def __repr__(self) -> str:
22 """Return a string representation of the PublishVariableHeader object."""
23 return f"{type(self).__name__}(topic={self.topic_name}, packet_id={self.packet_id})"
25 def to_bytes(self) -> bytes | bytearray:
26 out = bytearray()
27 out.extend(encode_string(self.topic_name))
28 if self.packet_id is not None:
29 out.extend(int_to_bytes(self.packet_id, 2))
30 return out
32 @classmethod
33 async def from_stream(cls, reader: ReaderAdapter | asyncio.StreamReader, fixed_header: MQTTFixedHeader) -> Self:
34 topic_name = await decode_string(reader)
35 has_qos = (fixed_header.flags >> 1) & 0x03
36 packet_id = await decode_packet_id(reader) if has_qos else None
37 return cls(topic_name, packet_id)
40class PublishPayload(MQTTPayload[MQTTVariableHeader]):
41 __slots__ = ("data",)
43 def __init__(self, data: bytes | None = None) -> None:
44 super().__init__()
45 self.data = data
47 def to_bytes(
48 self,
49 fixed_header: MQTTFixedHeader | None = None,
50 variable_header: MQTTVariableHeader | None = None,
51 ) -> bytes:
52 return self.data if self.data is not None else b""
54 @classmethod
55 async def from_stream(
56 cls,
57 reader: asyncio.StreamReader | ReaderAdapter,
58 fixed_header: MQTTFixedHeader | None,
59 variable_header: MQTTVariableHeader | None,
60 ) -> Self:
61 data = bytearray()
62 if fixed_header is None or variable_header is None: 62 ↛ 63line 62 didn't jump to line 63 because the condition on line 62 was never true
63 msg = "Fixed header or variable header cannot be None"
64 raise ValueError(msg)
66 data_length = fixed_header.remaining_length - variable_header.bytes_length
67 length_read = 0
68 while length_read < data_length:
69 buffer = await reader.read(data_length - length_read)
70 data.extend(buffer)
71 length_read = len(data)
72 return cls(bytes(data))
74 def __repr__(self) -> str:
75 """Return a string representation of the PublishPayload object."""
76 return f"{type(self).__name__}(data={repr(self.data)!r})"
79class PublishPacket(MQTTPacket[PublishVariableHeader, PublishPayload, MQTTFixedHeader]):
80 VARIABLE_HEADER = PublishVariableHeader
81 PAYLOAD = PublishPayload
83 DUP_FLAG = 0x08
84 RETAIN_FLAG = 0x01
85 QOS_FLAG = 0x06
87 def __init__(
88 self,
89 fixed: MQTTFixedHeader | None = None,
90 variable_header: PublishVariableHeader | None = None,
91 payload: PublishPayload | None = None,
92 ) -> None:
93 if fixed is None:
94 header = MQTTFixedHeader(PUBLISH, 0x00)
95 elif fixed.packet_type != PUBLISH:
96 msg = f"Invalid fixed packet type {fixed.packet_type} for PublishPacket init"
97 raise AMQTTError(msg) from None
98 else:
99 header = fixed
101 super().__init__(header)
102 self.variable_header = variable_header
103 self.payload = payload
105 @classmethod
106 def build(cls, topic_name: str, message: bytes, packet_id: int | None, dup_flag: bool, qos: int | None, retain: bool) -> Self:
107 v_header = PublishVariableHeader(topic_name, packet_id)
108 payload = PublishPayload(message)
109 packet = cls(variable_header=v_header, payload=payload)
110 packet.dup_flag = dup_flag
111 packet.retain_flag = retain
112 packet.qos = qos or 0
113 return packet
115 def set_flags(self, dup_flag: bool = False, qos: int = 0, retain_flag: bool = False) -> None:
116 self.dup_flag = dup_flag
117 self.retain_flag = retain_flag
118 self.qos = qos
120 def _set_header_flag(self, val: bool, mask: int) -> None:
121 if val:
122 self.fixed_header.flags |= mask
123 else:
124 self.fixed_header.flags &= ~mask
126 def _get_header_flag(self, mask: int) -> bool:
127 return bool(self.fixed_header.flags & mask)
129 @property
130 def dup_flag(self) -> bool:
131 return self._get_header_flag(self.DUP_FLAG)
133 @dup_flag.setter
134 def dup_flag(self, val: bool) -> None:
135 self._set_header_flag(val, self.DUP_FLAG)
137 @property
138 def retain_flag(self) -> bool:
139 return self._get_header_flag(self.RETAIN_FLAG)
141 @retain_flag.setter
142 def retain_flag(self, val: bool) -> None:
143 self._set_header_flag(val, self.RETAIN_FLAG)
145 @property
146 def qos(self) -> int | None:
147 return (self.fixed_header.flags & self.QOS_FLAG) >> 1
149 @qos.setter
150 def qos(self, val: int) -> None:
151 self.fixed_header.flags &= 0xF9
152 self.fixed_header.flags |= val << 1
154 @property
155 def packet_id(self) -> int | None:
156 if self.variable_header is None:
157 msg = "Variable header is not set"
158 raise ValueError(msg)
159 return self.variable_header.packet_id
161 @packet_id.setter
162 def packet_id(self, val: int) -> None:
163 if self.variable_header is None: 163 ↛ 166line 163 didn't jump to line 166 because the condition on line 163 was always true
164 msg = "Variable header is not set"
165 raise ValueError(msg)
166 self.variable_header.packet_id = val
168 @property
169 def data(self) -> bytes | None:
170 if self.payload is None:
171 msg = "Payload header is not set"
172 raise ValueError(msg)
173 return self.payload.data
175 @data.setter
176 def data(self, data: bytes) -> None:
177 if self.payload is None: 177 ↛ 180line 177 didn't jump to line 180 because the condition on line 177 was always true
178 msg = "Payload header is not set"
179 raise ValueError(msg)
180 self.payload.data = data
182 @property
183 def topic_name(self) -> str | None:
184 if self.variable_header is None:
185 msg = "Variable header is not set"
186 raise ValueError(msg)
187 return self.variable_header.topic_name
189 @topic_name.setter
190 def topic_name(self, name: str) -> None:
191 if self.variable_header is None:
192 msg = "Variable header is not set"
193 raise ValueError(msg)
194 self.variable_header.topic_name = name