Coverage for amqtt/contrib/jwt.py: 79%
85 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-08-12 14:35 +0000
1from dataclasses import dataclass
2import logging
3from typing import ClassVar
5import jwt
7try:
8 from enum import StrEnum
9except ImportError:
10 # support for python 3.10
11 from enum import Enum
12 class StrEnum(str, Enum): # type: ignore[no-redef]
13 pass
15from amqtt.broker import BrokerContext
16from amqtt.contexts import Action
17from amqtt.plugins import TopicMatcher
18from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin
19from amqtt.session import Session
21logger = logging.getLogger(__name__)
24class Algorithms(StrEnum):
25 ES256 = "ES256"
26 ES256K = "ES256K"
27 ES384 = "ES384"
28 ES512 = "ES512"
29 ES521 = "ES521"
30 EdDSA = "EdDSA"
31 HS256 = "HS256"
32 HS384 = "HS384"
33 HS512 = "HS512"
34 PS256 = "PS256"
35 PS384 = "PS384"
36 PS512 = "PS512"
37 RS256 = "RS256"
38 RS384 = "RS384"
39 RS512 = "RS512"
42class UserAuthJwtPlugin(BaseAuthPlugin):
44 async def authenticate(self, *, session: Session) -> bool | None:
46 if not session.username or not session.password: 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true
47 return None
49 try:
50 decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"])
51 return bool(decoded_payload.get(self.config.user_claim, None) == session.username)
52 except jwt.ExpiredSignatureError:
53 logger.debug(f"jwt for '{session.username}' is expired")
54 return False
55 except jwt.InvalidTokenError:
56 logger.debug(f"jwt for '{session.username}' is invalid")
57 return False
59 @dataclass
60 class Config:
61 """Configuration for the JWT user authentication."""
63 secret_key: str
64 """Secret key to decrypt the token."""
65 user_claim: str
66 """Payload key for user name."""
67 algorithm: str = "HS256"
68 """Algorithm to use for token encryption: 'ES256', 'ES256K', 'ES384', 'ES512', 'ES521', 'EdDSA', 'HS256',
69 'HS384', 'HS512', 'PS256', 'PS384', 'PS512', 'RS256', 'RS384', 'RS512'"""
72class TopicAuthJwtPlugin(BaseTopicPlugin):
74 _topic_jwt_claims: ClassVar = {
75 Action.PUBLISH: "publish_claim",
76 Action.SUBSCRIBE: "subscribe_claim",
77 Action.RECEIVE: "receive_claim",
78 }
80 def __init__(self, context: BrokerContext) -> None:
81 super().__init__(context)
83 self.topic_matcher = TopicMatcher()
85 async def topic_filtering(
86 self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
87 ) -> bool | None:
89 if not session or not topic or not action: 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true
90 return None
92 if not session.password: 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true
93 return None
95 try:
96 decoded_payload = jwt.decode(session.password.encode(), self.config.secret_key, algorithms=["HS256"])
97 claim = getattr(self.config, self._topic_jwt_claims[action])
98 return any(self.topic_matcher.is_topic_allowed(topic, a_filter) for a_filter in decoded_payload.get(claim, []))
99 except jwt.ExpiredSignatureError:
100 logger.debug(f"jwt for '{session.username}' is expired")
101 return False
102 except jwt.InvalidTokenError:
103 logger.debug(f"jwt for '{session.username}' is invalid")
104 return False
106 @dataclass
107 class Config:
108 """Configuration for the JWT topic authorization."""
110 secret_key: str
111 """Secret key to decrypt the token."""
112 publish_claim: str
113 """Payload key for contains a list of permissible publish topics."""
114 subscribe_claim: str
115 """Payload key for contains a list of permissible subscribe topics."""
116 receive_claim: str
117 """Payload key for contains a list of permissible receive topics."""
118 algorithm: str = "HS256"
119 """Algorithm to use for token encryption: 'ES256', 'ES256K', 'ES384', 'ES512', 'ES521', 'EdDSA', 'HS256',
120 'HS384', 'HS512', 'PS256', 'PS384', 'PS512', 'RS256', 'RS384', 'RS512'"""