Source code for aiobreaker.storage.redis

import calendar
import logging
import time
from datetime import datetime

from aiobreaker.state import CircuitBreakerState
from .base import CircuitBreakerStorage

try:
    from redis.exceptions import RedisError
except ImportError:
    HAS_REDIS_SUPPORT = False
    RedisError = None
else:
    HAS_REDIS_SUPPORT = True


[docs]class CircuitRedisStorage(CircuitBreakerStorage): """ Implements a `CircuitBreakerStorage` using redis. """ BASE_NAMESPACE = 'aiobreaker' logger = logging.getLogger(__name__)
[docs] def __init__(self, state: CircuitBreakerState, redis_object, namespace=None, fallback_circuit_state=CircuitBreakerState.CLOSED): """ Creates a new instance with the given `state` and `redis` object. The redis object should be similar to pyredis' StrictRedis class. If there are any connection issues with redis, the `fallback_circuit_state` is used to determine the state of the circuit. """ # Module does not exist, so this feature is not available if not HAS_REDIS_SUPPORT: raise ImportError("CircuitRedisStorage can only be used if the required dependencies exist") super(CircuitRedisStorage, self).__init__('redis') try: self.RedisError = __import__('redis').exceptions.RedisError except ImportError: # Module does not exist, so this feature is not available raise ImportError("CircuitRedisStorage can only be used if 'redis' is available") self._redis = redis_object self._namespace_name = namespace self._fallback_circuit_state = fallback_circuit_state self._initial_state = state self._initialize_redis_state(self._initial_state)
def _initialize_redis_state(self, state: CircuitBreakerState): self._redis.setnx(self._namespace('fail_counter'), 0) self._redis.setnx(self._namespace('state'), state.name) @property def state(self): """ Returns the current circuit breaker state. If the circuit breaker state on Redis is missing, re-initialize it with the fallback circuit state and reset the fail counter. """ try: state_bytes = self._redis.get(self._namespace('state')) except self.RedisError: self.logger.error('RedisError: falling back to default circuit state', exc_info=True) return self._fallback_circuit_state state = self._fallback_circuit_state if state_bytes is not None: state = state_bytes.decode('utf-8') else: # state retrieved from redis was missing, so we re-initialize # the circuit breaker state on redis self._initialize_redis_state(self._fallback_circuit_state) return getattr(CircuitBreakerState, state) @state.setter def state(self, state): """ Set the current circuit breaker state to `state`. """ try: self._redis.set(self._namespace('state'), state.name) except self.RedisError: self.logger.error('RedisError', exc_info=True)
[docs] def increment_counter(self): """ Increases the failure counter by one. """ try: self._redis.incr(self._namespace('fail_counter')) except self.RedisError: self.logger.error('RedisError', exc_info=True)
[docs] def reset_counter(self): """ Sets the failure counter to zero. """ try: self._redis.set(self._namespace('fail_counter'), 0) except self.RedisError: self.logger.error('RedisError', exc_info=True)
@property def counter(self): """ Returns the current value of the failure counter. """ try: value = self._redis.get(self._namespace('fail_counter')) if value: return int(value) else: return 0 except self.RedisError: self.logger.error('RedisError: Assuming no errors', exc_info=True) return 0 @property def opened_at(self): """ Returns a datetime object of the most recent value of when the circuit was opened. """ try: timestamp = self._redis.get(self._namespace('opened_at')) if timestamp: return datetime(*time.gmtime(int(timestamp))[:6]) except self.RedisError: self.logger.error('RedisError', exc_info=True) return None @opened_at.setter def opened_at(self, now): """ Atomically sets the most recent value of when the circuit was opened to `now`. Stored in redis as a simple integer of unix epoch time. To avoid timezone issues between different systems, the passed in datetime should be in UTC. """ try: key = self._namespace('opened_at') def set_if_greater(pipe): current_value = pipe.get(key) next_value = int(calendar.timegm(now.timetuple())) pipe.multi() if not current_value or next_value > int(current_value): pipe.set(key, next_value) self._redis.transaction(set_if_greater, key) except self.RedisError: self.logger.error('RedisError', exc_info=True) def _namespace(self, key): name_parts = [self.BASE_NAMESPACE, key] if self._namespace_name: name_parts.insert(0, self._namespace_name) return ':'.join(name_parts)