Browse Source

feat(utils): Add core `CircuitBreaker` functionality (#74560)

This completes the work, started in https://github.com/getsentry/sentry/pull/74557 and https://github.com/getsentry/sentry/pull/74559, of adding a new, class-and-rate-limit-based circuit breaker implementation to the codebase. In this PR, the core `record_error` and `should_allow_request` methods are added to the `CircuitBreaker` class, along with accompaying tests.
Katie Byers 7 months ago
parent
commit
b5ecab101f
2 changed files with 319 additions and 1 deletions
  1. 96 1
      src/sentry/utils/circuit_breaker2.py
  2. 223 0
      tests/sentry/utils/test_circuit_breaker2.py

+ 96 - 1
src/sentry/utils/circuit_breaker2.py

@@ -12,7 +12,12 @@ from typing import Any, Literal, NotRequired, TypedDict, overload
 
 from django.conf import settings
 
-from sentry.ratelimits.sliding_windows import Quota, RedisSlidingWindowRateLimiter, RequestedQuota
+from sentry.ratelimits.sliding_windows import (
+    GrantedQuota,
+    Quota,
+    RedisSlidingWindowRateLimiter,
+    RequestedQuota,
+)
 
 logger = logging.getLogger(__name__)
 
@@ -182,6 +187,96 @@ class CircuitBreaker:
             )
             self.recovery_duration = default_recovery_duration
 
+    def record_error(self) -> None:
+        """
+        Record a single error towards the breaker's quota, and handle the case where that error puts
+        us over the limit.
+        """
+        now = int(time.time())
+        state, seconds_left_in_state = self._get_state_and_remaining_time()
+
+        if state == CircuitBreakerState.BROKEN:
+            assert seconds_left_in_state is not None  # mypy appeasement
+
+            # If the circuit is BROKEN, and `should_allow_request` is being used correctly, requests
+            # should be blocked and we shouldn't even be here. That said, maybe there was a race
+            # condition, so make sure the circuit hasn't just been tripped before crying foul.
+            seconds_elapsed_in_state = self.broken_state_duration - seconds_left_in_state
+            if seconds_elapsed_in_state > 5:
+                logger.warning(
+                    "Attempt to record circuit breaker error while circuit is in BROKEN state",
+                    extra={"key": self.key, "time_in_state": seconds_elapsed_in_state},
+                )
+            # We shouldn't have made the request, so don't record the error
+            return
+
+        # Even though we're not checking it during RECOVERY, we track errors in the primary quota as
+        # well as in the RECOVERY quota because they still happened, and eventually switching back
+        # to the okay state doesn't make that untrue
+        quotas = (
+            [self.primary_quota, self.recovery_quota]
+            if state == CircuitBreakerState.RECOVERY
+            else [self.primary_quota]
+        )
+        self.limiter.use_quotas(
+            [RequestedQuota(self.key, 1, quotas)], [GrantedQuota(self.key, 1, [])], now
+        )
+
+        # If incrementing has made us hit the current limit, switch to the BROKEN state
+        controlling_quota = self._get_controlling_quota(state)
+        remaining_errors_allowed = self._get_remaining_error_quota(controlling_quota)
+        if remaining_errors_allowed == 0:
+            logger.warning(
+                "Circuit breaker '%s' error limit hit",
+                self.key,
+                extra={
+                    "current_state": state,
+                    "error_limit": controlling_quota.limit,
+                    "error_limit_window": controlling_quota.window_seconds,
+                },
+            )
+
+            # RECOVERY will only start after the BROKEN state has expired, so push out the RECOVERY
+            # expiry time. We'll store the expiry times as our redis values so we can determine how
+            # long we've been in a given state.
+            broken_state_timeout = self.broken_state_duration
+            recovery_state_timeout = self.broken_state_duration + self.recovery_duration
+            broken_state_expiry = now + broken_state_timeout
+            recovery_state_expiry = now + recovery_state_timeout
+
+            # Set reids keys for switching state. While they're both set (starting now) we'll be in
+            # the BROKEN state. Once `broken_state_key` expires in redis we'll switch to RECOVERY,
+            # and then once `recovery_state_key` expires we'll be back to normal.
+            try:
+                self._set_in_redis(
+                    [
+                        (self.broken_state_key, broken_state_expiry, broken_state_timeout),
+                        (self.recovery_state_key, recovery_state_expiry, recovery_state_timeout),
+                    ]
+                )
+
+            # If redis errors, stay in the current state
+            except Exception:
+                logger.exception(
+                    "Couldn't set state-change keys in redis for circuit breaker '%s'",
+                    self.key,
+                    extra={"current_state": state},
+                )
+
+    def should_allow_request(self) -> bool:
+        """
+        Determine, based on the current state of the breaker and the number of allowable errors
+        remaining, whether requests should be allowed through.
+        """
+        state, _ = self._get_state_and_remaining_time()
+
+        if state == CircuitBreakerState.BROKEN:
+            return False
+
+        controlling_quota = self._get_controlling_quota(state)
+
+        return self._get_remaining_error_quota(controlling_quota) > 0
+
     def _get_from_redis(self, keys: list[str]) -> Any:
         for key in keys:
             self.redis_pipeline.get(key)

+ 223 - 0
tests/sentry/utils/test_circuit_breaker2.py

@@ -3,6 +3,7 @@ from typing import Any
 from unittest import TestCase
 from unittest.mock import ANY, MagicMock, patch
 
+import time_machine
 from django.conf import settings
 from redis.client import Pipeline
 
@@ -318,3 +319,225 @@ class CircuitBreakerTest(TestCase):
                 500,
             )
             assert breaker.recovery_duration == 500
+
+
+@freeze_time()
+class RecordErrorTest(TestCase):
+    def setUp(self) -> None:
+        self.config = DEFAULT_CONFIG
+        self.breaker = MockCircuitBreaker("dogs_are_great", self.config)
+
+        # Clear all existing keys from redis
+        self.breaker.redis_pipeline.flushall()
+        self.breaker.redis_pipeline.execute()
+
+    def test_increments_error_count(self):
+        config = self.config
+        breaker = self.breaker
+
+        # The breaker starts with a clean slate
+        assert breaker._get_remaining_error_quota() == config["error_limit"]
+
+        breaker.record_error()
+
+        # The error has been tallied
+        assert breaker._get_remaining_error_quota() == config["error_limit"] - 1
+
+    def test_no_error_recorded_in_broken_state(self):
+        breaker = self.breaker
+
+        breaker._set_breaker_state(CircuitBreakerState.BROKEN)
+        breaker._add_quota_usage(breaker.primary_quota, breaker.error_limit)
+
+        # Because we're in the BROKEN state, we start with the main quota maxed out and the
+        # RECOVERY quota yet to be used
+        assert breaker._get_remaining_error_quota(breaker.primary_quota) == 0
+        assert (
+            breaker._get_remaining_error_quota(breaker.recovery_quota)
+            == breaker.recovery_error_limit
+        )
+
+        breaker.record_error()
+
+        # Neither quota is incremented
+        assert breaker._get_remaining_error_quota(breaker.primary_quota) == 0
+        assert (
+            breaker._get_remaining_error_quota(breaker.recovery_quota)
+            == breaker.recovery_error_limit
+        )
+
+    @patch("sentry.utils.circuit_breaker2.logger")
+    def test_logs_a_warning_in_broken_state(self, mock_logger: MagicMock):
+        breaker = self.breaker
+
+        seconds_ellapsed_since_circuit_break = 2
+        breaker._set_breaker_state(
+            CircuitBreakerState.BROKEN,
+            seconds_left=breaker.broken_state_duration - seconds_ellapsed_since_circuit_break,
+        )
+
+        breaker.record_error()
+
+        # No log - we just switched into BROKEN state, and even though we're not supposed to land in
+        # the `record_error` method in that state, there's a small buffer to account for race
+        # conditions
+        assert mock_logger.warning.call_count == 0
+
+        seconds_ellapsed_since_circuit_break = 20
+        breaker._set_breaker_state(
+            CircuitBreakerState.BROKEN,
+            seconds_left=breaker.broken_state_duration - seconds_ellapsed_since_circuit_break,
+        )
+
+        breaker.record_error()
+
+        # Now we do log a warning, because at this point we can no longer blame a race condition -
+        # it's been too long since the circuit broke
+        mock_logger.warning.assert_called_with(
+            "Attempt to record circuit breaker error while circuit is in BROKEN state",
+            extra={"key": "dogs_are_great", "time_in_state": 20},
+        )
+
+    @patch("sentry.utils.circuit_breaker2.logger")
+    def test_handles_hitting_max_errors_in_non_broken_state(self, mock_logger: MagicMock):
+        config = self.config
+        breaker = self.breaker
+        now = int(time.time())
+
+        for state, quota, limit in [
+            (CircuitBreakerState.OK, breaker.primary_quota, breaker.error_limit),
+            (CircuitBreakerState.RECOVERY, breaker.recovery_quota, breaker.recovery_error_limit),
+        ]:
+
+            breaker._set_breaker_state(state)
+            breaker._add_quota_usage(quota, limit - 1)
+            assert breaker._get_remaining_error_quota(quota) == 1
+            assert breaker._get_controlling_quota() == quota
+
+            breaker.record_error()
+
+            # Hitting the limit puts us into the BROKEN state
+            assert breaker._get_remaining_error_quota(quota) == 0
+            assert breaker._get_controlling_quota() is None
+            assert breaker._get_state_and_remaining_time() == (
+                CircuitBreakerState.BROKEN,
+                breaker.broken_state_duration,
+            )
+            mock_logger.warning.assert_called_with(
+                "Circuit breaker '%s' error limit hit",
+                "dogs_are_great",
+                extra={
+                    "current_state": state,
+                    "error_limit": limit,
+                    "error_limit_window": config["error_limit_window"],
+                },
+            )
+
+            # Now jump to one second after the BROKEN state has expired to see that we're in
+            # RECOVERY
+            with time_machine.travel(now + breaker.broken_state_duration + 1, tick=False):
+                assert breaker._get_controlling_quota() is breaker.recovery_quota
+                assert breaker._get_state_and_remaining_time() == (
+                    CircuitBreakerState.RECOVERY,
+                    breaker.recovery_duration - 1,
+                )
+
+    @patch("sentry.utils.circuit_breaker2.logger")
+    def test_stays_in_current_state_if_redis_call_changing_state_fails(
+        self, mock_logger: MagicMock
+    ):
+        breaker = self.breaker
+
+        for current_state, quota, limit, seconds_left in [
+            # The case where the current state is the BROKEN state isn't included here because the
+            # switch from BROKEN state to RECOVERY state happens passively (by `broken_state_key`
+            # expiring), rather than through an active call to redis
+            (
+                CircuitBreakerState.OK,
+                breaker.primary_quota,
+                breaker.error_limit,
+                None,
+            ),
+            (
+                CircuitBreakerState.RECOVERY,
+                breaker.recovery_quota,
+                breaker.recovery_error_limit,
+                1231,
+            ),
+        ]:
+
+            breaker._set_breaker_state(current_state, seconds_left)
+            breaker._add_quota_usage(quota, limit - 1)
+            assert breaker._get_remaining_error_quota(quota) == 1
+            assert breaker._get_controlling_quota() == quota
+
+            with patch(
+                "sentry.utils.circuit_breaker2.CircuitBreaker._set_in_redis", side_effect=Exception
+            ):
+                breaker.record_error()
+
+            # We've recorded the error, but the state hasn't changed
+            assert breaker._get_remaining_error_quota(quota) == 0
+            assert breaker._get_controlling_quota() == quota
+            assert breaker._get_state_and_remaining_time() == (current_state, seconds_left)
+            mock_logger.exception.assert_called_with(
+                "Couldn't set state-change keys in redis for circuit breaker '%s'",
+                breaker.key,
+                extra={"current_state": current_state},
+            )
+
+
+@freeze_time()
+class ShouldAllowRequestTest(TestCase):
+    def setUp(self) -> None:
+        self.config = DEFAULT_CONFIG
+        self.breaker = MockCircuitBreaker("dogs_are_great", self.config)
+
+        # Clear all existing keys from redis
+        self.breaker.redis_pipeline.flushall()
+        self.breaker.redis_pipeline.execute()
+
+    def test_allows_request_in_non_broken_state_with_quota_remaining(self):
+        breaker = self.breaker
+
+        for state, quota, limit in [
+            (CircuitBreakerState.OK, breaker.primary_quota, breaker.error_limit),
+            (CircuitBreakerState.RECOVERY, breaker.recovery_quota, breaker.recovery_error_limit),
+        ]:
+            breaker._set_breaker_state(state)
+            breaker._add_quota_usage(quota, limit - 5)
+            assert breaker._get_remaining_error_quota(quota) == 5
+
+            assert breaker.should_allow_request() is True
+
+    def test_blocks_request_in_non_broken_state_with_no_quota_remaining(self):
+        breaker = self.breaker
+
+        for state, quota, limit in [
+            (CircuitBreakerState.OK, breaker.primary_quota, breaker.error_limit),
+            (CircuitBreakerState.RECOVERY, breaker.recovery_quota, breaker.recovery_error_limit),
+        ]:
+            breaker._set_breaker_state(state)
+            breaker._add_quota_usage(quota, limit)
+            assert breaker._get_remaining_error_quota(quota) == 0
+
+            assert breaker.should_allow_request() is False
+
+    def test_blocks_request_in_BROKEN_state(self):
+        breaker = self.breaker
+
+        breaker._set_breaker_state(CircuitBreakerState.BROKEN)
+
+        assert breaker.should_allow_request() is False
+
+    @patch("sentry.utils.circuit_breaker2.logger")
+    def test_allows_request_if_redis_call_fails(self, mock_logger: MagicMock):
+        breaker = self.breaker
+
+        with patch(
+            "sentry.utils.circuit_breaker2.CircuitBreaker._get_from_redis", side_effect=Exception
+        ):
+            assert breaker.should_allow_request() is True
+            mock_logger.exception.assert_called_with(
+                "Couldn't get state from redis for circuit breaker '%s'", breaker.key
+            )