Coverage for amqtt/adapters.py: 89%

127 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1from abc import ABC, abstractmethod 

2from asyncio import StreamReader, StreamWriter 

3from contextlib import suppress 

4import io 

5import logging 

6import ssl 

7from typing import cast 

8 

9from websockets import ConnectionClosed 

10from websockets.asyncio.connection import Connection 

11 

12 

13class ReaderAdapter(ABC): 

14 """Base class for all network protocol reader adapters. 

15 

16 Reader adapters are used to adapt read operations on the network depending on the 

17 protocol used. 

18 """ 

19 

20 @abstractmethod 

21 async def read(self, n: int = -1) -> bytes: 

22 """Read up to n bytes. If n is not provided, or set to -1, read until EOF and return all read bytes. 

23 

24 If the EOF was received and the internal buffer is 

25 empty, return an empty bytes object. :return: packet read as bytes data. 

26 """ 

27 raise NotImplementedError 

28 

29 @abstractmethod 

30 def feed_eof(self) -> None: 

31 """Acknowledge EOF.""" 

32 raise NotImplementedError 

33 

34 

35class WriterAdapter(ABC): 

36 """Base class for all network protocol writer adapters. 

37 

38 Writer adapters are used to adapt write operations on the network depending on 

39 the protocol used. 

40 """ 

41 

42 @abstractmethod 

43 def write(self, data: bytes) -> None: 

44 """Write some data to the protocol layer.""" 

45 raise NotImplementedError 

46 

47 @abstractmethod 

48 async def drain(self) -> None: 

49 """Let the write buffer of the underlying transport a chance to be flushed.""" 

50 raise NotImplementedError 

51 

52 @abstractmethod 

53 def get_peer_info(self) -> tuple[str, int] | None: 

54 """Return peer socket info (remote address and remote port as tuple).""" 

55 raise NotImplementedError 

56 

57 @abstractmethod 

58 def get_ssl_info(self) -> ssl.SSLObject | None: 

59 """Return peer certificate information (if available) used to establish a TLS session.""" 

60 raise NotImplementedError 

61 

62 @abstractmethod 

63 async def close(self) -> None: 

64 """Close the protocol connection.""" 

65 raise NotImplementedError 

66 

67 

68class WebSocketsReader(ReaderAdapter): 

69 """WebSockets API reader adapter. 

70 

71 This adapter relies on Connection to read from a WebSocket. 

72 """ 

73 

74 def __init__(self, protocol: Connection) -> None: 

75 self._protocol = protocol 

76 self._stream = io.BytesIO(b"") 

77 

78 async def read(self, n: int = -1) -> bytes: 

79 await self._feed_buffer(n) 

80 return self._stream.read(n) 

81 

82 async def _feed_buffer(self, n: int = 1) -> None: 

83 """Feed the data buffer by reading a WebSocket message. 

84 

85 :param n: Optional; feed buffer until it contains at least n bytes. Defaults to 1. 

86 """ 

87 buffer = bytearray(self._stream.read()) 

88 message: str | bytes | None = None 

89 while len(buffer) < n: 

90 with suppress(ConnectionClosed): 

91 message = await self._protocol.recv() 

92 if message is None: 

93 break 

94 message = message.encode("utf-8") if isinstance(message, str) else message 

95 buffer.extend(message) 

96 self._stream = io.BytesIO(buffer) 

97 

98 def feed_eof(self) -> None: 

99 # NOTE: not implemented?! 

100 pass 

101 

102 

103class WebSocketsWriter(WriterAdapter): 

104 """WebSockets API writer adapter. 

105 

106 This adapter relies on Connection to write to a WebSocket. 

107 """ 

108 

109 def __init__(self, protocol: Connection) -> None: 

110 self._protocol = protocol 

111 self._stream = io.BytesIO(b"") 

112 

113 def write(self, data: bytes) -> None: 

114 """Write some data to the protocol layer.""" 

115 self._stream.write(data) 

116 

117 async def drain(self) -> None: 

118 """Let the write buffer of the underlying transport a chance to be flushed.""" 

119 data = self._stream.getvalue() 

120 if data and len(data): 120 ↛ 122line 120 didn't jump to line 122 because the condition on line 120 was always true

121 await self._protocol.send(data) 

122 self._stream = io.BytesIO(b"") 

123 

124 def get_peer_info(self) -> tuple[str, int] | None: 

125 # remote_address can be either a 4-tuple or 2-tuple depending on whether 

126 # it is an IPv6 or IPv4 address, so we take their shared (host, port) 

127 # prefix here to present a uniform return value. 

128 remote_address: tuple[str, int] | None = self._protocol.remote_address[:2] 

129 return remote_address 

130 

131 def get_ssl_info(self) -> ssl.SSLObject | None: 

132 return cast("ssl.SSLObject", self._protocol.transport.get_extra_info("ssl_object")) 

133 

134 async def close(self) -> None: 

135 await self._protocol.close() 

136 

137 

138class StreamReaderAdapter(ReaderAdapter): 

139 """Asyncio Streams API protocol adapter. 

140 

141 This adapter relies on StreamReader to read from a TCP socket. 

142 Because API is very close, this class is trivial. 

143 """ 

144 

145 def __init__(self, reader: StreamReader) -> None: 

146 self._reader = reader 

147 

148 async def read(self, n: int = -1) -> bytes: 

149 if n == -1: 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true

150 data = await self._reader.read(n) 

151 else: 

152 data = await self._reader.readexactly(n) 

153 return data 

154 

155 def feed_eof(self) -> None: 

156 self._reader.feed_eof() 

157 

158 

159class StreamWriterAdapter(WriterAdapter): 

160 """Asyncio Streams API protocol adapter. 

161 

162 This adapter relies on StreamWriter to write to a TCP socket. 

163 Because API is very close, this class is trivial. 

164 """ 

165 

166 def __init__(self, writer: StreamWriter) -> None: 

167 self.logger = logging.getLogger(__name__) 

168 self._writer = writer 

169 self.is_closed = False # StreamWriter has no test for closed...we use our own 

170 

171 def write(self, data: bytes) -> None: 

172 if not self.is_closed: 172 ↛ exitline 172 didn't return from function 'write' because the condition on line 172 was always true

173 self._writer.write(data) 

174 

175 async def drain(self) -> None: 

176 if not self.is_closed: 176 ↛ exitline 176 didn't return from function 'drain' because the condition on line 176 was always true

177 await self._writer.drain() 

178 

179 def get_peer_info(self) -> tuple[str, int]: 

180 extra_info = self._writer.get_extra_info("peername") 

181 return extra_info[0], extra_info[1] 

182 

183 def get_ssl_info(self) -> ssl.SSLObject | None: 

184 return cast("ssl.SSLObject", self._writer.get_extra_info("ssl_object")) 

185 

186 async def close(self) -> None: 

187 if not self.is_closed: 

188 self.is_closed = True # we first mark this closed so yields below don't cause races with waiting writes 

189 await self._writer.drain() 

190 if self._writer.can_write_eof(): 

191 self._writer.write_eof() 

192 self._writer.close() 

193 with suppress(AttributeError): 

194 await self._writer.wait_closed() 

195 

196 

197class BufferReader(ReaderAdapter): 

198 """Byte Buffer reader adapter. 

199 

200 This adapter simply adapts reading a byte buffer. 

201 """ 

202 

203 def __init__(self, buffer: bytes) -> None: 

204 self._stream = io.BytesIO(buffer) 

205 

206 async def read(self, n: int = -1) -> bytes: 

207 return self._stream.read(n) 

208 

209 def feed_eof(self) -> None: 

210 # NOTE: not implemented?! 

211 pass 

212 

213 

214class BufferWriter(WriterAdapter): 

215 """ByteBuffer writer adapter. 

216 

217 This adapter simply adapts writing to a byte buffer. 

218 """ 

219 

220 def get_ssl_info(self) -> ssl.SSLObject | None: 

221 return None 

222 

223 def __init__(self, buffer: bytes = b"") -> None: 

224 self._stream = io.BytesIO(buffer) 

225 

226 def write(self, data: bytes) -> None: 

227 """Write some data to the protocol layer.""" 

228 self._stream.write(data) 

229 

230 async def drain(self) -> None: 

231 pass 

232 

233 def get_buffer(self) -> bytes: 

234 return self._stream.getvalue() 

235 

236 def get_peer_info(self) -> tuple[str, int]: 

237 return "BufferWriter", 0 

238 

239 async def close(self) -> None: 

240 self._stream.close()