Browse Source

feat(flagpole): Adds identity fields to Eval Context for more stable partial rollouts (#72533)

Gabe Villalobos 9 months ago
parent
commit
d14828cd30

+ 1 - 1
src/flagpole/conditions.py

@@ -206,6 +206,6 @@ class Segment(BaseModel):
                 return False
         # Apply incremental rollout if available.
         if self.rollout is not None and self.rollout < 100:
-            return context.id() % 100 <= self.rollout
+            return context.id % 100 <= self.rollout
 
         return True

+ 46 - 13
src/flagpole/evaluation_context.py

@@ -19,27 +19,36 @@ class EvaluationContext:
     feature conditions.
     """
 
-    def __init__(self, data: EvaluationContextDict):
+    __data: EvaluationContextDict
+    __identity_fields: set[str]
+    __id: int
+
+    def __init__(self, data: EvaluationContextDict, identity_fields: set[str] | None = None):
         self.__data = deepcopy(data)
+        self.__set_identity_fields(identity_fields)
+        self.__id = self.__generate_id()
 
-    def get(self, key: str) -> Any:
-        return self.__data.get(key)
+    def __set_identity_fields(self, identity_fields: set[str] | None = None):
+        trimmed_id_fields = set()
+        if identity_fields is not None:
+            for field in identity_fields:
+                if field in self.__data:
+                    trimmed_id_fields.add(field)
 
-    def has(self, key: str) -> Any:
-        return key in self.__data
+        if not trimmed_id_fields:
+            trimmed_id_fields.update(self.__data.keys())
 
-    def size(self) -> int:
-        return len(self.__data)
+        self.__identity_fields = trimmed_id_fields
 
-    def id(self) -> int:
+    def __generate_id(self) -> int:
         """
-        Return a hashed identifier for this context
+        Generates and return a hashed identifier for this context
 
         The identifier should be stable for a given context contents.
         Identifiers are used to determine rollout groups deterministically
         and consistently.
         """
-        keys = self.__data.keys()
+        keys = list(self.__identity_fields)
         vector = []
         for key in sorted(keys):
             vector.append(key)
@@ -47,6 +56,23 @@ class EvaluationContext:
         hashed = hashlib.sha1(":".join(vector).encode("utf8"))
         return int.from_bytes(hashed.digest(), byteorder="big")
 
+    @property
+    def id(self) -> int:
+        """
+        Guard against context mutation by using this virtual property as a
+        getter for the private ID field.
+        """
+        return self.__id
+
+    def get(self, key: str) -> Any:
+        return self.__data.get(key)
+
+    def has(self, key: str) -> Any:
+        return key in self.__data
+
+    def size(self) -> int:
+        return len(self.__data)
+
 
 T_CONTEXT_DATA = TypeVar("T_CONTEXT_DATA")
 
@@ -61,20 +87,27 @@ class ContextBuilder(Generic[T_CONTEXT_DATA]):
     >>> from flagpole import ContextBuilder, Feature
     >>> builder = ContextBuilder().add_context_transformer(lambda _dict: dict(foo="bar"))
     >>> feature = Feature.from_feature_dictionary(name="foo", feature_dictionary=dict(), context=builder)
-    >>> feature.match(dict())
+    >>> feature.match(EvaluationContext(dict()))
     """
 
     context_transformers: list[Callable[[T_CONTEXT_DATA], EvaluationContextDict]]
     exception_handler: Callable[[Exception], Any] | None
+    __identity_fields: set[str]
 
     def __init__(self):
         self.context_transformers = []
         self.exception_handler = None
+        self.__identity_fields = set()
 
     def add_context_transformer(
-        self, context_transformer: Callable[[T_CONTEXT_DATA], EvaluationContextDict]
+        self,
+        context_transformer: Callable[[T_CONTEXT_DATA], EvaluationContextDict],
+        identity_fields: list[str] | None = None,
     ) -> ContextBuilder[T_CONTEXT_DATA]:
         self.context_transformers.append(context_transformer)
+        if identity_fields is not None:
+            self.__identity_fields.update(identity_fields)
+
         return self
 
     def add_exception_handler(
@@ -107,4 +140,4 @@ class ContextBuilder(Generic[T_CONTEXT_DATA]):
                 else:
                     raise
 
-        return EvaluationContext(context_data)
+        return EvaluationContext(context_data, self.__identity_fields)

+ 0 - 102
src/flagpole/sentry_flagpole_context.py

@@ -1,102 +0,0 @@
-from dataclasses import dataclass
-
-from django.contrib.auth.models import AnonymousUser
-
-from flagpole.evaluation_context import ContextBuilder, EvaluationContextDict
-from sentry.models.organization import Organization
-from sentry.models.project import Project
-from sentry.models.user import User
-from sentry.services.hybrid_cloud.organization import RpcOrganization
-from sentry.services.hybrid_cloud.project import RpcProject
-from sentry.services.hybrid_cloud.user import RpcUser
-
-
-class InvalidContextDataException(Exception):
-    pass
-
-
-@dataclass()
-class SentryContextData:
-    actor: User | RpcUser | AnonymousUser | None = None
-    organization: Organization | RpcOrganization | None = None
-    project: Project | RpcProject | None = None
-
-
-def organization_context_transformer(data: SentryContextData) -> EvaluationContextDict:
-    context_data: EvaluationContextDict = dict()
-    org = data.organization
-    if org is None:
-        return context_data
-
-    if isinstance(org, Organization):
-        context_data["organization_slug"] = org.slug
-        context_data["organization_name"] = org.name
-        context_data["organization_id"] = org.id
-        early_adopter = bool(org.flags.early_adopter) if org.flags is not None else False
-        context_data["organization_is-early-adopter"] = early_adopter
-
-    elif isinstance(org, RpcOrganization):
-        context_data["organization_slug"] = org.slug
-        context_data["organization_name"] = org.name
-        context_data["organization_id"] = org.id
-        context_data["organization_is-early-adopter"] = org.flags.early_adopter
-    else:
-        raise InvalidContextDataException("Invalid organization object provided")
-
-    return context_data
-
-
-def project_context_transformer(data: SentryContextData) -> EvaluationContextDict:
-    context_data: EvaluationContextDict = dict()
-
-    if (proj := data.project) is not None:
-        if not isinstance(proj, Project):
-            raise InvalidContextDataException("Invalid project object provided")
-
-        context_data["project_slug"] = proj.slug
-        context_data["project_name"] = proj.name
-        context_data["project_id"] = proj.id
-
-    return context_data
-
-
-def user_context_transformer(data: SentryContextData) -> EvaluationContextDict:
-    context_data: EvaluationContextDict = dict()
-    user = data.actor
-    if user is None or isinstance(user, AnonymousUser):
-        return context_data
-
-    if not isinstance(user, User) and not isinstance(user, RpcUser):
-        raise InvalidContextDataException("Invalid actor object provided")
-
-    if user.is_authenticated:
-        context_data["user_id"] = user.id
-        context_data["user_is-superuser"] = user.is_superuser
-        context_data["user_is-staff"] = user.is_staff
-
-    verified_emails: list[str]
-
-    if isinstance(user, RpcUser):
-        verified_emails = list(user.emails)
-    else:
-        verified_emails = user.get_verified_emails().values_list("email", flat=True)
-
-    if user.email in verified_emails:
-        context_data["user_email"] = user.email
-        context_data["user_domain"] = user.email.rsplit("@", 1)[-1]
-
-    return context_data
-
-
-def get_sentry_flagpole_context_builder() -> ContextBuilder[SentryContextData]:
-    """
-    Creates and returns a new sentry flagpole context builder with Organization,
-     User, and Project transformers appended to it.
-    :return:
-    """
-    return (
-        ContextBuilder[SentryContextData]()
-        .add_context_transformer(organization_context_transformer)
-        .add_context_transformer(project_context_transformer)
-        .add_context_transformer(user_context_transformer)
-    )

+ 2 - 2
src/sentry/features/flagpole_context.py

@@ -96,7 +96,7 @@ def get_sentry_flagpole_context_builder() -> ContextBuilder[SentryContextData]:
     """
     return (
         ContextBuilder[SentryContextData]()
-        .add_context_transformer(organization_context_transformer)
-        .add_context_transformer(project_context_transformer)
+        .add_context_transformer(organization_context_transformer, ["organization_id"])
+        .add_context_transformer(project_context_transformer, ["project_id"])
         .add_context_transformer(user_context_transformer)
     )

+ 80 - 2
tests/flagpole/test_evaluation_context.py

@@ -3,11 +3,56 @@ from typing import Any
 
 import pytest
 
-from flagpole.evaluation_context import ContextBuilder, EvaluationContextDict
+from flagpole.evaluation_context import ContextBuilder, EvaluationContext, EvaluationContextDict
 
 
 class TestEvaluationContext:
-    pass
+    # Identity fields tests are mainly upholding that our hashing strategy does
+    # not change in the future, and that we calculate the id using the correct
+    # context values and keys in order.
+    def test_adds_identity_fields(self):
+        eval_context = EvaluationContext({}, set())
+        assert eval_context.id == 1245845410931227995499360226027473197403882391305
+
+        eval_context = EvaluationContext({"foo": "bar", "baz": "barfoo"}, {"foo"})
+        expected_id = 484477975355580460928302712356218993825269143262
+        assert eval_context.id == expected_id
+
+        # Assert that we skip the missing field but still generate the same
+        # context ID.
+        eval_context = EvaluationContext({"foo": "bar", "baz": "barfoo"}, {"foo", "whoops"})
+        assert eval_context.id == expected_id
+
+        eval_context = EvaluationContext({"foo": "bar", "baz": "barfoo"}, {"foo", "baz"})
+        expected_id = 1249805218608667754842212156585681631068251083301
+        assert eval_context.id == expected_id
+
+        # Assert that we use all properties to generate the context when all
+        # identity fields are missing.
+        eval_context = EvaluationContext({"foo": "bar", "baz": "barfoo"}, {"whoops", "test"})
+        assert eval_context.id == expected_id
+
+    def test_no_identity_fields_included(self):
+        eval_context = EvaluationContext({})
+        assert eval_context.id == 1245845410931227995499360226027473197403882391305
+
+        eval_context = EvaluationContext({"foo": "bar", "baz": "barfoo"})
+        expected_id = 1249805218608667754842212156585681631068251083301
+        assert eval_context.id == expected_id
+
+        eval_context = EvaluationContext({"foo": "bar", "baz": "barfoo", "test": "property"})
+        expected_id = 1395427532315258482176540981434194664973697472186
+        assert eval_context.id == expected_id
+
+    def test_get_has_data(self):
+        eval_context = EvaluationContext({"foo": "bar", "baz": "barfoo"}, {"foo"})
+
+        assert eval_context.has("foo") is True
+        assert eval_context.get("foo") == "bar"
+        assert eval_context.has("baz") is True
+        assert eval_context.get("baz") == "barfoo"
+        assert eval_context.has("bar") is False
+        assert eval_context.get("bar") is None
 
 
 @dataclass
@@ -90,3 +135,36 @@ class TestContextBuilder:
 
         with pytest.raises(Exception):
             context_builder.add_exception_handler(lambda _exc: None)
+
+    # This is nearly identical to the evaluation context around identity fields,
+    # just to ensure we compile and pass the correct list
+    def test_identity_fields_passing(self):
+        def transformer_with_data(_data: ContextData) -> dict[str, Any]:
+            return dict(foo="bar", baz="barfoo")
+
+        eval_context = ContextBuilder[ContextData]().build(ContextData(baz=2))
+
+        # This should be empty dictionary, empty identity fields list
+        assert eval_context.id == 1245845410931227995499360226027473197403882391305
+
+        eval_context = (
+            ContextBuilder[ContextData]()
+            .add_context_transformer(transformer_with_data, ["foo"])
+            .build(ContextData(baz=2))
+        )
+
+        expected_context_id = 484477975355580460928302712356218993825269143262
+        assert eval_context.id == expected_context_id
+
+        # The full identity_fields list passed into the context should be
+        # ["foo", "baz", "whoops"], but "whoops" will be filtered out by the
+        # context since the field does not exist in the context dict.
+        eval_context = (
+            ContextBuilder[ContextData]()
+            .add_context_transformer(transformer_with_data, ["foo"])
+            .add_context_transformer(transformer_with_data, ["baz", "whoops"])
+            .build(ContextData(baz=2))
+        )
+
+        expected_context_id = 1249805218608667754842212156585681631068251083301
+        assert eval_context.id == expected_context_id