Browse Source

feat(rpc): Allow filtering when we can't use VirtualColumnContexts (#83663)

- VirtualColumnContexts aren't available on timeseries, so if we filter
by a vcc field (project, device.class) we need custom logic
- From a discussion in slack going to drop device.module support, will
do that in a follow up PR
William Mak 1 month ago
parent
commit
3e48f977d4

+ 16 - 1
src/sentry/search/eap/columns.py

@@ -102,6 +102,21 @@ class FunctionDefinition:
         return [arg for arg in self.arguments if arg.default_arg is None and not arg.ignored]
 
 
+@dataclass
+class VirtualColumnDefinition:
+    constructor: Callable[[SnubaParams], VirtualColumnContext]
+    # Allows additional processing to the term after its been resolved
+    term_resolver: (
+        Callable[
+            [str | list[str]],
+            int | str | list[int] | list[str],
+        ]
+        | None
+    ) = None
+    filter_column: str | None = None
+    default_value: str | None = None
+
+
 @dataclass(frozen=True, kw_only=True)
 class ResolvedFunction(ResolvedAttribute):
     # The internal rpc alias for this column
@@ -161,5 +176,5 @@ def datetime_processor(datetime_string: str) -> str:
 class ColumnDefinitions:
     functions: dict[str, FunctionDefinition]
     columns: dict[str, ResolvedColumn]
-    contexts: dict[str, Callable[[SnubaParams], VirtualColumnContext]]
+    contexts: dict[str, VirtualColumnDefinition]
     trace_item_type: TraceItemType.ValueType

+ 9 - 0
src/sentry/search/eap/constants.py

@@ -48,6 +48,12 @@ STRING = AttributeKey.TYPE_STRING
 BOOLEAN = AttributeKey.TYPE_BOOLEAN
 DOUBLE = AttributeKey.TYPE_DOUBLE
 INT = AttributeKey.TYPE_INT
+TYPE_TO_STRING_MAP = {
+    STRING: "string",
+    BOOLEAN: "boolean",
+    DOUBLE: "double",
+    INT: "integer",
+}
 
 # TODO: we need a datetime type
 # Maps search types back to types for the proto
@@ -105,3 +111,6 @@ VALID_GRANULARITIES = frozenset(
 TRUTHY_VALUES = {"1", "true"}
 FALSEY_VALUES = {"0", "false"}
 BOOLEAN_VALUES = TRUTHY_VALUES.union(FALSEY_VALUES)
+
+PROJECT_FIELDS = {"project", "project.slug", "project.name"}
+REVERSE_CONTEXT_ERROR = "Unknown value {} for filter {}, expecting one of: {}"

+ 114 - 24
src/sentry/search/eap/resolver.py

@@ -32,7 +32,12 @@ from sentry_protos.snuba.v1.trace_item_filter_pb2 import (
 from sentry.api import event_search
 from sentry.exceptions import InvalidSearchQuery
 from sentry.search.eap import constants
-from sentry.search.eap.columns import ColumnDefinitions, ResolvedColumn, ResolvedFunction
+from sentry.search.eap.columns import (
+    ColumnDefinitions,
+    ResolvedColumn,
+    ResolvedFunction,
+    VirtualColumnDefinition,
+)
 from sentry.search.eap.types import SearchResolverConfig
 from sentry.search.events import constants as qb_constants
 from sentry.search.events import fields
@@ -50,10 +55,10 @@ class SearchResolver:
     params: SnubaParams
     config: SearchResolverConfig
     definitions: ColumnDefinitions
-    _resolved_attribute_cache: dict[str, tuple[ResolvedColumn, VirtualColumnContext | None]] = (
+    _resolved_attribute_cache: dict[str, tuple[ResolvedColumn, VirtualColumnDefinition | None]] = (
         field(default_factory=dict)
     )
-    _resolved_function_cache: dict[str, tuple[ResolvedFunction, VirtualColumnContext | None]] = (
+    _resolved_function_cache: dict[str, tuple[ResolvedFunction, VirtualColumnDefinition | None]] = (
         field(default_factory=dict)
     )
 
@@ -76,7 +81,9 @@ class SearchResolver:
     @sentry_sdk.trace
     def resolve_query(
         self, querystring: str | None
-    ) -> tuple[TraceItemFilter | None, AggregationFilter | None, list[VirtualColumnContext | None]]:
+    ) -> tuple[
+        TraceItemFilter | None, AggregationFilter | None, list[VirtualColumnDefinition | None]
+    ]:
         """Given a query string in the public search syntax eg. `span.description:foo` construct the TraceItemFilter"""
         environment_query = self.__resolve_environment_query()
         where, having, contexts = self.__resolve_query(querystring)
@@ -137,7 +144,9 @@ class SearchResolver:
 
     def __resolve_query(
         self, querystring: str | None
-    ) -> tuple[TraceItemFilter | None, AggregationFilter | None, list[VirtualColumnContext | None]]:
+    ) -> tuple[
+        TraceItemFilter | None, AggregationFilter | None, list[VirtualColumnDefinition | None]
+    ]:
         if querystring is None:
             return None, None, []
         try:
@@ -164,7 +173,9 @@ class SearchResolver:
 
     def _resolve_boolean_conditions(
         self, terms: event_filter.ParsedTerms
-    ) -> tuple[TraceItemFilter | None, AggregationFilter | None, list[VirtualColumnContext | None]]:
+    ) -> tuple[
+        TraceItemFilter | None, AggregationFilter | None, list[VirtualColumnDefinition | None]
+    ]:
         if len(terms) == 0:
             return None, None, []
         elif len(terms) == 1:
@@ -256,14 +267,16 @@ class SearchResolver:
 
     def _resolve_terms(
         self, terms: event_filter.ParsedTerms
-    ) -> tuple[TraceItemFilter | None, AggregationFilter | None, list[VirtualColumnContext | None]]:
+    ) -> tuple[
+        TraceItemFilter | None, AggregationFilter | None, list[VirtualColumnDefinition | None]
+    ]:
         where, where_contexts = self._resolve_where(terms)
         having, having_contexts = self._resolve_having(terms)
         return where, having, where_contexts + having_contexts
 
     def _resolve_where(
         self, terms: event_filter.ParsedTerms
-    ) -> tuple[TraceItemFilter | None, list[VirtualColumnContext | None]]:
+    ) -> tuple[TraceItemFilter | None, list[VirtualColumnDefinition | None]]:
         parsed_terms = []
         resolved_contexts = []
         for item in terms:
@@ -282,7 +295,7 @@ class SearchResolver:
 
     def _resolve_having(
         self, terms: event_filter.ParsedTerms
-    ) -> tuple[AggregationFilter | None, list[VirtualColumnContext | None]]:
+    ) -> tuple[AggregationFilter | None, list[VirtualColumnDefinition | None]]:
         if not self.config.use_aggregate_conditions:
             return None, []
 
@@ -305,15 +318,87 @@ class SearchResolver:
             return parsed_terms[0], resolved_contexts
         return None, []
 
+    def resolve_virtual_context_term(
+        self,
+        term: str,
+        raw_value: str | list[str],
+        resolved_column: ResolvedColumn,
+        context: VirtualColumnDefinition,
+    ) -> list[str] | str:
+        # Convert the term to the expected values
+        final_raw_value: str | list[str] = []
+        resolved_context = context.constructor(self.params)
+        reversed_context = {v: k for k, v in resolved_context.value_map.items()}
+        if isinstance(raw_value, list):
+            new_value = []
+            for raw_iterable in raw_value:
+                if context.default_value and context.default_value == raw_iterable:
+                    # Avoiding this for now, while this could work with the Unknown:"" mapping
+                    # But that won't work once we use the VirtualColumnContext.default_value
+                    raise InvalidSearchQuery(
+                        f"Using {raw_iterable} in an IN filter is not currently supported"
+                    )
+                elif raw_iterable not in reversed_context:
+                    valid_values = list(reversed_context.keys())[:5]
+                    if len(valid_values) > 5:
+                        valid_values.append("...")
+                    raise InvalidSearchQuery(
+                        constants.REVERSE_CONTEXT_ERROR.format(
+                            raw_value, term, ", ".join(valid_values)
+                        )
+                    )
+                else:
+                    new_value.append(reversed_context[raw_iterable])
+            final_raw_value = new_value
+        elif raw_value in reversed_context:
+            final_raw_value = reversed_context[raw_value]
+        elif context.default_value and context.default_value == raw_value:
+            # Avoiding this for now, while this could work with the Unknown:"" mapping
+            # But that won't work once we use the VirtualColumnContext.default_value
+            raise InvalidSearchQuery(f"Using {raw_value} is not currently supported")
+        else:
+            valid_values = list(reversed_context.keys())[:5]
+            if len(valid_values) > 5:
+                valid_values.append("...")
+            raise InvalidSearchQuery(
+                constants.REVERSE_CONTEXT_ERROR.format(
+                    raw_value, term, ", ".join(list(reversed_context.keys())[:5])
+                )
+            )
+        return final_raw_value
+
     def resolve_term(
         self, term: event_search.SearchFilter
-    ) -> tuple[TraceItemFilter, VirtualColumnContext | None]:
-        resolved_column, context = self.resolve_column(term.key.name)
+    ) -> tuple[TraceItemFilter, VirtualColumnDefinition | None]:
+        resolved_column, context_definition = self.resolve_column(term.key.name)
 
         if not isinstance(resolved_column.proto_definition, AttributeKey):
             raise ValueError(f"{term.key.name} is not valid search term")
 
         raw_value = term.value.raw_value
+
+        if context_definition:
+            if term.value.is_wildcard():
+                # Avoiding this for now, but we could theoretically do a wildcard search on the resolved contexts
+                raise InvalidSearchQuery(f"Cannot use wildcards with {term.key.name}")
+            if (
+                isinstance(raw_value, str)
+                or isinstance(raw_value, list)
+                and all(isinstance(value, str) for value in raw_value)
+            ):
+                raw_value = self.resolve_virtual_context_term(
+                    term.key.name,
+                    raw_value,
+                    resolved_column,
+                    context_definition,
+                )
+            else:
+                raise InvalidSearchQuery(f"{raw_value} not a valid term for {term.key.name}")
+            if context_definition.term_resolver:
+                raw_value = context_definition.term_resolver(raw_value)
+            if context_definition.filter_column is not None:
+                resolved_column, _ = self.resolve_attribute(context_definition.filter_column)
+
         if term.value.is_wildcard():
             if term.operator == "=":
                 operator = ComparisonFilter.OP_LIKE
@@ -343,12 +428,12 @@ class SearchResolver:
                     value=self._resolve_search_value(resolved_column, term.operator, raw_value),
                 )
             ),
-            context,
+            context_definition,
         )
 
     def resolve_aggregate_term(
         self, term: event_search.AggregateFilter
-    ) -> tuple[AggregationFilter, VirtualColumnContext | None]:
+    ) -> tuple[AggregationFilter, VirtualColumnDefinition | None]:
         resolved_column, context = self.resolve_column(term.key.name)
 
         if not isinstance(resolved_column.proto_definition, AttributeAggregation):
@@ -431,18 +516,21 @@ class SearchResolver:
                     bool_value = lowered_value in constants.TRUTHY_VALUES
                     return AttributeValue(val_bool=bool_value)
             raise InvalidSearchQuery(
-                f"{value} is not a valid filter value for {column.public_alias}"
+                f"{value} is not a valid filter value for {column.public_alias}, expecting {constants.TYPE_TO_STRING_MAP[column_type]}, but got a {type(value)}"
             )
         else:
             raise NotImplementedError("Aggregate Queries not implemented yet")
 
-    def clean_contexts(
-        self, resolved_contexts: list[VirtualColumnContext | None]
+    def resolve_contexts(
+        self, context_definitions: list[VirtualColumnDefinition | None]
     ) -> list[VirtualColumnContext]:
         """Given a list of contexts that may have None in them, remove the Nones and remove the dupes"""
         final_contexts = []
         existing_target_columns = set()
-        for context in resolved_contexts:
+        for context_definition in context_definitions:
+            if context_definition is None:
+                continue
+            context = context_definition.constructor(self.params)
             if context is None or context.to_column_name in existing_target_columns:
                 continue
             else:
@@ -453,7 +541,7 @@ class SearchResolver:
     @sentry_sdk.trace
     def resolve_columns(
         self, selected_columns: list[str]
-    ) -> tuple[list[ResolvedColumn | ResolvedFunction], list[VirtualColumnContext | None]]:
+    ) -> tuple[list[ResolvedColumn | ResolvedFunction], list[VirtualColumnDefinition | None]]:
         """Given a list of columns resolve them and get their context if applicable
 
         This function will also dedupe the virtual column contexts if necessary
@@ -488,7 +576,7 @@ class SearchResolver:
 
     def resolve_column(
         self, column: str, match: Match | None = None
-    ) -> tuple[ResolvedColumn | ResolvedFunction, VirtualColumnContext | None]:
+    ) -> tuple[ResolvedColumn | ResolvedFunction, VirtualColumnDefinition | None]:
         """Column is either an attribute or an aggregate, this function will determine which it is and call the relevant
         resolve function"""
         match = fields.is_function(column)
@@ -504,7 +592,7 @@ class SearchResolver:
     @sentry_sdk.trace
     def resolve_attributes(
         self, columns: list[str]
-    ) -> tuple[list[ResolvedColumn], list[VirtualColumnContext | None]]:
+    ) -> tuple[list[ResolvedColumn], list[VirtualColumnDefinition | None]]:
         """Helper function to resolve a list of attributes instead of 1 attribute at a time"""
         resolved_columns = []
         resolved_contexts = []
@@ -514,14 +602,16 @@ class SearchResolver:
             resolved_contexts.append(context)
         return resolved_columns, resolved_contexts
 
-    def resolve_attribute(self, column: str) -> tuple[ResolvedColumn, VirtualColumnContext | None]:
+    def resolve_attribute(
+        self, column: str
+    ) -> tuple[ResolvedColumn, VirtualColumnDefinition | None]:
         """Attributes are columns that aren't 'functions' or 'aggregates', usually this means string or numeric
         attributes (aka. tags), but can also refer to fields like span.description"""
         # If a virtual context is defined the column definition is always the same
         if column in self._resolved_attribute_cache:
             return self._resolved_attribute_cache[column]
         if column in self.definitions.contexts:
-            column_context = self.definitions.contexts[column](self.params)
+            column_context = self.definitions.contexts[column]
             column_definition = ResolvedColumn(
                 public_alias=column, internal_name=column, search_type="string"
             )
@@ -568,7 +658,7 @@ class SearchResolver:
     @sentry_sdk.trace
     def resolve_aggregates(
         self, columns: list[str]
-    ) -> tuple[list[ResolvedFunction], list[VirtualColumnContext | None]]:
+    ) -> tuple[list[ResolvedFunction], list[VirtualColumnDefinition | None]]:
         """Helper function to resolve a list of aggregates instead of 1 attribute at a time"""
         resolved_aggregates, resolved_contexts = [], []
         for column in columns:
@@ -579,7 +669,7 @@ class SearchResolver:
 
     def resolve_aggregate(
         self, column: str, match: Match | None = None
-    ) -> tuple[ResolvedFunction, VirtualColumnContext | None]:
+    ) -> tuple[ResolvedFunction, VirtualColumnDefinition | None]:
         if column in self._resolved_function_cache:
             return self._resolved_function_cache[column]
         # Check if this is a valid function, parse the function name and args out

+ 25 - 5
src/sentry/search/eap/span_columns.py

@@ -10,6 +10,7 @@ from sentry.search.eap.columns import (
     ColumnDefinitions,
     FunctionDefinition,
     ResolvedColumn,
+    VirtualColumnDefinition,
     datetime_processor,
     simple_measurements_field,
     simple_sentry_field,
@@ -337,6 +338,15 @@ def project_context_constructor(column_name: str) -> Callable[[SnubaParams], Vir
     return context_constructor
 
 
+def project_term_resolver(
+    raw_value: str | list[str],
+) -> list[int] | int:
+    if isinstance(raw_value, list):
+        return [int(val) for val in raw_value]
+    else:
+        return int(raw_value)
+
+
 def device_class_context_constructor(params: SnubaParams) -> VirtualColumnContext:
     # EAP defaults to lower case `unknown`, but in querybuilder we used `Unknown`
     value_map = {"": "Unknown"}
@@ -360,12 +370,22 @@ def module_context_constructor(params: SnubaParams) -> VirtualColumnContext:
 
 
 SPAN_VIRTUAL_CONTEXTS = {
-    "project": project_context_constructor("project"),
-    "project.slug": project_context_constructor("project.slug"),
-    "project.name": project_context_constructor("project.name"),
-    "device.class": device_class_context_constructor,
-    "span.module": module_context_constructor,
+    "device.class": VirtualColumnDefinition(
+        constructor=device_class_context_constructor,
+        filter_column="sentry.device.class",
+        # TODO: need to change this so the VCC is using it too, but would require rewriting the term_resolver
+        default_value="Unknown",
+    ),
+    "span.module": VirtualColumnDefinition(
+        constructor=module_context_constructor,
+    ),
 }
+for key in constants.PROJECT_FIELDS:
+    SPAN_VIRTUAL_CONTEXTS[key] = VirtualColumnDefinition(
+        constructor=project_context_constructor(key),
+        term_resolver=project_term_resolver,
+        filter_column="project.id",
+    )
 
 
 SPAN_FUNCTION_DEFINITIONS = {

+ 1 - 3
src/sentry/snuba/spans_rpc.py

@@ -59,7 +59,7 @@ def run_table_query(
     meta = resolver.resolve_meta(referrer=referrer)
     where, having, query_contexts = resolver.resolve_query(query_string)
     columns, column_contexts = resolver.resolve_columns(selected_columns)
-    contexts = resolver.clean_contexts(query_contexts + column_contexts)
+    contexts = resolver.resolve_contexts(query_contexts + column_contexts)
     # We allow orderby function_aliases if they're a selected_column
     # eg. can orderby sum_span_self_time, assuming sum(span.self_time) is selected
     orderby_aliases = {
@@ -185,8 +185,6 @@ def get_timeseries_query(
             if isinstance(groupby.proto_definition, AttributeKey)
         ],
         granularity_secs=granularity_secs,
-        # TODO: need to add this once the RPC supports it
-        # virtual_column_contexts=[context for context in resolver.clean_contexts(query_contexts) if context is not None],
     )
 
 

+ 4 - 2
tests/sentry/search/eap/test_spans.py

@@ -548,7 +548,8 @@ class SearchResolverColumnTest(TestCase):
         assert resolved_column.proto_definition == AttributeKey(
             name="project", type=AttributeKey.Type.TYPE_STRING
         )
-        assert virtual_context == VirtualColumnContext(
+        assert virtual_context is not None
+        assert virtual_context.constructor(self.resolver.params) == VirtualColumnContext(
             from_column_name="sentry.project_id",
             to_column_name="project",
             value_map={str(self.project.id): self.project.slug},
@@ -559,7 +560,8 @@ class SearchResolverColumnTest(TestCase):
         assert resolved_column.proto_definition == AttributeKey(
             name="project.slug", type=AttributeKey.Type.TYPE_STRING
         )
-        assert virtual_context == VirtualColumnContext(
+        assert virtual_context is not None
+        assert virtual_context.constructor(self.resolver.params) == VirtualColumnContext(
             from_column_name="sentry.project_id",
             to_column_name="project.slug",
             value_map={str(self.project.id): self.project.slug},

+ 125 - 0
tests/snuba/api/endpoints/test_organization_events_stats_span_indexed.py

@@ -1025,3 +1025,128 @@ class OrganizationEventsEAPRPCSpanEndpointTest(OrganizationEventsEAPSpanEndpoint
     @pytest.mark.xfail(reason="division by 0 error in snuba")
     def test_handle_nans_from_snuba_top_n(self):
         super().test_handle_nans_from_snuba_top_n()
+
+    def test_project_filters(self):
+        event_counts = [6, 0, 6, 3, 0, 3]
+        spans = []
+        for hour, count in enumerate(event_counts):
+            spans.extend(
+                [
+                    self.create_span(
+                        {"description": "foo", "sentry_tags": {"status": "success"}},
+                        start_ts=self.day_ago + timedelta(hours=hour, minutes=minute),
+                    )
+                    for minute in range(count)
+                ],
+            )
+        self.store_spans(spans, is_eap=self.is_eap)
+
+        for querystring in [f"project:{self.project.slug}", f"project:[{self.project.slug}]"]:
+            response = self._do_request(
+                data={
+                    "start": self.day_ago,
+                    "end": self.day_ago + timedelta(hours=6),
+                    "interval": "1h",
+                    "yAxis": "count()",
+                    "query": querystring,
+                    "project": self.project.id,
+                    "dataset": self.dataset,
+                },
+            )
+            assert response.status_code == 200, response.content
+            data = response.data["data"]
+            assert len(data) == 6
+            assert response.data["meta"]["dataset"] == self.dataset
+
+            rows = data[0:6]
+            for test in zip(event_counts, rows):
+                assert test[1][1][0]["count"] == test[0]
+
+    def test_nonexistent_project_filter(self):
+        response = self._do_request(
+            data={
+                "start": self.day_ago,
+                "end": self.day_ago + timedelta(hours=6),
+                "interval": "1h",
+                "yAxis": "count()",
+                "query": "project:foobar",
+                "project": self.project.id,
+                "dataset": self.dataset,
+            },
+        )
+        assert response.status_code == 400, response.content
+        assert "Unknown value foobar" in response.data["detail"]
+
+    def test_device_class_filter(self):
+        event_counts = [6, 0, 6, 3, 0, 3]
+        spans = []
+        for hour, count in enumerate(event_counts):
+            spans.extend(
+                [
+                    self.create_span(
+                        {
+                            "description": "foo",
+                            "sentry_tags": {"status": "success", "device.class": "1"},
+                        },
+                        start_ts=self.day_ago + timedelta(hours=hour, minutes=minute),
+                    )
+                    for minute in range(count)
+                ],
+            )
+        self.store_spans(spans, is_eap=self.is_eap)
+
+        for querystring in ["device.class:low", "device.class:[low,medium]"]:
+            response = self._do_request(
+                data={
+                    "start": self.day_ago,
+                    "end": self.day_ago + timedelta(hours=6),
+                    "interval": "1h",
+                    "yAxis": "count()",
+                    "query": querystring,
+                    "project": self.project.id,
+                    "dataset": self.dataset,
+                },
+            )
+            assert response.status_code == 200, response.content
+            data = response.data["data"]
+            assert len(data) == 6
+            assert response.data["meta"]["dataset"] == self.dataset
+
+            rows = data[0:6]
+            for test in zip(event_counts, rows):
+                assert test[1][1][0]["count"] == test[0]
+
+    def test_device_class_filter_unknown(self):
+        event_counts = [6, 0, 6, 3, 0, 3]
+        spans = []
+        for hour, count in enumerate(event_counts):
+            spans.extend(
+                [
+                    self.create_span(
+                        {"description": "foo", "sentry_tags": {"status": "success"}},
+                        start_ts=self.day_ago + timedelta(hours=hour, minutes=minute),
+                    )
+                    for minute in range(count)
+                ],
+            )
+        self.store_spans(spans, is_eap=self.is_eap)
+
+        response = self._do_request(
+            data={
+                "start": self.day_ago,
+                "end": self.day_ago + timedelta(hours=6),
+                "interval": "1h",
+                "yAxis": "count()",
+                "query": "device.class:Unknown",
+                "project": self.project.id,
+                "dataset": self.dataset,
+            },
+        )
+        assert response.status_code == 200, response.content
+        data = response.data["data"]
+        assert len(data) == 6
+        assert response.data["meta"]["dataset"] == self.dataset
+
+        rows = data[0:6]
+        for test in zip(event_counts, rows):
+            assert test[1][1][0]["count"] == test[0]