Browse Source

ref: remove unused outfile argument from fetch_file (#65930)

best viewed with
[?w=1](https://github.com/getsentry/sentry/pull/65930/files?w=1)

<!-- Describe your PR here. -->
anthony sottile 1 year ago
parent
commit
605c24c358
2 changed files with 81 additions and 96 deletions
  1. 77 87
      src/sentry/http.py
  2. 4 9
      tests/sentry/test_http.py

+ 77 - 87
src/sentry/http.py

@@ -1,11 +1,11 @@
 from __future__ import annotations
 
+import contextlib
 import logging
 import time
 import warnings
-from collections import namedtuple
 from io import BytesIO
-from typing import Any
+from typing import Any, NamedTuple
 from urllib.parse import urlparse
 
 from django.conf import settings
@@ -31,8 +31,13 @@ logger = logging.getLogger(__name__)
 # fetched
 MAX_URL_LENGTH = 150
 
-# UrlResult.body **must** be bytes
-UrlResult = namedtuple("UrlResult", ["url", "headers", "body", "status", "encoding"])
+
+class UrlResult(NamedTuple):
+    url: str
+    headers: dict[str, str]
+    body: bytes
+    status: int
+    encoding: str | None
 
 
 class BadSource(Exception):
@@ -143,15 +148,14 @@ def lock_domain(url, error=None):
 
 
 def fetch_file(
-    url,
-    domain_lock_enabled=True,
-    outfile=None,
-    headers=None,
-    allow_redirects=True,
-    verify_ssl=False,
-    timeout=settings.SENTRY_SOURCE_FETCH_SOCKET_TIMEOUT,
-    **kwargs,
-):
+    url: str,
+    domain_lock_enabled: bool = True,
+    headers: dict[str, str] | None = None,
+    allow_redirects: bool = True,
+    verify_ssl: bool = False,
+    timeout: float = settings.SENTRY_SOURCE_FETCH_SOCKET_TIMEOUT,
+    **kwargs: Any,
+) -> UrlResult:
     """
     Pull down a URL, returning a UrlResult object.
     """
@@ -165,13 +169,13 @@ def fetch_file(
 
     logger.debug("Fetching %r from the internet", url)
 
-    with SafeSession() as http_session:
-        response = None
+    with contextlib.ExitStack() as ctx:
+        http_session = ctx.enter_context(SafeSession())
 
         try:
-            try:
-                start = time.time()
-                response = http_session.get(
+            start = time.monotonic()
+            response = ctx.enter_context(
+                http_session.get(
                     url,
                     allow_redirects=allow_redirects,
                     verify=verify_ssl,
@@ -180,74 +184,60 @@ def fetch_file(
                     stream=True,
                     **kwargs,
                 )
+            )
 
-                try:
-                    cl = int(response.headers["content-length"])
-                except (LookupError, ValueError):
-                    cl = 0
-                if cl > settings.SENTRY_SOURCE_FETCH_MAX_SIZE:
-                    raise OverflowError()
-
-                return_body = False
-                if outfile is None:
-                    outfile = BytesIO()
-                    return_body = True
-
+            try:
+                cl = int(response.headers["content-length"])
+            except (LookupError, ValueError):
                 cl = 0
-
-                # Only need to even attempt to read the response body if we
-                # got a 200 OK
-                if response.status_code == 200:
-                    for chunk in response.iter_content(16 * 1024):
-                        if time.time() - start > settings.SENTRY_SOURCE_FETCH_TIMEOUT:
-                            raise Timeout()
-                        outfile.write(chunk)
-                        cl += len(chunk)
-                        if cl > settings.SENTRY_SOURCE_FETCH_MAX_SIZE:
-                            raise OverflowError()
-
-            except Exception as exc:
-                logger.debug("Unable to fetch %r", url, exc_info=True)
-                if isinstance(exc, RestrictedIPAddress):
-                    error: dict[str, Any] = {"type": EventError.RESTRICTED_IP}
-                elif isinstance(exc, SuspiciousOperation):
-                    error = {"type": EventError.SECURITY_VIOLATION}
-                elif isinstance(exc, (Timeout, ReadTimeout)):
-                    error = {
-                        "type": EventError.FETCH_TIMEOUT,
-                        "timeout": settings.SENTRY_SOURCE_FETCH_TIMEOUT,
-                    }
-                elif isinstance(exc, OverflowError):
-                    error = {
-                        "type": EventError.FETCH_TOO_LARGE,
-                        # We want size in megabytes to format nicely
-                        "max_size": float(settings.SENTRY_SOURCE_FETCH_MAX_SIZE) / 1024 / 1024,
-                    }
-                elif isinstance(exc, RequestException):
-                    error = {
-                        "type": EventError.FETCH_GENERIC_ERROR,
-                        "value": f"{type(exc)}",
-                    }
-                else:
-                    logger.exception(str(exc))
-                    error = {"type": EventError.UNKNOWN_ERROR}
-
-                # TODO(dcramer): we want to be less aggressive on disabling domains
-                if domain_lock_enabled:
-                    lock_domain(url, error)
-                raise CannotFetch(error)
-
-            headers = {k.lower(): v for k, v in response.headers.items()}
-            encoding = response.encoding
-
-            body = None
-            if return_body:
-                body = outfile.getvalue()
-                outfile.close()  # we only want to close StringIO
-
-            result = (headers, body, response.status_code, encoding)
-        finally:
-            if response is not None:
-                response.close()
-
-        return UrlResult(url, result[0], result[1], result[2], result[3])
+            if cl > settings.SENTRY_SOURCE_FETCH_MAX_SIZE:
+                raise OverflowError()
+
+            outfile = BytesIO()
+
+            cl = 0
+
+            # Only need to even attempt to read the response body if we
+            # got a 200 OK
+            if response.status_code == 200:
+                for chunk in response.iter_content(16 * 1024):
+                    if time.monotonic() - start > settings.SENTRY_SOURCE_FETCH_TIMEOUT:
+                        raise Timeout()
+                    outfile.write(chunk)
+                    cl += len(chunk)
+                    if cl > settings.SENTRY_SOURCE_FETCH_MAX_SIZE:
+                        raise OverflowError()
+
+        except Exception as exc:
+            logger.debug("Unable to fetch %r", url, exc_info=True)
+            if isinstance(exc, RestrictedIPAddress):
+                error: dict[str, Any] = {"type": EventError.RESTRICTED_IP}
+            elif isinstance(exc, SuspiciousOperation):
+                error = {"type": EventError.SECURITY_VIOLATION}
+            elif isinstance(exc, (Timeout, ReadTimeout)):
+                error = {
+                    "type": EventError.FETCH_TIMEOUT,
+                    "timeout": settings.SENTRY_SOURCE_FETCH_TIMEOUT,
+                }
+            elif isinstance(exc, OverflowError):
+                error = {
+                    "type": EventError.FETCH_TOO_LARGE,
+                    # We want size in megabytes to format nicely
+                    "max_size": float(settings.SENTRY_SOURCE_FETCH_MAX_SIZE) / 1024 / 1024,
+                }
+            elif isinstance(exc, RequestException):
+                error = {
+                    "type": EventError.FETCH_GENERIC_ERROR,
+                    "value": f"{type(exc)}",
+                }
+            else:
+                logger.exception(str(exc))
+                error = {"type": EventError.UNKNOWN_ERROR}
+
+            # TODO(dcramer): we want to be less aggressive on disabling domains
+            if domain_lock_enabled:
+                lock_domain(url, error)
+            raise CannotFetch(error)
+
+    headers = {k.lower(): v for k, v in response.headers.items()}
+    return UrlResult(url, headers, outfile.getvalue(), response.status_code, response.encoding)

+ 4 - 9
tests/sentry/test_http.py

@@ -1,4 +1,3 @@
-import io
 import platform
 from unittest.mock import patch
 
@@ -72,10 +71,8 @@ def test_fetch_file():
         responses.GET, "http://example.com", body="foo bar", content_type="application/json"
     )
 
-    temp = io.BytesIO()
-    result = http.fetch_file(url="http://example.com", domain_lock_enabled=False, outfile=temp)
-    assert result.body is None
-    assert temp.getvalue() == b"foo bar"
+    result = http.fetch_file(url="http://example.com", domain_lock_enabled=False)
+    assert result.body == b"foo bar"
 
 
 @responses.activate
@@ -89,7 +86,5 @@ def test_fetch_file_brotli():
         adding_headers={"Content-Encoding": "br"},
     )
 
-    temp = io.BytesIO()
-    result = http.fetch_file(url="http://example.com", domain_lock_enabled=False, outfile=temp)
-    assert result.body is None
-    assert temp.getvalue() == b"foo bar"
+    result = http.fetch_file(url="http://example.com", domain_lock_enabled=False)
+    assert result.body == b"foo bar"