Просмотр исходного кода

fix(trends): Cross query pollution in trends endpoint (#67878)

When there are 2 projects that share a transaction name and the 2
transactions are split into 2 different parallel queries, we have an
issue where the query returns more data is expected and pollutes the
results with duplicate results once the 2 queries are merged. This
results in a validation error in Seer.
Tony Xiao 11 месяцев назад
Родитель
Сommit
930c9729e8

+ 25 - 41
src/sentry/api/endpoints/organization_events_trends_v2.py

@@ -1,7 +1,6 @@
 import logging
 import re
 from concurrent.futures import ThreadPoolExecutor
-from typing import Any
 
 import sentry_sdk
 from rest_framework.exceptions import ParseError
@@ -22,6 +21,7 @@ from sentry.snuba.discover import create_result_key, zerofill
 from sentry.snuba.metrics_performance import query as metrics_query
 from sentry.snuba.referrer import Referrer
 from sentry.types.ratelimit import RateLimit, RateLimitCategory
+from sentry.utils.iterators import chunked
 from sentry.utils.snuba import SnubaTSResult
 
 logger = logging.getLogger(__name__)
@@ -90,8 +90,6 @@ class OrganizationEventsNewTrendsStatsEndpoint(OrganizationEventsV2EndpointBase)
 
         query = request.GET.get("query")
 
-        top_trending_transactions = {}
-
         def get_top_events(user_query, params, event_limit, referrer):
             top_event_columns = selected_columns[:]
             top_event_columns.append("count()")
@@ -113,21 +111,22 @@ class OrganizationEventsNewTrendsStatsEndpoint(OrganizationEventsV2EndpointBase)
             )
 
         def generate_top_transaction_query(events):
-            top_transaction_names = [
-                re.sub(r'"', '\\"', event.get("transaction")) for event in events
+            pairs = [
+                (event["project_id"], re.sub(r'"', '\\"', event["transaction"])) for event in events
             ]
-            top_transaction_as_str = ", ".join(
-                f'"{transaction}"' for transaction in top_transaction_names
-            )
-            return f"transaction:[{top_transaction_as_str}]"
+            conditions = [
+                f'(project_id:{project_id} transaction:"{transaction}")'
+                for project_id, transaction in pairs
+            ]
+            return " OR ".join(conditions)
 
         def get_timeseries(top_events, _, rollup, zerofill_results):
             # Split top events into multiple queries for bulk timeseries query
             data = top_events["data"]
-            split_top_events = [
-                data[i : i + EVENTS_PER_QUERY] for i in range(0, len(data), EVENTS_PER_QUERY)
+
+            queries = [
+                generate_top_transaction_query(chunk) for chunk in chunked(data, EVENTS_PER_QUERY)
             ]
-            queries = [generate_top_transaction_query(t_e) for t_e in split_top_events]
 
             timeseries_columns = selected_columns[:]
             timeseries_columns.append(trend_function)
@@ -162,7 +161,7 @@ class OrganizationEventsNewTrendsStatsEndpoint(OrganizationEventsV2EndpointBase)
             translated_groupby = ["project_id", "transaction"]
             results = {}
             formatted_results = {}
-            for index, item in enumerate(top_events["data"]):
+            for index, item in enumerate(data):
                 result_key = create_result_key(item, translated_groupby, {})
                 results[result_key] = {
                     "order": index,
@@ -212,7 +211,6 @@ class OrganizationEventsNewTrendsStatsEndpoint(OrganizationEventsV2EndpointBase)
             )
 
             # Fetch transactions names with the highest event count
-            nonlocal top_trending_transactions
             top_trending_transactions = get_top_events(
                 user_query=user_query,
                 params=params,
@@ -243,36 +241,22 @@ class OrganizationEventsNewTrendsStatsEndpoint(OrganizationEventsV2EndpointBase)
             return data
 
         def get_trends_data(stats_data, request):
-            trend_function = request.GET.get("trendFunction", "p50()")
-
-            trends_request: dict[str, Any] = {
-                "data": {},
-                "sort": None,
-                "trendFunction": None,
-            }
-
-            trends_request["sort"] = (
-                "" if trend_type == ANY else request.GET.get("sort", "trend_percentage()")
-            )
-            trends_request["trendFunction"] = trend_function
-
-            # list of requests to send to microservice async
-            trends_requests = []
-
             stats_data = dict(
                 [format_start_end(data) for data in list(stats_data.items()) if data[1] is not None]
             )
 
-            # split the txns data into multiple dictionaries
-            split_transactions_data = [
-                dict(list(stats_data.items())[i : i + EVENTS_PER_QUERY])
-                for i in range(0, len(stats_data), EVENTS_PER_QUERY)
-            ]
+            trend_sort = "" if trend_type == ANY else request.GET.get("sort", "trend_percentage()")
+            trend_function = request.GET.get("trendFunction", "p50()")
 
-            for i in range(len(split_transactions_data)):
-                trends_request = trends_request.copy()
-                trends_request["data"] = split_transactions_data[i]
-                trends_requests.append(trends_request)
+            # list of requests to send to microservice async
+            trends_requests = [
+                {
+                    "data": dict(chunk),
+                    "sort": trend_sort,
+                    "trendFunction": trend_function,
+                }
+                for chunk in chunked(stats_data.items(), EVENTS_PER_QUERY)
+            ]
 
             # send the data to microservice
             results = list(_query_thread_pool.map(detect_breakpoints, trends_requests))
@@ -284,9 +268,9 @@ class OrganizationEventsNewTrendsStatsEndpoint(OrganizationEventsV2EndpointBase)
                 trend_results += output_dict
 
             # sort the results into trending events list
-            if trends_request["sort"] == "trend_percentage()":
+            if trend_sort == "trend_percentage()":
                 trending_events = sorted(trend_results, key=lambda d: d["trend_percentage"])
-            elif trends_request["sort"] == "-trend_percentage()":
+            elif trend_sort == "-trend_percentage()":
                 trending_events = sorted(
                     trend_results, key=lambda d: d["trend_percentage"], reverse=True
                 )

+ 77 - 0
tests/sentry/api/endpoints/test_organization_events_trends_v2.py

@@ -344,3 +344,80 @@ class OrganizationEventsTrendsStatsV2EndpointTest(MetricsAPIBaseTestCase):
 
         assert len(trends_call_args_data.get(f"{project1.id},bar")) > 0
         assert len(trends_call_args_data.get(f"{project2.id},bar")) > 0
+
+    @mock.patch("sentry.api.endpoints.organization_events_trends_v2.detect_breakpoints")
+    @mock.patch("sentry.api.endpoints.organization_events_trends_v2.EVENTS_PER_QUERY", 2)
+    def test_two_projects_same_transaction_split_queries(self, mock_detect_breakpoints):
+        project1 = self.create_project(organization=self.org)
+        project2 = self.create_project(organization=self.org)
+
+        # force these 2 transactions from different projects
+        # to fall into the FIRST bucket when quering
+        for i in range(2):
+            self.store_performance_metric(
+                name=TransactionMRI.DURATION.value,
+                tags={"transaction": "foo bar"},
+                org_id=self.org.id,
+                project_id=project1.id,
+                value=2,
+                hours_before_now=2,
+            )
+            self.store_performance_metric(
+                name=TransactionMRI.DURATION.value,
+                tags={"transaction": '"foo/bar"'},
+                org_id=self.org.id,
+                project_id=project2.id,
+                value=2,
+                hours_before_now=2,
+            )
+        # force these 2 transactions from different projects
+        # to fall into the SECOND bucket when quering
+        self.store_performance_metric(
+            name=TransactionMRI.DURATION.value,
+            tags={"transaction": "foo bar"},
+            org_id=self.org.id,
+            project_id=project2.id,
+            value=2,
+            hours_before_now=2,
+        )
+        self.store_performance_metric(
+            name=TransactionMRI.DURATION.value,
+            tags={"transaction": '"foo/bar"'},
+            org_id=self.org.id,
+            project_id=project1.id,
+            value=2,
+            hours_before_now=2,
+        )
+
+        with self.feature([*self.features, "organizations:global-views"]):
+            response = self.client.get(
+                self.url,
+                format="json",
+                data={
+                    "end": iso_format(self.now),
+                    "start": iso_format(self.now - timedelta(days=1)),
+                    "interval": "1h",
+                    "field": ["project", "transaction"],
+                    "query": "event.type:transaction",
+                    "project": [project1.id, project2.id],
+                    "trendFunction": "p95(transaction.duration)",
+                    "topEvents": 4,
+                    "statsPeriod": "3h",
+                },
+            )
+
+        assert response.status_code == 200, response.content
+
+        trends_call_args_data_1 = mock_detect_breakpoints.call_args_list[0][0][0]["data"]
+        trends_call_args_data_2 = mock_detect_breakpoints.call_args_list[1][0][0]["data"]
+
+        assert len(trends_call_args_data_1[f"{project1.id},foo bar"]) > 0
+        assert len(trends_call_args_data_1[f'{project2.id},"foo/bar"']) > 0
+        assert len(trends_call_args_data_2[f'{project1.id},"foo/bar"']) > 0
+        assert len(trends_call_args_data_2[f"{project2.id},foo bar"]) > 0
+
+        for trends_call_args_data in [trends_call_args_data_1, trends_call_args_data_2]:
+            for k, v in trends_call_args_data.items():
+                for entry in v["data"]:
+                    # each entry should have exactly 1 data point
+                    assert len(entry[1]) == 1