Browse Source

fix(ratelimit): Store header info with request (#31450)

* fix(ratelimit): Store header info with request

Previously we stored the rate limit header information in a class property of the middleware assuming that each request had it's own instance of the middleware. Turns out that a lot of the requests share the same middleware so every request was thrashing the metadata property.

Instead, I moved the header information to a property of the request. This way the correct header is tied to its request

`before_after` was added for the race condition test
Aniket Das "Tekky 3 years ago
parent
commit
9ffb055577

+ 1 - 0
requirements-dev.txt

@@ -1,3 +1,4 @@
+before_after==1.0.1
 docker>=3.7.0,<3.8.0
 docker>=3.7.0,<3.8.0
 exam>=0.5.1
 exam>=0.5.1
 freezegun==1.1.0
 freezegun==1.1.0

+ 10 - 16
src/sentry/middleware/ratelimit.py

@@ -11,7 +11,7 @@ from sentry.ratelimits import (
     get_rate_limit_key,
     get_rate_limit_key,
     get_rate_limit_value,
     get_rate_limit_value,
 )
 )
-from sentry.types.ratelimit import RateLimitCategory, RateLimitMeta
+from sentry.types.ratelimit import RateLimitCategory
 
 
 DEFAULT_ERROR_MESSAGE = (
 DEFAULT_ERROR_MESSAGE = (
     "You are attempting to use this endpoint too frequently. Limit is "
     "You are attempting to use this endpoint too frequently. Limit is "
@@ -22,8 +22,6 @@ DEFAULT_ERROR_MESSAGE = (
 class RatelimitMiddleware(MiddlewareMixin):
 class RatelimitMiddleware(MiddlewareMixin):
     """Middleware that applies a rate limit to every endpoint."""
     """Middleware that applies a rate limit to every endpoint."""
 
 
-    rate_limit_metadata: RateLimitMeta | None = None
-
     def process_view(self, request: Request, view_func, view_args, view_kwargs) -> Response | None:
     def process_view(self, request: Request, view_func, view_args, view_kwargs) -> Response | None:
         """Check if the endpoint call will violate."""
         """Check if the endpoint call will violate."""
         request.will_be_rate_limited = False
         request.will_be_rate_limited = False
@@ -46,29 +44,25 @@ class RatelimitMiddleware(MiddlewareMixin):
         if rate_limit is None:
         if rate_limit is None:
             return
             return
 
 
-        self.rate_limit_metadata = above_rate_limit_check(key, rate_limit)
-        if self.rate_limit_metadata.is_limited:
+        request.rate_limit_metadata = above_rate_limit_check(key, rate_limit)
+
+        if request.rate_limit_metadata.is_limited:
             request.will_be_rate_limited = True
             request.will_be_rate_limited = True
             enforce_rate_limit = getattr(view_func.view_class, "enforce_rate_limit", False)
             enforce_rate_limit = getattr(view_func.view_class, "enforce_rate_limit", False)
             if enforce_rate_limit:
             if enforce_rate_limit:
                 return HttpResponse(
                 return HttpResponse(
                     {
                     {
                         "detail": DEFAULT_ERROR_MESSAGE.format(
                         "detail": DEFAULT_ERROR_MESSAGE.format(
-                            limit=self.rate_limit_metadata.limit,
-                            window=self.rate_limit_metadata.window,
+                            limit=request.rate_limit_metadata.limit,
+                            window=request.rate_limit_metadata.window,
                         )
                         )
                     },
                     },
                     status=429,
                     status=429,
                 )
                 )
 
 
     def process_response(self, request: Request, response: Response) -> Response:
     def process_response(self, request: Request, response: Response) -> Response:
-        if self.rate_limit_metadata is not None:
-            remaining_count = (
-                self.rate_limit_metadata.limit - self.rate_limit_metadata.current
-                if not self.rate_limit_metadata.is_limited
-                else 0
-            )
-            response["X-Sentry-Rate-Limit-Remaining"] = remaining_count
-            response["X-Sentry-Rate-Limit-Limit"] = self.rate_limit_metadata.limit
-            response["X-Sentry-Rate-Limit-Reset"] = self.rate_limit_metadata.reset_time
+        if hasattr(request, "rate_limit_metadata"):
+            response["X-Sentry-Rate-Limit-Remaining"] = request.rate_limit_metadata.remaining
+            response["X-Sentry-Rate-Limit-Limit"] = request.rate_limit_metadata.limit
+            response["X-Sentry-Rate-Limit-Reset"] = request.rate_limit_metadata.reset_time
         return response
         return response

+ 2 - 0
src/sentry/ratelimits/utils.py

@@ -128,12 +128,14 @@ def above_rate_limit_check(key: str, rate_limit: RateLimit) -> RateLimitMeta:
     is_limited, current, reset_time = ratelimiter.is_limited_with_value(
     is_limited, current, reset_time = ratelimiter.is_limited_with_value(
         key, limit=rate_limit.limit, window=rate_limit.window
         key, limit=rate_limit.limit, window=rate_limit.window
     )
     )
+    remaining = rate_limit.limit - current if not is_limited else 0
     return RateLimitMeta(
     return RateLimitMeta(
         is_limited=is_limited,
         is_limited=is_limited,
         current=current,
         current=current,
         limit=rate_limit.limit,
         limit=rate_limit.limit,
         window=rate_limit.window,
         window=rate_limit.window,
         reset_time=reset_time,
         reset_time=reset_time,
+        remaining=remaining,
     )
     )
 
 
 
 

+ 2 - 0
src/sentry/types/ratelimit.py

@@ -31,6 +31,7 @@ class RateLimitMeta:
     Attributes:
     Attributes:
         is_limited (bool): request is rate limited
         is_limited (bool): request is rate limited
         current (int): number of requests done in the current window
         current (int): number of requests done in the current window
+        remaining (int): number of requests left in the current window
         limit (int): max number of requests per window
         limit (int): max number of requests per window
         window (int): window size in seconds
         window (int): window size in seconds
         reset_time (int): UTC Epoch time in seconds when the current window expires
         reset_time (int): UTC Epoch time in seconds when the current window expires
@@ -38,6 +39,7 @@ class RateLimitMeta:
 
 
     is_limited: bool
     is_limited: bool
     current: int
     current: int
+    remaining: int
     limit: int
     limit: int
     window: int
     window: int
     reset_time: int
     reset_time: int

+ 33 - 1
tests/sentry/middleware/test_ratelimit_middleware.py

@@ -1,10 +1,12 @@
 from time import time
 from time import time
 from unittest.mock import patch
 from unittest.mock import patch
 
 
+from before_after import before
 from django.conf import settings
 from django.conf import settings
 from django.conf.urls import url
 from django.conf.urls import url
 from django.contrib.auth.models import AnonymousUser
 from django.contrib.auth.models import AnonymousUser
 from django.test import RequestFactory, override_settings
 from django.test import RequestFactory, override_settings
+from django.urls import reverse
 from exam import fixture
 from exam import fixture
 from freezegun import freeze_time
 from freezegun import freeze_time
 from rest_framework.permissions import AllowAny
 from rest_framework.permissions import AllowAny
@@ -220,12 +222,27 @@ class RateLimitHeaderTestEndpoint(Endpoint):
     enforce_rate_limit = True
     enforce_rate_limit = True
     rate_limits = {"GET": {RateLimitCategory.IP: RateLimit(2, 100)}}
     rate_limits = {"GET": {RateLimitCategory.IP: RateLimit(2, 100)}}
 
 
+    def inject_call(self):
+        return
+
+    def get(self, request):
+        self.inject_call()
+        return Response({"ok": True})
+
+
+class RaceConditionEndpoint(Endpoint):
+    permission_classes = (AllowAny,)
+
+    enforce_rate_limit = False
+    rate_limits = {"GET": {RateLimitCategory.IP: RateLimit(40, 100)}}
+
     def get(self, request):
     def get(self, request):
         return Response({"ok": True})
         return Response({"ok": True})
 
 
 
 
 urlpatterns = [
 urlpatterns = [
-    url(r"^/ratelimit$", RateLimitHeaderTestEndpoint.as_view(), name="ratelimit-header-endpoint")
+    url(r"^/ratelimit$", RateLimitHeaderTestEndpoint.as_view(), name="ratelimit-header-endpoint"),
+    url(r"^/race-condition$", RaceConditionEndpoint.as_view(), name="race-condition-endpoint"),
 ]
 ]
 
 
 
 
@@ -272,3 +289,18 @@ class TestRatelimitHeader(APITestCase):
         assert "X-Sentry-Rate-Limit-Remaining" not in response._headers
         assert "X-Sentry-Rate-Limit-Remaining" not in response._headers
         assert "X-Sentry-Rate-Limit-Limit" not in response._headers
         assert "X-Sentry-Rate-Limit-Limit" not in response._headers
         assert "X-Sentry-Rate-Limit-Reset" not in response._headers
         assert "X-Sentry-Rate-Limit-Reset" not in response._headers
+
+    def test_header_race_condition(self):
+        """Make sure concurrent requests don't affect each other's rate limit"""
+
+        def parallel_request(*args, **kwargs):
+            self.client.get(reverse("race-condition-endpoint"))
+
+        with before(
+            "tests.sentry.middleware.test_ratelimit_middleware.RateLimitHeaderTestEndpoint.inject_call",
+            parallel_request,
+        ):
+            response = self.get_success_response()
+
+        assert int(response["X-Sentry-Rate-Limit-Remaining"]) == 1
+        assert int(response["X-Sentry-Rate-Limit-Limit"]) == 2

+ 16 - 1
tests/sentry/ratelimits/utils/test_above_rate_limit_check.py

@@ -13,5 +13,20 @@ class RatelimitMiddlewareTest(TestCase):
             expected_reset_time = int(time() + 100)
             expected_reset_time = int(time() + 100)
             return_val = above_rate_limit_check("foo", RateLimit(10, 100))
             return_val = above_rate_limit_check("foo", RateLimit(10, 100))
             assert return_val == RateLimitMeta(
             assert return_val == RateLimitMeta(
-                is_limited=False, current=1, limit=10, window=100, reset_time=expected_reset_time
+                is_limited=False,
+                current=1,
+                limit=10,
+                window=100,
+                reset_time=expected_reset_time,
+                remaining=9,
+            )
+            for _ in range(10):
+                return_val = above_rate_limit_check("foo", RateLimit(10, 100))
+            assert return_val == RateLimitMeta(
+                is_limited=True,
+                current=11,
+                limit=10,
+                window=100,
+                reset_time=expected_reset_time,
+                remaining=0,
             )
             )