Coverage for tubthumper/_retry_factory.py: 100%

89 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-09 04:23 +0000

1"""Module defining the retry_factory function""" 

2 

3import asyncio 

4import random 

5import time 

6from dataclasses import dataclass 

7from functools import update_wrapper 

8from typing import Awaitable, Callable, overload 

9 

10from tubthumper import _types as tub_types 

11 

12__all__ = ["RetryError"] 

13 

14 

15class RetryError(Exception): 

16 """Exception raised when a retry or time limit is reached""" 

17 

18 

19@dataclass(frozen=True) 

20class RetryConfig: 

21 """Config class for retry logic""" 

22 

23 exceptions: tub_types.Exceptions 

24 retry_limit: tub_types.RetryLimit 

25 time_limit: tub_types.Duration 

26 init_backoff: tub_types.Duration 

27 exponential: tub_types.Exponential 

28 jitter: tub_types.Jitter 

29 reraise: tub_types.Reraise 

30 log_level: tub_types.LogLevel 

31 logger: tub_types.Logger 

32 

33 

34class _RetryHandler: 

35 """Class for handling exceptions to be retried""" 

36 

37 exceptions: tub_types.Exceptions 

38 _retry_config: RetryConfig 

39 _timeout: tub_types.Duration 

40 _count: int 

41 _backoff: tub_types.Duration 

42 _unjittered_backoff: tub_types.Duration 

43 

44 def __init__(self, retry_config: RetryConfig): 

45 self.exceptions = retry_config.exceptions 

46 self._retry_config = retry_config 

47 

48 self._calc_backoff: Callable[[], tub_types.Duration] 

49 if self._retry_config.jitter: 

50 self._calc_backoff = lambda: self._unjittered_backoff * random.random() 

51 else: 

52 self._calc_backoff = lambda: self._unjittered_backoff 

53 

54 def start(self) -> None: 

55 """Initialize the retry handler's timeout, count, and backoff""" 

56 self._timeout = time.perf_counter() + self._retry_config.time_limit 

57 self._count = 0 

58 self._unjittered_backoff = self._retry_config.init_backoff 

59 

60 def handle(self, exc: Exception) -> tub_types.Duration: 

61 """ 

62 Handles the exception, either: 

63 (a) raising a RetryError (or the exception provided), or 

64 (b) returning a backoff duration to sleep, logging the caught exception 

65 """ 

66 self._increment() 

67 self._check_retry_limit(exc) 

68 self._check_time_limit(exc) 

69 self._retry_config.logger.log( 

70 self._retry_config.log_level, 

71 f"Function threw exception below on try {self._count}, " 

72 f"retrying in {self._backoff:n} seconds", 

73 exc_info=True, 

74 ) 

75 return self._backoff 

76 

77 def _increment(self) -> None: 

78 """Increment the retry handler's count and backoff duration""" 

79 self._count += 1 

80 self._backoff = self._calc_backoff() 

81 self._unjittered_backoff *= self._retry_config.exponential 

82 

83 def _check_retry_limit(self, exc: Exception) -> None: 

84 if self._count > self._retry_config.retry_limit: 

85 if self._retry_config.reraise: 

86 raise exc 

87 raise RetryError( 

88 f"Retry limit {self._retry_config.retry_limit} reached" 

89 ) from exc 

90 

91 def _check_time_limit(self, exc: Exception) -> None: 

92 if (time.perf_counter() + self._backoff) > self._timeout: 

93 if self._retry_config.reraise: 

94 raise exc 

95 raise RetryError( 

96 f"Time limit {self._retry_config.time_limit} exceeded" 

97 ) from exc 

98 

99 

100@overload 

101def retry_factory( 

102 func: Callable[tub_types.P, Awaitable[tub_types.T]], 

103 retry_config: RetryConfig, 

104) -> Callable[tub_types.P, Awaitable[tub_types.T]]: 

105 ... 

106 

107 

108@overload 

109def retry_factory( 

110 func: Callable[tub_types.P, tub_types.T], 

111 retry_config: RetryConfig, 

112) -> Callable[tub_types.P, tub_types.T]: 

113 ... 

114 

115 

116def retry_factory(func, retry_config): # type: ignore 

117 """ 

118 Function that produces a retry_function given a function to retry, 

119 and config to determine retry logic. 

120 """ 

121 retry_hanlder = _RetryHandler(retry_config) 

122 if asyncio.iscoroutinefunction(func): 

123 retry_func = _async_retry_factory(func, retry_hanlder) 

124 else: 

125 retry_func = _sync_retry_factory(func, retry_hanlder) 

126 update_wrapper(retry_func, func) 

127 return retry_func 

128 

129 

130def _async_retry_factory( 

131 func: Callable[tub_types.P, Awaitable[tub_types.T]], 

132 retry_handler: _RetryHandler, 

133) -> Callable[tub_types.P, Awaitable[tub_types.T]]: 

134 async def retry_func( 

135 *args: tub_types.P.args, **kwargs: tub_types.P.kwargs 

136 ) -> tub_types.T: 

137 retry_handler.start() 

138 while True: 

139 try: 

140 return await func(*args, **kwargs) 

141 except retry_handler.exceptions as exc: 

142 backoff = retry_handler.handle(exc) 

143 await asyncio.sleep(backoff) 

144 

145 return retry_func 

146 

147 

148def _sync_retry_factory( 

149 func: Callable[tub_types.P, tub_types.T], 

150 retry_handler: _RetryHandler, 

151) -> Callable[tub_types.P, tub_types.T]: 

152 def retry_func( 

153 *args: tub_types.P.args, **kwargs: tub_types.P.kwargs 

154 ) -> tub_types.T: 

155 retry_handler.start() 

156 while True: 

157 try: 

158 return func(*args, **kwargs) 

159 except retry_handler.exceptions as exc: 

160 backoff = retry_handler.handle(exc) 

161 time.sleep(backoff) 

162 

163 return retry_func