Browse Source

fix(api): Properly fetch org from token (#65403)

Requires https://github.com/getsentry/sentry/pull/65403

Uses the new param in sentry app to fetch the relevant
`SentryAppInstallation` from a token id by also checking
`SentryAppInstallationToken`, which will link to the relevant
`SentryAppInstallation`.

Moved test for `get_rate_limit_key` from `test_ratelimit_middleware.py`
to `test_get_ratelimit_key.py`.
Also added individual tests for `ApiTokenReplica` and `OrgAuthToken`
models

Closes
https://github.com/getsentry/team-core-product-foundations/issues/159

---------

Co-authored-by: getsantry[bot] <66042841+getsantry[bot]@users.noreply.github.com>
Seiji Chew 11 months ago
parent
commit
873de128c0

+ 8 - 2
src/sentry/ratelimits/utils.py

@@ -71,12 +71,15 @@ def get_rate_limit_key(
         return None
 
     ip_address = request.META.get("REMOTE_ADDR")
-    request_auth: AuthenticatedToken | ApiToken | None = getattr(request, "auth", None)
+    request_auth: AuthenticatedToken | ApiToken | ApiTokenReplica | None = getattr(
+        request, "auth", None
+    )
     request_user = getattr(request, "user", None)
 
     from django.contrib.auth.models import AnonymousUser
 
     from sentry.auth.system import is_system_auth
+    from sentry.hybridcloud.models.apitokenreplica import ApiTokenReplica
     from sentry.models.apikey import ApiKey
 
     # Don't Rate Limit System Token Requests
@@ -88,6 +91,8 @@ def get_rate_limit_key(
             token_id = request_auth.id
         elif isinstance(request_auth, AuthenticatedToken) and request_auth.entity_id is not None:
             token_id = request_auth.entity_id
+        elif isinstance(request_auth, ApiTokenReplica) and request_auth.apitoken_id is not None:
+            token_id = request_auth.apitoken_id
         else:
             assert False  # Can't happen as asserted by is_api_token_auth check
 
@@ -134,9 +139,10 @@ def get_organization_id_from_token(token_id: int) -> int | None:
     installations = app_service.get_many(
         filter={
             "status": SentryAppInstallationStatus.INSTALLED,
-            "api_token_id": token_id,
+            "api_installation_token_id": token_id,
         }
     )
+
     installation = installations[0] if installations else None
 
     # Return None to avoid collisions caused by tokens not being associated with

+ 9 - 84
tests/sentry/middleware/test_ratelimit_middleware.py

@@ -3,7 +3,6 @@ from functools import cached_property
 from time import sleep, time
 from unittest.mock import patch, sentinel
 
-from django.contrib.auth.models import AnonymousUser
 from django.http.request import HttpRequest
 from django.test import RequestFactory, override_settings
 from django.urls import re_path, reverse
@@ -11,16 +10,13 @@ from rest_framework.permissions import AllowAny
 from rest_framework.response import Response
 
 from sentry.api.base import Endpoint
-from sentry.issues.endpoints.organization_group_index import OrganizationGroupIndexEndpoint
 from sentry.middleware.ratelimit import RatelimitMiddleware
-from sentry.models.apikey import ApiKey
 from sentry.models.user import User
 from sentry.ratelimits.config import RateLimitConfig, get_default_rate_limits_for_group
-from sentry.ratelimits.utils import get_rate_limit_config, get_rate_limit_key, get_rate_limit_value
-from sentry.silo.base import SiloMode
+from sentry.ratelimits.utils import get_rate_limit_config, get_rate_limit_value
 from sentry.testutils.cases import APITestCase, BaseTestCase, TestCase
 from sentry.testutils.helpers.datetime import freeze_time
-from sentry.testutils.silo import all_silo_test, assume_test_silo_mode
+from sentry.testutils.silo import all_silo_test, assume_test_silo_mode_of
 from sentry.types.ratelimit import RateLimit, RateLimitCategory
 
 
@@ -48,29 +44,26 @@ class RatelimitMiddlewareTest(TestCase, BaseTestCase):
     _test_endpoint = TestEndpoint.as_view()
     _test_endpoint_no_rate_limits = TestEndpointNoRateLimits.as_view()
 
-    def populate_sentry_app_request(self, request):
+    def _populate_public_integration_request(self, request) -> None:
         install = self.create_sentry_app_installation(organization=self.organization)
-
         token = install.api_token
 
-        with assume_test_silo_mode(SiloMode.CONTROL):
+        with assume_test_silo_mode_of(User):
             request.user = User.objects.get(id=install.sentry_app.proxy_user_id)
         request.auth = token
 
-    def populate_internal_integration_request(self, request):
+    def _populate_internal_integration_request(self, request) -> None:
         internal_integration = self.create_internal_integration(
             name="my_app",
             organization=self.organization,
             scopes=("project:read",),
             webhook_url="http://example.com",
         )
-
         token = self.create_internal_integration_token(
             user=self.user,
             internal_integration=internal_integration,
         )
-
-        with assume_test_silo_mode(SiloMode.CONTROL):
+        with assume_test_silo_mode_of(User):
             request.user = User.objects.get(id=internal_integration.proxy_user_id)
         request.auth = token
 
@@ -178,81 +171,13 @@ class RatelimitMiddlewareTest(TestCase, BaseTestCase):
         self.middleware.process_view(request, self._test_endpoint, [], {})
         assert request.rate_limit_category == RateLimitCategory.USER
 
-        self.populate_sentry_app_request(request)
+        self._populate_public_integration_request(request)
         self.middleware.process_view(request, self._test_endpoint, [], {})
         assert request.rate_limit_category == RateLimitCategory.ORGANIZATION
 
-        self.populate_internal_integration_request(request)
+        self._populate_internal_integration_request(request)
         self.middleware.process_view(request, self._test_endpoint, [], {})
-        assert request.rate_limit_category == RateLimitCategory.IP
-
-    def test_get_rate_limit_key(self):
-        # Import an endpoint
-
-        view = OrganizationGroupIndexEndpoint.as_view()
-        rate_limit_config = get_rate_limit_config(view.view_class)
-        rate_limit_group = rate_limit_config.group if rate_limit_config else RateLimitConfig().group
-
-        # Test for default IP
-        request = self.factory.get("/")
-        assert (
-            get_rate_limit_key(view, request, rate_limit_group, rate_limit_config)
-            == "ip:default:GET:127.0.0.1"
-        )
-        # Test when IP address is missing
-        request.META["REMOTE_ADDR"] = None
-        assert get_rate_limit_key(view, request, rate_limit_group, rate_limit_config) is None
-
-        # Test when IP addess is IPv6
-        request.META["REMOTE_ADDR"] = "684D:1111:222:3333:4444:5555:6:77"
-        assert (
-            get_rate_limit_key(view, request, rate_limit_group, rate_limit_config)
-            == "ip:default:GET:684D:1111:222:3333:4444:5555:6:77"
-        )
-
-        # Test for users
-        request.session = {}
-        request.user = self.user
-        assert (
-            get_rate_limit_key(view, request, rate_limit_group, rate_limit_config)
-            == f"user:default:GET:{self.user.id}"
-        )
-
-        # Test for user auth tokens
-        token = self.create_user_auth_token(user=self.user, scope_list=["event:read", "org:read"])
-        request.auth = token
-        request.user = self.user
-        assert (
-            get_rate_limit_key(view, request, rate_limit_group, rate_limit_config)
-            == f"user:default:GET:{self.user.id}"
-        )
-
-        # Test for sentryapp auth tokens:
-        self.populate_sentry_app_request(request)
-        assert (
-            get_rate_limit_key(view, request, rate_limit_group, rate_limit_config)
-            == f"org:default:GET:{self.organization.id}"
-        )
-
-        self.populate_internal_integration_request(request)
-        # Fallback to IP address limit if we can't find the organization
-        assert (
-            get_rate_limit_key(view, request, rate_limit_group, rate_limit_config)
-            == "ip:default:GET:684D:1111:222:3333:4444:5555:6:77"
-        )
-
-        # Test for
-        request.user = AnonymousUser()
-        api_key = None
-        with assume_test_silo_mode(SiloMode.CONTROL):
-            api_key = ApiKey.objects.create(
-                organization_id=self.organization.id, scope_list=["project:write"]
-            )
-        request.auth = api_key
-        assert (
-            get_rate_limit_key(view, request, rate_limit_group, rate_limit_config)
-            == "ip:default:GET:684D:1111:222:3333:4444:5555:6:77"
-        )
+        assert request.rate_limit_category == RateLimitCategory.ORGANIZATION
 
     def test_enforce_rate_limit_is_false(self):
         request = self.factory.get("/")

+ 114 - 30
tests/sentry/ratelimits/utils/test_get_ratelimit_key.py

@@ -1,3 +1,4 @@
+from django.contrib.auth.models import AnonymousUser
 from django.contrib.sessions.backends.base import SessionBase
 from django.test import RequestFactory
 from rest_framework.permissions import AllowAny
@@ -5,13 +6,16 @@ from rest_framework.response import Response
 
 from sentry.api.base import Endpoint
 from sentry.auth.system import SystemToken
+from sentry.hybridcloud.models.apitokenreplica import ApiTokenReplica
 from sentry.models.apitoken import ApiToken
+from sentry.models.integrations.sentry_app_installation import SentryAppInstallation
+from sentry.models.integrations.sentry_app_installation_token import SentryAppInstallationToken
 from sentry.models.user import User
 from sentry.ratelimits import get_rate_limit_config, get_rate_limit_key
 from sentry.ratelimits.config import RateLimitConfig
 from sentry.services.hybrid_cloud.auth import AuthenticatedToken
 from sentry.testutils.cases import TestCase
-from sentry.testutils.silo import assume_test_silo_mode_of, region_silo_test
+from sentry.testutils.silo import all_silo_test, assume_test_silo_mode_of, region_silo_test
 from sentry.types.ratelimit import RateLimit, RateLimitCategory
 
 CONCURRENT_RATE_LIMIT = 20
@@ -34,7 +38,7 @@ class APITestEndpoint(Endpoint):
         return Response({"ok": True})
 
 
-@region_silo_test
+@all_silo_test
 class GetRateLimitKeyTest(TestCase):
     def setUp(self) -> None:
         self.view = APITestEndpoint.as_view()
@@ -44,15 +48,39 @@ class GetRateLimitKeyTest(TestCase):
             self.rate_limit_config.group if self.rate_limit_config else RateLimitConfig().group
         )
 
-    def test_default_ip(self):
+    def _populate_public_integration_request(self, request) -> None:
+        install = self.create_sentry_app_installation(organization=self.organization)
+        token = install.api_token
+
+        with assume_test_silo_mode_of(User):
+            request.user = User.objects.get(id=install.sentry_app.proxy_user_id)
+        request.auth = token
+
+    def _populate_internal_integration_request(self, request) -> None:
+        internal_integration = self.create_internal_integration(
+            name="my_app",
+            organization=self.organization,
+            scopes=("project:read",),
+            webhook_url="http://example.com",
+        )
+        token = self.create_internal_integration_token(
+            user=self.user,
+            internal_integration=internal_integration,
+        )
+
+        with assume_test_silo_mode_of(User):
+            request.user = User.objects.get(id=internal_integration.proxy_user_id)
+        request.auth = token
+
+    def test_ips(self):
+        # Test for default IP
         assert (
             get_rate_limit_key(
                 self.view, self.request, self.rate_limit_group, self.rate_limit_config
             )
             == "ip:default:APITestEndpoint:GET:127.0.0.1"
         )
-
-    def test_ip_address_missing(self):
+        # Test when IP address is missing
         self.request.META["REMOTE_ADDR"] = None
         assert (
             get_rate_limit_key(
@@ -60,8 +88,7 @@ class GetRateLimitKeyTest(TestCase):
             )
             is None
         )
-
-    def test_ipv6(self):
+        # Test when IP address is IPv6
         self.request.META["REMOTE_ADDR"] = "684D:1111:222:3333:4444:5555:6:77"
         assert (
             get_rate_limit_key(
@@ -70,51 +97,46 @@ class GetRateLimitKeyTest(TestCase):
             == "ip:default:APITestEndpoint:GET:684D:1111:222:3333:4444:5555:6:77"
         )
 
-    def test_system_token(self):
-        self.request.auth = SystemToken()
+    def test_user(self):
+        self.request.session = SessionBase()
+        self.request.user = self.user
+
         assert (
             get_rate_limit_key(
                 self.view, self.request, self.rate_limit_group, self.rate_limit_config
             )
-            is None
+            == f"user:default:APITestEndpoint:GET:{self.user.id}"
         )
 
-    def test_users(self):
-        user = User(id=1)
-        self.request.session = SessionBase()
-        self.request.user = user
+    def test_system_token(self):
+        self.request.auth = SystemToken()
         assert (
             get_rate_limit_key(
                 self.view, self.request, self.rate_limit_group, self.rate_limit_config
             )
-            == f"user:default:APITestEndpoint:GET:{user.id}"
-        )
-
-    def test_organization(self):
-        self.request.session = SessionBase()
-        sentry_app = self.create_sentry_app(
-            name="Tesla App", published=True, organization=self.organization
-        )
-        install = self.create_sentry_app_installation(
-            slug=sentry_app.slug, organization=self.organization, user=self.user
+            is None
         )
 
-        self.request.user = sentry_app.proxy_user
-
-        self.request.auth = install.api_token
-
+    def test_api_token(self):
+        with assume_test_silo_mode_of(ApiToken):
+            token = ApiToken.objects.create(user=self.user, scope_list=["event:read", "org:read"])
+        self.request.auth = token
+        self.request.user = self.user
         assert (
             get_rate_limit_key(
                 self.view, self.request, self.rate_limit_group, self.rate_limit_config
             )
-            == f"org:default:APITestEndpoint:GET:{install.organization_id}"
+            == f"user:default:APITestEndpoint:GET:{self.user.id}"
         )
 
-    def test_api_token(self):
+    def test_api_token_replica(self):
         with assume_test_silo_mode_of(ApiToken):
             token = ApiToken.objects.create(user=self.user, scope_list=["event:read", "org:read"])
+        with assume_test_silo_mode_of(ApiTokenReplica):
+            token = ApiTokenReplica.objects.get(apitoken_id=token.id)
         self.request.auth = token
         self.request.user = self.user
+
         assert (
             get_rate_limit_key(
                 self.view, self.request, self.rate_limit_group, self.rate_limit_config
@@ -137,6 +159,68 @@ class GetRateLimitKeyTest(TestCase):
             == f"user:default:APITestEndpoint:GET:{self.user.id}"
         )
 
+    def test_api_key(self):
+        self.request.user = AnonymousUser()
+        self.request.auth = self.create_api_key(
+            organization=self.organization, scope_list=["project:write"]
+        )
+
+        assert (
+            get_rate_limit_key(
+                self.view, self.request, self.rate_limit_group, self.rate_limit_config
+            )
+            == "ip:default:APITestEndpoint:GET:127.0.0.1"
+        )
+
+    def test_org_auth_token(self):
+        self.request.user = AnonymousUser()
+        self.request.auth = self.create_org_auth_token(
+            organization_id=self.organization.id, scope_list=["org:ci"]
+        )
+
+        assert (
+            get_rate_limit_key(
+                self.view, self.request, self.rate_limit_group, self.rate_limit_config
+            )
+            == "ip:default:APITestEndpoint:GET:127.0.0.1"
+        )
+
+    def test_user_auth_token(self):
+        token = self.create_user_auth_token(user=self.user, scope_list=["event:read", "org:read"])
+        self.request.auth = token
+        self.request.user = self.user
+
+        assert (
+            get_rate_limit_key(
+                self.view, self.request, self.rate_limit_group, self.rate_limit_config
+            )
+            == f"user:default:APITestEndpoint:GET:{self.user.id}"
+        )
+
+    def test_integration_tokens(self):
+        # Test for PUBLIC Integration api tokens
+        self._populate_public_integration_request(self.request)
+        assert (
+            get_rate_limit_key(
+                self.view, self.request, self.rate_limit_group, self.rate_limit_config
+            )
+            == f"org:default:APITestEndpoint:GET:{self.organization.id}"
+        )
+
+        # Test for INTERNAL Integration api tokens
+        self._populate_internal_integration_request(self.request)
+        with assume_test_silo_mode_of(SentryAppInstallation, SentryAppInstallationToken):
+            # Ensure that the internal integration token lives in
+            # SentryAppInstallationToken instead of SentryAppInstallation
+            assert not SentryAppInstallation.objects.filter(api_token_id=self.request.auth.id)
+            assert SentryAppInstallationToken.objects.filter(api_token_id=self.request.auth.id)
+        assert (
+            get_rate_limit_key(
+                self.view, self.request, self.rate_limit_group, self.rate_limit_config
+            )
+            == f"org:default:APITestEndpoint:GET:{self.organization.id}"
+        )
+
 
 class DummyEndpoint(Endpoint):
     permission_classes = (AllowAny,)