Browse Source

ref: fix typing for sentry.features (#78557)

<!-- Describe your PR here. -->
anthony sottile 5 months ago
parent
commit
b0c9b0aee0

+ 0 - 3
pyproject.toml

@@ -205,8 +205,6 @@ module = [
     "sentry.db.router",
     "sentry.discover.endpoints.discover_key_transactions",
     "sentry.eventstore.models",
-    "sentry.features.handler",
-    "sentry.features.manager",
     "sentry.grouping.strategies.legacy",
     "sentry.identity.bitbucket.provider",
     "sentry.identity.github_enterprise.provider",
@@ -348,7 +346,6 @@ module = [
     "sentry.templatetags.sentry_plugins",
     "sentry.testutils.cases",
     "sentry.testutils.fixtures",
-    "sentry.testutils.helpers.features",
     "sentry.testutils.helpers.notifications",
     "sentry.testutils.helpers.slack",
     "sentry.utils.auth",

+ 18 - 9
src/sentry/features/handler.py

@@ -1,9 +1,7 @@
 from __future__ import annotations
 
-__all__ = ["FeatureHandler", "BatchFeatureHandler"]
-
 import abc
-from collections.abc import Mapping, MutableSet, Sequence
+from collections.abc import Sequence
 from typing import TYPE_CHECKING
 
 if TYPE_CHECKING:
@@ -17,6 +15,9 @@ if TYPE_CHECKING:
     from sentry.users.services.user import RpcUser
 
 
+__all__ = ["FeatureHandler", "BatchFeatureHandler"]
+
+
 class FeatureHandler:
     """
     Base class for defining custom logic for feature decisions.
@@ -28,7 +29,7 @@ class FeatureHandler:
     as we don't programatically release features in self-hosted.
     """
 
-    features: MutableSet[str] = set()
+    features: set[str] = set()
 
     def __call__(self, feature: Feature, actor: User) -> bool | None:
         if feature.name not in self.features:
@@ -45,7 +46,7 @@ class FeatureHandler:
     ) -> bool | None:
         raise NotImplementedError
 
-    def has_for_batch(self, batch: FeatureCheckBatch) -> Mapping[Project, bool | None]:
+    def has_for_batch(self, batch: FeatureCheckBatch) -> dict[Project, bool | None]:
         # If not overridden, iterate over objects in the batch individually.
         return {
             obj: self.has(feature, batch.actor)
@@ -60,7 +61,7 @@ class FeatureHandler:
         projects: Sequence[Project] | None = None,
         organization: Organization | None = None,
         batch: bool = True,
-    ) -> Mapping[str, Mapping[str, bool | None]] | None:
+    ) -> dict[str, dict[str, bool | None]] | None:
         raise NotImplementedError
 
 
@@ -80,13 +81,21 @@ class BatchFeatureHandler(FeatureHandler):
 
     @abc.abstractmethod
     def _check_for_batch(
-        self, feature_name: str, entity: Organization | User, actor: User
+        self,
+        feature_name: str,
+        entity: Organization | User | None,
+        actor: User | RpcUser | AnonymousUser | None,
     ) -> bool | None:
         raise NotImplementedError
 
-    def has(self, feature: Feature, actor: User, skip_entity: bool | None = False) -> bool | None:
+    def has(
+        self,
+        feature: Feature,
+        actor: User | RpcUser | AnonymousUser | None,
+        skip_entity: bool | None = False,
+    ) -> bool | None:
         return self._check_for_batch(feature.name, feature.get_subject(), actor)
 
-    def has_for_batch(self, batch: FeatureCheckBatch) -> Mapping[Project, bool | None]:
+    def has_for_batch(self, batch: FeatureCheckBatch) -> dict[Project, bool | None]:
         flag = self._check_for_batch(batch.feature_name, batch.subject, batch.actor)
         return {obj: flag for obj in batch.objects}

+ 16 - 16
src/sentry/features/manager.py

@@ -6,7 +6,7 @@ __all__ = ["FeatureManager"]
 
 import abc
 from collections import defaultdict
-from collections.abc import Iterable, Mapping, MutableMapping, MutableSet, Sequence
+from collections.abc import Iterable, Sequence
 from typing import TYPE_CHECKING, Any
 
 import sentry_sdk
@@ -44,7 +44,7 @@ class RegisteredFeatureManager:
     """
 
     def __init__(self) -> None:
-        self._handler_registry: MutableMapping[str, list[FeatureHandler]] = defaultdict(list)
+        self._handler_registry: dict[str, list[FeatureHandler]] = defaultdict(list)
 
     def add_handler(self, handler: FeatureHandler) -> None:
         """
@@ -78,7 +78,7 @@ class RegisteredFeatureManager:
         organization: Organization,
         objects: Sequence[Project],
         actor: User | None = None,
-    ) -> Mapping[Project, bool]:
+    ) -> dict[Project, bool | None]:
         """
         Determine if a feature is enabled for a batch of objects.
 
@@ -100,7 +100,7 @@ class RegisteredFeatureManager:
         >>> FeatureManager.has_for_batch('projects:feature', organization, [project1, project2], actor=request.user)
         """
 
-        result = dict()
+        result: dict[Project, bool | None] = {}
         remaining = set(objects)
 
         handlers = self._handler_registry[name]
@@ -143,17 +143,17 @@ FLAGPOLE_OPTION_PREFIX = "feature"
 class FeatureManager(RegisteredFeatureManager):
     def __init__(self) -> None:
         super().__init__()
-        self._feature_registry: MutableMapping[str, type[Feature]] = {}
+        self._feature_registry: dict[str, type[Feature]] = {}
         # Deprecated: Remove entity_features once flagr has been removed.
-        self.entity_features: MutableSet[str] = set()
-        self.exposed_features: MutableSet[str] = set()
-        self.option_features: MutableSet[str] = set()
-        self.flagpole_features: MutableSet[str] = set()
+        self.entity_features: set[str] = set()
+        self.exposed_features: set[str] = set()
+        self.option_features: set[str] = set()
+        self.flagpole_features: set[str] = set()
         self._entity_handler: FeatureHandler | None = None
 
     def all(
         self, feature_type: type[Feature] = Feature, api_expose_only: bool = False
-    ) -> Mapping[str, type[Feature]]:
+    ) -> dict[str, type[Feature]]:
         """
         Get a mapping of feature name -> feature class, optionally specific to a
         particular feature type.
@@ -328,7 +328,7 @@ class FeatureManager(RegisteredFeatureManager):
         actor: User | RpcUser | AnonymousUser | None = None,
         projects: Sequence[Project] | None = None,
         organization: Organization | None = None,
-    ) -> Mapping[str, Mapping[str, bool | None]] | None:
+    ) -> dict[str, dict[str, bool | None]] | None:
         """
         Determine if multiple features are enabled. Unhandled flags will not be in
         the results if they cannot be handled.
@@ -346,7 +346,7 @@ class FeatureManager(RegisteredFeatureManager):
                 # Fall back to default handler if no entity handler available.
                 project_features = [name for name in feature_names if name.startswith("projects:")]
                 if projects and project_features:
-                    results: MutableMapping[str, Mapping[str, bool]] = {}
+                    results: dict[str, dict[str, bool | None]] = {}
                     for project in projects:
                         proj_results = results[f"project:{project.id}"] = {}
                         for feature_name in project_features:
@@ -357,7 +357,7 @@ class FeatureManager(RegisteredFeatureManager):
 
                 org_features = filter(lambda name: name.startswith("organizations:"), feature_names)
                 if organization and org_features:
-                    org_results = {}
+                    org_results: dict[str, bool | None] = {}
                     for feature_name in org_features:
                         org_results[feature_name] = self.has(
                             feature_name, organization, actor=actor
@@ -370,7 +370,7 @@ class FeatureManager(RegisteredFeatureManager):
                     feature_names,
                 )
                 if unscoped_features:
-                    unscoped_results = {}
+                    unscoped_results: dict[str, bool | None] = {}
                     for feature_name in unscoped_features:
                         unscoped_results[feature_name] = self.has(feature_name, actor=actor)
                     return {"unscoped": unscoped_results}
@@ -417,7 +417,7 @@ class FeatureCheckBatch:
         self.objects = objects
         self.actor = actor
 
-    def get_feature_objects(self) -> Mapping[Project, Feature]:
+    def get_feature_objects(self) -> dict[Project, Feature]:
         """
         Iterate over individual Feature objects.
 
@@ -429,5 +429,5 @@ class FeatureCheckBatch:
         return {obj: cls(self.feature_name, obj) for obj in self.objects}
 
     @property
-    def subject(self) -> Organization | User:
+    def subject(self) -> Organization | User | None:
         return self.organization or self.actor

+ 13 - 9
src/sentry/testutils/helpers/features.py

@@ -2,7 +2,7 @@ __all__ = ["Feature", "with_feature", "apply_feature_flag_on_cls"]
 
 import functools
 import logging
-from collections.abc import Generator, Mapping
+from collections.abc import Generator, Mapping, Sequence
 from contextlib import contextmanager
 from unittest.mock import patch
 
@@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
 
 
 @contextmanager
-def Feature(names):
+def Feature(names: str | Sequence[str] | dict[str, bool]) -> Generator[None]:
     """
     Control whether a feature is enabled.
 
@@ -102,14 +102,18 @@ def Feature(names):
                 logger.info("Flag defaulting to %s: %s", default_value, repr(name))
             return default_value
 
-    def batch_features_override(_feature_names, projects=None, organization=None, *args, **kwargs):
+    def batch_features_override(
+        _feature_names: Sequence[str], projects=None, organization=None, *args, **kwargs
+    ):
         feature_results = {name: names[name] for name in _feature_names if name in names}
         default_feature_names = [name for name in _feature_names if name not in names]
-        default_feature_results = {}
+        default_feature_results: dict[str, dict[str, bool | None]] = {}
         if default_feature_names:
-            default_feature_results = default_batch_has(
+            defaults = default_batch_has(
                 default_feature_names, projects=projects, organization=organization, **kwargs
             )
+            if defaults:
+                default_feature_results.update(defaults)
 
         if projects:
             results = {}
@@ -122,13 +126,13 @@ def Feature(names):
             return results
         elif organization:
             result_key = f"organization:{organization.id}"
-            results = {**feature_results, **default_feature_results[result_key]}
-            results = {
+            results_for_org = {**feature_results, **default_feature_results[result_key]}
+            results_for_org = {
                 name: resolve_feature_name_value_for_org(organization, val)
-                for name, val in results.items()
+                for name, val in results_for_org.items()
                 if name.startswith("organization")
             }
-            return {result_key: results}
+            return {result_key: results_for_org}
 
     with patch("sentry.features.has") as features_has:
         features_has.side_effect = features_override

+ 2 - 2
tests/sentry/tasks/test_on_demand_metrics.py

@@ -364,7 +364,7 @@ def project(organization: Organization) -> Project:
 )
 @django_db_all
 def test_schedule_on_demand_check(
-    feature_flags: set[str],
+    feature_flags: dict[str, bool],
     option_enable: bool,
     option_rollout: bool,
     option_batch_size: float,
@@ -460,7 +460,7 @@ def test_schedule_on_demand_check(
 def test_process_widget_specs(
     raw_snql_query: Any,
     _set_cardinality_cache: Any,
-    feature_flags: set[str],
+    feature_flags: dict[str, bool],
     option_enable: bool,
     widget_query_ids: Sequence[int],
     set_high_cardinality: bool,