Browse Source

ref: --disallow-any-generics for utils.redis (#68112)

<!-- Describe your PR here. -->
anthony sottile 11 months ago
parent
commit
92a1a02ff3

+ 1 - 0
pyproject.toml

@@ -641,5 +641,6 @@ module = [
     "tests.sentry.tasks.test_on_demand_metrics",
     "tests.sentry.relay.config.test_metric_extraction",
 ]
+disallow_any_generics = true
 disallow_untyped_defs = true
 # end: stronger typing

+ 1 - 1
src/sentry/cache/redis.py

@@ -67,5 +67,5 @@ RedisCache = RbCache
 class RedisClusterCache(CommonRedisCache):
     def __init__(self, cluster_id: str, **options: object) -> None:
         client = redis_clusters.get(cluster_id)
-        raw_client = redis_clusters.get(cluster_id, decode_responses=False)
+        raw_client = redis_clusters.get_binary(cluster_id)
         super().__init__(client=client, raw_client=raw_client, **options)

+ 9 - 6
src/sentry/monitors/clock_dispatch.py

@@ -18,6 +18,13 @@ MONITOR_TASKS_LAST_TRIGGERED_KEY = "sentry.monitors.last_tasks_ts"
 MONITOR_TASKS_PARTITION_CLOCKS = "sentry.monitors.partition_clocks"
 
 
+def _int_or_none(s: str | None) -> int | None:
+    if s is None:
+        return None
+    else:
+        return int(s)
+
+
 def _dispatch_tasks(ts: datetime):
     """
     Dispatch monitor tasks triggered by the consumer clock.
@@ -74,9 +81,7 @@ def try_monitor_tasks_trigger(ts: datetime, partition: int):
     # timestamp. Use `int()` to keep the timestamp (score) as an int
     slowest_part_ts = int(slowest_partitions[0][1])
 
-    precheck_last_ts = redis_client.get(MONITOR_TASKS_LAST_TRIGGERED_KEY)
-    if precheck_last_ts is not None:
-        precheck_last_ts = int(precheck_last_ts)
+    precheck_last_ts = _int_or_none(redis_client.get(MONITOR_TASKS_LAST_TRIGGERED_KEY))
 
     # If we have the same or an older timestamp from the most recent tick there
     # is nothing to do, we've already handled this tick.
@@ -90,9 +95,7 @@ def try_monitor_tasks_trigger(ts: datetime, partition: int):
 
     # GETSET is atomic. This is critical to avoid another consumer also
     # processing the same tick.
-    last_ts = redis_client.getset(MONITOR_TASKS_LAST_TRIGGERED_KEY, slowest_part_ts)
-    if last_ts is not None:
-        last_ts = int(last_ts)
+    last_ts = _int_or_none(redis_client.getset(MONITOR_TASKS_LAST_TRIGGERED_KEY, slowest_part_ts))
 
     # Another consumer already handled the tick if the first LAST_TRIGGERED
     # timestamp we got is different from the one we just got from the GETSET.

+ 5 - 5
src/sentry/relay/projectconfig_cache/redis.py

@@ -15,10 +15,10 @@ logger = logging.getLogger(__name__)
 class RedisProjectConfigCache(ProjectConfigCache):
     def __init__(self, **options):
         cluster_key = options.get("cluster", "default")
-        self.cluster = redis.redis_clusters.get(cluster_key, decode_responses=False)
+        self.cluster = redis.redis_clusters.get_binary(cluster_key)
 
         read_cluster_key = options.get("read_cluster", cluster_key)
-        self.cluster_read = redis.redis_clusters.get(read_cluster_key, decode_responses=False)
+        self.cluster_read = redis.redis_clusters.get_binary(read_cluster_key)
 
         super().__init__(**options)
 
@@ -57,10 +57,10 @@ class RedisProjectConfigCache(ProjectConfigCache):
         )
 
     def get(self, public_key):
-        rv = self.cluster_read.get(self.__get_redis_key(public_key))
-        if rv is not None:
+        rv_b = self.cluster_read.get(self.__get_redis_key(public_key))
+        if rv_b is not None:
             try:
-                rv = zstandard.decompress(rv).decode()
+                rv = zstandard.decompress(rv_b).decode()
             except (TypeError, zstandard.ZstdError):
                 # assume raw json
                 pass

+ 2 - 2
src/sentry/spans/buffer/redis.py

@@ -9,8 +9,8 @@ from sentry.utils import json, redis
 SEGMENT_TTL = 5 * 60  # 5 min TTL in seconds
 
 
-def get_redis_client() -> RedisCluster | StrictRedis:
-    return redis.redis_clusters.get(settings.SENTRY_SPAN_BUFFER_CLUSTER, decode_responses=False)
+def get_redis_client() -> RedisCluster[bytes] | StrictRedis[bytes]:
+    return redis.redis_clusters.get_binary(settings.SENTRY_SPAN_BUFFER_CLUSTER)
 
 
 def get_segment_key(project_id: str | int, segment_id: str) -> str:

+ 9 - 4
src/sentry/utils/locking/backends/redis.py

@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from typing import Any
 from uuid import uuid4
 
@@ -14,7 +16,7 @@ delete_lock = redis.load_redis_script("utils/locking/delete_lock.lua")
 class BaseRedisLockBackend(LockBackend):
     def __init__(
         self,
-        cluster: RedisCluster | rb.Cluster | StrictRedis,
+        cluster: rb.Cluster | RedisCluster[str] | StrictRedis[str],
         prefix: str = "l:",
         uuid: str | None = None,
     ):
@@ -77,10 +79,13 @@ class RedisBlasterLockBackend(BaseRedisLockBackend):
 
 
 class RedisClusterLockBackend(BaseRedisLockBackend):
-    cluster: RedisCluster | StrictRedis
+    cluster: RedisCluster[str] | StrictRedis[str]
 
     def __init__(
-        self, cluster: str | RedisCluster | StrictRedis, prefix: str = "l:", uuid: str | None = None
+        self,
+        cluster: str | RedisCluster[str] | StrictRedis[str],
+        prefix: str = "l:",
+        uuid: str | None = None,
     ):
         if isinstance(cluster, str):
             cluster = redis.redis_clusters.get(cluster)
@@ -88,7 +93,7 @@ class RedisClusterLockBackend(BaseRedisLockBackend):
 
     def get_client(
         self, key: str, routing_key: int | str | None = None
-    ) -> RedisCluster | StrictRedis:
+    ) -> RedisCluster[str] | StrictRedis[str]:
         return self.cluster
 
 

+ 77 - 44
src/sentry/utils/redis.py

@@ -4,7 +4,7 @@ import importlib.resources
 import logging
 from copy import deepcopy
 from threading import Lock
-from typing import Any, TypeGuard
+from typing import Any, Literal, TypeGuard, TypeVar, overload
 
 import rb
 from django.utils.functional import SimpleLazyObject
@@ -23,6 +23,8 @@ from sentry.utils.warnings import DeprecatedSettingWarning
 
 logger = logging.getLogger(__name__)
 
+T = TypeVar("T", str, bytes)
+
 
 _REDIS_DEFAULT_CLIENT_ARGS = {
     # 3 seconds default socket and socket connection timeout avoids blocking on socket till the
@@ -75,10 +77,7 @@ class RBClusterManager:
 
         return rb.Cluster(**config, pool_cls=_shared_pool)
 
-    def get(self, key: str, *, decode_responses: bool = True) -> rb.Cluster:
-        if not decode_responses:
-            raise NotImplementedError("rb does not support decode_responses")
-
+    def get(self, key: str) -> rb.Cluster:
         try:
             return self._clusters[key]
         except KeyError:
@@ -95,15 +94,58 @@ class RBClusterManager:
         return ret
 
 
-class _RedisCluster:
-    def supports(self, config: dict[str, Any]) -> bool:
-        # _RedisCluster supports two configurations:
+class RedisClusterManager:
+    def __init__(self, options_manager: OptionsManager) -> None:
+        self._clusters_bytes: dict[str, RedisCluster[bytes] | StrictRedis[bytes]] = {}
+        self._clusters_str: dict[str, RedisCluster[str] | StrictRedis[str]] = {}
+        self._options_manager = options_manager
+
+    def _supports(self, config: dict[str, Any]) -> bool:
+        # supports two configurations:
         #  * Explicitly configured with is_redis_cluster. This mode is for real redis-cluster.
         #  * No is_redis_cluster, but only 1 host. This represents a singular node Redis running
         #    in non-cluster mode.
         return config.get("is_redis_cluster", False) or len(config.get("hosts", [])) == 1
 
-    def factory(
+    def _cfg(self, key: str) -> dict[str, Any]:
+        # TODO: This would probably be safer with a lock, but I'm not sure
+        # that it's necessary.
+        cfg = self._options_manager.get("redis.clusters", {}).get(key)
+        if cfg is None:
+            raise KeyError(f"Invalid cluster name: {key}")
+
+        if not self._supports(cfg):
+            raise KeyError("Invalid cluster type, expected redis cluster")
+
+        return cfg
+
+    @overload
+    def _factory(
+        self,
+        *,
+        decode_responses: Literal[False],
+        is_redis_cluster: bool = False,
+        readonly_mode: bool = False,
+        hosts: list[dict[Any, Any]] | dict[Any, Any] | None = None,
+        client_args: dict[str, Any] | None = None,
+        **config: Any,
+    ) -> RedisCluster[bytes] | StrictRedis[bytes]:
+        ...
+
+    @overload
+    def _factory(
+        self,
+        *,
+        decode_responses: Literal[True],
+        is_redis_cluster: bool = False,
+        readonly_mode: bool = False,
+        hosts: list[dict[Any, Any]] | dict[Any, Any] | None = None,
+        client_args: dict[str, Any] | None = None,
+        **config: Any,
+    ) -> RedisCluster[str] | StrictRedis[str]:
+        ...
+
+    def _factory(
         self,
         *,
         decode_responses: bool,
@@ -112,7 +154,7 @@ class _RedisCluster:
         hosts: list[dict[Any, Any]] | dict[Any, Any] | None = None,
         client_args: dict[str, Any] | None = None,
         **config: Any,
-    ) -> RetryingRedisCluster | FailoverRedis:
+    ) -> RedisCluster[bytes] | StrictRedis[bytes] | RedisCluster[str] | StrictRedis[str]:
         # StrictRedisCluster expects a list of { host, port } dicts. Coerce the
         # configuration into the correct format if necessary.
         if not hosts:
@@ -129,7 +171,9 @@ class _RedisCluster:
 
         # Redis cluster does not wait to attempt to connect. We'd prefer to not
         # make TCP connections on boot. Wrap the client in a lazy proxy object.
-        def cluster_factory() -> RetryingRedisCluster | FailoverRedis:
+        def cluster_factory() -> RedisCluster[
+            bytes
+        ] | StrictRedis[bytes] | RedisCluster[str] | StrictRedis[str]:
             if is_redis_cluster:
                 return RetryingRedisCluster(
                     # Intentionally copy hosts here because redis-cluster-py
@@ -156,39 +200,26 @@ class _RedisCluster:
         # losing some type safety: SimpleLazyObject acts like the underlying type
         return SimpleLazyObject(cluster_factory)  # type: ignore[return-value]
 
-    def __str__(self) -> str:
-        return "Redis Cluster"
-
-
-class RedisClusterManager:
-    def __init__(self, options_manager: OptionsManager) -> None:
-        self.__clusters: dict[tuple[str, bool], RedisCluster | StrictRedis] = {}
-        self.__options_manager = options_manager
-        self.__cluster_type = _RedisCluster()
-
-    def get(self, key: str, *, decode_responses: bool = True) -> RedisCluster | StrictRedis:
-        cache_key = (key, decode_responses)
+    def get(self, key: str) -> RedisCluster[str] | StrictRedis[str]:
         try:
-            return self.__clusters[cache_key]
+            return self._clusters_str[key]
         except KeyError:
             pass
 
         # Do not access attributes of the `cluster` object to prevent
-        # setup/init of lazy objects. The _RedisCluster type will try to
-        # connect to the cluster during initialization.
-
-        # TODO: This would probably be safer with a lock, but I'm not sure
-        # that it's necessary.
-        cfg = self.__options_manager.get("redis.clusters", {}).get(key)
-        if cfg is None:
-            raise KeyError(f"Invalid cluster name: {key}")
+        # setup/init of lazy objects.
+        ret = self._clusters_str[key] = self._factory(**self._cfg(key), decode_responses=True)
+        return ret
 
-        if not self.__cluster_type.supports(cfg):
-            raise KeyError(f"Invalid cluster type, expected: {self.__cluster_type}")
+    def get_binary(self, key: str) -> RedisCluster[bytes] | StrictRedis[bytes]:
+        try:
+            return self._clusters_bytes[key]
+        except KeyError:
+            pass
 
-        ret = self.__clusters[cache_key] = self.__cluster_type.factory(
-            **cfg, decode_responses=decode_responses
-        )
+        # Do not access attributes of the `cluster` object to prevent
+        # setup/init of lazy objects.
+        ret = self._clusters_bytes[key] = self._factory(**self._cfg(key), decode_responses=False)
         return ret
 
 
@@ -241,7 +272,7 @@ def get_cluster_from_options(
 
 def get_dynamic_cluster_from_options(
     setting: str, config: dict[str, Any]
-) -> tuple[bool, RedisCluster | StrictRedis | rb.Cluster, dict[str, Any]]:
+) -> tuple[bool, RedisCluster[str] | StrictRedis[str] | rb.Cluster, dict[str, Any]]:
     cluster_name = config.get("cluster", "default")
     cluster_opts: dict[str, Any] | None = options.default_manager.get("redis.clusters").get(
         cluster_name
@@ -256,8 +287,8 @@ def get_dynamic_cluster_from_options(
 
 
 def get_cluster_routing_client(
-    cluster: RedisCluster | rb.Cluster, is_redis_cluster: bool
-) -> RedisCluster | rb.RoutingClient:
+    cluster: RedisCluster[T] | rb.Cluster, is_redis_cluster: bool
+) -> RedisCluster[T] | rb.RoutingClient:
     if is_instance_redis_cluster(cluster, is_redis_cluster):
         return cluster
     elif is_instance_rb_cluster(cluster, is_redis_cluster):
@@ -267,18 +298,20 @@ def get_cluster_routing_client(
 
 
 def is_instance_redis_cluster(
-    val: rb.Cluster | RedisCluster, is_redis_cluster: bool
-) -> TypeGuard[RedisCluster]:
+    val: rb.Cluster | RedisCluster[str], is_redis_cluster: bool
+) -> TypeGuard[RedisCluster[str]]:
     return is_redis_cluster
 
 
 def is_instance_rb_cluster(
-    val: rb.Cluster | RedisCluster, is_redis_cluster: bool
+    val: rb.Cluster | RedisCluster[str], is_redis_cluster: bool
 ) -> TypeGuard[rb.Cluster]:
     return not is_redis_cluster
 
 
-def validate_dynamic_cluster(is_redis_cluster: bool, cluster: rb.Cluster | RedisCluster) -> None:
+def validate_dynamic_cluster(
+    is_redis_cluster: bool, cluster: rb.Cluster | RedisCluster[str]
+) -> None:
     try:
         if is_instance_redis_cluster(cluster, is_redis_cluster):
             cluster.ping()

+ 12 - 7
tests/sentry/attachments/test_redis.py

@@ -3,8 +3,9 @@ from unittest import mock
 
 import pytest
 
+from sentry.attachments.redis import RedisClusterAttachmentCache
 from sentry.cache.redis import RedisClusterCache
-from sentry.utils.imports import import_string
+from sentry.utils.redis import redis_clusters
 
 KEY_FMT = "c:1:%s"
 
@@ -24,12 +25,16 @@ def mock_client():
 
 @pytest.fixture
 def mocked_attachment_cache(request, mock_client):
-    with mock.patch(
-        "sentry.utils.redis.redis_clusters.get", return_value=mock_client
-    ) as cluster_get:
-        attachment_cache = import_string("sentry.attachments.redis.RedisClusterAttachmentCache")()
-        cluster_get.assert_any_call("rc-short")
-        assert isinstance(attachment_cache.inner, RedisClusterCache)
+    with (
+        mock.patch.object(redis_clusters, "get", return_value=mock_client) as cluster_get,
+        mock.patch.object(
+            redis_clusters, "get_binary", return_value=mock_client
+        ) as cluster_get_binary,
+    ):
+        attachment_cache = RedisClusterAttachmentCache()
+    cluster_get.assert_called_once_with("rc-short")
+    cluster_get_binary.assert_called_once_with("rc-short")
+    assert isinstance(attachment_cache.inner, RedisClusterCache)
 
     assert attachment_cache.inner._text_client is mock_client
     yield attachment_cache

+ 10 - 6
tests/sentry/buffer/test_redis.py

@@ -19,6 +19,14 @@ from sentry.utils.redis import (
 )
 
 
+def _hgetall_decode_keys(client, key, is_redis_cluster):
+    ret = client.hgetall(key)
+    if not is_redis_cluster:
+        return {k.decode(): v for k, v in ret.items()}
+    else:
+        return ret
+
+
 class TestRedisBuffer:
     @pytest.fixture(params=["cluster", "blaster"])
     def buffer(self, set_sentry_option, request):
@@ -146,9 +154,7 @@ class TestRedisBuffer:
         filters = {"pk": 1, "datetime": now}
         key = self.buf._make_key(model, filters=filters)
         self.buf.incr(model, columns, filters, extra={"foo": "bar", "datetime": now})
-        result = client.hgetall(key)
-        if not self.buf.is_redis_cluster:
-            result = {k.decode(): v for k, v in result.items()}
+        result = _hgetall_decode_keys(client, key, self.buf.is_redis_cluster)
 
         f = result.pop("f")
         if self.buf.is_redis_cluster:
@@ -176,9 +182,7 @@ class TestRedisBuffer:
         else:
             assert pending == [key.encode("utf-8")]
         self.buf.incr(model, columns, filters, extra={"foo": "baz", "datetime": now})
-        result = client.hgetall(key)
-        if not self.buf.is_redis_cluster:
-            result = {k.decode(): v for k, v in result.items()}
+        result = _hgetall_decode_keys(client, key, self.buf.is_redis_cluster)
         f = result.pop("f")
         assert load_values(f) == {"pk": 1, "datetime": now}
         assert load_value(result.pop("e+datetime")) == now