Browse Source

ref: stronger typing for a few more utils modules (#70659)

<!-- Describe your PR here. -->
anthony sottile 10 months ago
parent
commit
1deb0c9e69

+ 5 - 0
pyproject.toml

@@ -617,17 +617,22 @@ module = [
     "sentry.utils.env",
     "sentry.utils.event",
     "sentry.utils.files",
+    "sentry.utils.geo",
+    "sentry.utils.imports",
     "sentry.utils.iterators",
     "sentry.utils.javascript",
     "sentry.utils.locking.backends.redis",
     "sentry.utils.migrations",
+    "sentry.utils.numbers",
     "sentry.utils.otp",
     "sentry.utils.performance_issues.detectors.*",
     "sentry.utils.performance_issues.performance_detection",
+    "sentry.utils.pubsub",
     "sentry.utils.redis",
     "sentry.utils.redis_metrics",
     "sentry.utils.sentry_apps.*",
     "sentry.utils.sms",
+    "sentry.utils.urls",
     "sentry.utils.uwsgi",
     "sentry.utils.zip",
     "sentry_plugins.base",

+ 3 - 3
src/sentry/utils/geo.py

@@ -10,8 +10,8 @@ logger = logging.getLogger(__name__)
 
 
 # default is no-op
-def geo_by_addr(ip):
-    pass
+def geo_by_addr(ip: str) -> dict[str, Any] | None:
+    return None
 
 
 rust_geoip: None | GeoIpLookup = None
@@ -48,7 +48,7 @@ def _init_geoip() -> None:
     geo_by_addr = _geo_by_addr
 
 
-def _init_geoip_rust():
+def _init_geoip_rust() -> None:
     global rust_geoip
 
     from sentry_relay.processing import GeoIpLookup

+ 6 - 3
src/sentry/utils/imports.py

@@ -1,5 +1,8 @@
-class ModuleProxyCache(dict):
-    def __missing__(self, key):
+from typing import Any
+
+
+class ModuleProxyCache(dict[str, object]):
+    def __missing__(self, key: str) -> object:
         if "." not in key:
             return __import__(key)
 
@@ -17,7 +20,7 @@ class ModuleProxyCache(dict):
 _cache = ModuleProxyCache()
 
 
-def import_string(path: str):
+def import_string(path: str) -> Any:
     """
     Path must be module.path.ClassName
 

+ 5 - 3
src/sentry/utils/numbers.py

@@ -62,7 +62,9 @@ def base36_decode(s: str) -> int:
 DEFAULT_UNITS = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
 
 
-def format_bytes(number, units=DEFAULT_UNITS, decimal_places=2):
+def format_bytes(
+    number: float, units: tuple[str, ...] = DEFAULT_UNITS, decimal_places: int = 2
+) -> str:
     block = 1024.0
     if number < block:
         return f"{number} {units[0]}"
@@ -72,7 +74,7 @@ def format_bytes(number, units=DEFAULT_UNITS, decimal_places=2):
     while number >= block and u < max_unit:
         number /= block
         u += 1
-    return ("{:.%df} {}" % (decimal_places,)).format(number, units[u])
+    return f"{number:.{decimal_places}f} {units[u]}"
 
 
 def format_grouped_length(length: int, steps: list[int] | None = None) -> str:
@@ -92,7 +94,7 @@ def format_grouped_length(length: int, steps: list[int] | None = None) -> str:
     return f">{steps[-1]}"
 
 
-def validate_bigint(value):
+def validate_bigint(value: object) -> bool:
     return isinstance(value, int) and value >= 0 and value.bit_length() <= 63
 
 

+ 4 - 2
src/sentry/utils/pubsub.py

@@ -1,12 +1,14 @@
+from typing import Any
+
 from confluent_kafka import Producer
 
 
 class KafkaPublisher:
-    def __init__(self, connection, asynchronous=True):
+    def __init__(self, connection: dict[str, Any], asynchronous: bool = True) -> None:
         self.producer = Producer(connection or {})
         self.asynchronous = asynchronous
 
-    def publish(self, channel, value, key=None):
+    def publish(self, channel: str, value: str, key: str | None = None) -> None:
         self.producer.produce(topic=channel, value=value, key=key)
         if self.asynchronous:
             self.producer.poll(0)

+ 3 - 3
src/sentry/utils/urls.py

@@ -1,11 +1,11 @@
 import re
-from collections.abc import MutableMapping, Sequence
+from collections.abc import Mapping, MutableMapping, Sequence
 from urllib.parse import parse_qs, parse_qsl, urlencode, urljoin, urlparse, urlsplit, urlunparse
 
 _scheme_re = re.compile(r"^([a-zA-Z0-9-+]+://)(.*)$")
 
 
-def non_standard_url_join(base, to_join):
+def non_standard_url_join(base: str, to_join: str | None) -> str:
     """A version of url join that can deal with unknown protocols."""
     # joins to an absolute url are willing by default
     if not to_join:
@@ -31,7 +31,7 @@ def non_standard_url_join(base, to_join):
     return rv
 
 
-def add_params_to_url(url, params):
+def add_params_to_url(url: str, params: Mapping[str, str]) -> str:
     url_parts = urlparse(url)
     query = dict(parse_qsl(url_parts.query))
     query.update(params)

+ 1 - 1
tests/sentry/utils/test_numbers.py

@@ -287,7 +287,7 @@ def test_format_bytes():
     assert format_bytes(3000000000) == "2.79 GB"
     assert format_bytes(3000000000000) == "2.73 TB"
 
-    assert format_bytes(3000000000000, units=["B", "KB", "MB", "GB"]) == "2793.97 GB"
+    assert format_bytes(3000000000000, units=("B", "KB", "MB", "GB")) == "2793.97 GB"
 
 
 def test_format_grouped_length():