Browse Source

feat(hybrid-cloud): Add logic on when to proxy (#39945)

Request should only be proxied when the endpoint's silo limit doesn't match the current silo mode
Aniket Das 2 years ago
parent
commit
ed0e849d52

+ 11 - 4
src/sentry/api_gateway/api_gateway.py

@@ -7,16 +7,23 @@ from sentry.api_gateway.proxy import proxy_request
 from sentry.silo import SiloMode
 
 
-def _request_should_be_proxied(request: Request, view_kwargs) -> bool:
-    return SiloMode.get_current_mode() != SiloMode.MONOLITH and "organization_slug" in view_kwargs
+def _request_should_be_proxied(request: Request, view_func, view_kwargs) -> bool:
+    view_class = getattr(view_func, "view_class", None)
+    current_silo_mode = SiloMode.get_current_mode()
+    if view_class is not None:
+        endpoint_silo_limit = getattr(view_class, "silo_limit", None)
+        if endpoint_silo_limit is not None:
+            endpoint_silo_set = endpoint_silo_limit.modes
+            return current_silo_mode not in endpoint_silo_set and "organization_slug" in view_kwargs
+    return False
 
 
-def proxy_request_if_needed(request: Request, view_kwargs) -> HttpResponseBase | None:
+def proxy_request_if_needed(request: Request, view_func, view_kwargs) -> HttpResponseBase | None:
     """
     Main execution flow for the API Gateway
     returns None if proxying is not required
     """
-    if not _request_should_be_proxied(request, view_kwargs):
+    if not _request_should_be_proxied(request, view_func, view_kwargs):
         return None
 
     # Request should be proxied at this point

+ 1 - 1
src/sentry/middleware/api_gateway.py

@@ -20,6 +20,6 @@ class ApiGatewayMiddleware:
     def process_view(
         self, request: Request, view_func, view_args, view_kwargs
     ) -> HttpResponseBase | None:
-        proxy_response = proxy_request_if_needed(request, view_kwargs)
+        proxy_response = proxy_request_if_needed(request, view_func, view_kwargs)
         if proxy_response is not None:
             return proxy_response

+ 6 - 4
src/sentry/testutils/helpers/api_gateway.py

@@ -7,7 +7,7 @@ from django.test import override_settings
 from rest_framework.permissions import AllowAny
 from rest_framework.response import Response
 
-from sentry.api.base import Endpoint, control_silo_endpoint, pending_silo_endpoint
+from sentry.api.base import control_silo_endpoint, region_silo_endpoint
 from sentry.api.bases.organization import OrganizationEndpoint
 from sentry.testutils import APITestCase
 from sentry.types.region import Region, RegionCategory
@@ -24,12 +24,14 @@ SENTRY_REGION_CONFIG = [
 
 
 @control_silo_endpoint
-class ControlEndpoint(Endpoint):
-    def get(self, request):
+class ControlEndpoint(OrganizationEndpoint):
+    permission_classes = (AllowAny,)
+
+    def get(self, request, organization):
         return Response({"proxy": False})
 
 
-@pending_silo_endpoint
+@region_silo_endpoint
 class RegionEndpoint(OrganizationEndpoint):
     permission_classes = (AllowAny,)
 

+ 44 - 0
tests/sentry/api_gateway/test_api_gateway.py

@@ -37,3 +37,47 @@ class ApiGatewayTest(ApiGatewayTestCase):
         assert resp.status_code == 200, resp.content
         resp_json = json.loads(b"".join(resp.streaming_content))
         assert resp_json["proxy"]
+
+    @responses.activate
+    @patch("sentry.types.region.get_region_for_organization")
+    def test_proxy_check(self, region_fnc_patch):
+        """Test the logic of when a request should be proxied"""
+        region_fnc_patch.return_value = SENTRY_REGION_CONFIG[0]
+
+        responses.add(
+            responses.GET,
+            f"http://region1.testserver/organizations/{self.organization.slug}/region/",
+            json={"proxy": True},
+        )
+        responses.add(
+            responses.GET,
+            f"http://region1.testserver/organizations/{self.organization.slug}/control/",
+            json={"proxy": True},
+        )
+
+        region_url = reverse(
+            "region-endpoint", kwargs={"organization_slug": self.organization.slug}
+        )
+        control_url = reverse(
+            "control-endpoint", kwargs={"organization_slug": self.organization.slug}
+        )
+
+        with override_settings(SILO_MODE=SiloMode.CONTROL, MIDDLEWARE=tuple(self.middleware)):
+            resp = self.client.get(region_url)
+            assert resp.status_code == 200
+            resp_json = json.loads(b"".join(resp.streaming_content))
+            assert resp_json["proxy"]
+
+            resp = self.client.get(control_url)
+            assert resp.status_code == 200
+            assert not resp.data["proxy"]
+
+        with override_settings(SILO_MODE=SiloMode.REGION, MIDDLEWARE=tuple(self.middleware)):
+            resp = self.client.get(region_url)
+            assert resp.status_code == 200
+            assert not resp.data["proxy"]
+
+            resp = self.client.get(control_url)
+            assert resp.status_code == 200
+            resp_json = json.loads(b"".join(resp.streaming_content))
+            assert resp_json["proxy"]