Browse Source

feat(apigateway): Added proxy support for `organization_id_or_slug` path params (#69245)

Raj Joshi 10 months ago
parent
commit
30b79363b4

+ 8 - 3
src/sentry/api/utils.py

@@ -487,19 +487,24 @@ class Timer:
 
 
 def id_or_slug_path_params_enabled(
-    convert_args_class: str, organization_slug: str | None = None
+    convert_args_class: str | None = None, organization_slug: str | None = None
 ) -> bool:
     # GA option
     if options.get("api.id-or-slug-enabled"):
         return True
 
+    # Apigateway
+    if not convert_args_class and organization_slug:
+        # Return True if the organization is in the list of enabled organizations and the apigateway option is enabled
+        return organization_slug in options.get("api.id-or-slug-enabled-ea-org")
+
     # EA option for endpoints where organization is available
     if organization_slug and organization_slug not in options.get("api.id-or-slug-enabled-ea-org"):
         return False
 
     # EA option for endpoints where organization is not available
-    if convert_args_class in options.get("api.id-or-slug-enabled-ea-endpoints"):
-        return True
+    if convert_args_class:
+        return convert_args_class in options.get("api.id-or-slug-enabled-ea-endpoints")
 
     return False
 

+ 5 - 3
src/sentry/hybridcloud/apigateway/apigateway.py

@@ -60,8 +60,10 @@ def proxy_request_if_needed(
     if request.resolver_match:
         url_name = request.resolver_match.url_name or url_name
 
-    if "organization_slug" in view_kwargs:
-        org_slug = view_kwargs["organization_slug"]
+    if "organization_slug" in view_kwargs or "organization_id_or_slug" in view_kwargs:
+        org_id_or_slug: str = view_kwargs.get("organization_slug") or view_kwargs.get(
+            "organization_id_or_slug", ""
+        )
 
         metrics.incr(
             "apigateway.proxy_request",
@@ -70,7 +72,7 @@ def proxy_request_if_needed(
                 "kind": "orgslug",
             },
         )
-        return proxy_request(request, org_slug, url_name)
+        return proxy_request(request, org_id_or_slug, url_name)
 
     if (
         "uuid" in view_kwargs

+ 4 - 4
src/sentry/hybridcloud/apigateway/proxy.py

@@ -87,13 +87,13 @@ class _body_with_length:
         return self.request.read(size)
 
 
-def proxy_request(request: HttpRequest, org_slug: str, url_name: str) -> HttpResponseBase:
-    """Take a django request object and proxy it to a remote location given an org_slug"""
+def proxy_request(request: HttpRequest, org_id_or_slug: str, url_name: str) -> HttpResponseBase:
+    """Take a django request object and proxy it to a remote location given an org_id_or_slug"""
 
     try:
-        region = get_region_for_organization(org_slug)
+        region = get_region_for_organization(org_id_or_slug)
     except RegionResolutionError as e:
-        logger.info("region_resolution_error", extra={"org_slug": org_slug, "error": str(e)})
+        logger.info("region_resolution_error", extra={"org_slug": org_id_or_slug, "error": str(e)})
         return HttpResponse(status=404)
 
     return proxy_region_request(request, region, url_name)

+ 10 - 0
src/sentry/testutils/helpers/apigateway.py

@@ -56,6 +56,16 @@ urlpatterns = [
         RegionEndpoint.as_view(),
         name="region-endpoint",
     ),
+    re_path(
+        r"^organizations/(?P<organization_id_or_slug>[^\/]+)/control/$",
+        ControlEndpoint.as_view(),
+        name="control-endpoint-id-or-slug",
+    ),
+    re_path(
+        r"^organizations/(?P<organization_id_or_slug>[^\/]+)/region/$",
+        RegionEndpoint.as_view(),
+        name="region-endpoint-id-or-slug",
+    ),
 ] + api_urls.urlpatterns
 
 

+ 15 - 3
src/sentry/types/region.py

@@ -263,13 +263,25 @@ def subdomain_is_region(request: HttpRequest) -> bool:
 
 
 @control_silo_function
-def get_region_for_organization(organization_slug: str) -> Region:
+def get_region_for_organization(organization_id_or_slug: str) -> Region:
     """Resolve an organization to the region where its data is stored."""
+    from sentry.api.utils import id_or_slug_path_params_enabled
     from sentry.models.organizationmapping import OrganizationMapping
 
-    mapping = OrganizationMapping.objects.filter(slug=organization_slug).first()
+    if (
+        id_or_slug_path_params_enabled(organization_slug=organization_id_or_slug)
+        and organization_id_or_slug.isnumeric()
+    ):
+        mapping = OrganizationMapping.objects.filter(
+            organization_id=organization_id_or_slug
+        ).first()
+    else:
+        mapping = OrganizationMapping.objects.filter(slug=organization_id_or_slug).first()
+
     if not mapping:
-        raise RegionResolutionError(f"Organization {organization_slug} has no associated mapping.")
+        raise RegionResolutionError(
+            f"Organization {organization_id_or_slug} has no associated mapping."
+        )
 
     return get_region_by_name(name=mapping.region_name)
 

+ 71 - 0
tests/sentry/hybridcloud/apigateway/test_apigateway.py

@@ -8,6 +8,7 @@ from rest_framework.response import Response
 
 from sentry.silo import SiloMode
 from sentry.testutils.helpers.apigateway import ApiGatewayTestCase, verify_request_params
+from sentry.testutils.helpers.options import override_options
 from sentry.testutils.helpers.response import close_streaming_response
 from sentry.testutils.silo import control_silo_test
 from sentry.utils import json
@@ -108,6 +109,76 @@ class ApiGatewayTest(ApiGatewayTestCase):
             assert resp.status_code == 200
             assert resp.data["proxy"] is False
 
+    @responses.activate
+    @override_options({"api.id-or-slug-enabled": True})
+    def test_proxy_check_org_id_or_slug_url_with_params(self):
+        """Test the logic of when a request should be proxied"""
+        responses.add(
+            responses.GET,
+            f"{self.REGION.address}/organizations/{self.organization.slug}/region/",
+            json={"proxy": True},
+        )
+        responses.add(
+            responses.GET,
+            f"{self.REGION.address}/organizations/{self.organization.slug}/control/",
+            json={"proxy": True},
+        )
+        responses.add(
+            responses.GET,
+            f"{self.REGION.address}/organizations/{self.organization.id}/region/",
+            json={"proxy": True},
+        )
+        responses.add(
+            responses.GET,
+            f"{self.REGION.address}/organizations/{self.organization.id}/control/",
+            json={"proxy": True},
+        )
+
+        region_url_slug = reverse(
+            "region-endpoint-id-or-slug", kwargs={"organization_id_or_slug": self.organization.slug}
+        )
+        control_url_slug = reverse(
+            "control-endpoint-id-or-slug",
+            kwargs={"organization_id_or_slug": self.organization.slug},
+        )
+
+        region_url_id = reverse(
+            "region-endpoint-id-or-slug", kwargs={"organization_id_or_slug": self.organization.id}
+        )
+        control_url_id = reverse(
+            "control-endpoint-id-or-slug", kwargs={"organization_id_or_slug": self.organization.id}
+        )
+
+        with override_settings(SILO_MODE=SiloMode.CONTROL, MIDDLEWARE=tuple(self.middleware)):
+            resp = self.client.get(region_url_slug)
+            assert resp.status_code == 200
+            resp_json = json.loads(close_streaming_response(resp))
+            assert resp_json["proxy"] is True
+
+            resp = self.client.get(control_url_slug)
+            assert resp.status_code == 200
+            assert resp.data["proxy"] is False
+
+        with override_settings(SILO_MODE=SiloMode.REGION, MIDDLEWARE=tuple(self.middleware)):
+            resp = self.client.get(region_url_slug)
+            assert resp.status_code == 200
+            assert resp.data["proxy"] is False
+
+        with override_settings(SILO_MODE=SiloMode.CONTROL, MIDDLEWARE=tuple(self.middleware)):
+            resp = self.client.get(region_url_id)
+            assert resp.status_code == 200
+            resp_json = json.loads(close_streaming_response(resp))
+            assert resp_json["proxy"] is True
+
+            resp = self.client.get(control_url_id)
+            assert resp.status_code == 200
+            assert resp.data["proxy"] is False
+
+        with override_settings(SILO_MODE=SiloMode.REGION, MIDDLEWARE=tuple(self.middleware)):
+            resp = self.client.get(region_url_id)
+            assert resp.status_code == 200
+            assert resp.data["proxy"] is False
+
     @responses.activate
     def test_proxy_check_region_pinned_url(self):
         responses.add(