Coverage for amqtt/mqtt/connect.py: 91%

349 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1from asyncio import StreamReader 

2from typing_extensions import Self 

3 

4from amqtt.adapters import ReaderAdapter 

5from amqtt.codecs_amqtt import ( 

6 bytes_to_int, 

7 decode_data_with_length, 

8 decode_string, 

9 encode_data_with_length, 

10 encode_string, 

11 int_to_bytes, 

12 read_or_raise, 

13) 

14from amqtt.errors import AMQTTError, NoDataError 

15from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader, MQTTPacket, MQTTPayload, MQTTVariableHeader 

16from amqtt.utils import gen_client_id 

17 

18 

19class ConnectVariableHeader(MQTTVariableHeader): 

20 __slots__ = ("flags", "keep_alive", "proto_level", "proto_name") 

21 

22 USERNAME_FLAG = 0x80 

23 PASSWORD_FLAG = 0x40 

24 WILL_RETAIN_FLAG = 0x20 

25 WILL_FLAG = 0x04 

26 WILL_QOS_MASK = 0x18 

27 CLEAN_SESSION_FLAG = 0x02 

28 RESERVED_FLAG = 0x01 

29 

30 def __init__(self, connect_flags: int = 0x00, keep_alive: int = 0, proto_name: str = "MQTT", proto_level: int = 0x04) -> None: 

31 super().__init__() 

32 self.proto_name = proto_name 

33 self.proto_level = proto_level 

34 self.flags = connect_flags 

35 self.keep_alive = keep_alive 

36 

37 def __repr__(self) -> str: 

38 """Return a string representation of the ConnectVariableHeader object.""" 

39 return ( 

40 f"ConnectVariableHeader(proto_name={self.proto_name}, proto_level={self.proto_level}," 

41 f" flags={hex(self.flags)}, keepalive={self.keep_alive})" 

42 ) 

43 

44 def _set_flag(self, val: bool, mask: int) -> None: 

45 if val: 

46 self.flags |= mask 

47 else: 

48 self.flags &= ~mask 

49 

50 def _get_flag(self, mask: int) -> bool: 

51 return bool(self.flags & mask) 

52 

53 @classmethod 

54 async def from_stream(cls, reader: ReaderAdapter, _: MQTTFixedHeader) -> Self: 

55 # protocol name 

56 protocol_name = await decode_string(reader) 

57 

58 # protocol level 

59 protocol_level_byte = await read_or_raise(reader, 1) 

60 protocol_level = bytes_to_int(protocol_level_byte) 

61 

62 # flags 

63 flags_byte = await read_or_raise(reader, 1) 

64 flags = bytes_to_int(flags_byte) 

65 

66 # keep-alive 

67 keep_alive_byte = await read_or_raise(reader, 2) 

68 keep_alive = bytes_to_int(keep_alive_byte) 

69 

70 return cls(flags, keep_alive, protocol_name, protocol_level) 

71 

72 def to_bytes(self) -> bytes | bytearray: 

73 out = bytearray() 

74 

75 # Protocol name 

76 out.extend(encode_string(self.proto_name)) 

77 # Protocol level 

78 out.append(self.proto_level) 

79 # flags 

80 out.append(self.flags) 

81 # keep alive 

82 out.extend(int_to_bytes(self.keep_alive, 2)) 

83 

84 return out 

85 

86 @property 

87 def username_flag(self) -> bool: 

88 return self._get_flag(self.USERNAME_FLAG) 

89 

90 @username_flag.setter 

91 def username_flag(self, val: bool) -> None: 

92 self._set_flag(val, self.USERNAME_FLAG) 

93 

94 @property 

95 def password_flag(self) -> bool: 

96 return self._get_flag(self.PASSWORD_FLAG) 

97 

98 @password_flag.setter 

99 def password_flag(self, val: bool) -> None: 

100 self._set_flag(val, self.PASSWORD_FLAG) 

101 

102 @property 

103 def will_retain_flag(self) -> bool: 

104 return self._get_flag(self.WILL_RETAIN_FLAG) 

105 

106 @will_retain_flag.setter 

107 def will_retain_flag(self, val: bool) -> None: 

108 self._set_flag(val, self.WILL_RETAIN_FLAG) 

109 

110 @property 

111 def will_flag(self) -> bool: 

112 return self._get_flag(self.WILL_FLAG) 

113 

114 @will_flag.setter 

115 def will_flag(self, val: bool) -> None: 

116 self._set_flag(val, self.WILL_FLAG) 

117 

118 @property 

119 def clean_session_flag(self) -> bool: 

120 return self._get_flag(self.CLEAN_SESSION_FLAG) 

121 

122 @clean_session_flag.setter 

123 def clean_session_flag(self, val: bool) -> None: 

124 self._set_flag(val, self.CLEAN_SESSION_FLAG) 

125 

126 @property 

127 def reserved_flag(self) -> bool: 

128 return self._get_flag(self.RESERVED_FLAG) 

129 

130 @reserved_flag.setter 

131 def reserved_flag(self, val: bool) -> None: 

132 self._set_flag(val, self.RESERVED_FLAG) 

133 

134 @property 

135 def will_qos(self) -> int: 

136 return (self.flags & 0x18) >> 3 

137 

138 @will_qos.setter 

139 def will_qos(self, val: int) -> None: 

140 self.flags &= 0xE7 # Reset QOS flags 

141 self.flags |= val << 3 

142 

143 

144class ConnectPayload(MQTTPayload[ConnectVariableHeader]): 

145 __slots__ = ( 

146 "client_id", 

147 "client_id_is_random", 

148 "password", 

149 "username", 

150 "will_message", 

151 "will_topic", 

152 ) 

153 

154 def __init__( 

155 self, 

156 client_id: str | None = None, 

157 will_topic: str | None = None, 

158 will_message: bytes | bytearray | None = None, 

159 username: str | None = None, 

160 password: str | None = None, 

161 ) -> None: 

162 super().__init__() 

163 self.client_id_is_random = False 

164 self.client_id = client_id 

165 self.will_topic = will_topic 

166 self.will_message = will_message 

167 self.username = username 

168 self.password = password 

169 

170 def __repr__(self) -> str: 

171 """Return a string representation of the ConnectPayload object.""" 

172 return ( 

173 f"ConnectVariableHeader(client_id={self.client_id}, will_topic={self.will_topic}," 

174 f"will_message={self.will_message!r}, username={self.username}, password={self.password})" 

175 ) 

176 

177 @classmethod 

178 async def from_stream( 

179 cls, 

180 reader: StreamReader | ReaderAdapter, 

181 _: MQTTFixedHeader | None, 

182 variable_header: ConnectVariableHeader | None, 

183 ) -> Self: 

184 payload = cls() 

185 # Client identifier 

186 try: 

187 payload.client_id = await decode_string(reader) 

188 except NoDataError: 

189 payload.client_id = None 

190 

191 if payload.client_id is None or payload.client_id == "": 

192 # A Server MAY allow a Client to supply a ClientId that has a length of zero bytes 

193 # [MQTT-3.1.3-6] 

194 payload.client_id = gen_client_id() 

195 # indicator to throw exception in case CLEAN_SESSION_FLAG is set to False 

196 payload.client_id_is_random = True 

197 

198 # Read will topic, username and password 

199 if variable_header is not None and variable_header.will_flag: 

200 try: 

201 payload.will_topic = await decode_string(reader) 

202 payload.will_message = await decode_data_with_length(reader) 

203 except NoDataError: 

204 payload.will_topic = None 

205 payload.will_message = None 

206 

207 if variable_header is not None and variable_header.username_flag: 

208 try: 

209 payload.username = await decode_string(reader) 

210 except NoDataError: 

211 payload.username = None 

212 

213 if variable_header is not None and variable_header.password_flag: 

214 try: 

215 payload.password = await decode_string(reader) 

216 except NoDataError: 

217 payload.password = None 

218 

219 return payload 

220 

221 def to_bytes( 

222 self, 

223 fixed_header: MQTTFixedHeader | None = None, 

224 variable_header: ConnectVariableHeader | None = None, 

225 ) -> bytes | bytearray: 

226 out = bytearray() 

227 # Client identifier 

228 if self.client_id is not None: 228 ↛ 231line 228 didn't jump to line 231 because the condition on line 228 was always true

229 out.extend(encode_string(self.client_id)) 

230 # Will topic / message 

231 if variable_header is not None and variable_header.will_flag: 

232 if self.will_topic is not None: 232 ↛ 234line 232 didn't jump to line 234 because the condition on line 232 was always true

233 out.extend(encode_string(self.will_topic)) 

234 if self.will_message is not None: 234 ↛ 237line 234 didn't jump to line 237 because the condition on line 234 was always true

235 out.extend(encode_data_with_length(self.will_message)) 

236 # username 

237 if variable_header is not None and variable_header.username_flag and self.username is not None: 

238 out.extend(encode_string(self.username)) 

239 # password 

240 if variable_header is not None and variable_header.password_flag and self.password is not None: 

241 out.extend(encode_string(self.password)) 

242 

243 return out 

244 

245 

246class ConnectPacket(MQTTPacket[ConnectVariableHeader, ConnectPayload, MQTTFixedHeader]): # type: ignore [type-var] 

247 VARIABLE_HEADER = ConnectVariableHeader 

248 PAYLOAD = ConnectPayload 

249 

250 def __init__( 

251 self, 

252 fixed: MQTTFixedHeader | None = None, 

253 variable_header: ConnectVariableHeader | None = None, 

254 payload: ConnectPayload | None = None, 

255 ) -> None: 

256 if fixed is None: 

257 header = MQTTFixedHeader(CONNECT, 0x00) 

258 else: 

259 if fixed.packet_type is not CONNECT: 

260 msg = f"Invalid fixed packet type {fixed.packet_type} for ConnectPacket init" 

261 raise AMQTTError(msg) 

262 header = fixed 

263 super().__init__(header) 

264 self.variable_header = variable_header 

265 self.payload = payload 

266 

267 @property 

268 def proto_name(self) -> str: 

269 if self.variable_header is None: 

270 msg = "Variable header is not set" 

271 raise ValueError(msg) 

272 return self.variable_header.proto_name 

273 

274 @proto_name.setter 

275 def proto_name(self, name: str) -> None: 

276 if self.variable_header is None: 276 ↛ 279line 276 didn't jump to line 279 because the condition on line 276 was always true

277 msg = "Variable header is not set" 

278 raise ValueError(msg) 

279 self.variable_header.proto_name = name 

280 

281 @property 

282 def proto_level(self) -> int: 

283 if self.variable_header is None: 

284 msg = "Variable header is not set" 

285 raise ValueError(msg) 

286 return self.variable_header.proto_level 

287 

288 @proto_level.setter 

289 def proto_level(self, level: int) -> None: 

290 if self.variable_header is None: 290 ↛ 293line 290 didn't jump to line 293 because the condition on line 290 was always true

291 msg = "Variable header is not set" 

292 raise ValueError(msg) 

293 self.variable_header.proto_level = level 

294 

295 @property 

296 def username_flag(self) -> bool: 

297 if self.variable_header is None: 

298 msg = "Variable header is not set" 

299 raise ValueError(msg) 

300 return self.variable_header.username_flag 

301 

302 @username_flag.setter 

303 def username_flag(self, flag: bool) -> None: 

304 if self.variable_header is None: 304 ↛ 307line 304 didn't jump to line 307 because the condition on line 304 was always true

305 msg = "Variable header is not set" 

306 raise ValueError(msg) 

307 self.variable_header.username_flag = flag 

308 

309 @property 

310 def password_flag(self) -> bool: 

311 if self.variable_header is None: 

312 msg = "Variable header is not set" 

313 raise ValueError(msg) 

314 return self.variable_header.password_flag 

315 

316 @password_flag.setter 

317 def password_flag(self, flag: bool) -> None: 

318 if self.variable_header is None: 318 ↛ 321line 318 didn't jump to line 321 because the condition on line 318 was always true

319 msg = "Variable header is not set" 

320 raise ValueError(msg) 

321 self.variable_header.password_flag = flag 

322 

323 @property 

324 def clean_session_flag(self) -> bool: 

325 if self.variable_header is None: 

326 msg = "Variable header is not set" 

327 raise ValueError(msg) 

328 return self.variable_header.clean_session_flag 

329 

330 @clean_session_flag.setter 

331 def clean_session_flag(self, flag: bool) -> None: 

332 if self.variable_header is None: 332 ↛ 335line 332 didn't jump to line 335 because the condition on line 332 was always true

333 msg = "Variable header is not set" 

334 raise ValueError(msg) 

335 self.variable_header.clean_session_flag = flag 

336 

337 @property 

338 def will_retain_flag(self) -> bool: 

339 if self.variable_header is None: 

340 msg = "Variable header is not set" 

341 raise ValueError(msg) 

342 return self.variable_header.will_retain_flag 

343 

344 @will_retain_flag.setter 

345 def will_retain_flag(self, flag: bool) -> None: 

346 if self.variable_header is None: 346 ↛ 349line 346 didn't jump to line 349 because the condition on line 346 was always true

347 msg = "Variable header is not set" 

348 raise ValueError(msg) 

349 self.variable_header.will_retain_flag = flag 

350 

351 @property 

352 def will_qos(self) -> int: 

353 if self.variable_header is None: 

354 msg = "Variable header is not set" 

355 raise ValueError(msg) 

356 return self.variable_header.will_qos 

357 

358 @will_qos.setter 

359 def will_qos(self, flag: int) -> None: 

360 if self.variable_header is None: 360 ↛ 363line 360 didn't jump to line 363 because the condition on line 360 was always true

361 msg = "Variable header is not set" 

362 raise ValueError(msg) 

363 self.variable_header.will_qos = flag 

364 

365 @property 

366 def will_flag(self) -> bool: 

367 if self.variable_header is None: 

368 msg = "Variable header is not set" 

369 raise ValueError(msg) 

370 return self.variable_header.will_flag 

371 

372 @will_flag.setter 

373 def will_flag(self, flag: bool) -> None: 

374 if self.variable_header is None: 374 ↛ 377line 374 didn't jump to line 377 because the condition on line 374 was always true

375 msg = "Variable header is not set" 

376 raise ValueError(msg) 

377 self.variable_header.will_flag = flag 

378 

379 @property 

380 def reserved_flag(self) -> bool: 

381 if self.variable_header is None: 

382 msg = "Variable header is not set" 

383 raise ValueError(msg) 

384 return self.variable_header.reserved_flag 

385 

386 @reserved_flag.setter 

387 def reserved_flag(self, flag: bool) -> None: 

388 if self.variable_header is None: 388 ↛ 391line 388 didn't jump to line 391 because the condition on line 388 was always true

389 msg = "Variable header is not set" 

390 raise ValueError(msg) 

391 self.variable_header.reserved_flag = flag 

392 

393 @property 

394 def client_id(self) -> str | None: 

395 if self.payload is None: 

396 msg = "Payload is not set" 

397 raise ValueError(msg) 

398 return self.payload.client_id 

399 

400 @client_id.setter 

401 def client_id(self, client_id: str) -> None: 

402 if self.payload is None: 402 ↛ 405line 402 didn't jump to line 405 because the condition on line 402 was always true

403 msg = "Payload is not set" 

404 raise ValueError(msg) 

405 self.payload.client_id = client_id 

406 

407 @property 

408 def client_id_is_random(self) -> bool: 

409 if self.payload is None: 409 ↛ 412line 409 didn't jump to line 412 because the condition on line 409 was always true

410 msg = "Payload is not set" 

411 raise ValueError(msg) 

412 return self.payload.client_id_is_random 

413 

414 @client_id_is_random.setter 

415 def client_id_is_random(self, client_id_is_random: bool) -> None: 

416 if self.payload is None: 416 ↛ 419line 416 didn't jump to line 419 because the condition on line 416 was always true

417 msg = "Payload is not set" 

418 raise ValueError(msg) 

419 self.payload.client_id_is_random = client_id_is_random 

420 

421 @property 

422 def will_topic(self) -> str | None: 

423 if self.payload is None: 

424 msg = "Payload is not set" 

425 raise ValueError(msg) 

426 return self.payload.will_topic 

427 

428 @will_topic.setter 

429 def will_topic(self, will_topic: str) -> None: 

430 if self.payload is None: 430 ↛ 433line 430 didn't jump to line 433 because the condition on line 430 was always true

431 msg = "Payload is not set" 

432 raise ValueError(msg) 

433 self.payload.will_topic = will_topic 

434 

435 @property 

436 def will_message(self) -> bytes | bytearray | None: 

437 if self.payload is None: 

438 msg = "Payload is not set" 

439 raise ValueError(msg) 

440 return self.payload.will_message 

441 

442 @will_message.setter 

443 def will_message(self, will_message: bytes | bytearray) -> None: 

444 if self.payload is None: 444 ↛ 447line 444 didn't jump to line 447 because the condition on line 444 was always true

445 msg = "Payload is not set" 

446 raise ValueError(msg) 

447 self.payload.will_message = will_message 

448 

449 @property 

450 def username(self) -> str | None: 

451 if self.payload is None: 

452 msg = "Payload is not set" 

453 raise ValueError(msg) 

454 return self.payload.username 

455 

456 @username.setter 

457 def username(self, username: str) -> None: 

458 if self.payload is None: 458 ↛ 461line 458 didn't jump to line 461 because the condition on line 458 was always true

459 msg = "Payload is not set" 

460 raise ValueError(msg) 

461 self.payload.username = username 

462 

463 @property 

464 def password(self) -> str | None: 

465 if self.payload is None: 

466 msg = "Payload is not set" 

467 raise ValueError(msg) 

468 return self.payload.password 

469 

470 @password.setter 

471 def password(self, password: str) -> None: 

472 if self.payload is None: 472 ↛ 475line 472 didn't jump to line 475 because the condition on line 472 was always true

473 msg = "Payload is not set" 

474 raise ValueError(msg) 

475 self.payload.password = password 

476 

477 @property 

478 def keep_alive(self) -> int: 

479 if self.variable_header is None: 

480 msg = "Payload is not set" 

481 raise ValueError(msg) 

482 return self.variable_header.keep_alive 

483 

484 @keep_alive.setter 

485 def keep_alive(self, keep_alive: int) -> None: 

486 if self.variable_header is None: 486 ↛ 489line 486 didn't jump to line 489 because the condition on line 486 was always true

487 msg = "Variable header is not set" 

488 raise ValueError(msg) 

489 self.variable_header.keep_alive = keep_alive