Browse Source

ref: type api_gateway.proxy and its test (#52046)

<!-- Describe your PR here. -->
anthony sottile 1 year ago
parent
commit
b2f7b26b63

+ 0 - 2
pyproject.toml

@@ -404,7 +404,6 @@ module = [
     "sentry.api.validators.email",
     "sentry.api.validators.project_codeowners",
     "sentry.api.validators.servicehook",
-    "sentry.api_gateway.proxy",
     "sentry.apidocs.schema",
     "sentry.app",
     "sentry.audit_log.manager",
@@ -1173,7 +1172,6 @@ module = [
     "tests.sentry.api.validators.sentry_apps.test_text",
     "tests.sentry.api.validators.sentry_apps.test_textarea",
     "tests.sentry.api_gateway.test_api_gateway",
-    "tests.sentry.api_gateway.test_proxy",
     "tests.sentry.apidocs.test_extensions",
     "tests.sentry.auth.authenticators.test_sms",
     "tests.sentry.auth.authenticators.test_u2f",

+ 16 - 13
src/sentry/api_gateway/proxy.py

@@ -5,12 +5,11 @@ import logging
 from wsgiref.util import is_hop_by_hop
 
 from django.conf import settings
-from django.http import StreamingHttpResponse
+from django.http import HttpRequest, StreamingHttpResponse
 from requests import Response as ExternalResponse
 from requests import request as external_request
 from requests.exceptions import Timeout
 from rest_framework.exceptions import NotFound
-from rest_framework.request import Request
 
 from sentry.api.exceptions import RequestTimeout
 from sentry.silo.util import (
@@ -48,7 +47,7 @@ def _parse_response(response: ExternalResponse, remote_url: str) -> StreamingHtt
     return streamed_response
 
 
-def proxy_request(request: Request, org_slug: str) -> StreamingHttpResponse:
+def proxy_request(request: HttpRequest, org_slug: str) -> StreamingHttpResponse:
     """Take a django request object and proxy it to a remote location given an org_slug"""
     from sentry.types.region import get_region_for_organization
 
@@ -61,20 +60,24 @@ def proxy_request(request: Request, org_slug: str) -> StreamingHttpResponse:
     target_url = region.to_url(request.path)
     header_dict = clean_proxy_headers(request.headers)
     # TODO: use requests session for connection pooling capabilities
+    assert request.method is not None
     query_params = getattr(request, request.method, None)
-    request_args = {
-        "headers": header_dict,
-        "params": dict(query_params) if query_params is not None else None,
-        "files": getattr(request, "FILES", None),
-        "data": getattr(request, "body", None) if not getattr(request, "FILES", None) else None,
-        "stream": True,
-        "timeout": settings.GATEWAY_PROXY_TIMEOUT,
-    }
     try:
-        resp: ExternalResponse = external_request(request.method, url=target_url, **request_args)
+        resp = external_request(
+            request.method,
+            url=target_url,
+            headers=header_dict,
+            params=dict(query_params) if query_params is not None else None,
+            files=getattr(request, "FILES", None),
+            data=getattr(request, "body", None) if not getattr(request, "FILES", None) else None,
+            stream=True,
+            timeout=settings.GATEWAY_PROXY_TIMEOUT,
+        )
     except Timeout:
         # remote silo timeout. Use DRF timeout instead
         raise RequestTimeout()
 
-    resp.headers = clean_outbound_headers(resp.headers)
+    new_headers = clean_outbound_headers(resp.headers)
+    resp.headers.clear()
+    resp.headers.update(new_headers)
     return _parse_response(resp, target_url)

+ 7 - 3
src/sentry/testutils/helpers/response.py

@@ -1,11 +1,15 @@
+from __future__ import annotations
+
+from collections.abc import Iterable
+
 from django.http.response import StreamingHttpResponse
 
 
-def close_streaming_response(response: StreamingHttpResponse) -> None:
+def close_streaming_response(response: StreamingHttpResponse) -> bytes:
     """Exhausts the streamed file in a response.
 
     When the file is exahusted, this underlying file descriptor is closed
     avoiding a `ResourceWarning`.
     """
-    for _ in response.streaming_content:
-        pass
+    assert isinstance(response.streaming_content, Iterable)
+    return b"".join(response.streaming_content)

+ 11 - 10
tests/sentry/api_gateway/test_proxy.py

@@ -14,6 +14,7 @@ from sentry.testutils.helpers.api_gateway import (
     verify_request_body,
     verify_request_headers,
 )
+from sentry.testutils.helpers.response import close_streaming_response
 from sentry.utils import json
 
 
@@ -22,7 +23,7 @@ class ProxyTestCase(ApiGatewayTestCase):
     def test_simple(self):
         request = RequestFactory().get("http://sentry.io/get")
         resp = proxy_request(request, self.organization.slug)
-        resp_json = json.loads(b"".join(resp.streaming_content))
+        resp_json = json.loads(close_streaming_response(resp))
         assert resp.status_code == 200
         assert resp_json["proxy"]
         assert resp.has_header("test")
@@ -32,7 +33,7 @@ class ProxyTestCase(ApiGatewayTestCase):
 
         request = RequestFactory().get("http://sentry.io/error")
         resp = proxy_request(request, self.organization.slug)
-        resp_json = json.loads(b"".join(resp.streaming_content))
+        resp_json = json.loads(close_streaming_response(resp))
         assert resp.status_code == 400
         assert resp_json["proxy"]
         assert resp.has_header("test")
@@ -48,7 +49,7 @@ class ProxyTestCase(ApiGatewayTestCase):
         request = RequestFactory().get(f"http://sentry.io/echo?{query_param_str}")
 
         resp = proxy_request(request, self.organization.slug)
-        resp_json = json.loads(b"".join(resp.streaming_content))
+        resp_json = json.loads(close_streaming_response(resp))
 
         assert resp.status_code == 200
         # parse_qs returns everything in a list, including single arguments
@@ -74,7 +75,7 @@ class ProxyTestCase(ApiGatewayTestCase):
             "http://sentry.io/post", data=request_body, content_type="application/json"
         )
         resp = proxy_request(request, self.organization.slug)
-        resp_json = json.loads(b"".join(resp.streaming_content))
+        resp_json = json.loads(close_streaming_response(resp))
 
         assert resp.status_code == 200
         assert resp_json["proxy"]
@@ -94,7 +95,7 @@ class ProxyTestCase(ApiGatewayTestCase):
             "http://sentry.io/put", data=request_body, content_type="application/json"
         )
         resp = proxy_request(request, self.organization.slug)
-        resp_json = json.loads(b"".join(resp.streaming_content))
+        resp_json = json.loads(close_streaming_response(resp))
 
         assert resp.status_code == 200
         assert resp_json["proxy"]
@@ -114,7 +115,7 @@ class ProxyTestCase(ApiGatewayTestCase):
             "http://sentry.io/patch", data=request_body, content_type="application/json"
         )
         resp = proxy_request(request, self.organization.slug)
-        resp_json = json.loads(b"".join(resp.streaming_content))
+        resp_json = json.loads(close_streaming_response(resp))
 
         assert resp.status_code == 200
         assert resp_json["proxy"]
@@ -134,7 +135,7 @@ class ProxyTestCase(ApiGatewayTestCase):
             "http://sentry.io/head", data=request_body, content_type="application/json"
         )
         resp = proxy_request(request, self.organization.slug)
-        resp_json = json.loads(b"".join(resp.streaming_content))
+        resp_json = json.loads(close_streaming_response(resp))
 
         assert resp.status_code == 200
         assert resp_json["proxy"]
@@ -154,7 +155,7 @@ class ProxyTestCase(ApiGatewayTestCase):
             "http://sentry.io/delete", data=request_body, content_type="application/json"
         )
         resp = proxy_request(request, self.organization.slug)
-        resp_json = json.loads(b"".join(resp.streaming_content))
+        resp_json = json.loads(close_streaming_response(resp))
 
         assert resp.status_code == 200
         assert resp_json["proxy"]
@@ -179,7 +180,7 @@ class ProxyTestCase(ApiGatewayTestCase):
             "http://sentry.io/post", data=request_body, format="multipart"
         )
         resp = proxy_request(request, self.organization.slug)
-        resp_json = json.loads(b"".join(resp.streaming_content))
+        resp_json = json.loads(close_streaming_response(resp))
 
         assert resp.status_code == 200
         assert resp_json["proxy"]
@@ -201,7 +202,7 @@ class ProxyTestCase(ApiGatewayTestCase):
             content_type="application/x-www-form-urlencoded",
         )
         resp = proxy_request(request, self.organization.slug)
-        resp_json = json.loads(b"".join(resp.streaming_content))
+        resp_json = json.loads(close_streaming_response(resp))
 
         assert resp.status_code == 200
         assert resp_json["proxy"]