Coverage for amqtt/contrib/http.py: 89%
130 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1from dataclasses import dataclass
3try:
4 from enum import StrEnum
5except ImportError:
6 # support for python 3.10
7 from enum import Enum
8 class StrEnum(str, Enum): # type: ignore[no-redef]
9 pass
11import logging
12from typing import Any
14from aiohttp import ClientResponse, ClientSession, FormData
16from amqtt.broker import BrokerContext
17from amqtt.contexts import Action
18from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
19from amqtt.session import Session
21logger = logging.getLogger(__name__)
24class ResponseMode(StrEnum):
25 STATUS = "status"
26 JSON = "json"
27 TEXT = "text"
30class RequestMethod(StrEnum):
31 GET = "get"
32 POST = "post"
33 PUT = "put"
36class ParamsMode(StrEnum):
37 JSON = "json"
38 FORM = "form"
41class ACLError(Exception):
42 pass
45HTTP_2xx_MIN = 200
46HTTP_2xx_MAX = 299
48HTTP_4xx_MIN = 400
49HTTP_4xx_MAX = 499
52@dataclass
53class HttpConfig:
54 """Configuration for the HTTP Auth & ACL Plugin."""
56 host: str
57 """hostname of the server for the auth & acl check"""
58 port: int
59 """port of the server for the auth & acl check"""
60 request_method: RequestMethod = RequestMethod.GET
61 """send the request as a GET, POST or PUT"""
62 params_mode: ParamsMode = ParamsMode.JSON # see docs/plugins/http.md for additional details
63 """send the request with `JSON` or `FORM` data. *additional details below*"""
64 response_mode: ResponseMode = ResponseMode.JSON # see docs/plugins/http.md for additional details
65 """expected response from the auth/acl server. `STATUS` (code), `JSON`, or `TEXT`. *additional details below*"""
66 with_tls: bool = False
67 """http or https"""
68 user_agent: str = "amqtt"
69 """the 'User-Agent' header sent along with the request"""
70 superuser_uri: str | None = None
71 """URI to verify if the user is a superuser (e.g. '/superuser'), `None` if superuser is not supported"""
72 timeout: int = 5
73 """duration, in seconds, to wait for the HTTP server to respond"""
76class AuthHttpPlugin(BasePlugin[BrokerContext]):
78 def __init__(self, context: BrokerContext) -> None:
79 super().__init__(context)
80 self.http = ClientSession(headers={"User-Agent": self.config.user_agent})
82 match self.config.request_method:
83 case RequestMethod.GET:
84 self.method = self.http.get
85 case RequestMethod.PUT:
86 self.method = self.http.put
87 case _:
88 self.method = self.http.post
90 async def on_broker_pre_shutdown(self) -> None:
91 await self.http.close()
93 @staticmethod
94 def _is_2xx(r: ClientResponse) -> bool:
95 return HTTP_2xx_MIN <= r.status <= HTTP_2xx_MAX
97 @staticmethod
98 def _is_4xx(r: ClientResponse) -> bool:
99 return HTTP_4xx_MIN <= r.status <= HTTP_4xx_MAX
101 def _get_params(self, payload: dict[str, Any]) -> dict[str, Any]:
102 match self.config.params_mode:
103 case ParamsMode.FORM:
104 match self.config.request_method:
105 case RequestMethod.GET:
106 kwargs = {"params": payload}
107 case _: # POST, PUT
108 d: Any = FormData(payload)
109 kwargs = {"data": d}
110 case _: # JSON
111 kwargs = {"json": payload}
112 return kwargs
114 async def _send_request(self, url: str, payload: dict[str, Any]) -> bool | None: # pylint: disable=R0911
116 kwargs = self._get_params(payload)
118 async with self.method(url, **kwargs) as r:
119 logger.debug(f"http request returned {r.status}")
121 match self.config.response_mode:
122 case ResponseMode.TEXT:
123 return self._is_2xx(r) and (await r.text()).lower() == "ok"
124 case ResponseMode.STATUS:
125 if self._is_2xx(r):
126 return True
127 if self._is_4xx(r):
128 return False
129 # any other code
130 return None
131 case _:
132 if not self._is_2xx(r): 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true
133 return False
134 data: dict[str, Any] = await r.json()
135 data = {k.lower(): v for k, v in data.items()}
136 return data.get("ok", None)
138 def get_url(self, uri: str) -> str:
139 return f"{'https' if self.config.with_tls else 'http'}://{self.config.host}:{self.config.port}{uri}"
142class UserAuthHttpPlugin(AuthHttpPlugin, BaseAuthPlugin):
144 async def authenticate(self, *, session: Session) -> bool | None:
145 d = {"username": session.username, "password": session.password, "client_id": session.client_id}
146 return await self._send_request(self.get_url(self.config.user_uri), d)
148 @dataclass
149 class Config(HttpConfig):
150 """Configuration for the HTTP Auth Plugin."""
152 user_uri: str = "/user"
153 """URI of the auth check."""
156class TopicAuthHttpPlugin(AuthHttpPlugin, BaseTopicPlugin):
158 async def topic_filtering(self, *,
159 session: Session | None = None,
160 topic: str | None = None,
161 action: Action | None = None) -> bool | None:
162 if not session: 162 ↛ 163line 162 didn't jump to line 163 because the condition on line 162 was never true
163 return None
164 acc = 0
165 match action:
166 case Action.PUBLISH: 166 ↛ 168line 166 didn't jump to line 168 because the pattern on line 166 always matched
167 acc = 2
168 case Action.SUBSCRIBE:
169 acc = 4
170 case Action.RECEIVE:
171 acc = 1
173 d = {"username": session.username, "client_id": session.client_id, "topic": topic, "acc": acc}
174 return await self._send_request(self.get_url(self.config.topic_uri), d)
176 @dataclass
177 class Config(HttpConfig):
178 """Configuration for the HTTP Topic Plugin."""
180 topic_uri: str = "/acl"
181 """URI of the topic check."""