Browse Source

feat(discover-snql): Add Snql support for compare_numeric_agg (#27960)

Add support for compare_numeric_aggregate in SnQL.
Shruthi 3 years ago
parent
commit
c314e1eff5
2 changed files with 95 additions and 3 deletions
  1. 39 3
      src/sentry/search/events/fields.py
  2. 56 0
      tests/sentry/snuba/test_discover.py

+ 39 - 3
src/sentry/search/events/fields.py

@@ -2157,6 +2157,25 @@ for alias, name in FUNCTION_ALIASES.items():
 FUNCTION_ALIAS_PATTERN = re.compile(r"^({}).*".format("|".join(list(FUNCTIONS.keys()))))
 
 
+def normalize_percentile_alias(args: Mapping[str, str]) -> str:
+    # The compare_numeric_aggregate SnQL function accepts a percentile
+    # alias which is resolved to the percentile function call here
+    # to maintain backward compatibility with the legacy compare_numeric_aggregate
+    # function signature. This function only accepts percentile
+    # aliases.
+    aggregate_alias = args["aggregate_alias"]
+    match = re.match(r"(p\d{2,3})_(\w+)", aggregate_alias)
+
+    if not match:
+        raise InvalidFunctionArgument("Aggregate alias must be a percentile function.")
+
+    # Translating an arg of the pattern `measurements_lcp`
+    # to `measurements.lcp`.
+    aggregate_arg = ".".join(match.group(2).split("_"))
+
+    return f"{match.group(1)}({aggregate_arg})"
+
+
 class SnQLFunction(DiscoverFunction):
     def __init__(self, *args, **kwargs):
         self.snql_aggregate = kwargs.pop("snql_aggregate", None)
@@ -2622,6 +2641,26 @@ class QueryFields(QueryBase):
                     optional_args=[IntervalDefault("interval", 1, None)],
                     default_result_type="number",
                 ),
+                SnQLFunction(
+                    "compare_numeric_aggregate",
+                    required_args=[
+                        FunctionAliasArg("aggregate_alias"),
+                        ConditionArg("condition"),
+                        NumberRange("value", 0, None),
+                    ],
+                    calculated_args=[
+                        {
+                            "name": "aggregate_function",
+                            "fn": normalize_percentile_alias,
+                        }
+                    ],
+                    snql_aggregate=lambda args, alias: Function(
+                        args["condition"],
+                        [self.resolve_function(args["aggregate_function"]), args["value"]],
+                        alias,
+                    ),
+                    default_result_type="number",
+                ),
                 # TODO: implement these
                 SnQLFunction("array_join", snql_aggregate=self._resolve_unimplemented_function),
                 SnQLFunction("histogram", snql_aggregate=self._resolve_unimplemented_function),
@@ -2629,9 +2668,6 @@ class QueryFields(QueryBase):
                 SnQLFunction("t_test", snql_aggregate=self._resolve_unimplemented_function),
                 SnQLFunction("minus", snql_aggregate=self._resolve_unimplemented_function),
                 SnQLFunction("absolute_delta", snql_aggregate=self._resolve_unimplemented_function),
-                SnQLFunction(
-                    "compare_numeric_aggregate", snql_aggregate=self._resolve_unimplemented_function
-                ),
             ]
         }
         for alias, name in FUNCTION_ALIASES.items():

+ 56 - 0
tests/sentry/snuba/test_discover.py

@@ -1240,6 +1240,62 @@ class QueryIntegrationTest(SnubaTestCase, TestCase):
                 for index, count in enumerate(data):
                     assert count["count"] == expected_counts[index], use_snql
 
+    def test_compare_numeric_aggregate_function(self):
+        project = self.create_project()
+
+        for i in range(6):
+            data = load_data(
+                "transaction",
+                timestamp=before_now(minutes=3),
+                start_timestamp=before_now(minutes=4 + i),
+            )
+            data["transaction"] = "/percentile"
+            self.store_event(data, project_id=project.id)
+
+        fields = [
+            (
+                [
+                    "transaction",
+                    "p50(measurements.lcp)",
+                    "compare_numeric_aggregate(p50_measurements_lcp,greater,2000)",
+                ],
+                "",
+            ),
+            (
+                [
+                    "transaction",
+                    "p50(measurements.lcp)",
+                    "compare_numeric_aggregate(p50_measurements_lcp,less,2000)",
+                ],
+                "",
+            ),
+        ]
+
+        expected_results = [
+            ("compare_numeric_aggregate_p50_measurements_lcp_greater_2000", 1),
+            ("compare_numeric_aggregate_p50_measurements_lcp_less_2000", 0),
+        ]
+
+        for i, test_case in enumerate(fields):
+            for use_snql in [False, True]:
+                selected, query = test_case
+                result = discover.query(
+                    selected_columns=selected,
+                    query=query,
+                    orderby="transaction",
+                    params={
+                        "start": before_now(minutes=10),
+                        "end": before_now(minutes=2),
+                        "project_id": [project.id],
+                    },
+                    use_aggregate_conditions=True,
+                    use_snql=use_snql,
+                )
+                alias, expected_value = expected_results[i]
+                data = result["data"]
+
+                assert data[0][alias] == expected_value, use_snql
+
     def test_last_seen(self):
         project = self.create_project()