Coverage for amqtt/contrib/shadows/states.py: 78%

119 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-08-12 14:35 +0000

1from collections import Counter 

2from collections.abc import MutableMapping 

3from dataclasses import dataclass, field 

4 

5try: 

6 from enum import StrEnum 

7except ImportError: 

8 # support for python 3.10 

9 from enum import Enum 

10 class StrEnum(str, Enum): # type: ignore[no-redef] 

11 pass 

12import time 

13from typing import Any, Generic, TypeVar 

14 

15from mergedeep import merge 

16 

17C = TypeVar("C", bound=Any) 

18 

19 

20class StateError(Exception): 

21 def __init__(self, msg: str = "'state' field is required") -> None: 

22 super().__init__(msg) 

23 

24 

25@dataclass 

26class MetaTimestamp: 

27 timestamp: int = 0 

28 

29 def __eq__(self, other: object) -> bool: 

30 """Compare timestamps.""" 

31 if isinstance(other, int): 31 ↛ 32line 31 didn't jump to line 32 because the condition on line 31 was never true

32 return self.timestamp == other 

33 if isinstance(other, self.__class__): 33 ↛ 35line 33 didn't jump to line 35 because the condition on line 33 was always true

34 return self.timestamp == other.timestamp 

35 msg = "needs to be int or MetaTimestamp" 

36 raise ValueError(msg) 

37 

38 # numeric operations to make this dataclass transparent 

39 def __abs__(self) -> int: 

40 """Absolute timestamp.""" 

41 return self.timestamp 

42 

43 def __add__(self, other: int) -> int: 

44 """Add to a timestamp.""" 

45 return self.timestamp + other 

46 

47 def __sub__(self, other: int) -> int: 

48 """Subtract from a timestamp.""" 

49 return self.timestamp - other 

50 

51 def __mul__(self, other: int) -> int: 

52 """Multiply a timestamp.""" 

53 return self.timestamp * other 

54 

55 def __float__(self) -> float: 

56 """Convert timestamp to float.""" 

57 return float(self.timestamp) 

58 

59 def __int__(self) -> int: 

60 """Convert timestamp to int.""" 

61 return int(self.timestamp) 

62 

63 def __lt__(self, other: int) -> bool: 

64 """Compare timestamp.""" 

65 return self.timestamp < other 

66 

67 def __le__(self, other: int) -> bool: 

68 """Compare timestamp.""" 

69 return self.timestamp <= other 

70 

71 def __gt__(self, other: int) -> bool: 

72 """Compare timestamp.""" 

73 return self.timestamp > other 

74 

75 def __ge__(self, other: int) -> bool: 

76 """Compare timestamp.""" 

77 return self.timestamp >= other 

78 

79 

80def create_metadata(state: MutableMapping[str, Any], timestamp: int) -> dict[str, Any]: 

81 """Create metadata (timestamps) for each of the keys in 'state'.""" 

82 metadata: dict[str, Any] = {} 

83 for key, value in state.items(): 

84 if isinstance(value, dict): 84 ↛ 85line 84 didn't jump to line 85 because the condition on line 84 was never true

85 metadata[key] = create_metadata(value, timestamp) 

86 elif value is None: 

87 metadata[key] = None 

88 else: 

89 metadata[key] = MetaTimestamp(timestamp) 

90 

91 return metadata 

92 

93 

94def calculate_delta_update(desired: MutableMapping[str, Any], 

95 reported: MutableMapping[str, Any], 

96 depth: bool = True, 

97 exclude_nones: bool = True, 

98 ordered_lists: bool = True) -> dict[str, Any]: 

99 """Calculate state differences between desired and reported.""" 

100 diff_dict = {} 

101 for key, value in desired.items(): 

102 if value is None and exclude_nones: 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true

103 continue 

104 # if the desired has an element that the reported does not... 

105 if key not in reported: 

106 diff_dict[key] = value 

107 # if the desired has an element that's a list, but the list is 

108 elif isinstance(value, list) and not ordered_lists: 108 ↛ 109line 108 didn't jump to line 109 because the condition on line 108 was never true

109 if Counter(value) != Counter(reported[key]): 

110 diff_dict[key] = value 

111 elif isinstance(value, dict) and depth: 111 ↛ 113line 111 didn't jump to line 113 because the condition on line 111 was never true

112 # recurse, report when there is a difference 

113 obj_diff = calculate_delta_update(value, reported[key]) 

114 if obj_diff: 

115 diff_dict[key] = obj_diff 

116 elif value != reported[key]: 

117 diff_dict[key] = value 

118 

119 return diff_dict 

120 

121 

122def calculate_iota_update(desired: MutableMapping[str, Any], reported: MutableMapping[str, Any]) -> MutableMapping[str, Any]: 

123 """Calculate state differences between desired and reported (including missing keys).""" 

124 delta = calculate_delta_update(desired, reported, depth=False, exclude_nones=False) 

125 

126 for key in reported: 

127 if key not in desired: 

128 delta[key] = None 

129 

130 return delta 

131 

132 

133@dataclass 

134class State(Generic[C]): 

135 desired: MutableMapping[str, C] = field(default_factory=dict) 

136 reported: MutableMapping[str, C] = field(default_factory=dict) 

137 

138 @classmethod 

139 def from_dict(cls, data: dict[str, C]) -> "State[C]": 

140 """Create state from dictionary.""" 

141 return cls( 

142 desired=data.get("desired", {}), 

143 reported=data.get("reported", {}) 

144 ) 

145 

146 def __bool__(self) -> bool: 

147 """Determine if state is empty.""" 

148 return bool(self.desired) or bool(self.reported) 

149 

150 def __add__(self, other: "State[C]") -> "State[C]": 

151 """Merge states together.""" 

152 return State( 

153 desired=merge({}, self.desired, other.desired), 

154 reported=merge({}, self.reported, other.reported) 

155 ) 

156 

157 

158@dataclass 

159class StateDocument: 

160 state: State[dict[str, Any]] = field(default_factory=State) 

161 metadata: State[MetaTimestamp] = field(default_factory=State) 

162 version: int | None = None # only required when generating shadow messages 

163 timestamp: int | None = None # only required when generating shadow messages 

164 

165 @classmethod 

166 def from_dict(cls, data: dict[str, Any]) -> "StateDocument": 

167 """Create state document from dictionary.""" 

168 now = int(time.time()) 

169 if data and "state" not in data: 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true

170 raise StateError 

171 

172 state = State.from_dict(data.get("state", {})) 

173 metadata = State( 

174 desired=create_metadata(state.desired, now), 

175 reported=create_metadata(state.reported, now) 

176 ) 

177 

178 return cls(state=state, metadata=metadata) 

179 

180 def __post_init__(self) -> None: 

181 """Initialize meta data if not provided.""" 

182 now = int(time.time()) 

183 if not self.metadata: 

184 self.metadata = State( 

185 desired=create_metadata(self.state.desired, now), 

186 reported=create_metadata(self.state.reported, now), 

187 ) 

188 

189 def __add__(self, other: "StateDocument") -> "StateDocument": 

190 """Merge two state documents together.""" 

191 return StateDocument( 

192 state=self.state + other.state, 

193 metadata=self.metadata + other.metadata 

194 ) 

195 

196 

197class ShadowOperation(StrEnum): 

198 GET = "get" 

199 UPDATE = "update" 

200 GET_ACCEPT = "get/accepted" 

201 GET_REJECT = "get/rejected" 

202 UPDATE_ACCEPT = "update/accepted" 

203 UPDATE_REJECT = "update/rejected" 

204 UPDATE_DOCUMENTS = "update/documents" 

205 UPDATE_DELTA = "update/delta" 

206 UPDATE_IOTA = "update/iota"