import calendar
import logging
import time
from datetime import datetime
from aiobreaker.state import CircuitBreakerState
from .base import CircuitBreakerStorage
try:
from redis.exceptions import RedisError
except ImportError:
HAS_REDIS_SUPPORT = False
RedisError = None
else:
HAS_REDIS_SUPPORT = True
[docs]class CircuitRedisStorage(CircuitBreakerStorage):
"""
Implements a `CircuitBreakerStorage` using redis.
"""
BASE_NAMESPACE = 'aiobreaker'
logger = logging.getLogger(__name__)
[docs] def __init__(self, state: CircuitBreakerState, redis_object, namespace=None,
fallback_circuit_state=CircuitBreakerState.CLOSED):
"""
Creates a new instance with the given `state` and `redis` object. The
redis object should be similar to pyredis' StrictRedis class. If there
are any connection issues with redis, the `fallback_circuit_state` is
used to determine the state of the circuit.
"""
# Module does not exist, so this feature is not available
if not HAS_REDIS_SUPPORT:
raise ImportError("CircuitRedisStorage can only be used if the required dependencies exist")
super(CircuitRedisStorage, self).__init__('redis')
try:
self.RedisError = __import__('redis').exceptions.RedisError
except ImportError:
# Module does not exist, so this feature is not available
raise ImportError("CircuitRedisStorage can only be used if 'redis' is available")
self._redis = redis_object
self._namespace_name = namespace
self._fallback_circuit_state = fallback_circuit_state
self._initial_state = state
self._initialize_redis_state(self._initial_state)
def _initialize_redis_state(self, state: CircuitBreakerState):
self._redis.setnx(self._namespace('fail_counter'), 0)
self._redis.setnx(self._namespace('state'), state.name)
@property
def state(self):
"""
Returns the current circuit breaker state.
If the circuit breaker state on Redis is missing, re-initialize it
with the fallback circuit state and reset the fail counter.
"""
try:
state_bytes = self._redis.get(self._namespace('state'))
except self.RedisError:
self.logger.error('RedisError: falling back to default circuit state', exc_info=True)
return self._fallback_circuit_state
state = self._fallback_circuit_state
if state_bytes is not None:
state = state_bytes.decode('utf-8')
else:
# state retrieved from redis was missing, so we re-initialize
# the circuit breaker state on redis
self._initialize_redis_state(self._fallback_circuit_state)
return getattr(CircuitBreakerState, state)
@state.setter
def state(self, state):
"""
Set the current circuit breaker state to `state`.
"""
try:
self._redis.set(self._namespace('state'), state.name)
except self.RedisError:
self.logger.error('RedisError', exc_info=True)
[docs] def increment_counter(self):
"""
Increases the failure counter by one.
"""
try:
self._redis.incr(self._namespace('fail_counter'))
except self.RedisError:
self.logger.error('RedisError', exc_info=True)
[docs] def reset_counter(self):
"""
Sets the failure counter to zero.
"""
try:
self._redis.set(self._namespace('fail_counter'), 0)
except self.RedisError:
self.logger.error('RedisError', exc_info=True)
@property
def counter(self):
"""
Returns the current value of the failure counter.
"""
try:
value = self._redis.get(self._namespace('fail_counter'))
if value:
return int(value)
else:
return 0
except self.RedisError:
self.logger.error('RedisError: Assuming no errors', exc_info=True)
return 0
@property
def opened_at(self):
"""
Returns a datetime object of the most recent value of when the circuit
was opened.
"""
try:
timestamp = self._redis.get(self._namespace('opened_at'))
if timestamp:
return datetime(*time.gmtime(int(timestamp))[:6])
except self.RedisError:
self.logger.error('RedisError', exc_info=True)
return None
@opened_at.setter
def opened_at(self, now):
"""
Atomically sets the most recent value of when the circuit was opened
to `now`. Stored in redis as a simple integer of unix epoch time.
To avoid timezone issues between different systems, the passed in
datetime should be in UTC.
"""
try:
key = self._namespace('opened_at')
def set_if_greater(pipe):
current_value = pipe.get(key)
next_value = int(calendar.timegm(now.timetuple()))
pipe.multi()
if not current_value or next_value > int(current_value):
pipe.set(key, next_value)
self._redis.transaction(set_if_greater, key)
except self.RedisError:
self.logger.error('RedisError', exc_info=True)
def _namespace(self, key):
name_parts = [self.BASE_NAMESPACE, key]
if self._namespace_name:
name_parts.insert(0, self._namespace_name)
return ':'.join(name_parts)