Просмотр исходного кода

fix(ratelimit): Store header info with request (#31450)

* fix(ratelimit): Store header info with request

Previously we stored the rate limit header information in a class property of the middleware assuming that each request had it's own instance of the middleware. Turns out that a lot of the requests share the same middleware so every request was thrashing the metadata property.

Instead, I moved the header information to a property of the request. This way the correct header is tied to its request

`before_after` was added for the race condition test
Aniket Das "Tekky 3 лет назад
Родитель
Сommit
9ffb055577

+ 1 - 0
requirements-dev.txt

@@ -1,3 +1,4 @@
+before_after==1.0.1
 docker>=3.7.0,<3.8.0
 exam>=0.5.1
 freezegun==1.1.0

+ 10 - 16
src/sentry/middleware/ratelimit.py

@@ -11,7 +11,7 @@ from sentry.ratelimits import (
     get_rate_limit_key,
     get_rate_limit_value,
 )
-from sentry.types.ratelimit import RateLimitCategory, RateLimitMeta
+from sentry.types.ratelimit import RateLimitCategory
 
 DEFAULT_ERROR_MESSAGE = (
     "You are attempting to use this endpoint too frequently. Limit is "
@@ -22,8 +22,6 @@ DEFAULT_ERROR_MESSAGE = (
 class RatelimitMiddleware(MiddlewareMixin):
     """Middleware that applies a rate limit to every endpoint."""
 
-    rate_limit_metadata: RateLimitMeta | None = None
-
     def process_view(self, request: Request, view_func, view_args, view_kwargs) -> Response | None:
         """Check if the endpoint call will violate."""
         request.will_be_rate_limited = False
@@ -46,29 +44,25 @@ class RatelimitMiddleware(MiddlewareMixin):
         if rate_limit is None:
             return
 
-        self.rate_limit_metadata = above_rate_limit_check(key, rate_limit)
-        if self.rate_limit_metadata.is_limited:
+        request.rate_limit_metadata = above_rate_limit_check(key, rate_limit)
+
+        if request.rate_limit_metadata.is_limited:
             request.will_be_rate_limited = True
             enforce_rate_limit = getattr(view_func.view_class, "enforce_rate_limit", False)
             if enforce_rate_limit:
                 return HttpResponse(
                     {
                         "detail": DEFAULT_ERROR_MESSAGE.format(
-                            limit=self.rate_limit_metadata.limit,
-                            window=self.rate_limit_metadata.window,
+                            limit=request.rate_limit_metadata.limit,
+                            window=request.rate_limit_metadata.window,
                         )
                     },
                     status=429,
                 )
 
     def process_response(self, request: Request, response: Response) -> Response:
-        if self.rate_limit_metadata is not None:
-            remaining_count = (
-                self.rate_limit_metadata.limit - self.rate_limit_metadata.current
-                if not self.rate_limit_metadata.is_limited
-                else 0
-            )
-            response["X-Sentry-Rate-Limit-Remaining"] = remaining_count
-            response["X-Sentry-Rate-Limit-Limit"] = self.rate_limit_metadata.limit
-            response["X-Sentry-Rate-Limit-Reset"] = self.rate_limit_metadata.reset_time
+        if hasattr(request, "rate_limit_metadata"):
+            response["X-Sentry-Rate-Limit-Remaining"] = request.rate_limit_metadata.remaining
+            response["X-Sentry-Rate-Limit-Limit"] = request.rate_limit_metadata.limit
+            response["X-Sentry-Rate-Limit-Reset"] = request.rate_limit_metadata.reset_time
         return response

+ 2 - 0
src/sentry/ratelimits/utils.py

@@ -128,12 +128,14 @@ def above_rate_limit_check(key: str, rate_limit: RateLimit) -> RateLimitMeta:
     is_limited, current, reset_time = ratelimiter.is_limited_with_value(
         key, limit=rate_limit.limit, window=rate_limit.window
     )
+    remaining = rate_limit.limit - current if not is_limited else 0
     return RateLimitMeta(
         is_limited=is_limited,
         current=current,
         limit=rate_limit.limit,
         window=rate_limit.window,
         reset_time=reset_time,
+        remaining=remaining,
     )
 
 

+ 2 - 0
src/sentry/types/ratelimit.py

@@ -31,6 +31,7 @@ class RateLimitMeta:
     Attributes:
         is_limited (bool): request is rate limited
         current (int): number of requests done in the current window
+        remaining (int): number of requests left in the current window
         limit (int): max number of requests per window
         window (int): window size in seconds
         reset_time (int): UTC Epoch time in seconds when the current window expires
@@ -38,6 +39,7 @@ class RateLimitMeta:
 
     is_limited: bool
     current: int
+    remaining: int
     limit: int
     window: int
     reset_time: int

+ 33 - 1
tests/sentry/middleware/test_ratelimit_middleware.py

@@ -1,10 +1,12 @@
 from time import time
 from unittest.mock import patch
 
+from before_after import before
 from django.conf import settings
 from django.conf.urls import url
 from django.contrib.auth.models import AnonymousUser
 from django.test import RequestFactory, override_settings
+from django.urls import reverse
 from exam import fixture
 from freezegun import freeze_time
 from rest_framework.permissions import AllowAny
@@ -220,12 +222,27 @@ class RateLimitHeaderTestEndpoint(Endpoint):
     enforce_rate_limit = True
     rate_limits = {"GET": {RateLimitCategory.IP: RateLimit(2, 100)}}
 
+    def inject_call(self):
+        return
+
+    def get(self, request):
+        self.inject_call()
+        return Response({"ok": True})
+
+
+class RaceConditionEndpoint(Endpoint):
+    permission_classes = (AllowAny,)
+
+    enforce_rate_limit = False
+    rate_limits = {"GET": {RateLimitCategory.IP: RateLimit(40, 100)}}
+
     def get(self, request):
         return Response({"ok": True})
 
 
 urlpatterns = [
-    url(r"^/ratelimit$", RateLimitHeaderTestEndpoint.as_view(), name="ratelimit-header-endpoint")
+    url(r"^/ratelimit$", RateLimitHeaderTestEndpoint.as_view(), name="ratelimit-header-endpoint"),
+    url(r"^/race-condition$", RaceConditionEndpoint.as_view(), name="race-condition-endpoint"),
 ]
 
 
@@ -272,3 +289,18 @@ class TestRatelimitHeader(APITestCase):
         assert "X-Sentry-Rate-Limit-Remaining" not in response._headers
         assert "X-Sentry-Rate-Limit-Limit" not in response._headers
         assert "X-Sentry-Rate-Limit-Reset" not in response._headers
+
+    def test_header_race_condition(self):
+        """Make sure concurrent requests don't affect each other's rate limit"""
+
+        def parallel_request(*args, **kwargs):
+            self.client.get(reverse("race-condition-endpoint"))
+
+        with before(
+            "tests.sentry.middleware.test_ratelimit_middleware.RateLimitHeaderTestEndpoint.inject_call",
+            parallel_request,
+        ):
+            response = self.get_success_response()
+
+        assert int(response["X-Sentry-Rate-Limit-Remaining"]) == 1
+        assert int(response["X-Sentry-Rate-Limit-Limit"]) == 2

+ 16 - 1
tests/sentry/ratelimits/utils/test_above_rate_limit_check.py

@@ -13,5 +13,20 @@ class RatelimitMiddlewareTest(TestCase):
             expected_reset_time = int(time() + 100)
             return_val = above_rate_limit_check("foo", RateLimit(10, 100))
             assert return_val == RateLimitMeta(
-                is_limited=False, current=1, limit=10, window=100, reset_time=expected_reset_time
+                is_limited=False,
+                current=1,
+                limit=10,
+                window=100,
+                reset_time=expected_reset_time,
+                remaining=9,
+            )
+            for _ in range(10):
+                return_val = above_rate_limit_check("foo", RateLimit(10, 100))
+            assert return_val == RateLimitMeta(
+                is_limited=True,
+                current=11,
+                limit=10,
+                window=100,
+                reset_time=expected_reset_time,
+                remaining=0,
             )