Browse Source

ref(typing): stronger types for buffer/redis (#65920)

Add stronger types for buffer/redis and buffer/base
Yagiz Nizipli 1 year ago
parent
commit
d3b1c9447d
4 changed files with 96 additions and 43 deletions
  1. 9 6
      pyproject.toml
  2. 28 5
      src/sentry/buffer/base.py
  3. 57 31
      src/sentry/buffer/redis.py
  4. 2 1
      tests/sentry/buffer/test_base.py

+ 9 - 6
pyproject.toml

@@ -628,12 +628,15 @@ disable_error_code = [
 # beginning: stronger typing
 [[tool.mypy.overrides]]
 module = [
-  "sentry.utils.redis",
-  "sentry.tasks.on_demand_metrics",
-  "sentry.relay.config.metric_extraction",
-  "sentry.snuba.metrics.extraction",
-  "tests.sentry.tasks.test_on_demand_metrics",
-  "tests.sentry.relay.config.test_metric_extraction",
+    "sentry.buffer.base",
+    "sentry.buffer.redis",
+    "sentry.utils.redis",
+    "sentry.utils.redis_metrics",
+    "sentry.tasks.on_demand_metrics",
+    "sentry.relay.config.metric_extraction",
+    "sentry.snuba.metrics.extraction",
+    "tests.sentry.tasks.test_on_demand_metrics",
+    "tests.sentry.relay.config.test_metric_extraction",
 ]
 disallow_untyped_defs = true
 # end: stronger typing

+ 28 - 5
src/sentry/buffer/base.py

@@ -1,5 +1,9 @@
+from datetime import date, datetime
+from typing import Any
+
 from django.db.models import F
 
+from sentry.db import models
 from sentry.signals import buffer_incr_complete
 from sentry.tasks.process_buffer import process_incr
 from sentry.utils.services import Service
@@ -21,13 +25,25 @@ class Buffer(Service):
 
     __all__ = ("get", "incr", "process", "process_pending", "validate")
 
-    def get(self, model, columns, filters):
+    def get(
+        self,
+        model: type[models.Model],
+        columns: list[str],
+        filters: dict[str, models.Model | str | int],
+    ) -> dict[str, int]:
         """
         We can't fetch values from Celery, so just assume buffer values are all 0 here.
         """
         return {col: 0 for col in columns}
 
-    def incr(self, model, columns, filters, extra=None, signal_only=None):
+    def incr(
+        self,
+        model: type[models.Model],
+        columns: dict[str, int],
+        filters: dict[str, models.Model | str | int],
+        extra: dict[str, Any] | None = None,
+        signal_only: bool | None = None,
+    ) -> None:
         """
         >>> incr(Group, columns={'times_seen': 1}, filters={'pk': group.pk})
         signal_only - added to indicate that `process` should only call the complete
@@ -45,10 +61,17 @@ class Buffer(Service):
             }
         )
 
-    def process_pending(self, partition=None):
-        return []
+    def process_pending(self, partition: int | None = None) -> None:
+        return
 
-    def process(self, model, columns, filters, extra=None, signal_only=None):
+    def process(
+        self,
+        model: type[models.Model],
+        columns: dict[Any, Any],
+        filters: dict[str, str | datetime | date | int | float],
+        extra: dict[str, Any] | None = None,
+        signal_only: bool | None = None,
+    ) -> None:
         from sentry.event_manager import ScoreClause
         from sentry.models.group import Group
 

+ 57 - 31
src/sentry/buffer/redis.py

@@ -5,13 +5,14 @@ import pickle
 import threading
 from datetime import date, datetime, timezone
 from time import time
+from typing import Any
 
 import rb
-from django.db import models
 from django.utils.encoding import force_bytes, force_str
 from rediscluster import RedisCluster
 
 from sentry.buffer.base import Buffer
+from sentry.db import models
 from sentry.tasks.process_buffer import process_incr, process_pending
 from sentry.utils import json, metrics
 from sentry.utils.compat import crc32
@@ -34,7 +35,7 @@ logger = logging.getLogger(__name__)
 _last_validation_log: float | None = None
 
 
-def _validate_json_roundtrip(value, model):
+def _validate_json_roundtrip(value: dict[str, Any], model: type[models.Model]) -> None:
     global _last_validation_log
 
     if _last_validation_log is None or _last_validation_log < time() - 10:
@@ -50,27 +51,27 @@ def _validate_json_roundtrip(value, model):
 
 
 class PendingBuffer:
-    def __init__(self, size):
+    def __init__(self, size: int):
         assert size > 0
-        self.buffer = [None] * size
+        self.buffer: list[str | None] = [None] * size
         self.size = size
         self.pointer = 0
 
-    def full(self):
+    def full(self) -> bool:
         return self.pointer == self.size
 
-    def empty(self):
+    def empty(self) -> bool:
         return self.pointer == 0
 
-    def append(self, item):
+    def append(self, item: str) -> None:
         assert not self.full()
         self.buffer[self.pointer] = item
         self.pointer += 1
 
-    def clear(self):
+    def clear(self) -> None:
         self.pointer = 0
 
-    def flush(self):
+    def flush(self) -> list[str | None]:
         rv = self.buffer[: self.pointer]
         self.clear()
         return rv
@@ -80,7 +81,7 @@ class RedisBuffer(Buffer):
     key_expire = 60 * 60  # 1 hour
     pending_key = "b:p"
 
-    def __init__(self, pending_partitions=1, incr_batch_size=2, **options):
+    def __init__(self, pending_partitions: int = 1, incr_batch_size: int = 2, **options: object):
         self.is_redis_cluster, self.cluster, options = get_dynamic_cluster_from_options(
             "SENTRY_BUFFER_OPTIONS", options
         )
@@ -97,24 +98,26 @@ class RedisBuffer(Buffer):
         else:
             raise AssertionError("unreachable")
 
-    def validate(self):
+    def validate(self) -> None:
         validate_dynamic_cluster(self.is_redis_cluster, self.cluster)
 
-    def _coerce_val(self, value):
+    def _coerce_val(self, value: models.Model | str | int) -> bytes:
         if isinstance(value, models.Model):
             value = value.pk
         return force_bytes(value, errors="replace")
 
-    def _make_key(self, model, filters):
+    def _make_key(
+        self, model: type[models.Model], filters: dict[str, models.Model | str | int]
+    ) -> str:
         """
         Returns a Redis-compatible key for the model given filters.
         """
         md5 = md5_text(
-            "&".join(f"{k}={self._coerce_val(v)}" for k, v in sorted(filters.items()))
+            "&".join(f"{k}={self._coerce_val(v)!r}" for k, v in sorted(filters.items()))
         ).hexdigest()
         return f"b:k:{model._meta}:{md5}"
 
-    def _make_pending_key(self, partition=None):
+    def _make_pending_key(self, partition: int | None = None) -> str:
         """
         Returns the key to be used for the pending buffer.
         When partitioning is enabled, there is a key for each
@@ -125,7 +128,7 @@ class RedisBuffer(Buffer):
         assert partition >= 0
         return "%s:%d" % (self.pending_key, partition)
 
-    def _make_pending_key_from_key(self, key):
+    def _make_pending_key_from_key(self, key: str) -> str:
         """
         Return the pending_key for a given key. This is used
         to route a key into the correct pending buffer. If partitioning
@@ -135,18 +138,18 @@ class RedisBuffer(Buffer):
             return self.pending_key
         return self._make_pending_key(crc32(key) % self.pending_partitions)
 
-    def _make_lock_key(self, key):
+    def _make_lock_key(self, key: str) -> str:
         return f"l:{key}"
 
     @classmethod
-    def _dump_values(cls, values):
+    def _dump_values(cls, values: dict[Any, Any]) -> dict[Any, tuple[str, str]]:
         result = {}
         for k, v in values.items():
             result[k] = cls._dump_value(v)
         return result
 
     @classmethod
-    def _dump_value(cls, value):
+    def _dump_value(cls, value: str | datetime | date | int | float) -> tuple[str, str]:
         if isinstance(value, str):
             type_ = "s"
         elif isinstance(value, datetime):
@@ -161,17 +164,19 @@ class RedisBuffer(Buffer):
             type_ = "f"
         else:
             raise TypeError(type(value))
-        return (type_, str(value))
+        return type_, str(value)
 
     @classmethod
-    def _load_values(cls, payload):
+    def _load_values(
+        cls, payload: dict[str, tuple[str, Any]]
+    ) -> dict[str, str | datetime | date | int | float]:
         result = {}
         for k, (t, v) in payload.items():
             result[k] = cls._load_value((t, v))
         return result
 
     @classmethod
-    def _load_value(cls, payload):
+    def _load_value(cls, payload: tuple[str, Any]) -> str | datetime | date | int | float:
         (type_, value) = payload
         if type_ == "s":
             return force_str(value)
@@ -186,7 +191,12 @@ class RedisBuffer(Buffer):
         else:
             raise TypeError(f"invalid type: {type_}")
 
-    def get(self, model, columns, filters):
+    def get(
+        self,
+        model: type[models.Model],
+        columns: list[str],
+        filters: dict[str, models.Model | str | int],
+    ) -> dict[str, int]:
         """
         Fetches buffered values for a model/filter. Passed columns must be integer columns.
         """
@@ -207,7 +217,15 @@ class RedisBuffer(Buffer):
             col: (int(results[i]) if results[i] is not None else 0) for i, col in enumerate(columns)
         }
 
-    def incr(self, model, columns, filters, extra=None, signal_only=None, return_incr_results=True):
+    def incr(
+        self,
+        model: type[models.Model],
+        columns: dict[str, int],
+        filters: dict[str, models.Model | str | int],
+        extra: dict[str, Any] | None = None,
+        signal_only: bool | None = None,
+        return_incr_results: bool = True,
+    ) -> None:
         """
         Increment the key by doing the following:
 
@@ -265,7 +283,7 @@ class RedisBuffer(Buffer):
             tags={"module": model.__module__, "model": model.__name__},
         )
 
-    def process_pending(self, partition=None):
+    def process_pending(self, partition: int | None = None) -> None:
         if partition is None and self.pending_partitions > 1:
             # If we're using partitions, this one task fans out into
             # N subtasks instead.
@@ -323,20 +341,28 @@ class RedisBuffer(Buffer):
         finally:
             client.delete(lock_key)
 
-    def process(self, key=None, batch_keys=None):
+    def process(self, key: str | None = None, batch_keys: list[str] | None = None) -> None:  # type: ignore[override]
         assert not (key is None and batch_keys is None)
         assert not (key is not None and batch_keys is not None)
 
         if key is not None:
             batch_keys = [key]
 
-        for key in batch_keys:
-            self._process_single_incr(key)
-
-    def _process(self, model, columns, filters, extra=None, signal_only=None):
+        if batch_keys is not None:
+            for key in batch_keys:
+                self._process_single_incr(key)
+
+    def _process(
+        self,
+        model: type[models.Model],
+        columns: dict[str, int],
+        filters: dict[str, str | datetime | date | int | float],
+        extra: dict[str, Any] | None = None,
+        signal_only: bool | None = None,
+    ) -> Any:
         return super().process(model, columns, filters, extra, signal_only)
 
-    def _process_single_incr(self, key):
+    def _process_single_incr(self, key: str) -> None:
         client = self.get_routing_client()
         lock_key = self._make_lock_key(key)
         # prevent a stampede due to the way we use celery etas + duplicate

+ 2 - 1
tests/sentry/buffer/test_base.py

@@ -4,6 +4,7 @@ from unittest import mock
 from django.utils import timezone
 
 from sentry.buffer.base import Buffer
+from sentry.db import models
 from sentry.models.group import Group
 from sentry.models.organization import Organization
 from sentry.models.project import Project
@@ -22,7 +23,7 @@ class BufferTest(TestCase):
     def test_incr_delays_task(self, process_incr):
         model = mock.Mock()
         columns = {"times_seen": 1}
-        filters = {"id": 1}
+        filters: dict[str, models.Model | str | int] = {"id": 1}
         self.buf.incr(model, columns, filters)
         kwargs = dict(model=model, columns=columns, filters=filters, extra=None, signal_only=None)
         process_incr.apply_async.assert_called_once_with(kwargs=kwargs)