Coverage for amqtt/mqtt/publish.py: 96%

141 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1import asyncio 

2from typing_extensions import Self 

3 

4from amqtt.adapters import ReaderAdapter 

5from amqtt.codecs_amqtt import decode_packet_id, decode_string, encode_string, int_to_bytes 

6from amqtt.errors import AMQTTError, MQTTError 

7from amqtt.mqtt.packet import PUBLISH, MQTTFixedHeader, MQTTPacket, MQTTPayload, MQTTVariableHeader 

8 

9 

10class PublishVariableHeader(MQTTVariableHeader): 

11 __slots__ = ("packet_id", "topic_name") 

12 

13 def __init__(self, topic_name: str, packet_id: int | None = None) -> None: 

14 super().__init__() 

15 if "#" in topic_name or "+" in topic_name: 

16 msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters." 

17 raise MQTTError(msg) 

18 self.topic_name = topic_name 

19 self.packet_id = packet_id 

20 

21 def __repr__(self) -> str: 

22 """Return a string representation of the PublishVariableHeader object.""" 

23 return f"{type(self).__name__}(topic={self.topic_name}, packet_id={self.packet_id})" 

24 

25 def to_bytes(self) -> bytes | bytearray: 

26 out = bytearray() 

27 out.extend(encode_string(self.topic_name)) 

28 if self.packet_id is not None: 

29 out.extend(int_to_bytes(self.packet_id, 2)) 

30 return out 

31 

32 @classmethod 

33 async def from_stream(cls, reader: ReaderAdapter | asyncio.StreamReader, fixed_header: MQTTFixedHeader) -> Self: 

34 topic_name = await decode_string(reader) 

35 has_qos = (fixed_header.flags >> 1) & 0x03 

36 packet_id = await decode_packet_id(reader) if has_qos else None 

37 return cls(topic_name, packet_id) 

38 

39 

40class PublishPayload(MQTTPayload[MQTTVariableHeader]): 

41 __slots__ = ("data",) 

42 

43 def __init__(self, data: bytes | None = None) -> None: 

44 super().__init__() 

45 self.data = data 

46 

47 def to_bytes( 

48 self, 

49 fixed_header: MQTTFixedHeader | None = None, 

50 variable_header: MQTTVariableHeader | None = None, 

51 ) -> bytes: 

52 return self.data if self.data is not None else b"" 

53 

54 @classmethod 

55 async def from_stream( 

56 cls, 

57 reader: asyncio.StreamReader | ReaderAdapter, 

58 fixed_header: MQTTFixedHeader | None, 

59 variable_header: MQTTVariableHeader | None, 

60 ) -> Self: 

61 data = bytearray() 

62 if fixed_header is None or variable_header is None: 62 ↛ 63line 62 didn't jump to line 63 because the condition on line 62 was never true

63 msg = "Fixed header or variable header cannot be None" 

64 raise ValueError(msg) 

65 

66 data_length = fixed_header.remaining_length - variable_header.bytes_length 

67 length_read = 0 

68 while length_read < data_length: 

69 buffer = await reader.read(data_length - length_read) 

70 data.extend(buffer) 

71 length_read = len(data) 

72 return cls(bytes(data)) 

73 

74 def __repr__(self) -> str: 

75 """Return a string representation of the PublishPayload object.""" 

76 return f"{type(self).__name__}(data={repr(self.data)!r})" 

77 

78 

79class PublishPacket(MQTTPacket[PublishVariableHeader, PublishPayload, MQTTFixedHeader]): 

80 VARIABLE_HEADER = PublishVariableHeader 

81 PAYLOAD = PublishPayload 

82 

83 DUP_FLAG = 0x08 

84 RETAIN_FLAG = 0x01 

85 QOS_FLAG = 0x06 

86 

87 def __init__( 

88 self, 

89 fixed: MQTTFixedHeader | None = None, 

90 variable_header: PublishVariableHeader | None = None, 

91 payload: PublishPayload | None = None, 

92 ) -> None: 

93 if fixed is None: 

94 header = MQTTFixedHeader(PUBLISH, 0x00) 

95 elif fixed.packet_type != PUBLISH: 

96 msg = f"Invalid fixed packet type {fixed.packet_type} for PublishPacket init" 

97 raise AMQTTError(msg) from None 

98 else: 

99 header = fixed 

100 

101 super().__init__(header) 

102 self.variable_header = variable_header 

103 self.payload = payload 

104 

105 @classmethod 

106 def build(cls, topic_name: str, message: bytes, packet_id: int | None, dup_flag: bool, qos: int | None, retain: bool) -> Self: 

107 v_header = PublishVariableHeader(topic_name, packet_id) 

108 payload = PublishPayload(message) 

109 packet = cls(variable_header=v_header, payload=payload) 

110 packet.dup_flag = dup_flag 

111 packet.retain_flag = retain 

112 packet.qos = qos or 0 

113 return packet 

114 

115 def set_flags(self, dup_flag: bool = False, qos: int = 0, retain_flag: bool = False) -> None: 

116 self.dup_flag = dup_flag 

117 self.retain_flag = retain_flag 

118 self.qos = qos 

119 

120 def _set_header_flag(self, val: bool, mask: int) -> None: 

121 if val: 

122 self.fixed_header.flags |= mask 

123 else: 

124 self.fixed_header.flags &= ~mask 

125 

126 def _get_header_flag(self, mask: int) -> bool: 

127 return bool(self.fixed_header.flags & mask) 

128 

129 @property 

130 def dup_flag(self) -> bool: 

131 return self._get_header_flag(self.DUP_FLAG) 

132 

133 @dup_flag.setter 

134 def dup_flag(self, val: bool) -> None: 

135 self._set_header_flag(val, self.DUP_FLAG) 

136 

137 @property 

138 def retain_flag(self) -> bool: 

139 return self._get_header_flag(self.RETAIN_FLAG) 

140 

141 @retain_flag.setter 

142 def retain_flag(self, val: bool) -> None: 

143 self._set_header_flag(val, self.RETAIN_FLAG) 

144 

145 @property 

146 def qos(self) -> int | None: 

147 return (self.fixed_header.flags & self.QOS_FLAG) >> 1 

148 

149 @qos.setter 

150 def qos(self, val: int) -> None: 

151 self.fixed_header.flags &= 0xF9 

152 self.fixed_header.flags |= val << 1 

153 

154 @property 

155 def packet_id(self) -> int | None: 

156 if self.variable_header is None: 

157 msg = "Variable header is not set" 

158 raise ValueError(msg) 

159 return self.variable_header.packet_id 

160 

161 @packet_id.setter 

162 def packet_id(self, val: int) -> None: 

163 if self.variable_header is None: 163 ↛ 166line 163 didn't jump to line 166 because the condition on line 163 was always true

164 msg = "Variable header is not set" 

165 raise ValueError(msg) 

166 self.variable_header.packet_id = val 

167 

168 @property 

169 def data(self) -> bytes | None: 

170 if self.payload is None: 

171 msg = "Payload header is not set" 

172 raise ValueError(msg) 

173 return self.payload.data 

174 

175 @data.setter 

176 def data(self, data: bytes) -> None: 

177 if self.payload is None: 177 ↛ 180line 177 didn't jump to line 180 because the condition on line 177 was always true

178 msg = "Payload header is not set" 

179 raise ValueError(msg) 

180 self.payload.data = data 

181 

182 @property 

183 def topic_name(self) -> str | None: 

184 if self.variable_header is None: 

185 msg = "Variable header is not set" 

186 raise ValueError(msg) 

187 return self.variable_header.topic_name 

188 

189 @topic_name.setter 

190 def topic_name(self, name: str) -> None: 

191 if self.variable_header is None: 

192 msg = "Variable header is not set" 

193 raise ValueError(msg) 

194 self.variable_header.topic_name = name