Browse Source

ref(trace-explorer): Move trace samples queries into separate class (#69925)

In preparation for querying traces from a metric, move all the query
logic into a separate class to make it easier to refactor.
Tony Xiao 10 months ago
parent
commit
fc450c0594
1 changed files with 433 additions and 322 deletions
  1. 433 322
      src/sentry/api/endpoints/organization_traces.py

+ 433 - 322
src/sentry/api/endpoints/organization_traces.py

@@ -1,10 +1,9 @@
 import dataclasses
 from collections import defaultdict
-from collections.abc import Mapping, MutableMapping
+from collections.abc import Callable, Mapping, MutableMapping
 from datetime import datetime, timedelta
 from typing import Any, Literal, TypedDict, cast
 
-import sentry_sdk
 from rest_framework import serializers
 from rest_framework.request import Request
 from rest_framework.response import Response
@@ -18,6 +17,7 @@ from sentry.api.bases import NoProjects, OrganizationEventsV2EndpointBase
 from sentry.api.paginator import GenericOffsetPaginator
 from sentry.api.utils import handle_query_errors
 from sentry.models.organization import Organization
+from sentry.models.project import Project
 from sentry.search.events.builder import QueryBuilder, SpansIndexedQueryBuilder
 from sentry.search.events.types import ParamsType, QueryBuilderConfig, SnubaParams, WhereType
 from sentry.snuba.dataset import Dataset
@@ -81,216 +81,168 @@ class OrganizationTracesEndpoint(OrganizationEventsV2EndpointBase):
             return Response(serializer.errors, status=400)
         serialized = serializer.validated_data
 
-        per_page = self.get_per_page(request)
-
-        def data_fn(offset: int, limit: int):
-            with handle_query_errors():
-                sample_rate = options.get("traces.sample-list.sample-rate")
-                if sample_rate <= 0:
-                    sample_rate = None
-
-                # Filter out empty queries as they do not do anything to change the results.
-                user_queries = [query for query in serialized.get("query", []) if query]
-
-                trace_ids, min_timestamp, max_timestamp = self.get_matching_traces(
-                    cast(ParamsType, params),
-                    snuba_params,
-                    user_queries,
-                    per_page,
-                )
-
-            if not trace_ids:
-                return {"data": [], "meta": {"fields": {}}}
-
-            # TODO: move to use `update_snuba_params_with_timestamp`
-            time_buffer = options.get("performance.traces.trace-explorer-buffer-hours")
-            buffer = timedelta(hours=time_buffer)
-            params["start"] = min_timestamp - buffer
-            params["end"] = max_timestamp + buffer
-            snuba_params.start = min_timestamp - buffer
-            snuba_params.end = max_timestamp + buffer
-
-            all_projects = self.get_projects(
+        executor = TraceSamplesExecutor(
+            params=cast(ParamsType, params),
+            snuba_params=snuba_params,
+            fields=serialized["field"],
+            # Filter out empty queries as they do not do anything to change the results.
+            user_queries=[query.strip() for query in serialized.get("query", []) if query.strip()],
+            suggested_query=serialized.get("suggestedQuery", ""),
+            sort=serialized.get("sort"),
+            limit=self.get_per_page(request),
+            max_spans_per_trace=serialized["maxSpansPerTrace"],
+            get_all_projects=lambda: self.get_projects(
                 request,
                 organization,
                 project_ids={-1},
                 project_slugs=None,
                 include_all_accessible=True,
-            )
-            all_projects_snuba_params = dataclasses.replace(snuba_params, projects=all_projects)
-            all_projects_params = dict(params)
-            all_projects_params["projects"] = all_projects_snuba_params.projects
-            all_projects_params["projects_objects"] = all_projects_snuba_params.projects
-            all_projects_params["projects_id"] = all_projects_snuba_params.project_ids
-
-            traces_breakdowns_query = self.get_traces_breakdowns_query(
-                cast(ParamsType, all_projects_params),
-                all_projects_snuba_params,
-                trace_ids,
-            )
+            ),
+        )
 
-            traces_metas_query = self.get_traces_metas_query(
-                cast(ParamsType, all_projects_params),
-                all_projects_snuba_params,
-                trace_ids,
-            )
+        return self.paginate(
+            request=request,
+            # paginator=GenericOffsetPaginator(data_fn=data_fn),
+            paginator=GenericOffsetPaginator(data_fn=executor.execute),
+            on_results=lambda results: self.handle_results_with_meta(
+                request,
+                organization,
+                params["project_id"],
+                results,
+                standard_meta=True,
+                dataset=Dataset.SpansIndexed,
+            ),
+        )
 
-            traces_errors_query = self.get_traces_errors_query(
-                cast(ParamsType, all_projects_params),
-                all_projects_snuba_params,
-                trace_ids,
+
+class TraceSamplesExecutor:
+    def __init__(
+        self,
+        *,
+        params: ParamsType,
+        snuba_params: SnubaParams,
+        fields: list[str],
+        user_queries: list[str],
+        suggested_query: str,
+        sort: str | None,
+        limit: int,
+        max_spans_per_trace: int,
+        get_all_projects: Callable[[], list[Project]],
+    ):
+        self.params = params
+        self.snuba_params = snuba_params
+        self.fields = fields
+        self.user_queries = user_queries
+        self.suggested_query = suggested_query
+        self.sort = sort
+        self.limit = limit
+        self.max_spans_per_trace = max_spans_per_trace
+        self.get_all_projects = get_all_projects
+        self._all_projects: list[Project] | None = None
+
+    @property
+    def all_projects(self) -> list[Project]:
+        if self._all_projects is None:
+            self._all_projects = self.get_all_projects()
+        return self._all_projects
+
+    def execute(self, offset: int, limit: int):
+        return self._execute()
+
+    def _execute(self):
+        selected_projects_params = self.params
+        selected_projects_snuba_params = self.snuba_params
+
+        with handle_query_errors():
+            min_timestamp, max_timestamp, trace_ids, _ = self.get_traces_matching_conditions(
+                selected_projects_params,
+                selected_projects_snuba_params,
             )
 
-            traces_occurrences_query = self.get_traces_occurrences_query(
-                cast(ParamsType, all_projects_params),
+        self.refine_params(min_timestamp, max_timestamp)
+
+        all_projects_params, all_projects_snuba_params = self.params_with_all_projects()
+
+        if not trace_ids:
+            return {"data": [], "meta": {"fields": {}}}
+
+        with handle_query_errors():
+            all_queries = self.get_all_queries(
+                all_projects_params,
                 all_projects_snuba_params,
                 trace_ids,
             )
 
-            user_spans_query, suggested_spans_query = self.get_matching_spans_query(
-                cast(ParamsType, params),
-                snuba_params,
-                serialized["field"],
-                serialized.get("sort"),
-                user_queries,
-                serialized.get("suggestedQuery", ""),
-                trace_ids,
-                serialized["maxSpansPerTrace"],
+            all_raw_results = bulk_snuba_queries(
+                [query.get_snql_query() for query in all_queries],
+                Referrer.API_TRACE_EXPLORER_TRACES_META.value,
             )
 
-            queries = [
-                query
-                for query in [
-                    traces_breakdowns_query,
-                    traces_metas_query,
-                    traces_errors_query,
-                    traces_occurrences_query,
-                    user_spans_query,
-                    suggested_spans_query,
-                ]
-                if query
+            all_results = [
+                query.process_results(result) for query, result in zip(all_queries, all_raw_results)
             ]
 
-            with handle_query_errors():
-                results = bulk_snuba_queries(
-                    [query.get_snql_query() for query in queries],
-                    Referrer.API_TRACE_EXPLORER_TRACES_META.value,
-                )
-
-                all_results = [
-                    query.process_results(result) for query, result in zip(queries, results)
-                ]
-                traces_breakdowns_results = all_results[0]
-                traces_metas_results = all_results[1]
-                traces_errors_results = all_results[2]
-                traces_occurrences_results = all_results[3]
-                spans_results = all_results[4]
-                suggested_spans_results = all_results[5] if len(all_results) > 5 else None
-
-            fields = spans_results["meta"].get("fields", {})
-            meta = {
-                **spans_results["meta"],
-                "fields": {field: fields[field] for field in serialized["field"]},
-            }
+            meta = self.process_meta_results(all_results)
+            data = self.process_final_results(all_results)
 
-            errors_by_trace: Mapping[str, int] = {
-                row["trace"]: row["count()"] for row in traces_errors_results["data"]
-            }
+        return {"data": data, "meta": meta}
 
-            occurrences_by_trace: Mapping[str, int] = {
-                row["trace"]: row["count()"] for row in traces_occurrences_results["data"]
-            }
+    def refine_params(self, min_timestamp: datetime, max_timestamp: datetime):
+        """
+        Once we have a min/max timestamp for all the traces in the query,
+        refine the params so that it selects a time range that is as small as possible.
+        """
 
-            spans_by_trace: Mapping[str, list[Mapping[str, Any]]] = defaultdict(list)
-            for row in spans_results["data"]:
-                spans_by_trace[row["trace"]].append(row)
+        # TODO: move to use `update_snuba_params_with_timestamp`
+        time_buffer = options.get("performance.traces.trace-explorer-buffer-hours")
+        buffer = timedelta(hours=time_buffer)
 
-            suggested_spans_by_trace: Mapping[str, list[Mapping[str, Any]]] = defaultdict(list)
-            if suggested_spans_results:
-                for row in suggested_spans_results["data"]:
-                    suggested_spans_by_trace[row["trace"]].append(row)
+        self.params["start"] = min_timestamp - buffer
+        self.params["end"] = max_timestamp + buffer
+        self.snuba_params.start = min_timestamp - buffer
+        self.snuba_params.end = max_timestamp + buffer
 
-            try:
-                traces_range = {
-                    row["trace"]: (row["first_seen()"], row["last_seen()"])
-                    for row in traces_metas_results["data"]
-                }
-                breakdowns = process_breakdowns(
-                    traces_breakdowns_results["data"],
-                    traces_range,
-                )
-            except Exception as e:
-                sentry_sdk.capture_exception(e)
-                breakdowns = defaultdict(list)
-
-            names_by_trace: MutableMapping[str, tuple[str, str]] = {}
-            for row in traces_breakdowns_results["data"]:
-                # The underlying column is a Nullable(UInt64) but we write a default of 0 to it.
-                # So make sure to handle both in case something changes.
-                if not row["parent_span"] or int(row["parent_span"], 16) == 0:
-                    names_by_trace[row["trace"]] = (row["project"], row["transaction"])
-
-            traces: list[TraceResult] = [
-                {
-                    "trace": row["trace"],
-                    "numErrors": errors_by_trace.get(row["trace"], 0),
-                    "numOccurrences": occurrences_by_trace.get(row["trace"], 0),
-                    "numSpans": row["count()"],
-                    "project": names_by_trace.get(row["trace"], (None, None))[0],
-                    "name": names_by_trace.get(row["trace"], (None, None))[1],
-                    "duration": row["last_seen()"] - row["first_seen()"],
-                    "start": row["first_seen()"],
-                    "end": row["last_seen()"],
-                    "breakdowns": breakdowns[row["trace"]],
-                    "spans": [
-                        {field: span[field] for field in serialized["field"]}
-                        for span in spans_by_trace[row["trace"]]
-                    ],
-                    "suggestedSpans": [
-                        {field: span[field] for field in serialized["field"]}
-                        for span in suggested_spans_by_trace[row["trace"]]
-                    ],
-                }
-                for row in traces_metas_results["data"]
-            ]
+    def params_with_all_projects(self) -> tuple[ParamsType, SnubaParams]:
+        all_projects_snuba_params = dataclasses.replace(
+            self.snuba_params, projects=self.all_projects
+        )
 
-            return {"data": traces, "meta": meta}
+        all_projects_params = dict(self.params)
+        all_projects_params["projects"] = all_projects_snuba_params.projects
+        all_projects_params["projects_objects"] = all_projects_snuba_params.projects
+        all_projects_params["projects_id"] = all_projects_snuba_params.project_ids
 
-        return self.paginate(
-            request=request,
-            paginator=GenericOffsetPaginator(data_fn=data_fn),
-            on_results=lambda results: self.handle_results_with_meta(
-                request,
-                organization,
-                params["project_id"],
-                results,
-                standard_meta=True,
-                dataset=Dataset.SpansIndexed,
-            ),
-        )
+        return cast(ParamsType, all_projects_params), all_projects_snuba_params
 
-    def get_matching_traces(
+    def get_traces_matching_conditions(
         self,
         params: ParamsType,
         snuba_params: SnubaParams,
-        user_queries: list[str],
-        limit: int,
-    ) -> tuple[list[str], datetime, datetime]:
-        if len(user_queries) < 2:
+    ) -> tuple[datetime, datetime, list[str], list[str] | None]:
+        # For now, this only fetches traces based on span conditions.
+        # Soon, this will support fetching traces based on metric conditions.
+        return self.get_traces_matching_span_conditions(params, snuba_params)
+
+    def get_traces_matching_span_conditions(
+        self,
+        params: ParamsType,
+        snuba_params: SnubaParams,
+    ) -> tuple[datetime, datetime, list[str], list[str] | None]:
+        if len(self.user_queries) < 2:
             # Optimization: If there is only a condition for a single span,
             # we can take the fast path and query without using aggregates.
             timestamp_column = "timestamp"
             builder = SpansIndexedQueryBuilder(
                 Dataset.SpansIndexed,
-                params,
+                params=params,
                 snuba_params=snuba_params,
-                query=user_queries[0] if user_queries else None,
+                query=self.user_queries[0] if self.user_queries else None,
                 selected_columns=["trace", timestamp_column],
                 # The orderby is intentionally `None` here as this query is much faster
                 # if we let Clickhouse decide which order to return the results in.
                 # This also means we cannot order by any columns or paginate.
                 orderby=None,
-                limit=limit,
+                limit=self.limit,
                 limitby=("trace", 1),
                 config=QueryBuilderConfig(
                     transform_alias_to_input_format=True,
@@ -300,7 +252,7 @@ class OrganizationTracesEndpoint(OrganizationEventsV2EndpointBase):
             timestamp_column = "min(timestamp)"
             builder = SpansIndexedQueryBuilder(
                 Dataset.SpansIndexed,
-                params,
+                params=params,
                 snuba_params=snuba_params,
                 query=None,
                 selected_columns=["trace", timestamp_column],
@@ -308,14 +260,14 @@ class OrganizationTracesEndpoint(OrganizationEventsV2EndpointBase):
                 # if we let Clickhouse decide which order to return the results in.
                 # This also means we cannot order by any columns or paginate.
                 orderby=None,
-                limit=limit,
+                limit=self.limit,
                 config=QueryBuilderConfig(
                     auto_aggregations=True,
                     transform_alias_to_input_format=True,
                 ),
             )
 
-            for query in user_queries:
+            for query in self.user_queries:
                 # We want to ignore all the aggregate conditions here because we're strictly
                 # searching on span attributes, not aggregates
                 where, _ = builder.resolve_conditions(query)
@@ -329,8 +281,8 @@ class OrganizationTracesEndpoint(OrganizationEventsV2EndpointBase):
         trace_results = builder.process_results(trace_results)
 
         trace_ids: list[str] = []
-        min_timestamp = snuba_params.end
-        max_timestamp = snuba_params.start
+        min_timestamp = self.snuba_params.end
+        max_timestamp = self.snuba_params.start
         assert min_timestamp is not None
         assert max_timestamp is not None
 
@@ -342,69 +294,225 @@ class OrganizationTracesEndpoint(OrganizationEventsV2EndpointBase):
             if timestamp > max_timestamp:
                 max_timestamp = timestamp
 
-        return trace_ids, min_timestamp, max_timestamp
+        return min_timestamp, max_timestamp, trace_ids, None
 
-    def get_traces_breakdowns_query(
+    def get_all_queries(
         self,
         params: ParamsType,
         snuba_params: SnubaParams,
         trace_ids: list[str],
-    ) -> SpansIndexedQueryBuilder:
-        with handle_query_errors():
-            traces_breakdowns_query = SpansIndexedQueryBuilder(
-                Dataset.SpansIndexed,
-                params,
-                snuba_params=snuba_params,
-                query="is_transaction:1",
-                selected_columns=[
-                    "trace",
-                    "project",
-                    "parent_span",
-                    "transaction",
-                    "precise.start_ts",
-                    "precise.finish_ts",
+    ) -> list[QueryBuilder]:
+        meta_data_queries = self.get_all_meta_data_queries(
+            params,
+            snuba_params,
+            trace_ids,
+        )
+
+        span_samples_queries = self.get_all_span_samples_queries(
+            params,
+            snuba_params,
+            trace_ids,
+        )
+
+        return meta_data_queries + span_samples_queries
+
+    def get_all_meta_data_queries(
+        self,
+        params: ParamsType,
+        snuba_params: SnubaParams,
+        trace_ids: list[str],
+    ) -> list[QueryBuilder]:
+        traces_breakdowns_query = self.get_traces_breakdowns_query(
+            params,
+            snuba_params,
+            trace_ids,
+        )
+
+        traces_metas_query = self.get_traces_metas_query(
+            params,
+            snuba_params,
+            trace_ids,
+        )
+
+        traces_errors_query = self.get_traces_errors_query(
+            params,
+            snuba_params,
+            trace_ids,
+        )
+
+        traces_occurrences_query = self.get_traces_occurrences_query(
+            params,
+            snuba_params,
+            trace_ids,
+        )
+
+        return [
+            traces_breakdowns_query,
+            traces_metas_query,
+            traces_errors_query,
+            traces_occurrences_query,
+        ]
+
+    def get_all_span_samples_queries(
+        self,
+        params: ParamsType,
+        snuba_params: SnubaParams,
+        trace_ids: list[str],
+    ) -> list[QueryBuilder]:
+        user_spans_query = self.get_user_spans_query(
+            params,
+            snuba_params,
+            trace_ids,
+        )
+
+        suggested_spans_query = self.get_suggested_spans_query(
+            params,
+            snuba_params,
+            trace_ids,
+        )
+
+        span_samples_queries = [user_spans_query]
+
+        if suggested_spans_query:
+            span_samples_queries.append(suggested_spans_query)
+
+        return span_samples_queries
+
+    def process_final_results(self, results) -> list[TraceResult]:
+        # the order of these results is defined by the order
+        # of the queries in `get_all_meta_data_queries`
+        traces_breakdowns_results = results[0]
+        traces_metas_results = results[1]
+        traces_errors_results = results[2]
+        traces_occurrences_results = results[3]
+        user_spans_results = results[4]
+        suggested_spans_results = results[5] if len(results) > 5 else None
+
+        # mapping of trace id to a tuple of start/finish times
+        traces_range = {
+            row["trace"]: (row["first_seen()"], row["last_seen()"])
+            for row in traces_metas_results["data"]
+        }
+
+        traces_breakdowns = process_breakdowns(
+            traces_breakdowns_results["data"],
+            traces_range,
+        )
+
+        # mapping of trace id to a tuple of project slug + transaction name
+        traces_names: MutableMapping[str, tuple[str, str]] = {}
+        for row in traces_breakdowns_results["data"]:
+            # The underlying column is a Nullable(UInt64) but we write a default of 0 to it.
+            # So make sure to handle both in case something changes.
+            if not row["parent_span"] or int(row["parent_span"], 16) == 0:
+                traces_names[row["trace"]] = (row["project"], row["transaction"])
+
+        traces_errors: Mapping[str, int] = {
+            row["trace"]: row["count()"] for row in traces_errors_results["data"]
+        }
+
+        traces_occurrences: Mapping[str, int] = {
+            row["trace"]: row["count()"] for row in traces_occurrences_results["data"]
+        }
+
+        traces_user_spans: Mapping[str, list[Mapping[str, Any]]] = defaultdict(list)
+        for row in user_spans_results["data"]:
+            traces_user_spans[row["trace"]].append(row)
+
+        traces_suggested_spans: Mapping[str, list[Mapping[str, Any]]] = defaultdict(list)
+        if suggested_spans_results:
+            for row in suggested_spans_results["data"]:
+                traces_suggested_spans[row["trace"]].append(row)
+
+        return [
+            {
+                "trace": row["trace"],
+                "numErrors": traces_errors.get(row["trace"], 0),
+                "numOccurrences": traces_occurrences.get(row["trace"], 0),
+                "numSpans": row["count()"],
+                "project": traces_names.get(row["trace"], (None, None))[0],
+                "name": traces_names.get(row["trace"], (None, None))[1],
+                "duration": row["last_seen()"] - row["first_seen()"],
+                "start": row["first_seen()"],
+                "end": row["last_seen()"],
+                "breakdowns": traces_breakdowns[row["trace"]],
+                "spans": [
+                    {field: span[field] for field in self.fields}
+                    for span in traces_user_spans[row["trace"]]
                 ],
-                orderby=["precise.start_ts", "precise.finish_ts"],
-                # limit the number of segments we fetch per trace so a single
-                # large trace does not result in the rest being blank
-                limitby=("trace", int(10_000 / len(trace_ids))),
-                limit=10_000,
-                config=QueryBuilderConfig(
-                    transform_alias_to_input_format=True,
-                ),
-            )
-            traces_breakdowns_query.add_conditions(
-                [Condition(Column("trace_id"), Op.IN, trace_ids)]
-            )
-        return traces_breakdowns_query
+                "suggestedSpans": [
+                    {field: span[field] for field in self.fields}
+                    for span in traces_suggested_spans[row["trace"]]
+                ],
+            }
+            for row in traces_metas_results["data"]
+        ]
+
+    def process_meta_results(self, results):
+        user_spans_results = results[4]
+        fields = user_spans_results["meta"].get("fields", {})
+        return {
+            **user_spans_results["meta"],
+            "fields": {field: fields[field] for field in self.fields},
+        }
+
+    def get_traces_breakdowns_query(
+        self,
+        params: ParamsType,
+        snuba_params: SnubaParams,
+        trace_ids: list[str],
+    ) -> QueryBuilder:
+        trace_ids_str = ",".join(trace_ids)
+        trace_ids_condition = f"trace:[{trace_ids_str}]"
+        return SpansIndexedQueryBuilder(
+            Dataset.SpansIndexed,
+            params,
+            snuba_params=snuba_params,
+            query=f"is_transaction:1 {trace_ids_condition}",
+            selected_columns=[
+                "trace",
+                "project",
+                "parent_span",
+                "transaction",
+                "precise.start_ts",
+                "precise.finish_ts",
+            ],
+            orderby=["precise.start_ts", "precise.finish_ts"],
+            # limit the number of segments we fetch per trace so a single
+            # large trace does not result in the rest being blank
+            limitby=("trace", int(10_000 / len(trace_ids))),
+            limit=10_000,
+            config=QueryBuilderConfig(
+                transform_alias_to_input_format=True,
+            ),
+        )
 
     def get_traces_metas_query(
         self,
         params: ParamsType,
         snuba_params: SnubaParams,
         trace_ids: list[str],
-    ) -> SpansIndexedQueryBuilder:
-        with handle_query_errors():
-            traces_metas_query = SpansIndexedQueryBuilder(
-                Dataset.SpansIndexed,
-                params,
-                snuba_params=snuba_params,
-                query=None,
-                selected_columns=[
-                    "trace",
-                    "count()",
-                    # TODO: count if of matching spans
-                    "first_seen()",
-                    "last_seen()",
-                ],
-                limit=len(trace_ids),
-                config=QueryBuilderConfig(
-                    functions_acl=["first_seen", "last_seen"],
-                    transform_alias_to_input_format=True,
-                ),
-            )
-            traces_metas_query.add_conditions([Condition(Column("trace_id"), Op.IN, trace_ids)])
-        return traces_metas_query
+    ) -> QueryBuilder:
+        trace_ids_str = ",".join(trace_ids)
+        trace_ids_condition = f"trace:[{trace_ids_str}]"
+        return SpansIndexedQueryBuilder(
+            Dataset.SpansIndexed,
+            params,
+            snuba_params=snuba_params,
+            query=trace_ids_condition,
+            selected_columns=[
+                "trace",
+                "count()",
+                # TODO: count if of matching spans
+                "first_seen()",
+                "last_seen()",
+            ],
+            limit=len(trace_ids),
+            config=QueryBuilderConfig(
+                functions_acl=["first_seen", "last_seen"],
+                transform_alias_to_input_format=True,
+            ),
+        )
 
     def get_traces_errors_query(
         self,
@@ -412,20 +520,19 @@ class OrganizationTracesEndpoint(OrganizationEventsV2EndpointBase):
         snuba_params: SnubaParams,
         trace_ids: list[str],
     ) -> QueryBuilder:
-        with handle_query_errors():
-            traces_errors_query = QueryBuilder(
-                Dataset.Events,
-                params,
-                snuba_params=snuba_params,
-                query=None,
-                selected_columns=["trace", "count()"],
-                limit=len(trace_ids),
-                config=QueryBuilderConfig(
-                    transform_alias_to_input_format=True,
-                ),
-            )
-            traces_errors_query.add_conditions([Condition(Column("trace_id"), Op.IN, trace_ids)])
-        return traces_errors_query
+        trace_ids_str = ",".join(trace_ids)
+        trace_ids_condition = f"trace:[{trace_ids_str}]"
+        return QueryBuilder(
+            Dataset.Events,
+            params,
+            snuba_params=snuba_params,
+            query=trace_ids_condition,
+            selected_columns=["trace", "count()"],
+            limit=len(trace_ids),
+            config=QueryBuilderConfig(
+                transform_alias_to_input_format=True,
+            ),
+        )
 
     def get_traces_occurrences_query(
         self,
@@ -433,94 +540,98 @@ class OrganizationTracesEndpoint(OrganizationEventsV2EndpointBase):
         snuba_params: SnubaParams,
         trace_ids: list[str],
     ) -> QueryBuilder:
-        with handle_query_errors():
-            traces_occurrences_query = QueryBuilder(
-                Dataset.IssuePlatform,
-                params,
-                snuba_params=snuba_params,
-                query=None,
-                selected_columns=["trace", "count()"],
-                limit=len(trace_ids),
-                config=QueryBuilderConfig(
-                    transform_alias_to_input_format=True,
-                ),
-            )
-            traces_occurrences_query.add_conditions(
-                [Condition(Column("trace_id"), Op.IN, trace_ids)]
-            )
-        return traces_occurrences_query
+        trace_ids_str = ",".join(trace_ids)
+        trace_ids_condition = f"trace:[{trace_ids_str}]"
+        return QueryBuilder(
+            Dataset.IssuePlatform,
+            params,
+            snuba_params=snuba_params,
+            query=trace_ids_condition,
+            selected_columns=["trace", "count()"],
+            limit=len(trace_ids),
+            config=QueryBuilderConfig(
+                transform_alias_to_input_format=True,
+            ),
+        )
 
-    def get_matching_spans_query(
+    def get_user_spans_query(
         self,
         params: ParamsType,
         snuba_params: SnubaParams,
-        fields: list[str],
-        sort: str | None,
-        user_queries: list[str],
-        suggested_query: str,
         trace_ids: list[str],
-        max_spans_per_trace: int,
-    ) -> tuple[SpansIndexedQueryBuilder, SpansIndexedQueryBuilder | None]:
-        trace_id_condition = Condition(Column("trace_id"), Op.IN, trace_ids)
+    ) -> QueryBuilder:
+        user_spans_query = SpansIndexedQueryBuilder(
+            Dataset.SpansIndexed,
+            params,
+            snuba_params=snuba_params,
+            query=None,  # Note: conditions are added below
+            selected_columns=["trace"] + self.fields,
+            orderby=self.sort,
+            limit=len(trace_ids) * self.max_spans_per_trace,
+            limitby=("trace", self.max_spans_per_trace),
+            config=QueryBuilderConfig(
+                transform_alias_to_input_format=True,
+            ),
+        )
 
-        with handle_query_errors():
-            user_spans_query = SpansIndexedQueryBuilder(
-                Dataset.SpansIndexed,
-                params,
-                snuba_params=snuba_params,
-                query=None,
-                selected_columns=["trace"] + fields,
-                orderby=sort,
-                limit=len(trace_ids) * max_spans_per_trace,
-                limitby=("trace", max_spans_per_trace),
-                config=QueryBuilderConfig(
-                    transform_alias_to_input_format=True,
-                ),
+        # First make sure that we only return spans from one of the traces identified
+        user_spans_query.add_conditions([Condition(Column("trace_id"), Op.IN, trace_ids)])
+
+        # Next we have to turn the user queries into the appropriate conditions in
+        # the SnQL that we produce.
+
+        # There are multiple sets of user conditions that needs to be satisfied
+        # and if a span satisfy any of them, it should be considered.
+        #
+        # To handle this use case, we want to OR all the user specified
+        # conditions together in this query.
+        conditions = []
+        for query in self.user_queries:
+            # We want to ignore all the aggregate conditions here because we're strictly
+            # searching on span attributes, not aggregates
+            where, _ = user_spans_query.resolve_conditions(query)
+            if len(where) > 1:
+                conditions.append(BooleanCondition(op=BooleanOp.AND, conditions=where))
+            elif len(where) == 1:
+                conditions.append(where[0])
+
+        if len(conditions) > 1:
+            # More than 1 set of user conditions were specified, we want to show
+            # spans that match any 1 of them so join the conditions with `OR`s.
+            user_spans_query.add_conditions(
+                [BooleanCondition(op=BooleanOp.OR, conditions=conditions)]
             )
+        elif len(conditions) == 1:
+            # Only 1 set of user conditions were specified, simply insert them into
+            # the final query.
+            user_spans_query.add_conditions([conditions[0]])
 
-            # There are multiple sets of user conditions that needs to be satisfied
-            # and if a span satisfy any of them, it should be considered.
-            #
-            # To handle this use case, we want to OR all the user specified
-            # conditions together in this query.
-            conditions = []
-            for query in user_queries:
-                # We want to ignore all the aggregate conditions here because we're strictly
-                # searching on span attributes, not aggregates
-                where, _ = user_spans_query.resolve_conditions(query)
-                if len(where) > 1:
-                    conditions.append(BooleanCondition(op=BooleanOp.AND, conditions=where))
-                elif len(where) == 1:
-                    conditions.append(where[0])
-
-            if len(conditions) > 1:
-                user_spans_query.add_conditions(
-                    [BooleanCondition(op=BooleanOp.OR, conditions=conditions)]
-                )
-            elif len(conditions) == 1:
-                user_spans_query.add_conditions([conditions[0]])
-            user_spans_query.add_conditions([trace_id_condition])
-
-        if all(user_query != suggested_query for user_query in user_queries):
-            with handle_query_errors():
-                suggested_spans_query = SpansIndexedQueryBuilder(
-                    Dataset.SpansIndexed,
-                    params,
-                    snuba_params=snuba_params,
-                    query=suggested_query,
-                    selected_columns=["trace"] + fields,
-                    orderby=sort,
-                    limit=len(trace_ids) * max_spans_per_trace,
-                    limitby=("trace", max_spans_per_trace),
-                    config=QueryBuilderConfig(
-                        transform_alias_to_input_format=True,
-                    ),
-                )
-                suggested_spans_query.add_conditions([trace_id_condition])
-        else:
-            suggested_spans_query = None
+        return user_spans_query
 
-        return user_spans_query, suggested_spans_query
+    def get_suggested_spans_query(
+        self,
+        params: ParamsType,
+        snuba_params: SnubaParams,
+        trace_ids: list[str],
+    ) -> QueryBuilder | None:
+        # If any user queries is the same as the suggested query, we don't have to run it
+        if any(user_query == self.suggested_query for user_query in self.user_queries):
+            return None
+        suggested_spans_query = SpansIndexedQueryBuilder(
+            Dataset.SpansIndexed,
+            params,
+            snuba_params=snuba_params,
+            query=self.suggested_query,
+            selected_columns=["trace"] + self.fields,
+            orderby=self.sort,
+            limit=len(trace_ids) * self.max_spans_per_trace,
+            limitby=("trace", self.max_spans_per_trace),
+            config=QueryBuilderConfig(
+                transform_alias_to_input_format=True,
+            ),
+        )
+        suggested_spans_query.add_conditions([Condition(Column("trace_id"), Op.IN, trace_ids)])
+        return suggested_spans_query
 
 
 def process_breakdowns(data, traces_range):