Coverage for amqtt/contrib/http.py: 89%

130 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1from dataclasses import dataclass 

2 

3try: 

4 from enum import StrEnum 

5except ImportError: 

6 # support for python 3.10 

7 from enum import Enum 

8 class StrEnum(str, Enum): # type: ignore[no-redef] 

9 pass 

10 

11import logging 

12from typing import Any 

13 

14from aiohttp import ClientResponse, ClientSession, FormData 

15 

16from amqtt.broker import BrokerContext 

17from amqtt.contexts import Action 

18from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin 

19from amqtt.session import Session 

20 

21logger = logging.getLogger(__name__) 

22 

23 

24class ResponseMode(StrEnum): 

25 STATUS = "status" 

26 JSON = "json" 

27 TEXT = "text" 

28 

29 

30class RequestMethod(StrEnum): 

31 GET = "get" 

32 POST = "post" 

33 PUT = "put" 

34 

35 

36class ParamsMode(StrEnum): 

37 JSON = "json" 

38 FORM = "form" 

39 

40 

41class ACLError(Exception): 

42 pass 

43 

44 

45HTTP_2xx_MIN = 200 

46HTTP_2xx_MAX = 299 

47 

48HTTP_4xx_MIN = 400 

49HTTP_4xx_MAX = 499 

50 

51 

52@dataclass 

53class HttpConfig: 

54 """Configuration for the HTTP Auth & ACL Plugin.""" 

55 

56 host: str 

57 """hostname of the server for the auth & acl check""" 

58 port: int 

59 """port of the server for the auth & acl check""" 

60 request_method: RequestMethod = RequestMethod.GET 

61 """send the request as a GET, POST or PUT""" 

62 params_mode: ParamsMode = ParamsMode.JSON # see docs/plugins/http.md for additional details 

63 """send the request with `JSON` or `FORM` data. *additional details below*""" 

64 response_mode: ResponseMode = ResponseMode.JSON # see docs/plugins/http.md for additional details 

65 """expected response from the auth/acl server. `STATUS` (code), `JSON`, or `TEXT`. *additional details below*""" 

66 with_tls: bool = False 

67 """http or https""" 

68 user_agent: str = "amqtt" 

69 """the 'User-Agent' header sent along with the request""" 

70 superuser_uri: str | None = None 

71 """URI to verify if the user is a superuser (e.g. '/superuser'), `None` if superuser is not supported""" 

72 timeout: int = 5 

73 """duration, in seconds, to wait for the HTTP server to respond""" 

74 

75 

76class AuthHttpPlugin(BasePlugin[BrokerContext]): 

77 

78 def __init__(self, context: BrokerContext) -> None: 

79 super().__init__(context) 

80 self.http = ClientSession(headers={"User-Agent": self.config.user_agent}) 

81 

82 match self.config.request_method: 

83 case RequestMethod.GET: 

84 self.method = self.http.get 

85 case RequestMethod.PUT: 

86 self.method = self.http.put 

87 case _: 

88 self.method = self.http.post 

89 

90 async def on_broker_pre_shutdown(self) -> None: 

91 await self.http.close() 

92 

93 @staticmethod 

94 def _is_2xx(r: ClientResponse) -> bool: 

95 return HTTP_2xx_MIN <= r.status <= HTTP_2xx_MAX 

96 

97 @staticmethod 

98 def _is_4xx(r: ClientResponse) -> bool: 

99 return HTTP_4xx_MIN <= r.status <= HTTP_4xx_MAX 

100 

101 def _get_params(self, payload: dict[str, Any]) -> dict[str, Any]: 

102 match self.config.params_mode: 

103 case ParamsMode.FORM: 

104 match self.config.request_method: 

105 case RequestMethod.GET: 

106 kwargs = {"params": payload} 

107 case _: # POST, PUT 

108 d: Any = FormData(payload) 

109 kwargs = {"data": d} 

110 case _: # JSON 

111 kwargs = {"json": payload} 

112 return kwargs 

113 

114 async def _send_request(self, url: str, payload: dict[str, Any]) -> bool | None: # pylint: disable=R0911 

115 

116 kwargs = self._get_params(payload) 

117 

118 async with self.method(url, **kwargs) as r: 

119 logger.debug(f"http request returned {r.status}") 

120 

121 match self.config.response_mode: 

122 case ResponseMode.TEXT: 

123 return self._is_2xx(r) and (await r.text()).lower() == "ok" 

124 case ResponseMode.STATUS: 

125 if self._is_2xx(r): 

126 return True 

127 if self._is_4xx(r): 

128 return False 

129 # any other code 

130 return None 

131 case _: 

132 if not self._is_2xx(r): 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true

133 return False 

134 data: dict[str, Any] = await r.json() 

135 data = {k.lower(): v for k, v in data.items()} 

136 return data.get("ok", None) 

137 

138 def get_url(self, uri: str) -> str: 

139 return f"{'https' if self.config.with_tls else 'http'}://{self.config.host}:{self.config.port}{uri}" 

140 

141 

142class UserAuthHttpPlugin(AuthHttpPlugin, BaseAuthPlugin): 

143 

144 async def authenticate(self, *, session: Session) -> bool | None: 

145 d = {"username": session.username, "password": session.password, "client_id": session.client_id} 

146 return await self._send_request(self.get_url(self.config.user_uri), d) 

147 

148 @dataclass 

149 class Config(HttpConfig): 

150 """Configuration for the HTTP Auth Plugin.""" 

151 

152 user_uri: str = "/user" 

153 """URI of the auth check.""" 

154 

155 

156class TopicAuthHttpPlugin(AuthHttpPlugin, BaseTopicPlugin): 

157 

158 async def topic_filtering(self, *, 

159 session: Session | None = None, 

160 topic: str | None = None, 

161 action: Action | None = None) -> bool | None: 

162 if not session: 162 ↛ 163line 162 didn't jump to line 163 because the condition on line 162 was never true

163 return None 

164 acc = 0 

165 match action: 

166 case Action.PUBLISH: 166 ↛ 168line 166 didn't jump to line 168 because the pattern on line 166 always matched

167 acc = 2 

168 case Action.SUBSCRIBE: 

169 acc = 4 

170 case Action.RECEIVE: 

171 acc = 1 

172 

173 d = {"username": session.username, "client_id": session.client_id, "topic": topic, "acc": acc} 

174 return await self._send_request(self.get_url(self.config.topic_uri), d) 

175 

176 @dataclass 

177 class Config(HttpConfig): 

178 """Configuration for the HTTP Topic Plugin.""" 

179 

180 topic_uri: str = "/acl" 

181 """URI of the topic check."""