Browse Source

fix(flagpole): Adds generic type to context builder, adds AnonymousUser support (#71583)

Gabe Villalobos 9 months ago
parent
commit
d456f85dc5

+ 13 - 15
src/flagpole/evaluation_context.py

@@ -3,9 +3,7 @@ from __future__ import annotations
 import hashlib
 from collections.abc import Callable
 from copy import deepcopy
-from typing import Any, TypeVar
-
-from pydantic import BaseModel
+from typing import Any, Generic, TypeVar
 
 ValidContextTypes = TypeVar(
     "ValidContextTypes",
@@ -50,11 +48,10 @@ class EvaluationContext:
         return int.from_bytes(hashed.digest(), byteorder="big")
 
 
-# A function that generates a new slice of evaluation context data as a dictionary.
-EvaluationContextTransformer = Callable[[dict[str, Any]], EvaluationContextDict]
+T_CONTEXT_DATA = TypeVar("T_CONTEXT_DATA")
 
 
-class ContextBuilder(BaseModel):
+class ContextBuilder(Generic[T_CONTEXT_DATA]):
     """
     Used to build an EvaluationContext instance for use in Flagpole.
     This class aggregates a list of context transformers, each of which are
@@ -67,18 +64,18 @@ class ContextBuilder(BaseModel):
     >>> feature.match(dict())
     """
 
-    context_transformers: list[EvaluationContextTransformer] = []
-    exception_handler: Callable[[Exception], Any] | None
+    context_transformers: list[Callable[[T_CONTEXT_DATA], EvaluationContextDict]] = []
+    exception_handler: Callable[[Exception], Any] | None = None
 
     def add_context_transformer(
-        self, context_transformer: EvaluationContextTransformer
-    ) -> ContextBuilder:
+        self, context_transformer: Callable[[T_CONTEXT_DATA], EvaluationContextDict]
+    ) -> ContextBuilder[T_CONTEXT_DATA]:
         self.context_transformers.append(context_transformer)
         return self
 
     def add_exception_handler(
         self, exception_handler: Callable[[Exception], None]
-    ) -> ContextBuilder:
+    ) -> ContextBuilder[T_CONTEXT_DATA]:
         """
         Add a custom exception handler to the context builder if you need custom handling
         if any of the transformer functions raise an exception. This is useful for swallowing
@@ -92,13 +89,14 @@ class ContextBuilder(BaseModel):
         self.exception_handler = exception_handler
         return self
 
-    def build(self, data: dict[str, Any] | None = None) -> EvaluationContext:
-        builder_data: dict[str, Any] = data or dict()
-        context_data: dict[str, Any] = dict()
+    def build(self, data: T_CONTEXT_DATA | None = None) -> EvaluationContext:
+        context_data: EvaluationContextDict = dict()
+        if data is None:
+            return EvaluationContext(context_data)
 
         for transformer in self.context_transformers:
             try:
-                context_data = {**context_data, **transformer(builder_data)}
+                context_data = {**context_data, **transformer(data)}
             except Exception as e:
                 if self.exception_handler is not None:
                     self.exception_handler(e)

+ 20 - 10
src/flagpole/sentry_flagpole_context.py

@@ -1,10 +1,13 @@
-from typing import Any
+from dataclasses import dataclass
+
+from django.contrib.auth.models import AnonymousUser
 
 from flagpole.evaluation_context import ContextBuilder, EvaluationContextDict
 from sentry.models.organization import Organization
 from sentry.models.project import Project
 from sentry.models.user import User
 from sentry.services.hybrid_cloud.organization import RpcOrganization
+from sentry.services.hybrid_cloud.project import RpcProject
 from sentry.services.hybrid_cloud.user import RpcUser
 
 
@@ -12,9 +15,16 @@ class InvalidContextDataException(Exception):
     pass
 
 
-def organization_context_transformer(data: dict[str, Any]) -> EvaluationContextDict:
+@dataclass()
+class SentryContextData:
+    actor: User | RpcUser | AnonymousUser | None = None
+    organization: Organization | RpcOrganization | None = None
+    project: Project | RpcProject | None = None
+
+
+def organization_context_transformer(data: SentryContextData) -> EvaluationContextDict:
     context_data: EvaluationContextDict = dict()
-    org = data.get("organization", None)
+    org = data.organization
     if org is None:
         return context_data
 
@@ -36,10 +46,10 @@ def organization_context_transformer(data: dict[str, Any]) -> EvaluationContextD
     return context_data
 
 
-def project_context_transformer(data: dict[str, Any]) -> EvaluationContextDict:
+def project_context_transformer(data: SentryContextData) -> EvaluationContextDict:
     context_data: EvaluationContextDict = dict()
 
-    if (proj := data.get("project", None)) is not None:
+    if (proj := data.project) is not None:
         if not isinstance(proj, Project):
             raise InvalidContextDataException("Invalid project object provided")
 
@@ -50,10 +60,10 @@ def project_context_transformer(data: dict[str, Any]) -> EvaluationContextDict:
     return context_data
 
 
-def user_context_transformer(data: dict[str, Any]) -> EvaluationContextDict:
+def user_context_transformer(data: SentryContextData) -> EvaluationContextDict:
     context_data: EvaluationContextDict = dict()
-    user = data.get("actor", None)
-    if user is None:
+    user = data.actor
+    if user is None or isinstance(user, AnonymousUser):
         return context_data
 
     if not isinstance(user, User) and not isinstance(user, RpcUser):
@@ -78,14 +88,14 @@ def user_context_transformer(data: dict[str, Any]) -> EvaluationContextDict:
     return context_data
 
 
-def get_sentry_flagpole_context_builder() -> ContextBuilder:
+def get_sentry_flagpole_context_builder() -> ContextBuilder[SentryContextData]:
     """
     Creates and returns a new sentry flagpole context builder with Organization,
      User, and Project transformers appended to it.
     :return:
     """
     return (
-        ContextBuilder()
+        ContextBuilder[SentryContextData]()
         .add_context_transformer(organization_context_transformer)
         .add_context_transformer(project_context_transformer)
         .add_context_transformer(user_context_transformer)

+ 29 - 15
tests/flagpole/test_evaluation_context.py

@@ -1,3 +1,4 @@
+from dataclasses import dataclass
 from typing import Any
 
 import pytest
@@ -9,29 +10,42 @@ class TestEvaluationContext:
     pass
 
 
+@dataclass
+class ContextData:
+    foo: str | None = None
+    baz: int | None = None
+    buzz: dict | set | None = None
+
+
 class TestContextBuilder:
     def test_empty_context_builder(self):
-        context_builder = ContextBuilder()
+        context_builder = ContextBuilder[ContextData]()
         context = context_builder.build()
 
         assert context.size() == 0
 
     def test_static_transformer(self):
-        def static_transformer(_data: dict[str, Any]) -> dict[str, Any]:
+        def static_transformer(_data: ContextData) -> dict[str, Any]:
             return dict(foo="bar", baz=1)
 
-        eval_context = ContextBuilder().add_context_transformer(static_transformer).build()
+        eval_context = (
+            ContextBuilder[ContextData]()
+            .add_context_transformer(static_transformer)
+            .build(ContextData())
+        )
 
         assert eval_context.size() == 2
         assert eval_context.get("foo") == "bar"
         assert eval_context.get("baz") == 1
 
     def test_transformer_with_data(self):
-        def transformer_with_data(data: dict[str, Any]) -> dict[str, Any]:
-            return dict(foo="bar", baz=data.get("baz"))
+        def transformer_with_data(data: ContextData) -> dict[str, Any]:
+            return dict(foo="bar", baz=getattr(data, "baz", None))
 
         eval_context = (
-            ContextBuilder().add_context_transformer(transformer_with_data).build({"baz": 2})
+            ContextBuilder[ContextData]()
+            .add_context_transformer(transformer_with_data)
+            .build(ContextData(baz=2))
         )
 
         assert eval_context.size() == 2
@@ -39,17 +53,17 @@ class TestContextBuilder:
         assert eval_context.get("baz") == 2
 
     def test_multiple_context_transformers(self):
-        def transformer_one(data: dict[str, Any]) -> dict[str, Any]:
-            return dict(foo="overwrite_me", baz=2, buzz=data.get("buzz"))
+        def transformer_one(data: ContextData) -> dict[str, Any]:
+            return dict(foo="overwrite_me", baz=2, buzz=getattr(data, "buzz"))
 
-        def transformer_two(_data: dict[str, Any]) -> dict[str, Any]:
+        def transformer_two(_data: ContextData) -> dict[str, Any]:
             return dict(foo="bar")
 
         eval_context = (
-            ContextBuilder()
+            ContextBuilder[ContextData]()
             .add_context_transformer(transformer_one)
             .add_context_transformer(transformer_two)
-            .build({"foo": "bar", "buzz": {1, 2, 3}})
+            .build(ContextData(foo="bar", buzz={1, 2, 3}))
         )
 
         assert eval_context.size() == 3
@@ -60,19 +74,19 @@ class TestContextBuilder:
     def test_with_exception_handler(self):
         exc_message = "oh noooooo"
 
-        def broken_transformer(_data: dict[str, Any]) -> EvaluationContextDict:
+        def broken_transformer(_data: ContextData) -> EvaluationContextDict:
             raise Exception(exc_message)
 
-        context_builder = ContextBuilder().add_context_transformer(broken_transformer)
+        context_builder = ContextBuilder[ContextData]().add_context_transformer(broken_transformer)
 
         with pytest.raises(Exception) as exc:
-            context_builder.build(dict())
+            context_builder.build(ContextData())
 
         assert exc.match(exc_message)
 
         # Ensure builder doesn't raise an exception
         context_builder.add_exception_handler(lambda _exc: None)
-        context_builder.build(dict())
+        context_builder.build(ContextData())
 
         with pytest.raises(Exception):
             context_builder.add_exception_handler(lambda _exc: None)

+ 8 - 2
tests/flagpole/test_feature.py

@@ -1,3 +1,4 @@
+from dataclasses import dataclass
 from datetime import datetime, timezone
 
 import pytest
@@ -6,6 +7,11 @@ from flagpole import ContextBuilder, EvaluationContext, Feature, InvalidFeatureF
 from flagpole.conditions import ConditionOperatorKind
 
 
+@dataclass
+class ContextData:
+    pass
+
+
 class TestParseFeatureConfig:
     def get_is_true_context_builder(self, is_true_value: bool):
         return ContextBuilder().add_context_transformer(lambda _data: dict(is_true=is_true_value))
@@ -104,7 +110,7 @@ class TestParseFeatureConfig:
         )
 
         context_builder = self.get_is_true_context_builder(is_true_value=True)
-        assert feature.match(context_builder.build())
+        assert feature.match(context_builder.build(ContextData()))
 
     def test_disabled_feature(self):
         feature = Feature.from_feature_config_json(
@@ -128,4 +134,4 @@ class TestParseFeatureConfig:
         )
 
         context_builder = self.get_is_true_context_builder(is_true_value=True)
-        assert not feature.match(context_builder.build())
+        assert not feature.match(context_builder.build(ContextData()))

+ 27 - 40
tests/flagpole/test_sentry_flagpole_context.py

@@ -1,7 +1,9 @@
 import pytest
+from django.contrib.auth.models import AnonymousUser
 
 from flagpole.sentry_flagpole_context import (
     InvalidContextDataException,
+    SentryContextData,
     get_sentry_flagpole_context_builder,
     organization_context_transformer,
     project_context_transformer,
@@ -18,7 +20,9 @@ class TestSentryFlagpoleContext(TestCase):
         project = self.create_project(organization=org)
         sentry_flagpole_builder = get_sentry_flagpole_context_builder()
 
-        sentry_context = sentry_flagpole_builder.build(dict(organization=org, project=project))
+        sentry_context = sentry_flagpole_builder.build(
+            SentryContextData(organization=org, project=project)
+        )
 
         assert sentry_context.get("organization_slug") == org.slug
         assert sentry_context.get("organization_slug") == org.slug
@@ -28,15 +32,15 @@ class TestSentryFlagpoleContext(TestCase):
 
 class TestSentryOrganizationContextTransformer(TestCase):
     def test_without_organization_passed(self):
-        context_data = organization_context_transformer(dict())
-        assert context_data == dict()
+        context_data = organization_context_transformer(SentryContextData())
+        assert context_data == {}
 
     def test_with_invalid_organization(self):
         with pytest.raises(InvalidContextDataException):
-            organization_context_transformer(dict(organization=1234))
+            organization_context_transformer(SentryContextData(organization=1234))  # type: ignore[arg-type]
 
         with pytest.raises(InvalidContextDataException):
-            organization_context_transformer(dict(organization=self.create_project()))
+            organization_context_transformer(SentryContextData(organization=self.create_project()))
 
     def test_with_valid_organization(self):
         org = self.create_organization(slug="foobar", name="Foo Bar")
@@ -44,7 +48,7 @@ class TestSentryOrganizationContextTransformer(TestCase):
         org.save()
         assert bool(org.flags.early_adopter) is True
 
-        context_data = organization_context_transformer(dict(organization=org))
+        context_data = organization_context_transformer(SentryContextData(organization=org))
 
         assert context_data == {
             "organization_slug": "foobar",
@@ -56,20 +60,20 @@ class TestSentryOrganizationContextTransformer(TestCase):
 
 class TestProjectContextTransformer(TestCase):
     def test_without_project_passed(self):
-        context_data = project_context_transformer(dict())
-        assert context_data == dict()
+        context_data = project_context_transformer(SentryContextData())
+        assert context_data == {}
 
     def test_with_invalid_project_passed(self):
         with pytest.raises(InvalidContextDataException):
-            project_context_transformer(dict(project=123))
+            project_context_transformer(SentryContextData(project=123))  # type: ignore[arg-type]
 
         with pytest.raises(InvalidContextDataException):
-            project_context_transformer(dict(project=self.create_organization()))
+            project_context_transformer(SentryContextData(project=self.create_organization()))
 
     def test_with_valid_project(self):
         project = self.create_project(slug="foobar", name="Foo Bar")
 
-        context_data = project_context_transformer(dict(project=project))
+        context_data = project_context_transformer(SentryContextData(project=project))
         assert context_data == {
             "project_slug": "foobar",
             "project_name": "Foo Bar",
@@ -80,22 +84,22 @@ class TestProjectContextTransformer(TestCase):
 @control_silo_test
 class TestUserContextTransformer(TestCase):
     def test_without_user_passed(self):
-        context_data = project_context_transformer(dict())
-        assert context_data == dict()
+        context_data = project_context_transformer(SentryContextData())
+        assert context_data == {}
 
     def test_with_invalid_user_passed(self):
         with pytest.raises(InvalidContextDataException):
-            user_context_transformer(dict(actor=123))
+            user_context_transformer(SentryContextData(actor=123))  # type: ignore[arg-type]
 
         with pytest.raises(InvalidContextDataException):
-            user_context_transformer(dict(actor=self.create_organization()))
+            user_context_transformer(SentryContextData(actor=self.create_organization()))
 
     def test_with_valid_user(self):
         user = self.create_user(email="foobar@example.com")
         # Create a new, unverified email to ensure we don't list it
         self.create_useremail(user=user, email="unverified_email@example.com")
 
-        context_data = user_context_transformer(dict(actor=user))
+        context_data = user_context_transformer(SentryContextData(actor=user))
         assert context_data == {
             "user_email": "foobar@example.com",
             "user_domain": "example.com",
@@ -110,7 +114,7 @@ class TestUserContextTransformer(TestCase):
         user_email.is_verified = False
         user_email.save()
 
-        context_data = user_context_transformer(dict(actor=user))
+        context_data = user_context_transformer(SentryContextData(actor=user))
         assert context_data == {
             "user_id": user.id,
             "user_is-superuser": False,
@@ -119,7 +123,7 @@ class TestUserContextTransformer(TestCase):
 
     def test_with_super_user_and_staff(self):
         user = self.create_user(email="super_user_admin_person@sentry.io", is_superuser=True)
-        context_data = user_context_transformer(dict(actor=user))
+        context_data = user_context_transformer(SentryContextData(actor=user))
         assert context_data == {
             "user_email": "super_user_admin_person@sentry.io",
             "user_domain": "sentry.io",
@@ -131,7 +135,7 @@ class TestUserContextTransformer(TestCase):
         user.is_staff = True
         user.is_superuser = False
         user.save()
-        context_data = user_context_transformer(dict(actor=user))
+        context_data = user_context_transformer(SentryContextData(actor=user))
         assert context_data == {
             "user_email": "super_user_admin_person@sentry.io",
             "user_domain": "sentry.io",
@@ -140,24 +144,7 @@ class TestUserContextTransformer(TestCase):
             "user_is-staff": True,
         }
 
-
-class TestTeamContextTransformer(TestCase):
-    pass
-    # def test_with_missing_team(self):
-    #     context_data = team_context_transformer(dict())
-    #     assert context_data == dict()
-    #
-    # def test_with_invalid_team(self):
-    #     with pytest.raises(InvalidContextDataException):
-    #         team_context_transformer(dict(team="invalid"))
-    #
-    #     with pytest.raises(InvalidContextDataException):
-    #         team_context_transformer(dict(team=self.create_organization()))
-    #
-    # def test_with_valid_team(self):
-    #     team = self.create_team(organization=self.create_organization())
-    #
-    #     context_data = team_context_transformer(dict(team=team))
-    #     assert context_data == {
-    #         "team_id": team.id
-    #     }
+    def test_with_anonymous_user(self):
+        user = AnonymousUser()
+        context_data = user_context_transformer(SentryContextData(actor=user))
+        assert context_data == {}