Browse Source

feat(trace): Use spans to build the trace tree (#62898)

- This adds a useSpans param to the trace endpoint so we can use the
spans table to build the trace tree. Should be much more performant
- Loads less data now though (ie. can't use detailed) since we can't
load all tags from clickhouse
William Mak 1 year ago
parent
commit
c2bb91644d

+ 342 - 22
src/sentry/api/endpoints/organization_events_trace.py

@@ -1,6 +1,7 @@
 import logging
 from collections import defaultdict, deque
 from concurrent.futures import ThreadPoolExecutor, as_completed
+from datetime import datetime
 from typing import (
     Any,
     Callable,
@@ -35,10 +36,11 @@ from sentry.eventstore.models import Event
 from sentry.issues.issue_occurrence import IssueOccurrence
 from sentry.models.group import Group
 from sentry.models.organization import Organization
-from sentry.search.events.builder import QueryBuilder
+from sentry.search.events.builder import QueryBuilder, SpansIndexedQueryBuilder
 from sentry.search.events.types import ParamsType, QueryBuilderConfig
 from sentry.snuba import discover
 from sentry.snuba.dataset import Dataset
+from sentry.snuba.referrer import Referrer
 from sentry.utils.dates import to_timestamp_from_iso_format
 from sentry.utils.numbers import base32_encode, format_grouped_length
 from sentry.utils.sdk import set_measurement
@@ -62,6 +64,7 @@ SnubaTransaction = TypedDict(
         "timestamp": str,
         "trace.span": str,
         "trace.parent_span": str,
+        "trace.parent_transaction": Optional[str],
         "root": str,
         "project.id": int,
         "project": str,
@@ -167,6 +170,7 @@ class TraceEvent:
         generation: Optional[int],
         light: bool = False,
         snuba_params: Optional[ParamsType] = None,
+        span_serialized: bool = False,
     ) -> None:
         self.event: SnubaTransaction = event
         self.errors: List[TraceError] = []
@@ -179,7 +183,10 @@ class TraceEvent:
 
         # Added as required because getting the nodestore_event is expensive
         self._nodestore_event: Optional[Event] = None
-        self.fetched_nodestore: bool = False
+        self.fetched_nodestore: bool = span_serialized
+        self.span_serialized = span_serialized
+        if span_serialized:
+            self.fetched_nodestore = True
         self.load_performance_issues(light, snuba_params)
 
     @property
@@ -200,35 +207,62 @@ class TraceEvent:
                 continue
 
             suspect_spans: List[str] = []
+            unique_spans: Set[str] = set()
             start: Optional[float] = None
             end: Optional[float] = None
             if light:
                 # This value doesn't matter for the light view
                 span = [self.event["trace.span"]]
+            elif "occurrence_spans" in self.event:
+                for problem in self.event["issue_occurrences"]:
+                    parent_span_ids = problem.evidence_data.get("parent_span_ids")
+                    if parent_span_ids is not None:
+                        unique_spans = unique_spans.union(parent_span_ids)
+                span = list(unique_spans)
+                for event_span in self.event["occurrence_spans"]:
+                    for problem in self.event["issue_occurrences"]:
+                        offender_span_ids = problem.evidence_data.get("offender_span_ids", [])
+                        if event_span.get("span_id") in offender_span_ids:
+                            try:
+                                end_timestamp = float(event_span.get("timestamp"))
+                                if end is None:
+                                    end = end_timestamp
+                                else:
+                                    end = max(end, end_timestamp)
+                                if end_timestamp is not None:
+                                    start_timestamp = float(
+                                        end_timestamp - event_span.get("span.duration")
+                                    )
+                                    if start is None:
+                                        start = start_timestamp
+                                    else:
+                                        start = min(start, start_timestamp)
+                            except ValueError:
+                                pass
+                            suspect_spans.append(event_span.get("span_id"))
             else:
-                if self.nodestore_event is not None:
+                if self.nodestore_event is not None or self.span_serialized:
                     occurrence_query = QueryBuilder(
                         Dataset.IssuePlatform,
                         snuba_params,
-                        query=f"event_id:{self.nodestore_event.event_id}",
+                        query=f"event_id:{self.event['id']}",
                         selected_columns=["occurrence_id"],
                     )
                     occurrence_ids = occurrence_query.process_results(
                         occurrence_query.run_query("api.trace-view.get-occurrence-ids")
                     )["data"]
 
-                    problems = IssueOccurrence.fetch_multi(
+                    issue_occurrences = IssueOccurrence.fetch_multi(
                         [occurrence.get("occurrence_id") for occurrence in occurrence_ids],
-                        self.nodestore_event.project_id,
+                        self.event["project.id"],
                     )
-                    unique_spans: Set[str] = set()
-                    for problem in problems:
+                    for problem in issue_occurrences:
                         parent_span_ids = problem.evidence_data.get("parent_span_ids")
                         if parent_span_ids is not None:
                             unique_spans = unique_spans.union(parent_span_ids)
                     span = list(unique_spans)
                     for event_span in self.nodestore_event.data.get("spans", []):
-                        for problem in problems:
+                        for problem in issue_occurrences:
                             offender_span_ids = problem.evidence_data.get("offender_span_ids", [])
                             if event_span.get("span_id") in offender_span_ids:
                                 try:
@@ -277,9 +311,11 @@ class TraceEvent:
             )
 
     def to_dict(self) -> LightResponse:
+        timestamp = datetime.fromisoformat(self.event["timestamp"]).timestamp()
         return {
             "event_id": self.event["id"],
             "span_id": self.event["trace.span"],
+            "timestamp": timestamp,
             "transaction": self.event["transaction"],
             "transaction.duration": self.event["transaction.duration"],
             "transaction.op": self.event["transaction.op"],
@@ -303,6 +339,12 @@ class TraceEvent:
                     ),
                 }
             )
+        if self.span_serialized:
+            result["timestamp"] = datetime.fromisoformat(self.event["timestamp"]).timestamp()
+            result["start_timestamp"] = (
+                datetime.fromisoformat(self.event["timestamp"]).timestamp()
+                - self.event["transaction.duration"]
+            )
         if self.nodestore_event:
             result["timestamp"] = self.nodestore_event.data.get("timestamp")
             result["start_timestamp"] = self.nodestore_event.data.get("start_timestamp")
@@ -342,9 +384,11 @@ def child_sort_key(item: TraceEvent) -> List[int]:
             item.nodestore_event.data["start_timestamp"],
             item.nodestore_event.data["timestamp"],
         ]
-    # The sorting of items without nodestore events doesn't matter cause we drop them
     else:
-        return [0]
+        return [
+            item.event["transaction"],
+            item.event["id"],
+        ]
 
 
 def count_performance_issues(trace_id: str, params: Mapping[str, str]) -> int:
@@ -389,8 +433,7 @@ def query_trace_data(
         Dataset.IssuePlatform,
         params,
         query=f"trace:{trace_id}",
-        selected_columns=["event_id"],
-        groupby_columns=["event_id"],
+        selected_columns=["event_id", "occurrence_id"],
         config=QueryBuilderConfig(
             functions_acl=["groupArray"],
         ),
@@ -398,7 +441,7 @@ def query_trace_data(
     occurrence_query.columns.append(
         Function("groupArray", parameters=[Column("group_id")], alias="issue.ids")
     )
-    occurrence_query.groupby = [Column("event_id")]
+    occurrence_query.groupby = [Column("event_id"), Column("occurrence_id")]
 
     error_query = QueryBuilder(
         Dataset.Events,
@@ -437,15 +480,113 @@ def query_trace_data(
     ]
 
     # Join group IDs from the occurrence dataset to transactions data
-    occurrence_dict = {row["event_id"]: row["issue.ids"] for row in transformed_results[2]}
+    occurrence_issue_ids = {row["event_id"]: row["issue.ids"] for row in transformed_results[2]}
+    occurrence_ids = {row["event_id"]: row["occurrence_id"] for row in transformed_results[2]}
     for result in transformed_results[0]:
-        result["issue.ids"] = occurrence_dict.get(result["id"], [])
+        result["issue.ids"] = occurrence_issue_ids.get(result["id"], {})
+        result["occurrence_id"] = occurrence_ids.get(result["id"])
+        result["trace.parent_transaction"] = None
 
     return cast(Sequence[SnubaTransaction], transformed_results[0]), cast(
         Sequence[SnubaError], transformed_results[1]
     )
 
 
+def augment_transactions_with_spans(
+    transactions: Sequence[SnubaTransaction],
+    errors: Sequence[SnubaError],
+    trace_id: str,
+    params: Mapping[str, str],
+) -> Sequence[SnubaTransaction]:
+    """Augment the list of transactions with parent, error and problem data"""
+    trace_parent_spans = set()  # parent span ids of segment spans
+    transaction_problem_map = {}
+    problem_project_map = {}
+    issue_occurrences = []
+    occurrence_spans = set()
+    error_spans = {e["trace.span"] for e in errors if e["trace.span"]}
+
+    for transaction in transactions:
+        transaction["occurrence_spans"] = []
+        transaction["issue_occurrences"] = []
+
+        # Pull out occurrence data
+        transaction_problem_map[transaction["id"]] = transaction
+        project = transaction["project.id"]
+        if project not in problem_project_map:
+            problem_project_map[project] = []
+        problem_project_map[project].append(transaction["occurrence_id"])
+
+        # Need to strip the leading "0"s to match our query to the spans table
+        # This is cause spans are stored as UInt64, so a span like 0011
+        # converted to an int then converted to a hex will become 11
+        # so when we query snuba we need to remove the 00s ourselves as well
+        if not transaction["trace.parent_span"]:
+            continue
+        transaction["trace.parent_span.stripped"] = (
+            str(hex(int(transaction["trace.parent_span"], 16))).lstrip("0x")
+            if transaction["trace.parent_span"].startswith("00")
+            else transaction["trace.parent_span"]
+        )
+        # parent span ids of the segment spans
+        trace_parent_spans.add(transaction["trace.parent_span.stripped"])
+
+    for project, occurrences in problem_project_map.items():
+        if occurrences:
+            issue_occurrences.extend(
+                [
+                    occurrence
+                    for occurrence in IssueOccurrence.fetch_multi(occurrences, project)
+                    if occurrence is not None
+                ]
+            )
+
+    for problem in issue_occurrences:
+        occurrence_spans = occurrence_spans.union(set(problem.evidence_data["offender_span_ids"]))
+
+    query_spans = {*trace_parent_spans, *error_spans, *occurrence_spans}
+    if "" in query_spans:
+        query_spans.remove("")
+    # If there are no spans to query just return transactions as is
+    if len(query_spans) == 0:
+        return transactions
+
+    # Fetch parent span ids of segment spans and their corresponding
+    # transaction id so we can link parent/child transactions in
+    # a trace.
+    parents_results = SpansIndexedQueryBuilder(
+        Dataset.SpansIndexed,
+        params,
+        query=f"trace:{trace_id} span_id:[{','.join(query_spans)}]",
+        selected_columns=[
+            "transaction.id",
+            "span_id",
+            "timestamp",
+        ],
+        orderby=["timestamp", "id"],
+        limit=10000,
+    ).run_query(referrer=Referrer.API_TRACE_VIEW_GET_PARENTS.value)
+
+    parent_map = {parent["span_id"]: parent for parent in parents_results["data"]}
+    for transaction in transactions:
+        # For a given transaction, if parent span id exists in the tranaction (so this is
+        # not a root span), see if the indexed spans data can tell us what the parent
+        # transaction id is.
+        if "trace.parent_span.stripped" in transaction:
+            if parent := parent_map.get(transaction["trace.parent_span.stripped"]):
+                transaction["trace.parent_transaction"] = parent["transaction.id"]
+    for problem in issue_occurrences:
+        for span_id in problem.evidence_data["offender_span_ids"]:
+            if parent := parent_map.get(span_id):
+                transaction = transaction_problem_map[problem.event_id]
+                transaction["occurrence_spans"].append(parent)
+                transaction["issue_occurrences"].append(problem)
+    for error in errors:
+        if parent := parent_map.get(error["trace.span"]):
+            error["trace.transaction"] = parent["transaction.id"]
+    return transactions
+
+
 class OrganizationEventsTraceEndpointBase(OrganizationEventsV2EndpointBase):
     publish_status = {
         "GET": ApiPublishStatus.PRIVATE,
@@ -535,7 +676,11 @@ class OrganizationEventsTraceEndpointBase(OrganizationEventsV2EndpointBase):
             actor=request.user,
         )
 
+        # Detailed is deprecated now that we want to use spans instead
         detailed: bool = request.GET.get("detailed", "0") == "1"
+        use_spans: bool = request.GET.get("useSpans", "0") == "1"
+        if detailed and use_spans:
+            raise ParseError("Cannot return a detailed response while using spans")
         limit: int = (
             min(int(request.GET.get("limit", MAX_TRACE_SIZE)), 2000)
             if trace_view_load_more_enabled
@@ -554,18 +699,24 @@ class OrganizationEventsTraceEndpointBase(OrganizationEventsV2EndpointBase):
         )
         with handle_query_errors():
             transactions, errors = query_trace_data(trace_id, params, limit)
+            if use_spans:
+                transactions = augment_transactions_with_spans(
+                    transactions, errors, trace_id, params
+                )
             if len(transactions) == 0 and not tracing_without_performance_enabled:
                 return Response(status=404)
             self.record_analytics(transactions, trace_id, self.request.user.id, organization.id)
 
         warning_extra: Dict[str, str] = {"trace": trace_id, "organization": organization.slug}
 
-        # Look for the roots
+        # Look for all root transactions in the trace (i.e., transactions
+        # that explicitly have no parent span id)
         roots: List[SnubaTransaction] = []
         for item in transactions:
             if is_root(item):
                 roots.append(item)
             else:
+                # This is okay because the query does an order by on -root
                 break
         if len(roots) > 1:
             sentry_sdk.set_tag("discover.trace-view.warning", "root.extra-found")
@@ -585,6 +736,7 @@ class OrganizationEventsTraceEndpointBase(OrganizationEventsV2EndpointBase):
                 detailed,
                 tracing_without_performance_enabled,
                 trace_view_load_more_enabled,
+                use_spans,
             )
         )
 
@@ -661,8 +813,11 @@ class OrganizationEventsTraceLightEndpoint(OrganizationEventsTraceEndpointBase):
         detailed: bool = False,
         allow_orphan_errors: bool = False,
         allow_load_more: bool = False,
+        use_spans: bool = False,
     ) -> Sequence[LightResponse]:
         """Because the light endpoint could potentially have gaps between root and event we return a flattened list"""
+        if use_spans:
+            raise ParseError(detail="useSpans isn't supported on the trace-light")
         if event_id is None:
             raise ParseError(detail="An event_id is required for the light trace")
         snuba_event, nodestore_event = self.get_current_transaction(
@@ -828,12 +983,26 @@ class OrganizationEventsTraceEndpoint(OrganizationEventsTraceEndpointBase):
         detailed: bool = False,
         allow_orphan_errors: bool = False,
         allow_load_more: bool = False,
+        use_spans: bool = False,
     ) -> Sequence[FullResponse]:
         """For the full event trace, we return the results as a graph instead of a flattened list
 
         if event_id is passed, we prune any potential branches of the trace to make as few nodestore calls as
         possible
         """
+        if use_spans:
+            results = self.serialize_with_spans(
+                limit,
+                transactions,
+                errors,
+                roots,
+                warning_extra,
+                event_id,
+                detailed,
+                allow_orphan_errors,
+                allow_load_more,
+            )
+            return results
         event_id_to_nodestore_event = (
             self.nodestore_event_map(transactions) if allow_load_more else {}
         )
@@ -976,7 +1145,7 @@ class OrganizationEventsTraceEndpoint(OrganizationEventsTraceEndpointBase):
                 if iteration > limit:
                     break
 
-        root_traces: List[TraceEvent] = []
+        trace_roots: List[TraceEvent] = []
         orphans: List[TraceEvent] = []
         for index, result in enumerate(results_map.values()):
             for subtrace in result:
@@ -984,9 +1153,9 @@ class OrganizationEventsTraceEndpoint(OrganizationEventsTraceEndpointBase):
             if index > 0 or len(roots) == 0:
                 orphans.extend(result)
             elif len(roots) > 0:
-                root_traces = result
+                trace_roots = result
         # We sort orphans and roots separately because we always want the root(s) as the first element(s)
-        root_traces.sort(key=child_sort_key)
+        trace_roots.sort(key=child_sort_key)
         orphans.sort(key=child_sort_key)
         orphan_errors = sorted(orphan_errors, key=lambda k: k["timestamp"])
 
@@ -996,17 +1165,168 @@ class OrganizationEventsTraceEndpoint(OrganizationEventsTraceEndpointBase):
 
         if allow_orphan_errors:
             return {
-                "transactions": [trace.full_dict(detailed) for trace in root_traces]
+                "transactions": [trace.full_dict(detailed) for trace in trace_roots]
                 + [orphan.full_dict(detailed) for orphan in orphans],
                 "orphan_errors": [orphan for orphan in orphan_errors],
             }
 
         return (
-            [trace.full_dict(detailed) for trace in root_traces]
+            [trace.full_dict(detailed) for trace in trace_roots]
             + [orphan.full_dict(detailed) for orphan in orphans]
             + [orphan for orphan in orphan_errors]
         )
 
+    def serialize_with_spans(
+        self,
+        limit: int,
+        transactions: Sequence[SnubaTransaction],
+        errors: Sequence[SnubaError],
+        roots: Sequence[SnubaTransaction],
+        warning_extra: Dict[str, str],
+        event_id: Optional[str],
+        detailed: bool = False,
+        allow_orphan_errors: bool = False,
+        allow_load_more: bool = False,
+    ) -> Sequence[FullResponse]:
+        root_traces: List[TraceEvent] = []
+        orphans: List[TraceEvent] = []
+        visited_transactions: Set[str] = set()
+        visited_errors: Set[str] = set()
+        if not allow_orphan_errors:
+            raise ParseError("Must allow orphan errors to useSpans")
+        if detailed:
+            raise ParseError("Cannot return a detailed response using Spans")
+
+        # A trace can have multiple roots, so we want to visit
+        # all roots in a trace and build their children.
+        # A root segment is one that doesn't have a parent span id
+        # but here is identified by the attribute "root" = 1 on
+        # a SnubaTransaction object.
+        root_traces = self.visit_transactions(
+            roots,
+            transactions,
+            errors,
+            visited_transactions,
+            visited_errors,
+        )
+
+        # At this point all the roots have their tree built. Remaining
+        # transactions are either orphan transactions or children of
+        # orphan transactions. Orphan transactions (unlike roots) have
+        # a parent_id but the parent_id wasn't found (dropped span).
+        # We get a sorted list of these transactions by start timestamp.
+        remaining_transactions = self.calculate_remaining_transactions(
+            transactions, visited_transactions
+        )
+
+        # Determine orphan transactions. `trace.parent_transaction` on a
+        # transaction is set when the indexed spans dataset has a row for
+        # the parent span id for this transaction. Since we already considered
+        # the root spans cases, the remaining spans with no parent transaction
+        # id are orphan transactions.
+        orphan_roots = [
+            orphan
+            for orphan in remaining_transactions
+            if orphan["trace.parent_transaction"] is None
+        ]
+
+        # Build the trees for all the orphan transactions.
+        orphans = self.visit_transactions(
+            orphan_roots,
+            remaining_transactions,
+            errors,
+            visited_transactions,
+            visited_errors,
+        )
+
+        # Remaining are transactions with parent transactions but those
+        # parents don't map to any of the existing transactions.
+        remaining_transactions = self.calculate_remaining_transactions(
+            transactions, visited_transactions
+        )
+        orphans.extend(
+            self.visit_transactions(
+                remaining_transactions,
+                remaining_transactions,
+                errors,
+                visited_transactions,
+                visited_errors,
+            )
+        )
+
+        # Sort the results so they're consistent
+        orphan_errors = sorted(
+            [error for error in errors if error["id"] not in visited_errors],
+            key=lambda k: k["timestamp"],
+        )
+        root_traces.sort(key=child_sort_key)
+        orphans.sort(key=child_sort_key)
+
+        return {
+            "transactions": [trace.full_dict(detailed) for trace in root_traces]
+            + [orphan.full_dict(detailed) for orphan in orphans],
+            "orphan_errors": [self.serialize_error(error) for error in orphan_errors],
+        }
+
+    def calculate_remaining_transactions(self, transactions, visited_transactions):
+        return sorted(
+            [
+                transaction
+                for transaction in transactions
+                if transaction["id"] not in visited_transactions
+            ],
+            key=lambda k: -datetime.fromisoformat(k["timestamp"]).timestamp(),
+        )
+
+    def visit_transactions(
+        self, to_visit, transactions, errors, visited_transactions, visited_errors
+    ):
+        serialized_events: List[TraceEvent] = []
+        for transaction in to_visit:
+            if transaction["id"] in visited_transactions:
+                continue
+            visited_transactions.add(transaction["id"])
+            root_event = TraceEvent(transaction, None, 0, span_serialized=True)
+            self.add_children(
+                root_event, transactions, visited_transactions, errors, visited_errors, 1
+            )
+            serialized_events.append(root_event)
+        return serialized_events
+
+    def add_children(
+        self, parent, transactions, visited_transactions, errors, visited_errors, generation
+    ):
+        for error in errors:
+            if error["id"] in visited_errors:
+                continue
+            if "trace.transaction" in error and error["trace.transaction"] == parent.event["id"]:
+                visited_errors.add(error["id"])
+                parent.errors.append(self.serialize_error(error))
+
+        # Loop through all the transactions to see if any of them are
+        # children.
+        for transaction in transactions:
+            if transaction["id"] in visited_transactions:
+                continue
+            if transaction["trace.parent_transaction"] == parent.event["id"]:
+                # If transaction is a child, establish that relationship and add it
+                # to visited_transactions.
+                visited_transactions.add(transaction["id"])
+                new_child = TraceEvent(
+                    transaction, parent.event["id"], generation, span_serialized=True
+                )
+                # Repeat adding children until there are none.
+                self.add_children(
+                    new_child,
+                    transactions,
+                    visited_transactions,
+                    errors,
+                    visited_errors,
+                    generation + 1,
+                )
+                parent.children.append(new_child)
+        parent.children.sort(key=child_sort_key)
+
 
 @region_silo_endpoint
 class OrganizationEventsTraceMetaEndpoint(OrganizationEventsTraceEndpointBase):

+ 1 - 0
src/sentry/snuba/referrer.py

@@ -445,6 +445,7 @@ class Referrer(Enum):
     API_TRACE_VIEW_HOVER_CARD = "api.trace-view.hover-card"
     API_TRACE_VIEW_SPAN_DETAIL = "api.trace-view.span-detail"
     API_TRACE_VIEW_COUNT_PERFORMANCE_ISSUES = "api.trace-view.count-performance-issues"
+    API_TRACE_VIEW_GET_PARENTS = "api.trace-view.get-parents"
     API_TRENDS_GET_EVENT_STATS = "api.trends.get-event-stats"
     API_TRENDS_GET_EVENT_STATS_V2_TOP_EVENTS = "api.trends.get-event-statsv2.top-events"
     API_TRENDS_GET_EVENT_STATS_V2_TOP_EVENTS_PRIMARY = (

+ 1 - 0
src/sentry/utils/snuba.py

@@ -125,6 +125,7 @@ SPAN_COLUMN_MAP = {
     "trace": "trace_id",
     "transaction": "segment_name",
     "transaction.id": "transaction_id",
+    "segment.id": "segment_id",
     "transaction.op": "transaction_op",
     "user": "user",
     "profile_id": "profile_id",

+ 51 - 1
tests/snuba/api/endpoints/test_organization_events.py

@@ -1,7 +1,9 @@
 import math
 import uuid
-from datetime import timedelta, timezone
+from datetime import datetime, timedelta, timezone
+from typing import Any, Dict
 from unittest import mock
+from uuid import uuid4
 
 import pytest
 from django.test import override_settings
@@ -42,6 +44,14 @@ MAX_QUERYABLE_TRANSACTION_THRESHOLDS = 1
 class OrganizationEventsEndpointTestBase(APITestCase, SnubaTestCase):
     viewname = "sentry-api-0-organization-events"
     referrer = "api.organization-events"
+    # Some base data for create_span
+    base_span: Dict[str, Any] = {
+        "is_segment": False,
+        "retention_days": 90,
+        "tags": {},
+        "sentry_tags": {},
+        "measurements": {},
+    }
 
     def setUp(self):
         super().setUp()
@@ -92,6 +102,46 @@ class OrganizationEventsEndpointTestBase(APITestCase, SnubaTestCase):
 
         return load_data(platform, timestamp=timestamp, start_timestamp=start_timestamp, **kwargs)
 
+    def create_span(
+        self, extra_data=None, organization=None, project=None, start_ts=None, duration=1000
+    ):
+        """Create span json, not required for store_span, but with no params passed should just work out of the box"""
+        if organization is None:
+            organization = self.organization
+        if project is None:
+            project = self.project
+        if start_ts is None:
+            start_ts = datetime.now() - timedelta(minutes=1)
+        if extra_data is None:
+            extra_data = {}
+        span = self.base_span.copy()
+        # Load some defaults
+        span.update(
+            {
+                "event_id": uuid4().hex,
+                "organization_id": organization.id,
+                "project_id": project.id,
+                "trace_id": uuid4().hex,
+                "span_id": uuid4().hex[:16],
+                "parent_span_id": uuid4().hex[:16],
+                "segment_id": uuid4().hex[:16],
+                "group_raw": uuid4().hex[:16],
+                "profile_id": uuid4().hex,
+                # Multiply by 1000 cause it needs to be ms
+                "start_timestamp_ms": int(start_ts.timestamp() * 1000),
+                "timestamp": int(start_ts.timestamp() * 1000),
+                "received": start_ts.timestamp(),
+                "duration_ms": duration,
+                "exclusive_time_ms": duration,
+            }
+        )
+        # Load any specific custom data
+        span.update(extra_data)
+        # coerce to string
+        for tag, value in dict(span["tags"]).items():
+            span["tags"][tag] = str(value)
+        return span
+
 
 @region_silo_test
 class OrganizationEventsEndpointTest(OrganizationEventsEndpointTestBase, PerformanceIssueTestCase):

+ 0 - 47
tests/snuba/api/endpoints/test_organization_events_span_indexed.py

@@ -1,6 +1,3 @@
-from datetime import datetime, timedelta
-from uuid import uuid4
-
 from tests.snuba.api.endpoints.test_organization_events import OrganizationEventsEndpointTestBase
 
 
@@ -18,56 +15,12 @@ class OrganizationEventsSpanIndexedEndpointTest(OrganizationEventsEndpointTestBa
     Once span ingestion is on by default this will no longer need to be done
     """
 
-    # Some base data for create_span
-    base_span = {
-        "is_segment": False,
-        "retention_days": 90,
-        "tags": {},
-        "sentry_tags": {},
-        "measurements": {},
-    }
-
     def setUp(self):
         super().setUp()
         self.features = {
             "organizations:starfish-view": True,
         }
 
-    def create_span(
-        self, extra_data=None, organization=None, project=None, start_ts=None, duration=1000
-    ):
-        """Create span json, not required for store_span, but with no params passed should just work out of the box"""
-        if organization is None:
-            organization = self.organization
-        if project is None:
-            project = self.project
-        if start_ts is None:
-            start_ts = datetime.now() - timedelta(minutes=1)
-        if extra_data is None:
-            extra_data = {}
-        span = self.base_span.copy()
-        # Load some defaults
-        span.update(
-            {
-                "event_id": uuid4().hex,
-                "organization_id": organization.id,
-                "project_id": project.id,
-                "trace_id": uuid4().hex,
-                "span_id": uuid4().hex[:16],
-                "parent_span_id": uuid4().hex[:16],
-                "segment_id": uuid4().hex[:16],
-                "profile_id": uuid4().hex,
-                # Multiply by 1000 cause it needs to be ms
-                "start_timestamp_ms": int(start_ts.timestamp() * 1000),
-                "received": start_ts.timestamp(),
-                "duration_ms": duration,
-                "exclusive_time_ms": duration,
-            }
-        )
-        # Load any specific custom data
-        span.update(extra_data)
-        return span
-
     def test_simple(self):
         self.store_spans(
             [

+ 103 - 60
tests/snuba/api/endpoints/test_organization_events_trace.py

@@ -7,15 +7,15 @@ from django.urls import NoReverseMatch, reverse
 
 from sentry import options
 from sentry.issues.grouptype import NoiseConfig, PerformanceFileIOMainThreadGroupType
-from sentry.testutils.cases import APITestCase, SnubaTestCase
 from sentry.testutils.helpers import override_options
 from sentry.testutils.helpers.datetime import before_now, iso_format
 from sentry.testutils.silo import region_silo_test
 from sentry.utils.dates import to_timestamp_from_iso_format
 from sentry.utils.samples import load_data
+from tests.snuba.api.endpoints.test_organization_events import OrganizationEventsEndpointTestBase
 
 
-class OrganizationEventsTraceEndpointBase(APITestCase, SnubaTestCase):
+class OrganizationEventsTraceEndpointBase(OrganizationEventsEndpointTestBase):
     url_name: str
     FEATURES = [
         "organizations:performance-view",
@@ -74,7 +74,42 @@ class OrganizationEventsTraceEndpointBase(APITestCase, SnubaTestCase):
                     "performance-file-io-main-thread-creation": 1.0,
                 }
             ):
-                return self.store_event(data, project_id=project_id, **kwargs)
+                event = self.store_event(data, project_id=project_id, **kwargs)
+                for span in data["spans"]:
+                    if span:
+                        span.update({"event_id": event.event_id})
+                        self.store_span(self.create_span(span))
+                self.store_span(self.convert_event_data_to_span(event))
+                return event
+
+    def convert_event_data_to_span(self, event):
+        trace_context = event.data["contexts"]["trace"]
+        start_ts = event.data["start_timestamp"]
+        end_ts = event.data["timestamp"]
+        span_data = self.create_span(
+            {
+                "event_id": event.event_id,
+                "organization_id": event.organization.id,
+                "project_id": event.project.id,
+                "trace_id": trace_context["trace_id"],
+                "span_id": trace_context["span_id"],
+                "parent_span_id": trace_context.get("parent_span_id", "0" * 12),
+                "description": event.data["transaction"],
+                "segment_id": uuid4().hex[:16],
+                "group_raw": uuid4().hex[:16],
+                "profile_id": uuid4().hex,
+                # Multiply by 1000 cause it needs to be ms
+                "start_timestamp_ms": int(start_ts * 1000),
+                "received": start_ts,
+                "duration_ms": int(end_ts - start_ts),
+            }
+        )
+        if "parent_span_id" in trace_context:
+            span_data["parent_span_id"] = trace_context["parent_span_id"]
+        else:
+            del span_data["parent_span_id"]
+
+        return span_data
 
     def setUp(self):
         """
@@ -752,9 +787,10 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
     url_name = "sentry-api-0-organization-events-trace"
 
     def assert_event(self, result, event_data, message):
+        assert result["transaction"] == event_data.transaction, message
         assert result["event_id"] == event_data.event_id, message
-        assert result["timestamp"] == event_data.data["timestamp"], message
-        assert result["start_timestamp"] == event_data.data["start_timestamp"], message
+        # assert result["timestamp"] == event_data.data["timestamp"], message
+        # assert result["start_timestamp"] == event_data.data["start_timestamp"], message
 
     def assert_trace_data(self, root, gen2_no_children=True):
         """see the setUp docstring for an idea of what the response structure looks like"""
@@ -795,6 +831,15 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
             elif gen2_no_children:
                 assert len(gen2["children"]) == 0
 
+    def client_get(self, data, url=None):
+        if url is None:
+            url = self.url
+        return self.client.get(
+            url,
+            data,
+            format="json",
+        )
+
     def test_no_projects(self):
         user = self.create_user()
         org = self.create_organization(owner=user)
@@ -816,10 +861,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
     def test_simple(self):
         self.load_trace()
         with self.feature(self.FEATURES):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1},
-                format="json",
             )
         assert response.status_code == 200, response.content
         trace_transaction = response.data["transactions"][0]
@@ -832,10 +875,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
     def test_simple_with_limit(self):
         self.load_trace()
         with self.feature(self.FEATURES):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1, "limit": 200},
-                format="json",
             )
         assert response.status_code == 200, response.content
         trace_transaction = response.data["transactions"][0]
@@ -848,10 +889,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
     def test_detailed_trace(self):
         self.load_trace()
         with self.feature(self.FEATURES):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1, "detailed": 1},
-                format="json",
             )
         assert response.status_code == 200, response.content
         trace_transaction = response.data["transactions"][0]
@@ -893,10 +932,9 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
         )
 
         with self.feature(self.FEATURES):
-            response = self.client.get(
-                url,
+            response = self.client_get(
                 data={"project": -1, "detailed": 1},
-                format="json",
+                url=url,
             )
 
         assert response.status_code == 200, response.content
@@ -933,10 +971,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
         )
 
         with self.feature(self.FEATURES):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1},
-                format="json",
             )
 
         assert response.status_code == 200, response.content
@@ -988,10 +1024,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
             duration=300,
         )
         with self.feature(self.FEATURES):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1},
-                format="json",
             )
 
         assert response.status_code == 200, response.content
@@ -1028,10 +1062,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
             kwargs={"organization_slug": self.project.organization.slug, "trace_id": trace_id},
         )
         with self.feature(self.FEATURES):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1},
-                format="json",
             )
 
         assert response.status_code == 200, response.content
@@ -1062,10 +1094,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
         ]
 
         with self.feature(self.FEATURES):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1},
-                format="json",
             )
 
         assert response.status_code == 200, response.content
@@ -1098,10 +1128,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
         )
 
         with self.feature(self.FEATURES):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1},
-                format="json",
             )
 
         assert response.status_code == 200, response.content
@@ -1177,10 +1205,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
         )
 
         with self.feature(self.FEATURES):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1},
-                format="json",
             )
 
         assert response.status_code == 200, response.content
@@ -1207,17 +1233,15 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
         error, error1 = self.load_errors()
 
         with self.feature(self.FEATURES):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1},
-                format="json",
             )
 
         assert response.status_code == 200, response.content
         self.assert_trace_data(response.data["transactions"][0])
         gen1_event = response.data["transactions"][0]["children"][0]
         assert len(gen1_event["errors"]) == 2
-        assert {
+        data = {
             "event_id": error.event_id,
             "issue_id": error.group_id,
             "span": self.gen1_span_ids[0],
@@ -1228,8 +1252,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
             "timestamp": to_timestamp_from_iso_format(error.timestamp),
             "generation": 0,
             "event_type": "error",
-        } in gen1_event["errors"]
-        assert {
+        }
+        data1 = {
             "event_id": error1.event_id,
             "issue_id": error1.group_id,
             "span": self.gen1_span_ids[0],
@@ -1240,7 +1264,9 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
             "timestamp": to_timestamp_from_iso_format(error1.timestamp),
             "generation": 0,
             "event_type": "error",
-        } in gen1_event["errors"]
+        }
+        assert data in gen1_event["errors"]
+        assert data1 in gen1_event["errors"]
 
     def test_with_only_orphan_errors_with_same_span_ids(self):
         span_id = uuid4().hex[:16]
@@ -1275,10 +1301,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
         with self.feature(
             [*self.FEATURES, "organizations:performance-tracing-without-performance"]
         ):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1},
-                format="json",
             )
         assert response.status_code == 200, response.content
         assert len(response.data) == 2
@@ -1334,10 +1358,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
         with self.feature(
             [*self.FEATURES, "organizations:performance-tracing-without-performance"]
         ):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1},
-                format="json",
             )
         assert response.status_code == 200, response.content
         assert len(response.data["orphan_errors"]) == 2
@@ -1392,10 +1414,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
         with self.feature(
             [*self.FEATURES, "organizations:performance-tracing-without-performance"]
         ):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1},
-                format="json",
             )
         assert response.status_code == 200, response.content
         assert len(response.data["transactions"]) == 1
@@ -1419,10 +1439,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
         start, _ = self.get_start_end(1000)
         default_event = self.load_default()
         with self.feature(self.FEATURES):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1},
-                format="json",
             )
 
         assert response.status_code == 200, response.content
@@ -1446,10 +1464,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
         self.load_trace()
         # Pruning shouldn't happen for the root event
         with self.feature(self.FEATURES):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1, "event_id": self.root_event.event_id},
-                format="json",
             )
         assert response.status_code == 200, response.content
         self.assert_trace_data(response.data["transactions"][0])
@@ -1457,10 +1473,8 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
     def test_pruning_event(self):
         self.load_trace()
         with self.feature(self.FEATURES):
-            response = self.client.get(
-                self.url,
+            response = self.client_get(
                 data={"project": -1, "event_id": self.gen2_events[0].event_id},
-                format="json",
             )
         assert response.status_code == 200, response.content
         root = response.data["transactions"][0]
@@ -1483,6 +1497,35 @@ class OrganizationEventsTraceEndpointTest(OrganizationEventsTraceEndpointBase):
                 assert len(gen1["children"]) == 0
 
 
+class OrganizationEventsTraceEndpointTestUsingSpans(OrganizationEventsTraceEndpointTest):
+    def client_get(self, data, url=None):
+        data["useSpans"] = 1
+        return super().client_get(data, url)
+
+    @pytest.mark.skip(
+        "Loops can only be orphans cause the most recent parent to be saved will overwrite the previous"
+    )
+    def test_bad_span_loop(self):
+        super().test_bad_span_loop()
+
+    @pytest.mark.skip("Can't use the detailed response with useSpans on")
+    def test_detailed_trace_with_bad_tags(self):
+        super().test_detailed_trace_with_bad_tags()
+
+    @pytest.mark.skip("We shouldn't need to prune with events anymore since spans should be faster")
+    def test_pruning_event(self):
+        super().test_pruning_event()
+
+    def test_detailed_trace(self):
+        """Can't use detailed with useSpans, so this should actually just 400"""
+        with self.feature(self.FEATURES):
+            response = self.client_get(
+                data={"project": -1, "detailed": 1},
+            )
+
+        assert response.status_code == 400, response.content
+
+
 @region_silo_test
 class OrganizationEventsTraceMetaEndpointTest(OrganizationEventsTraceEndpointBase):
     url_name = "sentry-api-0-organization-events-trace-meta"