Browse Source

ref(metrics): Update metrics layer to support Formulas (#59911)

The metrics layer needed to be updated to properly resolve Formula
queries,
since those can contain more than one Timeseries.
Evan Hicks 1 year ago
parent
commit
eab0e8b5e8

+ 125 - 50
src/sentry/snuba/metrics_layer/query.py

@@ -10,19 +10,17 @@ from snuba_sdk import (
     Column,
     Condition,
     CurriedFunction,
+    Formula,
     MetricsQuery,
     Request,
+    Timeseries,
 )
 
 from sentry.api.utils import InvalidParams
 from sentry.sentry_metrics.use_case_id_registry import UseCaseID
 from sentry.sentry_metrics.utils import resolve_weak, string_to_use_case_id
 from sentry.snuba.dataset import Dataset, EntityKey
-from sentry.snuba.metrics.naming_layer.mapping import (
-    get_mri,
-    get_public_name_from_mri,
-    is_private_mri,
-)
+from sentry.snuba.metrics.naming_layer.mapping import get_mri
 from sentry.snuba.metrics.naming_layer.mri import parse_mri
 from sentry.snuba.metrics.utils import to_intervals
 from sentry.utils import metrics
@@ -50,6 +48,12 @@ def run_query(request: Request) -> Mapping[str, Any]:
     metrics_query = request.query
     assert isinstance(metrics_query, MetricsQuery)
 
+    # Currently we don't support nested Formula queries. Check to make sure that is what is being passed in.
+    # TODO: This should be removed once we fully support Formulas.
+    if isinstance(metrics_query.query, Formula):
+        if any(isinstance(p, Formula) for p in metrics_query.query.parameters):
+            raise InvalidParams("Nested formulas are not supported")
+
     assert len(metrics_query.scope.org_ids) == 1  # Initially only allow 1 org id
     organization_id = metrics_query.scope.org_ids[0]
     tenant_ids = request.tenant_ids or {"organization_id": organization_id}
@@ -132,14 +136,38 @@ GENERIC_ENTITIES = {
 
 def _resolve_use_case_id_str(metrics_query: MetricsQuery) -> str:
     # Automatically resolve the use_case_id if it is not provided
-    # TODO: At the moment only a single Timeseries is allowed. In the future this will need to find
-    # all the Timeseries and ensure they all have the same use case.
-    mri = metrics_query.query.metric.mri
-    parsed_mri = parse_mri(mri)
-    if parsed_mri is None:
-        raise InvalidParams(f"'{mri}' is not a valid MRI")
+    if isinstance(metrics_query.query, Timeseries):
+        if metrics_query.query.metric.mri is None:
+            mri = get_mri(metrics_query.query.metric.public_name)
+        else:
+            mri = metrics_query.query.metric.mri
+        parsed_mri = parse_mri(mri)
+        if parsed_mri is None:
+            raise InvalidParams(f"'{mri}' is not a valid MRI")
+
+        return parsed_mri.namespace
 
-    return parsed_mri.namespace
+    assert isinstance(metrics_query.query, Formula)
+    namespaces = set()
+    for p in metrics_query.query.parameters:
+        if isinstance(p, Timeseries):
+            if p.metric.mri is None:
+                mri = get_mri(p.metric.public_name)
+            else:
+                mri = p.metric.mri
+
+            parsed_mri = parse_mri(mri)
+            if parsed_mri is None:
+                raise InvalidParams(f"'{mri}' is not a valid MRI")
+
+            namespaces.add(parsed_mri.namespace)
+
+    if not namespaces:
+        raise InvalidParams("No use case found in formula parameters")
+    if len(namespaces) > 1:
+        raise InvalidParams("Formula parameters must all be from the same use case")
+
+    return namespaces.pop()
 
 
 def _resolve_metrics_entity(mri: str) -> EntityKey:
@@ -188,32 +216,78 @@ def _resolve_granularity(start: datetime, end: datetime, interval: int | None) -
     return min(found_granularities)
 
 
-def _resolve_metrics_query(
+def _resolve_query_metrics(
     metrics_query: MetricsQuery,
-) -> tuple[MetricsQuery, Mapping[str, str | int]]:
+    use_case_id: UseCaseID,
+    org_id: int,
+) -> tuple[MetricsQuery, dict[str, str | int]]:
     """
-    Returns an updated metrics query with all the indexer resolves complete. Also returns a mapping
-    that shows all the strings that were resolved and what they were resolved too.
+    Resolves all the fields of the Metric in the query. Public name -> MRI -> ID -> Entity.
+    Returns a mapping dictionary that shows any resolving that the function did.
     """
     assert metrics_query.query is not None
-    metric = metrics_query.query.metric
+    if isinstance(metrics_query.query, Timeseries):
+        series, mappings = _resolve_timeseries_metric(metrics_query.query, use_case_id, org_id)
+        metrics_query = metrics_query.set_query(series)
+    elif isinstance(metrics_query.query, Formula):
+        formula, mappings = _resolve_formula_metrics(metrics_query.query, use_case_id, org_id)
+        metrics_query = metrics_query.set_query(formula)
+
+    return metrics_query, mappings
+
+
+def _resolve_timeseries_metric(
+    series: Timeseries, use_case_id: UseCaseID, org_id: int
+) -> tuple[Timeseries, dict[str, str | int]]:
+    metric = series.metric
     mappings: dict[str, str | int] = {}
-    if not metric.public_name and metric.mri:
-        if not is_private_mri(metric.mri):
-            public_name = get_public_name_from_mri(metric.mri)
-            metrics_query = metrics_query.set_query(
-                metrics_query.query.set_metric(
-                    metrics_query.query.metric.set_public_name(public_name)
-                )
-            )
-            mappings[public_name] = metric.mri
-    elif not metric.mri and metric.public_name:
+    if not metric.mri and not metric.public_name:
+        raise InvalidParams("Metric must have either an MRI or a public name")
+
+    if not metric.mri and metric.public_name:
         mri = get_mri(metric.public_name)
-        metrics_query = metrics_query.set_query(
-            metrics_query.query.set_metric(metrics_query.query.metric.set_mri(mri))
-        )
+        metric = metric.set_mri(mri)
         mappings[metric.public_name] = mri
 
+    if metric.id is None:
+        metric_id = resolve_weak(
+            use_case_id, org_id, metric.mri
+        )  # only support raw metrics for now
+        metric = metric.set_id(metric_id)
+        mappings[metric.mri] = metric_id
+
+    if not metric.entity:
+        entity = _resolve_metrics_entity(metric.mri)
+        metric = metric.set_entity(entity.value)
+
+    series = series.set_metric(metric)
+    return series, mappings
+
+
+def _resolve_formula_metrics(
+    formula: Formula, use_case_id: UseCaseID, org_id: int
+) -> tuple[Formula, dict[str, str | int]]:
+    # TODO: This will eventually need to recursively resolve Formulas as Formula becomes a valid paramaeter
+    parameters = formula.parameters
+    formula_mappings = {}
+    for i, p in enumerate(parameters):
+        if isinstance(p, Timeseries):
+            series, mappings = _resolve_timeseries_metric(p, use_case_id, org_id)
+            parameters[i] = series
+            formula_mappings.update(mappings)
+
+    formula = formula.set_parameters(parameters)
+    return formula, mappings
+
+
+def _resolve_metrics_query(
+    metrics_query: MetricsQuery,
+) -> tuple[MetricsQuery, Mapping[str, str | int]]:
+    """
+    Returns an updated metrics query with all the indexer resolves complete. Also returns a mapping
+    that shows all the strings that were resolved and what they were resolved too.
+    """
+
     org_id = metrics_query.scope.org_ids[0]
     use_case_id_str = _resolve_use_case_id_str(metrics_query)
     if metrics_query.scope.use_case_id is None:
@@ -222,35 +296,36 @@ def _resolve_metrics_query(
         )
 
     use_case_id = string_to_use_case_id(use_case_id_str)
-    if metrics_query.query.metric.id is None:
-        metric_id = resolve_weak(
-            use_case_id, org_id, metrics_query.query.metric.mri
-        )  # only support raw metrics for now
-
-        metrics_query = metrics_query.set_query(
-            metrics_query.query.set_metric(metrics_query.query.metric.set_id(metric_id))
-        )
-    else:
-        metric_id = metrics_query.query.metric.id
-
-    mappings[metrics_query.query.metric.mri] = metric_id
+    metrics_query, mappings = _resolve_query_metrics(metrics_query, use_case_id, org_id)
 
-    if not metrics_query.query.metric.entity:
-        entity = _resolve_metrics_entity(metrics_query.query.metric.mri)
-        metrics_query = metrics_query.set_query(
-            metrics_query.query.set_metric(metrics_query.query.metric.set_entity(entity.value))
-        )
-
-    # TODO: Once we support formula queries, we would need to resolve groupby and filters recursively given a Formula object
-    # For now, metrics_query.query will only ever be a Timeseries
     new_groupby, new_mappings = _resolve_groupby(metrics_query.query.groupby, use_case_id, org_id)
     metrics_query = metrics_query.set_query(metrics_query.query.set_groupby(new_groupby))
     mappings.update(new_mappings)
 
+    if isinstance(metrics_query.query, Formula):
+        parameters = metrics_query.query.parameters
+        for i, p in enumerate(parameters):
+            if isinstance(p, Timeseries):
+                new_groupby, new_mappings = _resolve_groupby(p.groupby, use_case_id, org_id)
+                parameters[i] = p.set_groupby(new_groupby)
+                mappings.update(new_mappings)
+
+        metrics_query = metrics_query.set_query(metrics_query.query.set_parameters(parameters))
+
     new_filters, new_mappings = _resolve_filters(metrics_query.query.filters, use_case_id, org_id)
     metrics_query = metrics_query.set_query(metrics_query.query.set_filters(new_filters))
     mappings.update(new_mappings)
 
+    if isinstance(metrics_query.query, Formula):
+        parameters = metrics_query.query.parameters
+        for i, p in enumerate(parameters):
+            if isinstance(p, Timeseries):
+                new_filters, new_mappings = _resolve_filters(p.filters, use_case_id, org_id)
+                parameters[i] = p.set_filters(new_filters)
+                mappings.update(new_mappings)
+
+        metrics_query = metrics_query.set_query(metrics_query.query.set_parameters(parameters))
+
     return metrics_query, mappings
 
 

+ 199 - 3
tests/sentry/snuba/metrics/test_metrics_query_layer/test_metrics_query_layer.py

@@ -7,8 +7,10 @@ from datetime import datetime, timedelta
 import pytest
 from snuba_sdk import (
     AliasedExpression,
+    ArithmeticOperator,
     Column,
     Condition,
+    Formula,
     Metric,
     MetricsQuery,
     MetricsScope,
@@ -50,7 +52,6 @@ class MetricsQueryLayerTest(BaseMetricsLayerTestCase, TestCase):
         )
 
         resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query)
-        assert resolved_metrics_query.query.metric.public_name == "transaction.duration"
         expected_metric_id = indexer.resolve(
             UseCaseID.TRANSACTIONS,
             self.project.organization_id,
@@ -59,6 +60,37 @@ class MetricsQueryLayerTest(BaseMetricsLayerTestCase, TestCase):
         assert resolved_metrics_query.query.metric.id == expected_metric_id
         assert mappings[TransactionMRI.DURATION.value] == expected_metric_id
 
+    def test_resolve_formula_metrics_query(self):
+        self.store_performance_metric(
+            name=TransactionMRI.DURATION.value,
+            project_id=self.project.id,
+            tags={},
+            value=1,
+        )
+        metrics_query = MetricsQuery(
+            query=Formula(
+                ArithmeticOperator.PLUS,
+                [
+                    Timeseries(Metric(mri=TransactionMRI.DURATION.value), aggregate="count"),
+                    Timeseries(Metric(mri=TransactionMRI.DURATION.value), aggregate="count"),
+                ],
+            ),
+            scope=MetricsScope(
+                org_ids=[self.project.organization_id],
+                project_ids=[self.project.id],
+                use_case_id=UseCaseID.TRANSACTIONS.value,
+            ),
+        )
+
+        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query)
+        expected_metric_id = indexer.resolve(
+            UseCaseID.TRANSACTIONS,
+            self.project.organization_id,
+            TransactionMRI.DURATION.value,
+        )
+        assert resolved_metrics_query.query.parameters[0].metric.id == expected_metric_id
+        assert mappings[TransactionMRI.DURATION.value] == expected_metric_id
+
     def test_resolve_metrics_query_with_groupby(self):
         self.store_performance_metric(
             name=TransactionMRI.DURATION.value,
@@ -68,7 +100,7 @@ class MetricsQueryLayerTest(BaseMetricsLayerTestCase, TestCase):
         )
         metrics_query = MetricsQuery(
             query=Timeseries(
-                Metric(mri=TransactionMRI.DURATION.value),
+                Metric(public_name="transaction.duration"),
                 aggregate="count",
                 groupby=[Column("transaction")],
             ),
@@ -91,6 +123,7 @@ class MetricsQueryLayerTest(BaseMetricsLayerTestCase, TestCase):
 
         resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query)
         assert resolved_metrics_query.query.metric.public_name == "transaction.duration"
+        assert resolved_metrics_query.query.metric.mri == TransactionMRI.DURATION.value
         assert resolved_metrics_query.query.metric.id == expected_metric_id
         assert resolved_metrics_query.query.groupby == [
             AliasedExpression(Column(f"tags_raw[{expected_transaction_id}]"), "transaction")
@@ -98,6 +131,73 @@ class MetricsQueryLayerTest(BaseMetricsLayerTestCase, TestCase):
         assert mappings[TransactionMRI.DURATION.value] == expected_metric_id
         assert mappings["transaction"] == expected_transaction_id
 
+    def test_resolve_formula_metrics_query_with_groupby(self):
+        self.store_performance_metric(
+            name=TransactionMRI.DURATION.value,
+            project_id=self.project.id,
+            tags={"transaction": "/checkout", "status_code": "200"},
+            value=1,
+        )
+        metrics_query = MetricsQuery(
+            query=Formula(
+                ArithmeticOperator.PLUS,
+                [
+                    Timeseries(
+                        Metric(public_name="transaction.duration"),
+                        aggregate="count",
+                        groupby=[Column("transaction")],
+                    ),
+                    Timeseries(
+                        Metric(mri=TransactionMRI.DURATION.value),
+                        aggregate="count",
+                        groupby=[Column("transaction")],
+                    ),
+                ],
+                groupby=[Column("status_code")],
+            ),
+            scope=MetricsScope(
+                org_ids=[self.project.organization_id],
+                project_ids=[self.project.id],
+                use_case_id=UseCaseID.TRANSACTIONS.value,
+            ),
+        )
+        expected_metric_id = indexer.resolve(
+            UseCaseID.TRANSACTIONS,
+            self.project.organization_id,
+            TransactionMRI.DURATION.value,
+        )
+        expected_transaction_id = indexer.resolve(
+            UseCaseID.TRANSACTIONS,
+            self.project.organization_id,
+            "transaction",
+        )
+        expected_status_code_id = indexer.resolve(
+            UseCaseID.TRANSACTIONS,
+            self.project.organization_id,
+            "status_code",
+        )
+
+        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query)
+        assert (
+            resolved_metrics_query.query.parameters[0].metric.public_name == "transaction.duration"
+        )
+        assert (
+            resolved_metrics_query.query.parameters[0].metric.mri == TransactionMRI.DURATION.value
+        )
+        assert resolved_metrics_query.query.parameters[0].metric.id == expected_metric_id
+        assert resolved_metrics_query.query.parameters[0].groupby == [
+            AliasedExpression(Column(f"tags_raw[{expected_transaction_id}]"), "transaction")
+        ]
+        assert resolved_metrics_query.query.parameters[1].groupby == [
+            AliasedExpression(Column(f"tags_raw[{expected_transaction_id}]"), "transaction")
+        ]
+        assert resolved_metrics_query.query.groupby == [
+            AliasedExpression(Column(f"tags_raw[{expected_status_code_id}]"), "status_code")
+        ]
+        assert mappings[TransactionMRI.DURATION.value] == expected_metric_id
+        assert mappings["transaction"] == expected_transaction_id
+        assert mappings["status_code"] == expected_status_code_id
+
     def test_resolve_metrics_query_with_filters(self):
         self.store_performance_metric(
             name=TransactionMRI.DURATION.value,
@@ -143,7 +243,6 @@ class MetricsQueryLayerTest(BaseMetricsLayerTestCase, TestCase):
         )
 
         resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query)
-        assert resolved_metrics_query.query.metric.public_name == "transaction.duration"
         assert resolved_metrics_query.query.metric.id == expected_metric_id
         assert resolved_metrics_query.query.filters == [
             Condition(Column(f"tags_raw[{expected_transaction_id}]"), Op.EQ, "/checkout"),
@@ -158,6 +257,103 @@ class MetricsQueryLayerTest(BaseMetricsLayerTestCase, TestCase):
         assert mappings["transaction"] == expected_transaction_id
         assert mappings["device"] == expected_device_id
 
+    def test_resolve_formula_metrics_query_with_filters(self):
+        self.store_performance_metric(
+            name=TransactionMRI.DURATION.value,
+            project_id=self.project.id,
+            tags={"transaction": "/checkout", "device": "BlackBerry", "status_code": "200"},
+            value=1,
+        )
+        metrics_query = MetricsQuery(
+            query=Formula(
+                ArithmeticOperator.PLUS,
+                [
+                    Timeseries(
+                        Metric(mri=TransactionMRI.DURATION.value),
+                        aggregate="count",
+                        filters=[
+                            Condition(Column("transaction"), Op.EQ, "/checkout"),
+                            Or(
+                                [
+                                    Condition(Column("device"), Op.EQ, "BlackBerry"),
+                                    Condition(Column("device"), Op.EQ, "Nokia"),
+                                ]
+                            ),
+                        ],
+                        groupby=[Column("transaction")],
+                    ),
+                    Timeseries(
+                        Metric(mri=TransactionMRI.DURATION.value),
+                        aggregate="count",
+                        filters=[
+                            Condition(Column("transaction"), Op.EQ, "/cart"),
+                            Or(
+                                [
+                                    Condition(Column("device"), Op.EQ, "Android"),
+                                    Condition(Column("device"), Op.EQ, "Palm"),
+                                ]
+                            ),
+                        ],
+                        groupby=[Column("transaction")],
+                    ),
+                ],
+                filters=[Condition(Column("status_code"), Op.EQ, "200")],
+            ),
+            scope=MetricsScope(
+                org_ids=[self.project.organization_id],
+                project_ids=[self.project.id],
+                use_case_id=UseCaseID.TRANSACTIONS.value,
+            ),
+        )
+        expected_metric_id = indexer.resolve(
+            UseCaseID.TRANSACTIONS,
+            self.project.organization_id,
+            TransactionMRI.DURATION.value,
+        )
+        expected_transaction_id = indexer.resolve(
+            UseCaseID.TRANSACTIONS,
+            self.project.organization_id,
+            "transaction",
+        )
+        expected_device_id = indexer.resolve(
+            UseCaseID.TRANSACTIONS,
+            self.project.organization_id,
+            "device",
+        )
+        expected_status_code_id = indexer.resolve(
+            UseCaseID.TRANSACTIONS,
+            self.project.organization_id,
+            "status_code",
+        )
+
+        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query)
+        assert resolved_metrics_query.query.parameters[0].metric.id == expected_metric_id
+        assert resolved_metrics_query.query.parameters[0].filters == [
+            Condition(Column(f"tags_raw[{expected_transaction_id}]"), Op.EQ, "/checkout"),
+            Or(
+                [
+                    Condition(Column(f"tags_raw[{expected_device_id}]"), Op.EQ, "BlackBerry"),
+                    Condition(Column(f"tags_raw[{expected_device_id}]"), Op.EQ, "Nokia"),
+                ]
+            ),
+        ]
+        assert resolved_metrics_query.query.parameters[1].filters == [
+            Condition(Column(f"tags_raw[{expected_transaction_id}]"), Op.EQ, "/cart"),
+            Or(
+                [
+                    Condition(Column(f"tags_raw[{expected_device_id}]"), Op.EQ, "Android"),
+                    Condition(Column(f"tags_raw[{expected_device_id}]"), Op.EQ, "Palm"),
+                ]
+            ),
+        ]
+        assert resolved_metrics_query.query.filters == [
+            Condition(Column(f"tags_raw[{expected_status_code_id}]"), Op.EQ, "200"),
+        ]
+        assert mappings[TransactionMRI.DURATION.value] == expected_metric_id
+        assert mappings["transaction"] == expected_transaction_id
+        assert mappings["device"] == expected_device_id
+        assert mappings["status_code"] == expected_status_code_id
+
 
 @pytest.mark.parametrize(
     "day_range, sec_offset, interval, expected",

+ 53 - 0
tests/snuba/test_metrics_layer.py

@@ -5,9 +5,11 @@ from typing import Literal, Mapping
 
 import pytest
 from snuba_sdk import (
+    ArithmeticOperator,
     Column,
     Condition,
     Direction,
+    Formula,
     Metric,
     MetricsQuery,
     MetricsScope,
@@ -20,6 +22,7 @@ from snuba_sdk import (
 from sentry.api.utils import InvalidParams
 from sentry.sentry_metrics.use_case_id_registry import UseCaseID
 from sentry.snuba.metrics.naming_layer import SessionMRI, TransactionMRI
+from sentry.snuba.metrics.naming_layer.public import TransactionStatusTagValue, TransactionTagsKey
 from sentry.snuba.metrics_layer.query import run_query
 from sentry.testutils.cases import BaseMetricsTestCase, TestCase
 
@@ -450,3 +453,53 @@ class SnQLTest(TestCase, BaseMetricsTestCase):
 
         assert len(result["data"]) == 61
         assert result["totals"]["aggregate_value"] == 60
+
+    def test_failure_rate(self) -> None:
+        query = MetricsQuery(
+            query=Formula(
+                ArithmeticOperator.DIVIDE,
+                [
+                    Timeseries(
+                        metric=Metric(
+                            mri=TransactionMRI.DURATION.value,
+                        ),
+                        aggregate="count",
+                        filters=[
+                            Condition(
+                                Column(TransactionTagsKey.TRANSACTION_STATUS.value),
+                                Op.NOT_IN,
+                                [
+                                    TransactionStatusTagValue.OK.value,
+                                    TransactionStatusTagValue.CANCELLED.value,
+                                    TransactionStatusTagValue.UNKNOWN.value,
+                                ],
+                            )
+                        ],
+                    ),
+                    Timeseries(
+                        metric=Metric(
+                            mri=TransactionMRI.DURATION.value,
+                        ),
+                        aggregate="count",
+                    ),
+                ],
+            ),
+            start=self.hour_ago,
+            end=self.now,
+            rollup=Rollup(interval=60, totals=True, granularity=60),
+            scope=MetricsScope(
+                org_ids=[self.org_id],
+                project_ids=[self.project.id],
+            ),
+        )
+
+        request = Request(
+            dataset="generic_metrics",
+            app_id="tests",
+            query=query,
+            tenant_ids={"referrer": "metrics.testing.test", "organization_id": self.org_id},
+        )
+        result = run_query(request)
+
+        assert len(result["data"]) == 61
+        assert result["totals"]["aggregate_value"] == 1.0