Browse Source

fix(api): Fallback to IP rate limiting if missing org (#66839)

Fallback to IP ratelimiting when we fail to fetch the organization from
an integration token.
Seiji Chew 1 year ago
parent
commit
4d5381c2b7

+ 10 - 7
src/sentry/ratelimits/utils.py

@@ -1,8 +1,6 @@
 from __future__ import annotations
 
 import logging
-import random
-import string
 from collections.abc import Callable, Mapping
 from typing import TYPE_CHECKING, Any
 
@@ -96,6 +94,11 @@ def get_rate_limit_key(
         if request_user.is_sentry_app:
             category = "org"
             id = get_organization_id_from_token(token_id)
+
+            # Fallback to IP address limit if we can't find the organization
+            if id is None and ip_address is not None:
+                category = "ip"
+                id = ip_address
         else:
             category = "user"
             id = request_auth.user_id
@@ -125,7 +128,7 @@ def get_rate_limit_key(
         return f"{category}:{rate_limit_group}:{http_method}:{id}"
 
 
-def get_organization_id_from_token(token_id: int) -> Any:
+def get_organization_id_from_token(token_id: int) -> int | None:
     from sentry.services.hybrid_cloud.app import app_service
 
     installations = app_service.get_many(
@@ -134,13 +137,13 @@ def get_organization_id_from_token(token_id: int) -> Any:
             "api_token_id": token_id,
         }
     )
-    installation = installations[0] if len(installations) > 0 else None
+    installation = installations[0] if installations else None
 
-    # Return a random uppercase/lowercase letter to avoid collisions caused by tokens not being
-    # associated with a SentryAppInstallation. This is a temporary fix while we solve the root cause
+    # Return None to avoid collisions caused by tokens not being associated with
+    # a SentryAppInstallation. We fallback to IP address rate limiting in this case.
     if not installation:
         logger.info("installation.not_found", extra={"token_id": token_id})
-        return random.choice(string.ascii_letters)
+        return None
 
     return installation.organization_id
 

+ 9 - 8
tests/sentry/middleware/test_ratelimit_middleware.py

@@ -1,4 +1,3 @@
-import re
 from concurrent.futures import ThreadPoolExecutor
 from functools import cached_property
 from time import sleep, time
@@ -118,8 +117,9 @@ class RatelimitMiddlewareTest(TestCase, BaseTestCase):
     def test_positive_rate_limit_response_headers(self, default_rate_limit_mock):
         request = self.factory.get("/")
 
-        with freeze_time("2000-01-01"), patch.object(
-            RatelimitMiddlewareTest.TestEndpoint, "enforce_rate_limit", True
+        with (
+            freeze_time("2000-01-01"),
+            patch.object(RatelimitMiddlewareTest.TestEndpoint, "enforce_rate_limit", True),
         ):
             default_rate_limit_mock.return_value = RateLimit(0, 100)
             response = self.middleware.process_view(request, self._test_endpoint, [], {})
@@ -184,7 +184,7 @@ class RatelimitMiddlewareTest(TestCase, BaseTestCase):
 
         self.populate_internal_integration_request(request)
         self.middleware.process_view(request, self._test_endpoint, [], {})
-        assert request.rate_limit_category == RateLimitCategory.ORGANIZATION
+        assert request.rate_limit_category == RateLimitCategory.IP
 
     def test_get_rate_limit_key(self):
         # Import an endpoint
@@ -235,10 +235,11 @@ class RatelimitMiddlewareTest(TestCase, BaseTestCase):
         )
 
         self.populate_internal_integration_request(request)
-        key_pattern = re.compile(r"^org:default:OrganizationGroupIndexEndpoint:GET:[a-zA-Z]$")
-        key = get_rate_limit_key(view, request, rate_limit_group, rate_limit_config)
-        assert key
-        assert key_pattern.match(key)
+        # Fallback to IP address limit if we can't find the organization
+        assert (
+            get_rate_limit_key(view, request, rate_limit_group, rate_limit_config)
+            == "ip:default:OrganizationGroupIndexEndpoint:GET:684D:1111:222:3333:4444:5555:6:77"
+        )
 
         # Test for
         request.user = AnonymousUser()