Просмотр исходного кода

feat(trace-explorer): Support coming from a metric (#70000)

This supports querying for traces when coming from a metric.

Note: This performs an application-level join and is not performant at
all as it requires going back and forth between querying on the metric
and spans.
Tony Xiao 10 месяцев назад
Родитель
Сommit
bb946426d2

+ 12 - 12
src/sentry/api/endpoints/organization_metrics.py

@@ -563,23 +563,23 @@ class OrganizationMetricsSamplesEndpoint(OrganizationEventsV2EndpointBase):
                 raise ParseError(f"Unsupported sort: {sort} for MRI")
 
         executor = executor_cls(
-            serialized["mri"],
-            params,
-            snuba_params,
-            serialized["field"],
-            serialized.get("operation"),
-            serialized.get("query", ""),
-            serialized.get("min"),
-            serialized.get("max"),
-            serialized.get("sort"),
-            rollup,
-            Referrer.API_ORGANIZATION_METRICS_SAMPLES,
+            mri=serialized["mri"],
+            params=params,
+            snuba_params=snuba_params,
+            fields=serialized["field"],
+            operation=serialized.get("operation"),
+            query=serialized.get("query", ""),
+            min=serialized.get("min"),
+            max=serialized.get("max"),
+            sort=serialized.get("sort"),
+            rollup=rollup,
+            referrer=Referrer.API_ORGANIZATION_METRICS_SAMPLES,
         )
 
         with handle_query_errors():
             return self.paginate(
                 request=request,
-                paginator=GenericOffsetPaginator(data_fn=executor.execute),
+                paginator=GenericOffsetPaginator(data_fn=executor.get_matching_spans),
                 on_results=lambda results: self.handle_results_with_meta(
                     request,
                     organization,

+ 230 - 59
src/sentry/api/endpoints/organization_traces.py

@@ -1,13 +1,15 @@
 import dataclasses
+import math
 from collections import defaultdict
 from collections.abc import Callable, Mapping, MutableMapping
 from datetime import datetime, timedelta
 from typing import Any, Literal, TypedDict, cast
 
 from rest_framework import serializers
+from rest_framework.exceptions import ParseError
 from rest_framework.request import Request
 from rest_framework.response import Response
-from snuba_sdk import BooleanCondition, BooleanOp, Column, Condition, Function, Op
+from snuba_sdk import And, BooleanCondition, BooleanOp, Column, Condition, Function, Op, Or
 
 from sentry import features, options
 from sentry.api.api_owners import ApiOwner
@@ -20,11 +22,15 @@ from sentry.models.organization import Organization
 from sentry.models.project import Project
 from sentry.search.events.builder import QueryBuilder, SpansIndexedQueryBuilder
 from sentry.search.events.types import ParamsType, QueryBuilderConfig, SnubaParams, WhereType
+from sentry.sentry_metrics.querying.samples_list import SpanKey, get_sample_list_executor_cls
 from sentry.snuba.dataset import Dataset
 from sentry.snuba.referrer import Referrer
+from sentry.utils.iterators import chunked
 from sentry.utils.numbers import clip
 from sentry.utils.snuba import bulk_snuba_queries
 
+MAX_SNUBA_RESULTS = 10_000
+
 
 class TraceInterval(TypedDict):
     project: str | None
@@ -51,6 +57,8 @@ class TraceResult(TypedDict):
 class OrganizationTracesSerializer(serializers.Serializer):
     field = serializers.ListField(required=True, allow_empty=False, child=serializers.CharField())
     sort = serializers.ListField(required=False, allow_empty=True, child=serializers.CharField())
+    metricsQuery = serializers.CharField(required=False)
+    mri = serializers.CharField(required=False)
     query = serializers.ListField(
         required=False, allow_empty=True, child=serializers.CharField(allow_blank=True)
     )
@@ -88,6 +96,8 @@ class OrganizationTracesEndpoint(OrganizationEventsV2EndpointBase):
             # Filter out empty queries as they do not do anything to change the results.
             user_queries=[query.strip() for query in serialized.get("query", []) if query.strip()],
             suggested_query=serialized.get("suggestedQuery", ""),
+            metrics_query=serialized.get("metricsQuery", ""),
+            mri=serialized.get("mri"),
             sort=serialized.get("sort"),
             limit=self.get_per_page(request),
             max_spans_per_trace=serialized["maxSpansPerTrace"],
@@ -124,6 +134,8 @@ class TraceSamplesExecutor:
         fields: list[str],
         user_queries: list[str],
         suggested_query: str,
+        metrics_query: str,
+        mri: str | None,
         sort: str | None,
         limit: int,
         max_spans_per_trace: int,
@@ -134,6 +146,8 @@ class TraceSamplesExecutor:
         self.fields = fields
         self.user_queries = user_queries
         self.suggested_query = suggested_query
+        self.metrics_query = metrics_query
+        self.mri = mri
         self.sort = sort
         self.limit = limit
         self.max_spans_per_trace = max_spans_per_trace
@@ -154,7 +168,12 @@ class TraceSamplesExecutor:
         selected_projects_snuba_params = self.snuba_params
 
         with handle_query_errors():
-            min_timestamp, max_timestamp, trace_ids, _ = self.get_traces_matching_conditions(
+            (
+                min_timestamp,
+                max_timestamp,
+                trace_ids,
+                span_keys,
+            ) = self.get_traces_matching_conditions(
                 selected_projects_params,
                 selected_projects_snuba_params,
             )
@@ -171,6 +190,7 @@ class TraceSamplesExecutor:
                 all_projects_params,
                 all_projects_snuba_params,
                 trace_ids,
+                span_keys,
             )
 
             all_raw_results = bulk_snuba_queries(
@@ -218,21 +238,139 @@ class TraceSamplesExecutor:
         self,
         params: ParamsType,
         snuba_params: SnubaParams,
-    ) -> tuple[datetime, datetime, list[str], list[str] | None]:
-        # For now, this only fetches traces based on span conditions.
-        # Soon, this will support fetching traces based on metric conditions.
-        return self.get_traces_matching_span_conditions(params, snuba_params)
+    ) -> tuple[datetime, datetime, list[str], list[SpanKey] | None]:
+        if self.mri is not None:
+            return self.get_traces_matching_metric_conditions(params, snuba_params)
+
+        min_timestamp, max_timestamp, trace_ids = self.get_traces_matching_span_conditions(
+            params, snuba_params
+        )
+        return min_timestamp, max_timestamp, trace_ids, None
+
+    def get_traces_matching_metric_conditions(
+        self,
+        params: ParamsType,
+        snuba_params: SnubaParams,
+    ) -> tuple[datetime, datetime, list[str], list[SpanKey]]:
+        assert self.mri is not None
+
+        executor_cls = get_sample_list_executor_cls(self.mri)
+        if executor_cls is None:
+            raise ParseError(detail=f"Unsupported MRI: {self.mri}")
+
+        executor = executor_cls(
+            mri=self.mri,
+            params=params,
+            snuba_params=snuba_params,
+            fields=["trace"],
+            query=self.metrics_query,
+            referrer=Referrer.API_TRACE_EXPLORER_METRICS_SPANS_LIST,
+        )
+
+        min_timestamp, max_timestamp, trace_ids = executor.get_matching_traces(MAX_SNUBA_RESULTS)
+
+        if not trace_ids:
+            return min_timestamp, max_timestamp, [], []
+
+        self.refine_params(min_timestamp, max_timestamp)
+
+        # If there are user queries, further refine the trace ids by applying them
+        # leaving us with only traces where the metric exists and matches the user
+        # queries.
+        if self.user_queries:
+            min_timestamp, max_timestamp, trace_ids = self.get_traces_matching_span_conditions(
+                params, snuba_params, trace_ids
+            )
+
+            if not trace_ids:
+                return min_timestamp, max_timestamp, [], []
+
+            self.refine_params(min_timestamp, max_timestamp)
+
+        span_keys = executor.get_matching_spans_from_traces(
+            trace_ids,
+            self.max_spans_per_trace,
+        )
+
+        if not span_keys:
+            # TODO: log a message that we found traces but no span ids for metrics condition
+            return min_timestamp, max_timestamp, [], []
+
+        return min_timestamp, max_timestamp, trace_ids, span_keys
 
     def get_traces_matching_span_conditions(
         self,
         params: ParamsType,
         snuba_params: SnubaParams,
-    ) -> tuple[datetime, datetime, list[str], list[str] | None]:
+        trace_ids: list[str] | None = None,
+    ) -> tuple[datetime, datetime, list[str]]:
+        all_queries: list[QueryBuilder] = []
+        timestamp_column: str | None = None
+
+        if trace_ids:
+            # Putting all the trace ids into a single query will likely encounter the
+            # max query size limit in ClickHouse. This tries to spread the trace ids
+            # out evenly across N queries up to some limit per query.
+            max_trace_ids_per_chunk = options.get(
+                "performance.traces.trace-explorer-max-trace-ids-per-chunk"
+            )
+            num_chunks = math.ceil(len(trace_ids) / max_trace_ids_per_chunk)
+            chunk_size = math.ceil(len(trace_ids) / num_chunks)
+
+            for chunk in chunked(trace_ids, chunk_size):
+                query, timestamp_column = self.get_traces_matching_span_conditions_query(
+                    params,
+                    snuba_params,
+                )
+
+                # restrict the query to just this subset of trace ids
+                query.add_conditions([Condition(Column("trace_id"), Op.IN, trace_ids)])
+
+                all_queries.append(query)
+        else:
+            query, timestamp_column = self.get_traces_matching_span_conditions_query(
+                params,
+                snuba_params,
+            )
+            all_queries.append(query)
+
+        assert timestamp_column is not None
+
+        all_raw_results = bulk_snuba_queries(
+            [query.get_snql_query() for query in all_queries],
+            Referrer.API_TRACE_EXPLORER_SPANS_LIST.value,
+        )
+        all_results = [
+            query.process_results(result) for query, result in zip(all_queries, all_raw_results)
+        ]
+
+        matching_trace_ids: list[str] = []
+        min_timestamp = self.snuba_params.end
+        max_timestamp = self.snuba_params.start
+        assert min_timestamp is not None
+        assert max_timestamp is not None
+
+        for trace_results in all_results:
+            for row in trace_results["data"]:
+                matching_trace_ids.append(row["trace"])
+                timestamp = datetime.fromisoformat(row[timestamp_column])
+                if timestamp < min_timestamp:
+                    min_timestamp = timestamp
+                if timestamp > max_timestamp:
+                    max_timestamp = timestamp
+
+        return min_timestamp, max_timestamp, matching_trace_ids
+
+    def get_traces_matching_span_conditions_query(
+        self,
+        params: ParamsType,
+        snuba_params: SnubaParams,
+    ) -> tuple[QueryBuilder, str]:
         if len(self.user_queries) < 2:
             # Optimization: If there is only a condition for a single span,
             # we can take the fast path and query without using aggregates.
             timestamp_column = "timestamp"
-            builder = SpansIndexedQueryBuilder(
+            query = SpansIndexedQueryBuilder(
                 Dataset.SpansIndexed,
                 params=params,
                 snuba_params=snuba_params,
@@ -250,7 +388,7 @@ class TraceSamplesExecutor:
             )
         else:
             timestamp_column = "min(timestamp)"
-            builder = SpansIndexedQueryBuilder(
+            query = SpansIndexedQueryBuilder(
                 Dataset.SpansIndexed,
                 params=params,
                 snuba_params=snuba_params,
@@ -267,40 +405,24 @@ class TraceSamplesExecutor:
                 ),
             )
 
-            for query in self.user_queries:
+            for user_query in self.user_queries:
                 # We want to ignore all the aggregate conditions here because we're strictly
                 # searching on span attributes, not aggregates
-                where, _ = builder.resolve_conditions(query)
+                where, _ = query.resolve_conditions(user_query)
 
                 # Transform the condition into it's aggregate form so it can be used to
                 # match on the trace.
                 new_condition = generate_trace_condition(where)
-                builder.having.append(new_condition)
+                query.having.append(new_condition)
 
-        trace_results = builder.run_query(Referrer.API_TRACE_EXPLORER_SPANS_LIST.value)
-        trace_results = builder.process_results(trace_results)
-
-        trace_ids: list[str] = []
-        min_timestamp = self.snuba_params.end
-        max_timestamp = self.snuba_params.start
-        assert min_timestamp is not None
-        assert max_timestamp is not None
-
-        for row in trace_results["data"]:
-            trace_ids.append(row["trace"])
-            timestamp = datetime.fromisoformat(row[timestamp_column])
-            if timestamp < min_timestamp:
-                min_timestamp = timestamp
-            if timestamp > max_timestamp:
-                max_timestamp = timestamp
-
-        return min_timestamp, max_timestamp, trace_ids, None
+        return query, timestamp_column
 
     def get_all_queries(
         self,
         params: ParamsType,
         snuba_params: SnubaParams,
         trace_ids: list[str],
+        span_keys: list[SpanKey] | None,
     ) -> list[QueryBuilder]:
         meta_data_queries = self.get_all_meta_data_queries(
             params,
@@ -312,6 +434,7 @@ class TraceSamplesExecutor:
             params,
             snuba_params,
             trace_ids,
+            span_keys,
         )
 
         return meta_data_queries + span_samples_queries
@@ -358,11 +481,13 @@ class TraceSamplesExecutor:
         params: ParamsType,
         snuba_params: SnubaParams,
         trace_ids: list[str],
+        span_keys: list[SpanKey] | None,
     ) -> list[QueryBuilder]:
         user_spans_query = self.get_user_spans_query(
             params,
             snuba_params,
             trace_ids,
+            span_keys,
         )
 
         suggested_spans_query = self.get_suggested_spans_query(
@@ -480,8 +605,8 @@ class TraceSamplesExecutor:
             orderby=["precise.start_ts", "precise.finish_ts"],
             # limit the number of segments we fetch per trace so a single
             # large trace does not result in the rest being blank
-            limitby=("trace", int(10_000 / len(trace_ids))),
-            limit=10_000,
+            limitby=("trace", int(MAX_SNUBA_RESULTS / len(trace_ids))),
+            limit=MAX_SNUBA_RESULTS,
             config=QueryBuilderConfig(
                 transform_alias_to_input_format=True,
             ),
@@ -559,6 +684,7 @@ class TraceSamplesExecutor:
         params: ParamsType,
         snuba_params: SnubaParams,
         trace_ids: list[str],
+        span_keys: list[SpanKey] | None,
     ) -> QueryBuilder:
         user_spans_query = SpansIndexedQueryBuilder(
             Dataset.SpansIndexed,
@@ -577,34 +703,75 @@ class TraceSamplesExecutor:
         # First make sure that we only return spans from one of the traces identified
         user_spans_query.add_conditions([Condition(Column("trace_id"), Op.IN, trace_ids)])
 
-        # Next we have to turn the user queries into the appropriate conditions in
-        # the SnQL that we produce.
-
-        # There are multiple sets of user conditions that needs to be satisfied
-        # and if a span satisfy any of them, it should be considered.
-        #
-        # To handle this use case, we want to OR all the user specified
-        # conditions together in this query.
         conditions = []
-        for query in self.user_queries:
-            # We want to ignore all the aggregate conditions here because we're strictly
-            # searching on span attributes, not aggregates
-            where, _ = user_spans_query.resolve_conditions(query)
-            if len(where) > 1:
-                conditions.append(BooleanCondition(op=BooleanOp.AND, conditions=where))
-            elif len(where) == 1:
-                conditions.append(where[0])
-
-        if len(conditions) > 1:
-            # More than 1 set of user conditions were specified, we want to show
-            # spans that match any 1 of them so join the conditions with `OR`s.
-            user_spans_query.add_conditions(
-                [BooleanCondition(op=BooleanOp.OR, conditions=conditions)]
+        if span_keys is not None:
+            assert span_keys
+
+            # Next if there are known span_keys, we only try to fetch those spans
+            # This are the additional conditions to better take advantage of the ORDER BY
+            # on the spans table. This creates a list of conditions to be `OR`ed together
+            # that can will be used by ClickHouse to narrow down the granules.
+            #
+            # The span ids are not in this condition because they are more effective when
+            # specified within the `PREWHERE` clause. So, it's in a separate condition.
+            conditions = [
+                And(
+                    [
+                        Condition(user_spans_query.column("span.group"), Op.EQ, key.group),
+                        Condition(
+                            user_spans_query.column("timestamp"),
+                            Op.EQ,
+                            datetime.fromisoformat(key.timestamp),
+                        ),
+                    ]
+                )
+                for key in span_keys
+            ]
+
+            if len(conditions) == 1:
+                order_by_condition = conditions[0]
+            else:
+                order_by_condition = Or(conditions)
+
+            # Using `IN` combined with putting the list in a SnQL "tuple" triggers an optimizer
+            # in snuba where it
+            # 1. moves the condition into the `PREWHERE` clause
+            # 2. maps the ids to the underlying UInt64 and uses the bloom filter index
+            span_id_condition = Condition(
+                user_spans_query.column("id"),
+                Op.IN,
+                Function("tuple", [key.span_id for key in span_keys]),
             )
-        elif len(conditions) == 1:
-            # Only 1 set of user conditions were specified, simply insert them into
-            # the final query.
-            user_spans_query.add_conditions([conditions[0]])
+
+            user_spans_query.add_conditions([order_by_condition, span_id_condition])
+        else:
+            # Next we have to turn the user queries into the appropriate conditions in
+            # the SnQL that we produce.
+
+            # There are multiple sets of user conditions that needs to be satisfied
+            # and if a span satisfy any of them, it should be considered.
+            #
+            # To handle this use case, we want to OR all the user specified
+            # conditions together in this query.
+            for query in self.user_queries:
+                # We want to ignore all the aggregate conditions here because we're strictly
+                # searching on span attributes, not aggregates
+                where, _ = user_spans_query.resolve_conditions(query)
+                if len(where) > 1:
+                    conditions.append(BooleanCondition(op=BooleanOp.AND, conditions=where))
+                elif len(where) == 1:
+                    conditions.append(where[0])
+
+            if len(conditions) > 1:
+                # More than 1 set of conditions were specified, we want to show
+                # spans that match any 1 of them so join the conditions with `OR`s.
+                user_spans_query.add_conditions(
+                    [BooleanCondition(op=BooleanOp.OR, conditions=conditions)]
+                )
+            elif len(conditions) == 1:
+                # Only 1 set of user conditions were specified, simply insert them into
+                # the final query.
+                user_spans_query.add_conditions([conditions[0]])
 
         return user_spans_query
 
@@ -615,7 +782,11 @@ class TraceSamplesExecutor:
         trace_ids: list[str],
     ) -> QueryBuilder | None:
         # If any user queries is the same as the suggested query, we don't have to run it
-        if any(user_query == self.suggested_query for user_query in self.user_queries):
+        if (
+            not self.user_queries
+            and not self.suggested_query
+            or any(user_query == self.suggested_query for user_query in self.user_queries)
+        ):
             return None
         suggested_spans_query = SpansIndexedQueryBuilder(
             Dataset.SpansIndexed,

+ 6 - 0
src/sentry/options/defaults.py

@@ -1698,6 +1698,12 @@ register(
     default=1.0,
     flags=FLAG_AUTOMATOR_MODIFIABLE,
 )  # hours
+register(
+    "performance.traces.trace-explorer-max-trace-ids-per-chunk",
+    type=Int,
+    default=2500,
+    flags=FLAG_AUTOMATOR_MODIFIABLE,
+)  # hours
 register(
     "performance.traces.span_query_minimum_spans",
     type=Int,

+ 278 - 18
src/sentry/sentry_metrics/querying/samples_list.py

@@ -49,17 +49,18 @@ class AbstractSamplesListExecutor(ABC):
 
     def __init__(
         self,
+        *,
         mri: str,
         params: ParamsType,
         snuba_params: SnubaParams,
-        fields: list[str],
-        operation: str | None,
-        query: str | None,
-        min: float | None,
-        max: float | None,
-        sort: str | None,
-        rollup: int,
         referrer: Referrer,
+        fields: list[str],
+        operation: str | None = None,
+        query: str | None = None,
+        min: float | None = None,
+        max: float | None = None,
+        sort: str | None = None,
+        rollup: int | None = None,
     ):
         self.mri = mri
         self.params = params
@@ -82,19 +83,33 @@ class AbstractSamplesListExecutor(ABC):
     def supports_sort(cls, column: str) -> bool:
         return column in cls.sortable_columns
 
-    def execute(self, offset, limit):
+    @abstractmethod
+    def get_matching_traces(self, limit: int) -> tuple[datetime, datetime, list[str]]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_matching_spans_from_traces(
+        self,
+        trace_ids: list[str],
+        max_spans_per_trace: int,
+    ) -> list[SpanKey]:
+        raise NotImplementedError
+
+    def get_matching_spans(self, offset, limit):
+        assert self.rollup is not None
+
         if self.sort is None:
-            execute_fn = self.execute_unsorted
+            execute_fn = self.get_matching_spans_unsorted
         else:
-            execute_fn = self.execute_sorted
+            execute_fn = self.get_matching_spans_sorted
         return execute_fn(offset, limit)
 
     @abstractmethod
-    def execute_sorted(self, offset, limit):
+    def get_matching_spans_sorted(self, offset, limit):
         raise NotImplementedError
 
     @abstractmethod
-    def execute_unsorted(self, offset, limit):
+    def get_matching_spans_unsorted(self, offset, limit):
         raise NotImplementedError
 
     def get_spans_by_key(
@@ -195,6 +210,92 @@ class SegmentsSamplesListExecutor(AbstractSamplesListExecutor):
     def supports_mri(cls, mri: str) -> bool:
         return cls.mri_to_column(mri) is not None
 
+    def get_matching_traces(self, limit: int) -> tuple[datetime, datetime, list[str]]:
+        column = self.mri_to_column(self.mri)
+        assert column
+
+        builder = SpansIndexedQueryBuilder(
+            Dataset.Transactions,
+            params=self.params,
+            snuba_params=self.snuba_params,
+            query=self.query,
+            selected_columns=["trace", "timestamp"],
+            # The orderby is intentionally `None` here as this query is much faster
+            # if we let Clickhouse decide which order to return the results in.
+            # This also means we cannot order by any columns or paginate.
+            orderby=None,
+            limit=limit,
+            limitby=("trace", 1),
+        )
+
+        additional_conditions = self.get_additional_conditions(builder)
+        min_max_conditions = self.get_min_max_conditions(builder.resolve_column(column))
+        builder.add_conditions([*additional_conditions, *min_max_conditions])
+
+        query_results = builder.run_query(self.referrer.value)
+        results = builder.process_results(query_results)
+
+        matching_trace_ids: list[str] = []
+        min_timestamp = self.snuba_params.end
+        max_timestamp = self.snuba_params.start
+        assert min_timestamp is not None
+        assert max_timestamp is not None
+
+        for row in results["data"]:
+            matching_trace_ids.append(row["trace"])
+            timestamp = datetime.fromisoformat(row["timestamp"])
+            if timestamp < min_timestamp:
+                min_timestamp = timestamp
+            if timestamp > max_timestamp:
+                max_timestamp = timestamp
+
+        return min_timestamp, max_timestamp, matching_trace_ids
+
+    def get_matching_spans_from_traces(
+        self,
+        trace_ids: list[str],
+        max_spans_per_trace: int,
+    ) -> list[SpanKey]:
+        column = self.mri_to_column(self.mri)
+        assert column is not None
+
+        builder = SpansIndexedQueryBuilder(
+            Dataset.Transactions,
+            params=self.params,
+            snuba_params=self.snuba_params,
+            query=self.query,
+            selected_columns=["timestamp", "span_id"],
+            # The orderby is intentionally `None` here as this query is much faster
+            # if we let Clickhouse decide which order to return the results in.
+            # This also means we cannot order by any columns or paginate.
+            orderby=None,
+            limit=len(trace_ids) * max_spans_per_trace,
+            limitby=("trace", 1),
+        )
+
+        trace_id_condition = Condition(Column("trace_id"), Op.IN, trace_ids)
+        additional_conditions = self.get_additional_conditions(builder)
+        min_max_conditions = self.get_min_max_conditions(builder.resolve_column(column))
+        builder.add_conditions(
+            [
+                trace_id_condition,
+                *additional_conditions,
+                *min_max_conditions,
+            ]
+        )
+
+        query_results = builder.run_query(self.referrer.value)
+        results = builder.process_results(query_results)
+
+        return [
+            SpanKey(
+                group="00",  # all segments have a group of `00` currently
+                timestamp=row["timestamp"],
+                span_id=row["span_id"],
+            )
+            for row in results["data"]
+        ]
+
     def _get_spans(
         self,
         span_keys: list[SpanKey],
@@ -221,7 +322,7 @@ class SegmentsSamplesListExecutor(AbstractSamplesListExecutor):
 
         return result
 
-    def execute_sorted(self, offset, limit):
+    def get_matching_spans_sorted(self, offset, limit):
         span_keys, summaries = self.get_sorted_span_keys(offset, limit)
         return self._get_spans(span_keys, summaries)
 
@@ -308,7 +409,7 @@ class SegmentsSamplesListExecutor(AbstractSamplesListExecutor):
 
         return span_keys, summaries
 
-    def execute_unsorted(self, offset, limit):
+    def get_matching_spans_unsorted(self, offset, limit):
         span_keys, summaries = self.get_unsorted_span_keys(offset, limit)
         return self._get_spans(span_keys, summaries)
 
@@ -472,7 +573,93 @@ class SpansSamplesListExecutor(AbstractSamplesListExecutor):
     def supports_mri(cls, mri: str) -> bool:
         return cls.mri_to_column(mri) is not None
 
-    def execute_sorted(self, offset, limit):
+    def get_matching_traces(self, limit: int) -> tuple[datetime, datetime, list[str]]:
+        column = self.mri_to_column(self.mri)
+        assert column is not None
+
+        builder = SpansIndexedQueryBuilder(
+            Dataset.SpansIndexed,
+            params=self.params,
+            snuba_params=self.snuba_params,
+            query=self.query,
+            selected_columns=["trace", "timestamp"],
+            # The orderby is intentionally `None` here as this query is much faster
+            # if we let Clickhouse decide which order to return the results in.
+            # This also means we cannot order by any columns or paginate.
+            orderby=None,
+            limit=limit,
+            limitby=("trace", 1),
+        )
+
+        additional_conditions = self.get_additional_conditions(builder)
+        min_max_conditions = self.get_min_max_conditions(builder.resolve_column(column))
+        builder.add_conditions([*additional_conditions, *min_max_conditions])
+
+        query_results = builder.run_query(self.referrer.value)
+        results = builder.process_results(query_results)
+
+        matching_trace_ids: list[str] = []
+        min_timestamp = self.snuba_params.end
+        max_timestamp = self.snuba_params.start
+        assert min_timestamp is not None
+        assert max_timestamp is not None
+
+        for row in results["data"]:
+            matching_trace_ids.append(row["trace"])
+            timestamp = datetime.fromisoformat(row["timestamp"])
+            if timestamp < min_timestamp:
+                min_timestamp = timestamp
+            if timestamp > max_timestamp:
+                max_timestamp = timestamp
+
+        return min_timestamp, max_timestamp, matching_trace_ids
+
+    def get_matching_spans_from_traces(
+        self,
+        trace_ids: list[str],
+        max_spans_per_trace: int,
+    ) -> list[SpanKey]:
+        column = self.mri_to_column(self.mri)
+        assert column is not None
+
+        builder = SpansIndexedQueryBuilder(
+            Dataset.SpansIndexed,
+            params=self.params,
+            snuba_params=self.snuba_params,
+            query=self.query,
+            selected_columns=["span.group", "timestamp", "id"],
+            # The orderby is intentionally `None` here as this query is much faster
+            # if we let Clickhouse decide which order to return the results in.
+            # This also means we cannot order by any columns or paginate.
+            orderby=None,
+            limit=len(trace_ids) * max_spans_per_trace,
+            limitby=("trace", 1),
+        )
+
+        trace_id_condition = Condition(Column("trace_id"), Op.IN, trace_ids)
+        additional_conditions = self.get_additional_conditions(builder)
+        min_max_conditions = self.get_min_max_conditions(builder.resolve_column(column))
+        builder.add_conditions(
+            [
+                trace_id_condition,
+                *additional_conditions,
+                *min_max_conditions,
+            ]
+        )
+
+        query_results = builder.run_query(self.referrer.value)
+        results = builder.process_results(query_results)
+
+        return [
+            SpanKey(
+                group=row["span.group"],
+                timestamp=row["timestamp"],
+                span_id=row["id"],
+            )
+            for row in results["data"]
+        ]
+
+    def get_matching_spans_sorted(self, offset, limit):
         """
         Since we're already querying the spans table sorted on some column,
         there's no reason to split this into 2 queries. We can go ahead and
@@ -524,7 +711,7 @@ class SpansSamplesListExecutor(AbstractSamplesListExecutor):
 
         return result
 
-    def execute_unsorted(self, offset, limit):
+    def get_matching_spans_unsorted(self, offset, limit):
         span_keys = self.get_unsorted_span_keys(offset, limit)
 
         column = self.mri_to_column(self.mri)
@@ -738,6 +925,79 @@ class CustomSamplesListExecutor(AbstractSamplesListExecutor):
             return True
         return False
 
+    def get_matching_traces(self, limit: int) -> tuple[datetime, datetime, list[str]]:
+        builder = MetricsSummariesQueryBuilder(
+            Dataset.MetricsSummaries,
+            self.params,
+            snuba_params=self.snuba_params,
+            query=self.query,
+            selected_columns=["trace", "timestamp"],
+            # The orderby is intentionally `None` here as this query is much faster
+            # if we let Clickhouse decide which order to return the results in.
+            # This also means we cannot order by any columns or paginate.
+            orderby=None,
+            limit=limit,
+            limitby=("trace", 1),
+        )
+
+        additional_conditions = self.get_additional_conditions(builder)
+        min_max_conditions = self.get_min_max_conditions(builder)
+        builder.add_conditions([*additional_conditions, *min_max_conditions])
+
+        query_results = builder.run_query(self.referrer.value)
+        results = builder.process_results(query_results)
+
+        matching_trace_ids: list[str] = []
+        min_timestamp = self.snuba_params.end
+        max_timestamp = self.snuba_params.start
+        assert min_timestamp is not None
+        assert max_timestamp is not None
+
+        for row in results["data"]:
+            matching_trace_ids.append(row["trace"])
+            timestamp = datetime.fromisoformat(row["timestamp"])
+            if timestamp < min_timestamp:
+                min_timestamp = timestamp
+            if timestamp > max_timestamp:
+                max_timestamp = timestamp
+
+        return min_timestamp, max_timestamp, matching_trace_ids
+
+    def get_matching_spans_from_traces(
+        self,
+        trace_ids: list[str],
+        max_spans_per_trace: int,
+    ) -> list[SpanKey]:
+        builder = MetricsSummariesQueryBuilder(
+            Dataset.MetricsSummaries,
+            self.params,
+            snuba_params=self.snuba_params,
+            query=self.query,
+            selected_columns=["span.group", "timestamp", "id"],
+            # The orderby is intentionally `None` here as this query is much faster
+            # if we let Clickhouse decide which order to return the results in.
+            # This also means we cannot order by any columns or paginate.
+            orderby=None,
+            limit=len(trace_ids) * max_spans_per_trace,
+            limitby=("trace", 1),
+        )
+
+        additional_conditions = self.get_additional_conditions(builder)
+        min_max_conditions = self.get_min_max_conditions(builder)
+        builder.add_conditions([*additional_conditions, *min_max_conditions])
+
+        query_results = builder.run_query(self.referrer.value)
+        results = builder.process_results(query_results)
+
+        return [
+            SpanKey(
+                group=row["span.group"],
+                timestamp=row["timestamp"],
+                span_id=row["id"],
+            )
+            for row in results["data"]
+        ]
+
     def _get_spans(
         self,
         span_keys: list[SpanKey],
@@ -759,7 +1019,7 @@ class CustomSamplesListExecutor(AbstractSamplesListExecutor):
 
         return result
 
-    def execute_sorted(self, offset, limit):
+    def get_matching_spans_sorted(self, offset, limit):
         span_keys, summaries = self.get_sorted_span_keys(offset, limit)
         return self._get_spans(span_keys, summaries)
 
@@ -835,7 +1095,7 @@ class CustomSamplesListExecutor(AbstractSamplesListExecutor):
 
         return span_keys, summaries
 
-    def execute_unsorted(self, offset, limit):
+    def get_matching_spans_unsorted(self, offset, limit):
         span_keys, summaries = self.get_unsorted_span_keys(offset, limit)
         return self._get_spans(span_keys, summaries)
 

+ 1 - 0
src/sentry/snuba/referrer.py

@@ -437,6 +437,7 @@ class Referrer(Enum):
     API_STARFISH_MOBILE_STARTUP_SPAN_TABLE = "api.starfish.mobile-spartup-span-table"
     API_STARFISH_MOBILE_STARTUP_LOADED_LIBRARIES = "api.starfish.mobile-startup-loaded-libraries"
     API_STARFISH_MOBILE_STARTUP_TOTALS = "api.starfish.mobile-startup-totals"
+    API_TRACE_EXPLORER_METRICS_SPANS_LIST = "api.trace-explorer.metrics-spans-list"
     API_TRACE_EXPLORER_SPANS_LIST = "api.trace-explorer.spans-list"
     API_TRACE_EXPLORER_TRACES_META = "api.trace-explorer.traces-meta"
 

+ 6 - 0
src/sentry/testutils/cases.py

@@ -1484,6 +1484,7 @@ class BaseSpansTestCase(SnubaTestCase):
         tags: Mapping[str, Any] | None = None,
         measurements: Mapping[str, int | float] | None = None,
         timestamp: datetime | None = None,
+        store_metrics_summary: Mapping[str, Sequence[Mapping[str, Any]]] | None = None,
     ):
         if span_id is None:
             span_id = self._random_span_id()
@@ -1513,11 +1514,16 @@ class BaseSpansTestCase(SnubaTestCase):
             payload["measurements"] = {
                 measurement: {"value": value} for measurement, value in measurements.items()
             }
+        if store_metrics_summary:
+            payload["_metrics_summary"] = store_metrics_summary
         if parent_span_id:
             payload["parent_span_id"] = parent_span_id
 
         self.store_span(payload)
 
+        if "_metrics_summary" in payload:
+            self.store_metrics_summary(payload)
+
     def store_indexed_span(
         self,
         project_id: int,

+ 256 - 68
tests/sentry/api/endpoints/test_organization_traces.py

@@ -1,3 +1,4 @@
+from datetime import timedelta
 from uuid import uuid4
 
 import pytest
@@ -5,6 +6,7 @@ from django.urls import reverse
 from rest_framework.exceptions import ErrorDetail
 
 from sentry.api.endpoints.organization_traces import process_breakdowns
+from sentry.snuba.metrics.naming_layer.mri import SpanMRI, TransactionMRI
 from sentry.testutils.cases import APITestCase, BaseSpansTestCase
 from sentry.testutils.helpers.datetime import before_now
 from sentry.utils.samples import load_data
@@ -28,6 +30,167 @@ class OrganizationTracesEndpointTest(BaseSpansTestCase, APITestCase):
                 **kwargs,
             )
 
+    def double_write_segment(
+        self,
+        *,
+        project_id,
+        trace_id,
+        transaction_id,
+        span_id,
+        timestamp,
+        duration,
+        **kwargs,
+    ):
+        # first write to the transactions dataset
+        end_timestamp = timestamp + timedelta(microseconds=duration * 1000)
+        data = load_data(
+            "transaction",
+            start_timestamp=timestamp,
+            timestamp=end_timestamp,
+            trace=trace_id,
+            span_id=span_id,
+            spans=[],
+            event_id=transaction_id,
+        )
+        data["measurements"] = {"lcp": {"value": duration}}
+        if tags := kwargs.get("tags", {}):
+            data["tags"] = [[key, val] for key, val in tags.items()]
+
+        self.store_event(
+            data=data,
+            project_id=project_id,
+        )
+
+        self.store_segment(
+            project_id=project_id,
+            trace_id=trace_id,
+            transaction_id=transaction_id,
+            span_id=span_id,
+            timestamp=timestamp,
+            duration=duration,
+            **kwargs,
+        )
+
+    def create_mock_traces(self):
+        project_1 = self.create_project()
+        project_2 = self.create_project()
+
+        # Hack: ensure that no span ids with leading 0s are generated for the test
+        span_ids = ["1" + uuid4().hex[:15] for _ in range(8)]
+        tags = ["", "bar", "bar", "baz", "", "bar", "baz"]
+        timestamps = []
+
+        trace_id_1 = uuid4().hex
+        timestamps.append(before_now(days=0, minutes=10).replace(microsecond=0))
+        self.double_write_segment(
+            project_id=project_1.id,
+            trace_id=trace_id_1,
+            transaction_id=uuid4().hex,
+            span_id=span_ids[0],
+            timestamp=timestamps[-1],
+            transaction="foo",
+            duration=60_100,
+            exclusive_time=60_100,
+        )
+        for i in range(1, 4):
+            timestamps.append(before_now(days=0, minutes=9, seconds=45 - i).replace(microsecond=0))
+            self.double_write_segment(
+                project_id=project_2.id,
+                trace_id=trace_id_1,
+                transaction_id=uuid4().hex,
+                span_id=span_ids[i],
+                parent_span_id=span_ids[0],
+                timestamp=timestamps[-1],
+                transaction="bar",
+                duration=30_000 + i,
+                exclusive_time=30_000 + i,
+                tags={"foo": tags[i]},
+            )
+
+        trace_id_2 = uuid4().hex
+        txn_id_2 = uuid4().hex
+        timestamps.append(before_now(days=0, minutes=20).replace(microsecond=0))
+        self.double_write_segment(
+            project_id=project_1.id,
+            trace_id=trace_id_2,
+            transaction_id=txn_id_2,
+            span_id=span_ids[4],
+            timestamp=timestamps[-1],
+            transaction="bar",
+            duration=90_123,
+            exclusive_time=90_123,
+        )
+        for i in range(5, 7):
+            timestamps.append(before_now(days=0, minutes=19, seconds=55 - i).replace(microsecond=0))
+            self.double_write_segment(
+                project_id=project_2.id,
+                trace_id=trace_id_2,
+                transaction_id=uuid4().hex,
+                span_id=span_ids[i],
+                parent_span_id=span_ids[4],
+                timestamp=timestamps[-1],
+                transaction="baz",
+                duration=20_000 + i,
+                exclusive_time=20_000 + i,
+                tags={"foo": tags[i]},
+            )
+        timestamps.append(before_now(days=0, minutes=20).replace(microsecond=0))
+
+        trace_id_3 = uuid4().hex
+        self.double_write_segment(
+            project_id=project_1.id,
+            trace_id=trace_id_3,
+            transaction_id=uuid4().hex,
+            span_id=span_ids[7],
+            timestamp=timestamps[-1],
+            transaction="qux",
+            duration=40_000,
+            tags={"foo": "qux"},
+            measurements={
+                measurement: 40_000
+                for i, measurement in enumerate(
+                    [
+                        "score.total",
+                        "score.inp",
+                        "score.weight.inp",
+                        "http.response_content_length",
+                        "http.decoded_response_content_length",
+                        "http.response_transfer_size",
+                    ]
+                )
+            },
+            store_metrics_summary={
+                "d:custom/value@millisecond": [
+                    {
+                        "min": 40_000,
+                        "max": 40_000,
+                        "sum": 40_000,
+                        "count": 1,
+                        "tags": {"foo": "qux"},
+                    }
+                ]
+            },
+        )
+
+        error_data = load_data("javascript", timestamp=timestamps[0])
+        error_data["contexts"]["trace"] = {
+            "type": "trace",
+            "trace_id": trace_id_1,
+            "span_id": span_ids[0],
+        }
+        error_data["tags"] = [["transaction", "foo"]]
+        self.store_event(error_data, project_id=project_1.id)
+
+        return (
+            project_1,
+            project_2,
+            trace_id_1,
+            trace_id_2,
+            trace_id_3,
+            timestamps,
+            span_ids,
+        )
+
     def test_no_feature(self):
         query = {
             "field": ["id"],
@@ -115,6 +278,22 @@ class OrganizationTracesEndpointTest(BaseSpansTestCase, APITestCase):
             ),
         }
 
+    def test_unsupported_mri(self):
+        query = {
+            "project": [self.project.id],
+            "field": ["id"],
+            "maxSpansPerTrace": 1,
+            "mri": "d:spans/made_up@none",
+        }
+
+        response = self.do_request(query)
+        assert response.status_code == 400, response.data
+        assert response.data == {
+            "detail": ErrorDetail(
+                string="Unsupported MRI: d:spans/made_up@none", code="parse_error"
+            ),
+        }
+
     def test_no_traces(self):
         query = {
             "project": [self.project.id],
@@ -149,76 +328,15 @@ class OrganizationTracesEndpointTest(BaseSpansTestCase, APITestCase):
         assert response.status_code == 200, response.data
 
     def test_matching_tag(self):
-        project_1 = self.create_project()
-        project_2 = self.create_project()
-
-        # Hack: ensure that no span ids with leading 0s are generated for the test
-        span_ids = ["1" + uuid4().hex[:15] for _ in range(7)]
-        tags = ["", "bar", "bar", "baz", "", "bar", "baz"]
-        timestamps = []
-
-        trace_id_1 = uuid4().hex
-        timestamps.append(before_now(days=0, minutes=10).replace(microsecond=0))
-        self.store_segment(
-            project_1.id,
+        (
+            project_1,
+            project_2,
             trace_id_1,
-            uuid4().hex,
-            span_id=span_ids[0],
-            timestamp=timestamps[-1],
-            transaction="foo",
-            duration=60_100,
-            exclusive_time=60_100,
-        )
-        for i in range(1, 4):
-            timestamps.append(before_now(days=0, minutes=9, seconds=45 - i).replace(microsecond=0))
-            self.store_segment(
-                project_2.id,
-                trace_id_1,
-                uuid4().hex,
-                span_id=span_ids[i],
-                parent_span_id=span_ids[0],
-                timestamp=timestamps[-1],
-                transaction="bar",
-                duration=30_000 + i,
-                exclusive_time=30_000 + i,
-                tags={"foo": tags[i]},
-            )
-
-        trace_id_2 = uuid4().hex
-        timestamps.append(before_now(days=0, minutes=20).replace(microsecond=0))
-        self.store_segment(
-            project_1.id,
             trace_id_2,
-            uuid4().hex,
-            span_id=span_ids[4],
-            timestamp=timestamps[-1],
-            transaction="bar",
-            duration=90_123,
-            exclusive_time=90_123,
-        )
-        for i in range(5, 7):
-            timestamps.append(before_now(days=0, minutes=19, seconds=55 - i).replace(microsecond=0))
-            self.store_segment(
-                project_2.id,
-                trace_id_2,
-                uuid4().hex,
-                span_id=span_ids[i],
-                parent_span_id=span_ids[4],
-                timestamp=timestamps[-1],
-                transaction="baz",
-                duration=20_000 + i,
-                exclusive_time=20_000 + i,
-                tags={"foo": tags[i]},
-            )
-
-        error_data = load_data("javascript", timestamp=timestamps[0])
-        error_data["contexts"]["trace"] = {
-            "type": "trace",
-            "trace_id": trace_id_1,
-            "span_id": span_ids[0],
-        }
-        error_data["tags"] = [["transaction", "foo"]]
-        self.store_event(error_data, project_id=project_1.id)
+            _,
+            timestamps,
+            span_ids,
+        ) = self.create_mock_traces()
 
         for q in [
             [
@@ -358,6 +476,76 @@ class OrganizationTracesEndpointTest(BaseSpansTestCase, APITestCase):
                 key=lambda trace: trace["trace"],
             )
 
+    def test_matching_tag_metrics(self):
+        (
+            project_1,
+            _,
+            _,
+            _,
+            trace_id_3,
+            timestamps,
+            span_ids,
+        ) = self.create_mock_traces()
+
+        for mri in [
+            TransactionMRI.DURATION.value,
+            "d:transactions/measurements.lcp@millisecond",
+            SpanMRI.DURATION.value,
+            SpanMRI.SELF_TIME.value,
+            "d:spans/webvital.score.total@ratio",
+            "d:spans/webvital.score.inp@ratio",
+            "d:spans/webvital.score.weight.inp@ratio",
+            "d:spans/http.response_content_length@byte",
+            "d:spans/http.decoded_response_content_length@byte",
+            "d:spans/http.response_transfer_size@byte",
+            "d:custom/value@millisecond",
+        ]:
+            query = {
+                "mri": mri,
+                "metricsQuery": ["foo:qux"],
+                "project": [project_1.id],
+                "field": ["id", "parent_span", "span.duration"],
+                "query": ["foo:qux"],
+                "suggestedQuery": ["foo:qux"],
+                "maxSpansPerTrace": 3,
+                "sort": ["-span.duration"],
+            }
+
+            response = self.do_request(query)
+            assert response.status_code == 200, (mri, response.data)
+
+            result_data = sorted(response.data["data"], key=lambda trace: trace["trace"])
+
+            assert result_data == [
+                {
+                    "trace": trace_id_3,
+                    "numErrors": 0,
+                    "numOccurrences": 0,
+                    "numSpans": 1,
+                    "project": project_1.slug,
+                    "name": "qux",
+                    "duration": 40_000,
+                    "start": int(timestamps[7].timestamp() * 1000),
+                    "end": int(timestamps[7].timestamp() * 1000) + 40_000,
+                    "breakdowns": [
+                        {
+                            "project": project_1.slug,
+                            "start": int(timestamps[7].timestamp() * 1000),
+                            "end": int(timestamps[7].timestamp() * 1000) + 40_000,
+                            "kind": "project",
+                        },
+                    ],
+                    "spans": [
+                        {
+                            "id": span_ids[7],
+                            "parent_span": "00",
+                            "span.duration": 40_000.0,
+                        },
+                    ],
+                    "suggestedSpans": [],
+                },
+            ], mri
+
 
 @pytest.mark.parametrize(
     ["data", "traces_range", "expected"],