Coverage for amqtt/contrib/jwt.py: 79%

85 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1from dataclasses import dataclass 

2import logging 

3from typing import ClassVar 

4 

5import jwt 

6 

7try: 

8 from enum import StrEnum 

9except ImportError: 

10 # support for python 3.10 

11 from enum import Enum 

12 class StrEnum(str, Enum): # type: ignore[no-redef] 

13 pass 

14 

15from amqtt.broker import BrokerContext 

16from amqtt.contexts import Action 

17from amqtt.plugins import TopicMatcher 

18from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin 

19from amqtt.session import Session 

20 

21logger = logging.getLogger(__name__) 

22 

23 

24class Algorithms(StrEnum): 

25 ES256 = "ES256" 

26 ES256K = "ES256K" 

27 ES384 = "ES384" 

28 ES512 = "ES512" 

29 ES521 = "ES521" 

30 EdDSA = "EdDSA" 

31 HS256 = "HS256" 

32 HS384 = "HS384" 

33 HS512 = "HS512" 

34 PS256 = "PS256" 

35 PS384 = "PS384" 

36 PS512 = "PS512" 

37 RS256 = "RS256" 

38 RS384 = "RS384" 

39 RS512 = "RS512" 

40 

41 

42class UserAuthJwtPlugin(BaseAuthPlugin): 

43 

44 async def authenticate(self, *, session: Session) -> bool | None: 

45 

46 if not session.username or not session.password: 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true

47 return None 

48 

49 try: 

50 decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"]) 

51 return bool(decoded_payload.get(self.config.user_claim, None) == session.username) 

52 except jwt.ExpiredSignatureError: 

53 logger.debug(f"jwt for '{session.username}' is expired") 

54 return False 

55 except jwt.InvalidTokenError: 

56 logger.debug(f"jwt for '{session.username}' is invalid") 

57 return False 

58 

59 @dataclass 

60 class Config: 

61 """Configuration for the JWT user authentication.""" 

62 

63 secret_key: str 

64 """Secret key to decrypt the token.""" 

65 user_claim: str 

66 """Payload key for user name.""" 

67 algorithm: str = "HS256" 

68 """Algorithm to use for token encryption: 'ES256', 'ES256K', 'ES384', 'ES512', 'ES521', 'EdDSA', 'HS256', 

69 'HS384', 'HS512', 'PS256', 'PS384', 'PS512', 'RS256', 'RS384', 'RS512'""" 

70 

71 

72class TopicAuthJwtPlugin(BaseTopicPlugin): 

73 

74 _topic_jwt_claims: ClassVar = { 

75 Action.PUBLISH: "publish_claim", 

76 Action.SUBSCRIBE: "subscribe_claim", 

77 Action.RECEIVE: "receive_claim", 

78 } 

79 

80 def __init__(self, context: BrokerContext) -> None: 

81 super().__init__(context) 

82 

83 self.topic_matcher = TopicMatcher() 

84 

85 async def topic_filtering( 

86 self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None 

87 ) -> bool | None: 

88 

89 if not session or not topic or not action: 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true

90 return None 

91 

92 if not session.password: 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true

93 return None 

94 

95 try: 

96 decoded_payload = jwt.decode(session.password.encode(), self.config.secret_key, algorithms=["HS256"]) 

97 claim = getattr(self.config, self._topic_jwt_claims[action]) 

98 return any(self.topic_matcher.is_topic_allowed(topic, a_filter) for a_filter in decoded_payload.get(claim, [])) 

99 except jwt.ExpiredSignatureError: 

100 logger.debug(f"jwt for '{session.username}' is expired") 

101 return False 

102 except jwt.InvalidTokenError: 

103 logger.debug(f"jwt for '{session.username}' is invalid") 

104 return False 

105 

106 @dataclass 

107 class Config: 

108 """Configuration for the JWT topic authorization.""" 

109 

110 secret_key: str 

111 """Secret key to decrypt the token.""" 

112 publish_claim: str 

113 """Payload key for contains a list of permissible publish topics.""" 

114 subscribe_claim: str 

115 """Payload key for contains a list of permissible subscribe topics.""" 

116 receive_claim: str 

117 """Payload key for contains a list of permissible receive topics.""" 

118 algorithm: str = "HS256" 

119 """Algorithm to use for token encryption: 'ES256', 'ES256K', 'ES384', 'ES512', 'ES521', 'EdDSA', 'HS256', 

120 'HS384', 'HS512', 'PS256', 'PS384', 'PS512', 'RS256', 'RS384', 'RS512'"""