Browse Source

types(features): Add types to feature flags (#27114)

Marcos Gaeta 3 years ago
parent
commit
b3afb8e566

+ 1 - 0
mypy.ini

@@ -16,6 +16,7 @@ files = src/sentry/api/bases/external_actor.py,
         src/sentry/api/serializers/models/team.py,
         src/sentry/api/validators/external_actor.py,
         src/sentry/api/validators/notifications.py,
+        src/sentry/features/**/*.py,
         src/sentry/grouping/strategies/base.py,
         src/sentry/grouping/strategies/message.py,
         src/sentry/grouping/result.py,

+ 30 - 7
src/sentry/features/base.py

@@ -1,24 +1,47 @@
 __all__ = ["Feature", "OrganizationFeature", "ProjectFeature", "ProjectPluginFeature"]
 
+import abc
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+    from sentry.models import Organization, Project
+
 
 class Feature:
-    def __init__(self, name):
+    """Feature is not actually an abstract class even though it only has abstract
+    methods. This is because we need to be able to instantiate it."""
+
+    def __init__(self, name: str, *args: Any, **kwargs: Any) -> None:
+        """
+        `FeatureManager.get()` and `FeatureCheckBatch.get_feature_objects()`
+        expect to be able to pass a `Feature` arbitrary `args` and `kwargs`.
+        """
         self.name = name
 
+    @abc.abstractmethod
+    def get_organization(self) -> "Organization":
+        raise NotImplementedError
+
 
 class OrganizationFeature(Feature):
-    def __init__(self, name, organization):
-        Feature.__init__(self, name)
+    def __init__(self, name: str, organization: "Organization") -> None:
+        super().__init__(name)
         self.organization = organization
 
+    def get_organization(self) -> "Organization":
+        return self.organization
+
 
 class ProjectFeature(Feature):
-    def __init__(self, name, project):
-        Feature.__init__(self, name)
+    def __init__(self, name: str, project: "Project") -> None:
+        super().__init__(name)
         self.project = project
 
+    def get_organization(self) -> "Organization":
+        return self.project.organization
+
 
 class ProjectPluginFeature(ProjectFeature):
-    def __init__(self, name, project, plugin):
-        ProjectFeature.__init__(self, name, project)
+    def __init__(self, name: str, project: "Project", plugin: Any) -> None:
+        super().__init__(name, project=project)
         self.plugin = plugin

+ 1 - 1
src/sentry/features/exceptions.py

@@ -2,7 +2,7 @@ __all__ = ["FeatureNotRegistered"]
 
 
 class FeatureNotRegistered(Exception):
-    def __init__(self, name):
+    def __init__(self, name: str) -> None:
         msg = (
             'The "{}" feature has not been registered. '
             "Ensure that a feature has been added to sentry.features.default_manager"

+ 31 - 9
src/sentry/features/handler.py

@@ -1,25 +1,45 @@
 __all__ = ["FeatureHandler", "BatchFeatureHandler"]
 
+import abc
+from typing import TYPE_CHECKING, Mapping, MutableSet, Optional, Sequence
+
+if TYPE_CHECKING:
+    from sentry.features.base import Feature
+    from sentry.features.manager import FeatureCheckBatch
+    from sentry.models import Organization, Project, User
+
 
 class FeatureHandler:
-    features = set()
+    features: MutableSet[str] = set()
 
-    def __call__(self, feature, actor):
+    def __call__(self, feature: "Feature", actor: "User") -> Optional[bool]:
         if feature.name not in self.features:
             return None
 
         return self.has(feature, actor)
 
-    def has(self, feature, actor):
+    @abc.abstractmethod
+    def has(self, feature: "Feature", actor: "User") -> bool:
         raise NotImplementedError
 
-    def has_for_batch(self, batch):
+    def has_for_batch(self, batch: "FeatureCheckBatch") -> Mapping["Project", bool]:
         # If not overridden, iterate over objects in the batch individually.
         return {
             obj: self.has(feature, batch.actor)
             for (obj, feature) in batch.get_feature_objects().items()
         }
 
+    @abc.abstractmethod
+    def batch_has(
+        self,
+        feature_names: Sequence[str],
+        actor: "User",
+        projects: Optional[Sequence["Project"]] = None,
+        organization: Optional["Organization"] = None,
+        batch: bool = True,
+    ) -> Optional[Mapping[str, Mapping[str, bool]]]:
+        raise NotImplementedError
+
 
 # It is generally better to extend BatchFeatureHandler if it is possible to do
 # the check with no more than the feature name, organization, and actor. If it
@@ -28,13 +48,15 @@ class FeatureHandler:
 
 
 class BatchFeatureHandler(FeatureHandler):
-    def _check_for_batch(self, feature_name, organization, actor):
+    @abc.abstractmethod
+    def _check_for_batch(
+        self, feature_name: str, organization: "Organization", actor: "User"
+    ) -> bool:
         raise NotImplementedError
 
-    def has(self, feature, actor):
-        organization = getattr(feature, "organization", None) or feature.project.organization
-        return self._check_for_batch(feature.name, organization, actor)
+    def has(self, feature: "Feature", actor: "User") -> bool:
+        return self._check_for_batch(feature.name, feature.get_organization(), actor)
 
-    def has_for_batch(self, batch):
+    def has_for_batch(self, batch: "FeatureCheckBatch") -> Mapping["Project", bool]:
         flag = self._check_for_batch(batch.feature_name, batch.organization, batch.actor)
         return {obj: flag for obj in batch.objects}

+ 10 - 4
src/sentry/features/helpers.py

@@ -1,10 +1,14 @@
-from typing import Any, Sequence
+from typing import Any, Callable, Optional, Sequence
 
+from rest_framework.request import Request
 from rest_framework.response import Response
 
 from sentry import features
 from sentry.models import Organization
 
+# TODO(mgaeta): It's not currently possible to type a Callable's args with kwargs.
+EndpointFunc = Callable[..., Response]
+
 
 def any_organization_has_feature(
     feature: str, organizations: Sequence[Organization], **kwargs: Any
@@ -12,7 +16,9 @@ def any_organization_has_feature(
     return any([features.has(feature, organization, **kwargs) for organization in organizations])
 
 
-def requires_feature(feature, any_org=None):
+def requires_feature(
+    feature: str, any_org: Optional[bool] = None
+) -> Callable[[EndpointFunc], EndpointFunc]:
     """
     Require a feature flag to access an endpoint.
 
@@ -31,8 +37,8 @@ def requires_feature(feature, any_org=None):
         >>>     return Response()
     """
 
-    def decorator(func):
-        def wrapped(self, request, *args, **kwargs):
+    def decorator(func: EndpointFunc) -> EndpointFunc:
+        def wrapped(self: Any, request: Request, *args: Any, **kwargs: Any) -> Response:
             # The endpoint is accessible if any of the User's Orgs have the feature
             # flag enabled.
             if any_org:

+ 63 - 18
src/sentry/features/manager.py

@@ -1,6 +1,19 @@
 __all__ = ["FeatureManager"]
 
+import abc
 from collections import defaultdict
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Iterable,
+    List,
+    Mapping,
+    MutableMapping,
+    MutableSet,
+    Optional,
+    Sequence,
+    Type,
+)
 
 import sentry_sdk
 from django.conf import settings
@@ -8,6 +21,10 @@ from django.conf import settings
 from .base import Feature
 from .exceptions import FeatureNotRegistered
 
+if TYPE_CHECKING:
+    from sentry.features.handler import FeatureHandler
+    from sentry.models import Organization, Project, User
+
 
 class RegisteredFeatureManager:
     """
@@ -18,10 +35,10 @@ class RegisteredFeatureManager:
     handler, remove this class entirely
     """
 
-    def __init__(self):
-        self._handler_registry = defaultdict(list)
+    def __init__(self) -> None:
+        self._handler_registry: MutableMapping[str, List["FeatureHandler"]] = defaultdict(list)
 
-    def add_handler(self, handler):
+    def add_handler(self, handler: "FeatureHandler") -> None:
         """
         Register a feature handler.
 
@@ -31,14 +48,29 @@ class RegisteredFeatureManager:
         for feature_name in handler.features:
             self._handler_registry[feature_name].append(handler)
 
-    def _get_handler(self, feature, actor):
+    def _get_handler(self, feature: Feature, actor: "User") -> Optional[bool]:
         for handler in self._handler_registry[feature.name]:
             rv = handler(feature, actor)
             if rv is not None:
                 return rv
         return None
 
-    def has_for_batch(self, name, organization, objects, actor=None):
+    @abc.abstractmethod
+    def _get_feature_class(self, name: str) -> Type[Feature]:
+        """
+        We need this abstract method on this class because the `has_for_batch()`
+        method instantiates a `FeatureCheckBatch` and sets `manager` as `self`
+        as a `RegisteredFeatureManager`.
+        """
+        raise NotImplementedError
+
+    def has_for_batch(
+        self,
+        name: str,
+        organization: "Organization",
+        objects: Sequence["Project"],
+        actor: Optional["User"] = None,
+    ) -> Mapping["Project", bool]:
         """
         Determine in a batch if a feature is enabled.
 
@@ -95,20 +127,20 @@ class RegisteredFeatureManager:
 
 # TODO: Change RegisteredFeatureManager back to object once it can be removed
 class FeatureManager(RegisteredFeatureManager):
-    def __init__(self):
+    def __init__(self) -> None:
         super().__init__()
-        self._feature_registry = {}
-        self.entity_features = set()
-        self._entity_handler = None
+        self._feature_registry: MutableMapping[str, Type[Feature]] = {}
+        self.entity_features: MutableSet[str] = set()
+        self._entity_handler: Optional["FeatureHandler"] = None
 
-    def all(self, feature_type=Feature):
+    def all(self, feature_type: Type[Feature] = Feature) -> Mapping[str, Type[Feature]]:
         """
         Get a mapping of feature name -> feature class, optionally specific to a
         particular feature type.
         """
         return {k: v for k, v in self._feature_registry.items() if v == feature_type}
 
-    def add(self, name, cls=Feature, entity_feature=False):
+    def add(self, name: str, cls: Type[Feature] = Feature, entity_feature: bool = False) -> None:
         """
         Register a feature.
 
@@ -121,13 +153,13 @@ class FeatureManager(RegisteredFeatureManager):
             self.entity_features.add(name)
         self._feature_registry[name] = cls
 
-    def _get_feature_class(self, name):
+    def _get_feature_class(self, name: str) -> Type[Feature]:
         try:
             return self._feature_registry[name]
         except KeyError:
             raise FeatureNotRegistered(name)
 
-    def get(self, name: str, *args, **kwargs):
+    def get(self, name: str, *args: Any, **kwargs: Any) -> Feature:
         """
         Lookup a registered feature context scope given the feature name.
 
@@ -136,13 +168,13 @@ class FeatureManager(RegisteredFeatureManager):
         cls = self._get_feature_class(name)
         return cls(name, *args, **kwargs)
 
-    def add_entity_handler(self, handler):
+    def add_entity_handler(self, handler: "FeatureHandler") -> None:
         """
         Registers a handler that doesn't require a feature name match
         """
         self._entity_handler = handler
 
-    def has(self, name: str, *args, **kwargs) -> bool:
+    def has(self, name: str, *args: Any, **kwargs: Any) -> bool:
         """
         Determine if a feature is enabled. If a handler returns None, then the next
         mechanism is used for feature checking.
@@ -193,7 +225,13 @@ class FeatureManager(RegisteredFeatureManager):
         # Features are by default disabled if no plugin or default enables them
         return False
 
-    def batch_has(self, feature_names, actor=None, projects=None, organization=None):
+    def batch_has(
+        self,
+        feature_names: Sequence[str],
+        actor: Optional["User"] = None,
+        projects: Optional[Sequence["Project"]] = None,
+        organization: Optional["Organization"] = None,
+    ) -> Optional[Mapping[str, Mapping[str, bool]]]:
         """
         Determine if multiple features are enabled. Unhandled flags will not be in
         the results if they cannot be handled.
@@ -218,14 +256,21 @@ class FeatureCheckBatch:
     common parent organization.
     """
 
-    def __init__(self, manager, name, organization, objects, actor):
+    def __init__(
+        self,
+        manager: RegisteredFeatureManager,
+        name: str,
+        organization: "Organization",
+        objects: Iterable["Project"],
+        actor: "User",
+    ) -> None:
         self._manager = manager
         self.feature_name = name
         self.organization = organization
         self.objects = objects
         self.actor = actor
 
-    def get_feature_objects(self):
+    def get_feature_objects(self) -> Mapping["Project", Feature]:
         """
         Iterate over individual Feature objects.
 

+ 14 - 19
tests/sentry/api/endpoints/test_auth_login.py

@@ -1,47 +1,42 @@
 from django.urls import reverse
-from exam import fixture
 
 from sentry.testutils import APITestCase
 from sentry.utils.compat.mock import patch
 
 
 class AuthLoginEndpointTest(APITestCase):
-    @fixture
-    def path(self):
-        return reverse("sentry-api-0-auth-login")
+    endpoint = "sentry-api-0-auth-login"
+    method = "post"
 
     def setUp(self):
         # Requests to set the test cookie
         self.client.get(reverse("sentry-api-0-auth-config"))
 
     def test_login_invalid_password(self):
-        resp = self.client.post(self.path, {"username": self.user.username, "password": "bizbar"})
-        assert resp.status_code == 400
-        assert resp.data["errors"]["__all__"] == [
+        response = self.get_error_response(
+            username=self.user.username, password="bizbar", status_code=400
+        )
+        assert response.data["errors"]["__all__"] == [
             "Please enter a correct username and password. Note that both fields may be case-sensitive."
         ]
 
     def test_login_valid_credentials(self):
-        resp = self.client.post(self.path, {"username": self.user.username, "password": "admin"})
-
-        assert resp.status_code == 200
-        assert resp.data["nextUri"] == "/organizations/new/"
+        response = self.get_success_response(username=self.user.username, password="admin")
+        assert response.data["nextUri"] == "/organizations/new/"
 
     def test_must_reactivate(self):
         self.user.update(is_active=False)
 
-        resp = self.client.post(self.path, {"username": self.user.username, "password": "admin"})
-
-        assert resp.status_code == 200
-        assert resp.data["nextUri"] == "/auth/reactivate/"
+        response = self.get_success_response(username=self.user.username, password="admin")
+        assert response.data["nextUri"] == "/auth/reactivate/"
 
     @patch(
         "sentry.api.endpoints.auth_login.ratelimiter.is_limited", autospec=True, return_value=True
     )
     def test_login_ratelimit(self, is_limited):
-        resp = self.client.post(self.path, {"username": self.user.username, "password": "admin"})
-
-        assert resp.status_code == 400
-        assert [str(s) for s in resp.data["errors"]["__all__"]] == [
+        response = self.get_error_response(
+            username=self.user.username, password="admin", status_code=400
+        )
+        assert [str(s) for s in response.data["errors"]["__all__"]] == [
             "You have made too many failed authentication attempts. Please try again later."
         ]

+ 50 - 0
tests/sentry/features/test_manager.py

@@ -1,10 +1,34 @@
+from typing import Any, Mapping, Optional, Union
+
 from django.conf import settings
 
 from sentry import features
+from sentry.features import Feature
+from sentry.models import User
 from sentry.testutils import TestCase
 from sentry.utils.compat import mock
 
 
+class MockBatchHandler(features.BatchFeatureHandler):
+    features = frozenset(["auth:register", "organizations:feature", "projects:feature"])
+
+    def has(
+        self, feature: Feature, actor: User
+    ) -> Union[Optional[bool], Mapping[str, Optional[bool]]]:
+        return {feature.name: True}
+
+    def batch_has(self, feature_names, *args: Any, **kwargs: Any):
+        if isinstance(feature_names, str):
+            return {feature_names: True}
+
+        return {
+            feature_name: True for feature_name in feature_names if feature_name in self.features
+        }
+
+    def _check_for_batch(self, feature_name, organization, actor):
+        return True if feature_name in self.features else None
+
+
 class FeatureManagerTest(TestCase):
     def test_feature_registry(self):
         manager = features.FeatureManager()
@@ -151,3 +175,29 @@ class FeatureManagerTest(TestCase):
         assert after_no_handler.hit_counter == 0
 
         assert null_handler.hit_counter == 2
+
+    def test_batch_has(self):
+        manager = features.FeatureManager()
+        manager.add("auth:register")
+        manager.add("organizations:feature", features.OrganizationFeature)
+        manager.add("projects:feature", features.ProjectFeature)
+        manager.add_entity_handler(MockBatchHandler())
+
+        assert manager.batch_has("auth:register", actor=self.user)["auth:register"]
+        assert manager.batch_has(
+            "organizations:feature", actor=self.user, organization=self.organization
+        )["organizations:feature"]
+        assert manager.batch_has("projects:feature", actor=self.user, projects=[self.project])[
+            "projects:feature"
+        ]
+
+    def test_has(self):
+        manager = features.FeatureManager()
+        manager.add("auth:register")
+        manager.add("organizations:feature", features.OrganizationFeature)
+        manager.add("projects:feature", features.ProjectFeature)
+        manager.add_handler(MockBatchHandler())
+
+        assert manager.has("organizations:feature", actor=self.user, organization=self.organization)
+        assert manager.has("projects:feature", actor=self.user, project=self.project)
+        assert manager.has("auth:register", actor=self.user)