Browse Source

feat(eap-spans): add count_op function (#86823)

1. Add the last type of column supported by RPC, conditional_aggregation
2. Adds `count_op` function

Closes https://github.com/getsentry/team-visibility/issues/32

---------

Co-authored-by: getsantry[bot] <66042841+getsantry[bot]@users.noreply.github.com>
Dominik Buszowiecki 1 day ago
parent
commit
427b54377a

+ 65 - 1
src/sentry/search/eap/columns.py

@@ -16,6 +16,7 @@ from sentry_protos.snuba.v1.trace_item_attribute_pb2 import (
     Function,
     VirtualColumnContext,
 )
+from sentry_protos.snuba.v1.trace_item_filter_pb2 import TraceItemFilter
 
 from sentry.exceptions import InvalidSearchQuery
 from sentry.search.eap import constants
@@ -169,6 +170,36 @@ class ResolvedAggregate(ResolvedFunction):
         )
 
 
+@dataclass(frozen=True, kw_only=True)
+class ResolvedConditionalAggregate(ResolvedFunction):
+
+    # The internal rpc alias for this column
+    internal_name: Function.ValueType
+    # Whether to enable extrapolation
+    extrapolation: bool = True
+    # The condition to filter on
+    filter: TraceItemFilter
+    # The attribute to conditionally aggregate on
+    key: AttributeKey
+
+    is_aggregate: bool = field(default=True, init=False)
+
+    @property
+    def proto_definition(self) -> AttributeConditionalAggregation:
+        """The definition of this function as needed by the RPC"""
+        return AttributeConditionalAggregation(
+            aggregate=self.internal_name,
+            key=self.key,
+            filter=self.filter,
+            label=self.public_alias,
+            extrapolation_mode=(
+                ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED
+                if self.extrapolation
+                else ExtrapolationMode.EXTRAPOLATION_MODE_NONE
+            ),
+        )
+
+
 @dataclass(kw_only=True)
 class FunctionDefinition:
     """
@@ -197,7 +228,7 @@ class FunctionDefinition:
         alias: str,
         search_type: constants.SearchType,
         resolved_argument: AttributeKey | Any | None,
-    ) -> ResolvedFormula | ResolvedAggregate:
+    ) -> ResolvedFormula | ResolvedAggregate | ResolvedConditionalAggregate:
         raise NotImplementedError()
 
 
@@ -219,6 +250,38 @@ class AggregateDefinition(FunctionDefinition):
         )
 
 
+@dataclass(kw_only=True)
+class ConditionalAggregateDefinition(FunctionDefinition):
+    """
+    The definition of a conditional aggregation,
+    Conditionally aggregates the `key`, if it passes the `filter`.
+    The type of aggregation is defined by the `internal_name`.
+    The `filter` is returned by the `filter_resolver` function which takes in the args from the user and returns a `TraceItemFilter`.
+    """
+
+    # The type of aggregation (ex. sum, avg)
+    internal_function: Function.ValueType
+    # The attribute to conditionally aggregate on
+    key: AttributeKey
+    # A function that takes in the resolved argument and returns the condition to filter on
+    filter_resolver: Callable[..., TraceItemFilter]
+
+    def resolve(
+        self, alias: str, search_type: constants.SearchType, resolved_argument: AttributeKey | None
+    ) -> ResolvedConditionalAggregate:
+        return ResolvedConditionalAggregate(
+            public_alias=alias,
+            internal_name=self.internal_function,
+            search_type=search_type,
+            internal_type=self.internal_type,
+            filter=self.filter_resolver(resolved_argument),
+            key=self.key,
+            processor=self.processor,
+            extrapolation=self.extrapolation,
+            argument=resolved_argument,
+        )
+
+
 @dataclass(kw_only=True)
 class FormulaDefinition(FunctionDefinition):
     # A function that takes in the resolved argument and returns a Column.BinaryFormula
@@ -299,6 +362,7 @@ def project_term_resolver(
 @dataclass(frozen=True)
 class ColumnDefinitions:
     aggregates: dict[str, AggregateDefinition]
+    conditional_aggregates: dict[str, ConditionalAggregateDefinition]
     formulas: dict[str, FormulaDefinition]
     columns: dict[str, ResolvedAttribute]
     contexts: dict[str, VirtualColumnDefinition]

+ 1 - 0
src/sentry/search/eap/ourlogs/definitions.py

@@ -8,6 +8,7 @@ from sentry.search.eap.ourlogs.attributes import (
 
 OURLOG_DEFINITIONS = ColumnDefinitions(
     aggregates={},
+    conditional_aggregates={},
     formulas={},
     columns=OURLOG_ATTRIBUTE_DEFINITIONS,
     contexts=OURLOG_VIRTUAL_CONTEXTS,

+ 26 - 13
src/sentry/search/eap/resolver.py

@@ -38,9 +38,11 @@ from sentry.search.eap import constants
 from sentry.search.eap.columns import (
     AggregateDefinition,
     ColumnDefinitions,
+    ConditionalAggregateDefinition,
     FormulaDefinition,
     ResolvedAggregate,
     ResolvedAttribute,
+    ResolvedConditionalAggregate,
     ResolvedFormula,
     VirtualColumnDefinition,
 )
@@ -65,7 +67,11 @@ class SearchResolver:
         str, tuple[ResolvedAttribute, VirtualColumnDefinition | None]
     ] = field(default_factory=dict)
     _resolved_function_cache: dict[
-        str, tuple[ResolvedFormula | ResolvedAggregate, VirtualColumnDefinition | None]
+        str,
+        tuple[
+            ResolvedFormula | ResolvedAggregate | ResolvedConditionalAggregate,
+            VirtualColumnDefinition | None,
+        ],
     ] = field(default_factory=dict)
 
     @sentry_sdk.trace
@@ -553,7 +559,9 @@ class SearchResolver:
 
     @sentry_sdk.trace
     def resolve_columns(self, selected_columns: list[str]) -> tuple[
-        list[ResolvedAttribute | ResolvedAggregate | ResolvedFormula],
+        list[
+            ResolvedAttribute | ResolvedAggregate | ResolvedConditionalAggregate | ResolvedFormula
+        ],
         list[VirtualColumnDefinition | None],
     ]:
         """Given a list of columns resolve them and get their context if applicable
@@ -588,10 +596,9 @@ class SearchResolver:
 
         return resolved_columns, resolved_contexts
 
-    def resolve_column(
-        self, column: str, match: Match | None = None
-    ) -> tuple[
-        ResolvedAttribute | ResolvedAggregate | ResolvedFormula, VirtualColumnDefinition | None
+    def resolve_column(self, column: str, match: Match | None = None) -> tuple[
+        ResolvedAttribute | ResolvedAggregate | ResolvedConditionalAggregate | ResolvedFormula,
+        VirtualColumnDefinition | None,
     ]:
         """Column is either an attribute or an aggregate, this function will determine which it is and call the relevant
         resolve function"""
@@ -672,9 +679,10 @@ class SearchResolver:
             raise InvalidSearchQuery(f"Could not parse {column}")
 
     @sentry_sdk.trace
-    def resolve_functions(
-        self, columns: list[str]
-    ) -> tuple[list[ResolvedFormula | ResolvedAggregate], list[VirtualColumnDefinition | None]]:
+    def resolve_functions(self, columns: list[str]) -> tuple[
+        list[ResolvedFormula | ResolvedAggregate | ResolvedConditionalAggregate],
+        list[VirtualColumnDefinition | None],
+    ]:
         """Helper function to resolve a list of functions instead of 1 attribute at a time"""
         resolved_functions, resolved_contexts = [], []
         for column in columns:
@@ -683,9 +691,10 @@ class SearchResolver:
             resolved_contexts.append(context)
         return resolved_functions, resolved_contexts
 
-    def resolve_function(
-        self, column: str, match: Match | None = None
-    ) -> tuple[ResolvedFormula | ResolvedAggregate, VirtualColumnDefinition | None]:
+    def resolve_function(self, column: str, match: Match | None = None) -> tuple[
+        ResolvedFormula | ResolvedAggregate | ResolvedConditionalAggregate,
+        VirtualColumnDefinition | None,
+    ]:
         if column in self._resolved_function_cache:
             return self._resolved_function_cache[column]
         # Check if the column looks like a function (matches a pattern), parse the function name and args out
@@ -700,11 +709,15 @@ class SearchResolver:
         alias = match.group("alias") or column
 
         # Get the function definition
-        function_definition: AggregateDefinition | FormulaDefinition
+        function_definition: (
+            AggregateDefinition | FormulaDefinition | ConditionalAggregateDefinition
+        )
         if function in self.definitions.aggregates:
             function_definition = self.definitions.aggregates[function]
         elif function in self.definitions.formulas:
             function_definition = self.definitions.formulas[function]
+        elif function in self.definitions.conditional_aggregates:
+            function_definition = self.definitions.conditional_aggregates[function]
         else:
             raise InvalidSearchQuery(f"Unknown function {function}")
 

+ 30 - 2
src/sentry/search/eap/spans/aggregates.py

@@ -1,7 +1,12 @@
-from sentry_protos.snuba.v1.trace_item_attribute_pb2 import Function
+from sentry_protos.snuba.v1.trace_item_attribute_pb2 import AttributeKey, AttributeValue, Function
+from sentry_protos.snuba.v1.trace_item_filter_pb2 import ComparisonFilter, TraceItemFilter
 
 from sentry.search.eap import constants
-from sentry.search.eap.columns import AggregateDefinition, ArgumentDefinition
+from sentry.search.eap.columns import (
+    AggregateDefinition,
+    ArgumentDefinition,
+    ConditionalAggregateDefinition,
+)
 
 
 def count_processor(count_value: int | None) -> int:
@@ -11,6 +16,29 @@ def count_processor(count_value: int | None) -> int:
         return count_value
 
 
+def resolve_count_op_filter(op_value: str) -> TraceItemFilter:
+    return TraceItemFilter(
+        comparison_filter=ComparisonFilter(
+            key=AttributeKey(
+                name="sentry.op",
+                type=AttributeKey.TYPE_STRING,
+            ),
+            op=ComparisonFilter.OP_EQUALS,
+            value=AttributeValue(val_str=op_value),
+        )
+    )
+
+
+SPAN_CONDITIONAL_AGGREGATE_DEFINITIONS = {
+    "count_op": ConditionalAggregateDefinition(
+        internal_function=Function.FUNCTION_COUNT,
+        default_search_type="integer",
+        arguments=[ArgumentDefinition(argument_types={"string"}, is_attribute=False)],
+        key=AttributeKey(type=AttributeKey.TYPE_STRING, name="sentry.op"),
+        filter_resolver=resolve_count_op_filter,
+    )
+}
+
 SPAN_AGGREGATE_DEFINITIONS = {
     "sum": AggregateDefinition(
         internal_function=Function.FUNCTION_SUM,

+ 5 - 1
src/sentry/search/eap/spans/definitions.py

@@ -1,12 +1,16 @@
 from sentry_protos.snuba.v1.request_common_pb2 import TraceItemType
 
 from sentry.search.eap.columns import ColumnDefinitions
-from sentry.search.eap.spans.aggregates import SPAN_AGGREGATE_DEFINITIONS
+from sentry.search.eap.spans.aggregates import (
+    SPAN_AGGREGATE_DEFINITIONS,
+    SPAN_CONDITIONAL_AGGREGATE_DEFINITIONS,
+)
 from sentry.search.eap.spans.attributes import SPAN_ATTRIBUTE_DEFINITIONS, SPAN_VIRTUAL_CONTEXTS
 from sentry.search.eap.spans.formulas import SPAN_FORMULA_DEFINITIONS
 
 SPAN_DEFINITIONS = ColumnDefinitions(
     aggregates=SPAN_AGGREGATE_DEFINITIONS,
+    conditional_aggregates=SPAN_CONDITIONAL_AGGREGATE_DEFINITIONS,
     formulas=SPAN_FORMULA_DEFINITIONS,
     columns=SPAN_ATTRIBUTE_DEFINITIONS,
     contexts=SPAN_VIRTUAL_CONTEXTS,

+ 1 - 0
src/sentry/search/eap/uptime_checks/definitions.py

@@ -8,6 +8,7 @@ from sentry.search.eap.uptime_checks.attributes import (
 
 UPTIME_CHECK_DEFINITIONS = ColumnDefinitions(
     aggregates={},
+    conditional_aggregates={},
     formulas={},
     columns=UPTIME_CHECK_ATTRIBUTE_DEFINITIONS,
     contexts=UPTIME_CHECK_VIRTUAL_CONTEXTS,

+ 11 - 2
src/sentry/snuba/rpc_dataset_common.py

@@ -5,7 +5,12 @@ from sentry_protos.snuba.v1.endpoint_trace_item_table_pb2 import Column, TraceIt
 from sentry_protos.snuba.v1.request_common_pb2 import PageToken
 from sentry_protos.snuba.v1.trace_item_attribute_pb2 import AttributeKey
 
-from sentry.search.eap.columns import ResolvedAggregate, ResolvedAttribute, ResolvedFormula
+from sentry.search.eap.columns import (
+    ResolvedAggregate,
+    ResolvedAttribute,
+    ResolvedConditionalAggregate,
+    ResolvedFormula,
+)
 from sentry.search.eap.resolver import SearchResolver
 from sentry.search.eap.types import CONFIDENCES, ConfidenceData, EAPResponse
 from sentry.search.events.fields import get_function_alias
@@ -16,11 +21,15 @@ from sentry.utils.snuba import process_value
 logger = logging.getLogger("sentry.snuba.spans_rpc")
 
 
-def categorize_column(column: ResolvedAttribute | ResolvedAggregate | ResolvedFormula) -> Column:
+def categorize_column(
+    column: ResolvedAttribute | ResolvedAggregate | ResolvedConditionalAggregate | ResolvedFormula,
+) -> Column:
     if isinstance(column, ResolvedFormula):
         return Column(formula=column.proto_definition, label=column.public_alias)
     if isinstance(column, ResolvedAggregate):
         return Column(aggregation=column.proto_definition, label=column.public_alias)
+    if isinstance(column, ResolvedConditionalAggregate):
+        return Column(conditional_aggregation=column.proto_definition, label=column.public_alias)
     else:
         return Column(key=column.proto_definition, label=column.public_alias)
 

+ 11 - 2
src/sentry/snuba/spans_rpc.py

@@ -13,7 +13,12 @@ from sentry_protos.snuba.v1.trace_item_filter_pb2 import AndFilter, OrFilter, Tr
 
 from sentry.api.event_search import SearchFilter, SearchKey, SearchValue
 from sentry.exceptions import InvalidSearchQuery
-from sentry.search.eap.columns import ResolvedAggregate, ResolvedAttribute, ResolvedFormula
+from sentry.search.eap.columns import (
+    ResolvedAggregate,
+    ResolvedAttribute,
+    ResolvedConditionalAggregate,
+    ResolvedFormula,
+)
 from sentry.search.eap.constants import DOUBLE, INT, MAX_ROLLUP_POINTS, STRING, VALID_GRANULARITIES
 from sentry.search.eap.resolver import SearchResolver
 from sentry.search.eap.spans.definitions import SPAN_DEFINITIONS
@@ -76,7 +81,11 @@ def get_timeseries_query(
     config: SearchResolverConfig,
     granularity_secs: int,
     extra_conditions: TraceItemFilter | None = None,
-) -> tuple[TimeSeriesRequest, list[ResolvedFormula | ResolvedAggregate], list[ResolvedAttribute]]:
+) -> tuple[
+    TimeSeriesRequest,
+    list[ResolvedFormula | ResolvedAggregate | ResolvedConditionalAggregate],
+    list[ResolvedAttribute],
+]:
     resolver = get_resolver(params=params, config=config)
     meta = resolver.resolve_meta(referrer=referrer)
     query, _, query_contexts = resolver.resolve_query(query_string)

+ 40 - 0
tests/snuba/api/endpoints/test_organization_events_span_indexed.py

@@ -2827,6 +2827,46 @@ class OrganizationEventsEAPRPCSpanEndpointTest(OrganizationEventsEAPSpanEndpoint
 
         assert meta["dataset"] == self.dataset
 
+    def test_count_op(self):
+        self.store_spans(
+            [
+                self.create_span(
+                    {"op": "queue.process", "sentry_tags": {"op": "queue.process"}},
+                    start_ts=self.ten_mins_ago,
+                ),
+                self.create_span(
+                    {"op": "queue.process", "sentry_tags": {"op": "queue.process"}},
+                    start_ts=self.ten_mins_ago,
+                ),
+                self.create_span(
+                    {"op": "queue.publish", "sentry_tags": {"op": "queue.publish"}},
+                    start_ts=self.ten_mins_ago,
+                ),
+            ],
+            is_eap=self.is_eap,
+        )
+
+        response = self.do_request(
+            {
+                "field": [
+                    "count_op(queue.process)",
+                    "count_op(queue.publish)",
+                ],
+                "project": self.project.id,
+                "dataset": self.dataset,
+            }
+        )
+
+        assert response.status_code == 200, response.content
+        data = response.data["data"]
+        meta = response.data["meta"]
+
+        assert len(data) == 1
+        assert data[0]["count_op(queue.process)"] == 2
+        assert data[0]["count_op(queue.publish)"] == 1
+
+        assert meta["dataset"] == self.dataset
+
     def test_ttif_ttfd_contribution_rate(self):
         spans = []
         for _ in range(8):