Coverage for amqtt/adapters.py: 89%
127 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1from abc import ABC, abstractmethod
2from asyncio import StreamReader, StreamWriter
3from contextlib import suppress
4import io
5import logging
6import ssl
7from typing import cast
9from websockets import ConnectionClosed
10from websockets.asyncio.connection import Connection
13class ReaderAdapter(ABC):
14 """Base class for all network protocol reader adapters.
16 Reader adapters are used to adapt read operations on the network depending on the
17 protocol used.
18 """
20 @abstractmethod
21 async def read(self, n: int = -1) -> bytes:
22 """Read up to n bytes. If n is not provided, or set to -1, read until EOF and return all read bytes.
24 If the EOF was received and the internal buffer is
25 empty, return an empty bytes object. :return: packet read as bytes data.
26 """
27 raise NotImplementedError
29 @abstractmethod
30 def feed_eof(self) -> None:
31 """Acknowledge EOF."""
32 raise NotImplementedError
35class WriterAdapter(ABC):
36 """Base class for all network protocol writer adapters.
38 Writer adapters are used to adapt write operations on the network depending on
39 the protocol used.
40 """
42 @abstractmethod
43 def write(self, data: bytes) -> None:
44 """Write some data to the protocol layer."""
45 raise NotImplementedError
47 @abstractmethod
48 async def drain(self) -> None:
49 """Let the write buffer of the underlying transport a chance to be flushed."""
50 raise NotImplementedError
52 @abstractmethod
53 def get_peer_info(self) -> tuple[str, int] | None:
54 """Return peer socket info (remote address and remote port as tuple)."""
55 raise NotImplementedError
57 @abstractmethod
58 def get_ssl_info(self) -> ssl.SSLObject | None:
59 """Return peer certificate information (if available) used to establish a TLS session."""
60 raise NotImplementedError
62 @abstractmethod
63 async def close(self) -> None:
64 """Close the protocol connection."""
65 raise NotImplementedError
68class WebSocketsReader(ReaderAdapter):
69 """WebSockets API reader adapter.
71 This adapter relies on Connection to read from a WebSocket.
72 """
74 def __init__(self, protocol: Connection) -> None:
75 self._protocol = protocol
76 self._stream = io.BytesIO(b"")
78 async def read(self, n: int = -1) -> bytes:
79 await self._feed_buffer(n)
80 return self._stream.read(n)
82 async def _feed_buffer(self, n: int = 1) -> None:
83 """Feed the data buffer by reading a WebSocket message.
85 :param n: Optional; feed buffer until it contains at least n bytes. Defaults to 1.
86 """
87 buffer = bytearray(self._stream.read())
88 message: str | bytes | None = None
89 while len(buffer) < n:
90 with suppress(ConnectionClosed):
91 message = await self._protocol.recv()
92 if message is None:
93 break
94 message = message.encode("utf-8") if isinstance(message, str) else message
95 buffer.extend(message)
96 self._stream = io.BytesIO(buffer)
98 def feed_eof(self) -> None:
99 # NOTE: not implemented?!
100 pass
103class WebSocketsWriter(WriterAdapter):
104 """WebSockets API writer adapter.
106 This adapter relies on Connection to write to a WebSocket.
107 """
109 def __init__(self, protocol: Connection) -> None:
110 self._protocol = protocol
111 self._stream = io.BytesIO(b"")
113 def write(self, data: bytes) -> None:
114 """Write some data to the protocol layer."""
115 self._stream.write(data)
117 async def drain(self) -> None:
118 """Let the write buffer of the underlying transport a chance to be flushed."""
119 data = self._stream.getvalue()
120 if data and len(data): 120 ↛ 122line 120 didn't jump to line 122 because the condition on line 120 was always true
121 await self._protocol.send(data)
122 self._stream = io.BytesIO(b"")
124 def get_peer_info(self) -> tuple[str, int] | None:
125 # remote_address can be either a 4-tuple or 2-tuple depending on whether
126 # it is an IPv6 or IPv4 address, so we take their shared (host, port)
127 # prefix here to present a uniform return value.
128 remote_address: tuple[str, int] | None = self._protocol.remote_address[:2]
129 return remote_address
131 def get_ssl_info(self) -> ssl.SSLObject | None:
132 return cast("ssl.SSLObject", self._protocol.transport.get_extra_info("ssl_object"))
134 async def close(self) -> None:
135 await self._protocol.close()
138class StreamReaderAdapter(ReaderAdapter):
139 """Asyncio Streams API protocol adapter.
141 This adapter relies on StreamReader to read from a TCP socket.
142 Because API is very close, this class is trivial.
143 """
145 def __init__(self, reader: StreamReader) -> None:
146 self._reader = reader
148 async def read(self, n: int = -1) -> bytes:
149 if n == -1: 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true
150 data = await self._reader.read(n)
151 else:
152 data = await self._reader.readexactly(n)
153 return data
155 def feed_eof(self) -> None:
156 self._reader.feed_eof()
159class StreamWriterAdapter(WriterAdapter):
160 """Asyncio Streams API protocol adapter.
162 This adapter relies on StreamWriter to write to a TCP socket.
163 Because API is very close, this class is trivial.
164 """
166 def __init__(self, writer: StreamWriter) -> None:
167 self.logger = logging.getLogger(__name__)
168 self._writer = writer
169 self.is_closed = False # StreamWriter has no test for closed...we use our own
171 def write(self, data: bytes) -> None:
172 if not self.is_closed: 172 ↛ exitline 172 didn't return from function 'write' because the condition on line 172 was always true
173 self._writer.write(data)
175 async def drain(self) -> None:
176 if not self.is_closed: 176 ↛ exitline 176 didn't return from function 'drain' because the condition on line 176 was always true
177 await self._writer.drain()
179 def get_peer_info(self) -> tuple[str, int]:
180 extra_info = self._writer.get_extra_info("peername")
181 return extra_info[0], extra_info[1]
183 def get_ssl_info(self) -> ssl.SSLObject | None:
184 return cast("ssl.SSLObject", self._writer.get_extra_info("ssl_object"))
186 async def close(self) -> None:
187 if not self.is_closed:
188 self.is_closed = True # we first mark this closed so yields below don't cause races with waiting writes
189 await self._writer.drain()
190 if self._writer.can_write_eof():
191 self._writer.write_eof()
192 self._writer.close()
193 with suppress(AttributeError):
194 await self._writer.wait_closed()
197class BufferReader(ReaderAdapter):
198 """Byte Buffer reader adapter.
200 This adapter simply adapts reading a byte buffer.
201 """
203 def __init__(self, buffer: bytes) -> None:
204 self._stream = io.BytesIO(buffer)
206 async def read(self, n: int = -1) -> bytes:
207 return self._stream.read(n)
209 def feed_eof(self) -> None:
210 # NOTE: not implemented?!
211 pass
214class BufferWriter(WriterAdapter):
215 """ByteBuffer writer adapter.
217 This adapter simply adapts writing to a byte buffer.
218 """
220 def get_ssl_info(self) -> ssl.SSLObject | None:
221 return None
223 def __init__(self, buffer: bytes = b"") -> None:
224 self._stream = io.BytesIO(buffer)
226 def write(self, data: bytes) -> None:
227 """Write some data to the protocol layer."""
228 self._stream.write(data)
230 async def drain(self) -> None:
231 pass
233 def get_buffer(self) -> bytes:
234 return self._stream.getvalue()
236 def get_peer_info(self) -> tuple[str, int]:
237 return "BufferWriter", 0
239 async def close(self) -> None:
240 self._stream.close()