Browse Source

feat(metrics-layer): Support release-health tag resolution (#62355)

### Overview
This PR is responsible for resolving tags appropriately for release
health queries in the metrics layer. Currently, there are two ways to
query snuba in the metrics layer: MQL and SnQL. This code change address
both. For SnQL, all tag filters are hardcoded to the `tags_raw` column.
However, this column does not exist for release health. Instead, it
should use the `tags` column. For MQL, the column resolution (`tags`)
occurs in snuba, as a result, snuba needs to be aware of the mappings of
the tag values. This code change checks if the query is for the
`dataset: metrics`, resolves the tag values, adds them into the
mql_context mappings, and forwards them to snuba to handle.

Dependant on: https://github.com/getsentry/snuba/pull/5255.

### What this PR doesn't include
Result converter which maps resolved integers back into the column
names. This can be added in a separate change.

### Testing
Added release health tests to both MQL and SnQL metrics layer.
Enoch Tang 1 year ago
parent
commit
23ff2de2c3

+ 103 - 37
src/sentry/snuba/metrics_layer/query.py

@@ -33,7 +33,7 @@ from sentry.utils.snuba import bulk_snuba_queries, raw_snql_query
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
-FilterTypes = Union[Column, CurriedFunction, Condition, BooleanCondition]
+FilterTypes = Union[Column, CurriedFunction, Condition, BooleanCondition, str, list]
 
 
 
 
 ALLOWED_GRANULARITIES = [10, 60, 3600, 86400]
 ALLOWED_GRANULARITIES = [10, 60, 3600, 86400]
@@ -172,10 +172,6 @@ def mql_query(request: Request, start: datetime, end: datetime) -> Mapping[str,
         # There are two kinds of resolving: lookup up in the indexer, and resolving things like
         # There are two kinds of resolving: lookup up in the indexer, and resolving things like
         # aggregate_alias, entities and use_case_id.
         # aggregate_alias, entities and use_case_id.
         metrics_query, mappings = _resolve_query_metadata(metrics_query)
         metrics_query, mappings = _resolve_query_metadata(metrics_query)
-        indexer_mappings = _lookup_indexer_resolve(metrics_query)
-        mappings.update(indexer_mappings)
-        request.query = metrics_query.set_indexer_mappings(mappings)
-        request.tenant_ids["use_case_id"] = metrics_query.scope.use_case_id
         # Release health AKA sessions uses a separate Dataset. Change the dataset based on the use case id.
         # Release health AKA sessions uses a separate Dataset. Change the dataset based on the use case id.
         # This is necessary here because the product code that uses this isn't aware of which feature is
         # This is necessary here because the product code that uses this isn't aware of which feature is
         # using it.
         # using it.
@@ -183,6 +179,10 @@ def mql_query(request: Request, start: datetime, end: datetime) -> Mapping[str,
             request.dataset = Dataset.Metrics.value
             request.dataset = Dataset.Metrics.value
         else:
         else:
             request.dataset = Dataset.PerformanceMetrics.value
             request.dataset = Dataset.PerformanceMetrics.value
+        indexer_mappings = _lookup_indexer_resolve(metrics_query, request.dataset)
+        mappings.update(indexer_mappings)
+        request.query = metrics_query.set_indexer_mappings(mappings)
+        request.tenant_ids["use_case_id"] = metrics_query.scope.use_case_id
     except Exception as e:
     except Exception as e:
         metrics.incr(
         metrics.incr(
             "metrics_layer.query",
             "metrics_layer.query",
@@ -204,6 +204,7 @@ def mql_query(request: Request, start: datetime, end: datetime) -> Mapping[str,
         )
         )
         raise e
         raise e
 
 
+    # TODO: Right now, if the query is release health, the tag values in the results are left unresolved. We need to fix this.
     # If we normalized the start/end, return those values in the response so the caller is aware
     # If we normalized the start/end, return those values in the response so the caller is aware
     results = {
     results = {
         **snuba_result,
         **snuba_result,
@@ -356,30 +357,30 @@ GENERIC_ENTITIES = {
 }
 }
 
 
 
 
-def _lookup_indexer_resolve(metrics_query: MetricsQuery) -> Mapping[str, str | int]:
+def _lookup_indexer_resolve(metrics_query: MetricsQuery, dataset: str) -> Mapping[str, str | int]:
     """
     """
     Returns an updated metrics query with all the indexer resolves complete. Also returns a mapping
     Returns an updated metrics query with all the indexer resolves complete. Also returns a mapping
     that shows all the strings that were resolved and what they were resolved too.
     that shows all the strings that were resolved and what they were resolved too.
     """
     """
     org_id = metrics_query.scope.org_ids[0]
     org_id = metrics_query.scope.org_ids[0]
     use_case_id = string_to_use_case_id(metrics_query.scope.use_case_id)
     use_case_id = string_to_use_case_id(metrics_query.scope.use_case_id)
-    return _lookup_indexer_resolve_exp(metrics_query.query, org_id, use_case_id)
+    return _lookup_indexer_resolve_exp(metrics_query.query, org_id, use_case_id, dataset)
 
 
 
 
 def _lookup_indexer_resolve_exp(
 def _lookup_indexer_resolve_exp(
-    exp: Formula | Timeseries, org_id: int, use_case_id: UseCaseID
+    exp: Formula | Timeseries, org_id: int, use_case_id: UseCaseID, dataset: str
 ) -> Mapping[str, str | int]:
 ) -> Mapping[str, str | int]:
     indexer_mappings: dict[str, str | int] = {}
     indexer_mappings: dict[str, str | int] = {}
     new_mappings = _lookup_resolve_groupby(exp.groupby, use_case_id, org_id)
     new_mappings = _lookup_resolve_groupby(exp.groupby, use_case_id, org_id)
     indexer_mappings.update(new_mappings)
     indexer_mappings.update(new_mappings)
-    new_mappings = _lookup_resolve_filters(exp.filters, use_case_id, org_id)
+    new_mappings = _lookup_resolve_filters(exp.filters, use_case_id, org_id, dataset)
     indexer_mappings.update(new_mappings)
     indexer_mappings.update(new_mappings)
 
 
     if isinstance(exp, Formula):
     if isinstance(exp, Formula):
         parameters = exp.parameters
         parameters = exp.parameters
         for i, p in enumerate(parameters):
         for i, p in enumerate(parameters):
             if isinstance(p, (Formula, Timeseries)):
             if isinstance(p, (Formula, Timeseries)):
-                new_mappings = _lookup_indexer_resolve_exp(p, org_id, use_case_id)
+                new_mappings = _lookup_indexer_resolve_exp(p, org_id, use_case_id, dataset)
                 indexer_mappings.update(new_mappings)
                 indexer_mappings.update(new_mappings)
 
 
     return indexer_mappings
     return indexer_mappings
@@ -406,7 +407,7 @@ def _lookup_resolve_groupby(
 
 
 
 
 def _lookup_resolve_filters(
 def _lookup_resolve_filters(
-    filters: list[Condition | BooleanCondition], use_case_id: UseCaseID, org_id: int
+    filters: list[Condition | BooleanCondition], use_case_id: UseCaseID, org_id: int, dataset: str
 ) -> Mapping[str, str | int]:
 ) -> Mapping[str, str | int]:
     """
     """
     Go through the columns in the filter and resolve any that can be resolved.
     Go through the columns in the filter and resolve any that can be resolved.
@@ -418,22 +419,38 @@ def _lookup_resolve_filters(
 
 
     mappings = {}
     mappings = {}
 
 
-    def lookup_resolve_exp(exp: FilterTypes) -> None:
-        if isinstance(exp, Column):
+    def lookup_resolve_exp(exp: FilterTypes, dataset: str) -> None:
+        if dataset == Dataset.Metrics.value and (isinstance(exp, str) or isinstance(exp, list)):
+            if isinstance(exp, str):
+                resolved = resolve_weak(use_case_id, org_id, exp)
+                if resolved > -1:
+                    mappings[exp] = resolved
+            elif isinstance(exp, list):
+                for value in exp:
+                    assert isinstance(value, str)
+                    resolved = resolve_weak(use_case_id, org_id, value)
+                    if resolved > -1:
+                        mappings[value] = resolved
+            else:
+                raise InvalidParams("Invalid filter tag value type")
+        elif isinstance(exp, Column):
             resolved = resolve_weak(use_case_id, org_id, exp.name)
             resolved = resolve_weak(use_case_id, org_id, exp.name)
             if resolved > -1:
             if resolved > -1:
                 mappings[exp.name] = resolved
                 mappings[exp.name] = resolved
         elif isinstance(exp, CurriedFunction):
         elif isinstance(exp, CurriedFunction):
             for p in exp.parameters:
             for p in exp.parameters:
-                lookup_resolve_exp(p)
+                lookup_resolve_exp(p, dataset)
         elif isinstance(exp, BooleanCondition):
         elif isinstance(exp, BooleanCondition):
             for c in exp.conditions:
             for c in exp.conditions:
-                lookup_resolve_exp(c)
+                lookup_resolve_exp(c, dataset)
         elif isinstance(exp, Condition):
         elif isinstance(exp, Condition):
-            lookup_resolve_exp(exp.lhs)
+            lookup_resolve_exp(exp.lhs, dataset)
+            # If the dataset is metrics, then we need to resolve the tag values as well
+            if dataset == Dataset.Metrics.value:
+                lookup_resolve_exp(exp.rhs, dataset)
 
 
     for exp in filters:
     for exp in filters:
-        lookup_resolve_exp(exp)
+        lookup_resolve_exp(exp, dataset)
     return mappings
     return mappings
 
 
 
 
@@ -451,7 +468,7 @@ def snql_query(request: Request, start: datetime, end: datetime) -> Mapping[str,
     try:
     try:
         # Replace any aggregate aliases with the appropriate aggregate
         # Replace any aggregate aliases with the appropriate aggregate
         metrics_query = metrics_query.set_query(_resolve_aggregate_aliases(metrics_query.query))
         metrics_query = metrics_query.set_query(_resolve_aggregate_aliases(metrics_query.query))
-        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query)
+        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query, request.dataset)
         request.query = resolved_metrics_query.set_indexer_mappings(mappings)
         request.query = resolved_metrics_query.set_indexer_mappings(mappings)
         request.tenant_ids["use_case_id"] = resolved_metrics_query.scope.use_case_id
         request.tenant_ids["use_case_id"] = resolved_metrics_query.scope.use_case_id
         # Release health AKA sessions uses a separate Dataset. Change the dataset based on the use case id.
         # Release health AKA sessions uses a separate Dataset. Change the dataset based on the use case id.
@@ -459,6 +476,8 @@ def snql_query(request: Request, start: datetime, end: datetime) -> Mapping[str,
         # using it.
         # using it.
         if resolved_metrics_query.scope.use_case_id == UseCaseID.SESSIONS.value:
         if resolved_metrics_query.scope.use_case_id == UseCaseID.SESSIONS.value:
             request.dataset = Dataset.Metrics.value
             request.dataset = Dataset.Metrics.value
+        else:
+            request.dataset = Dataset.PerformanceMetrics.value
 
 
     except Exception as e:
     except Exception as e:
         metrics.incr(
         metrics.incr(
@@ -476,6 +495,8 @@ def snql_query(request: Request, start: datetime, end: datetime) -> Mapping[str,
         )
         )
         raise e
         raise e
 
 
+    # TODO: Right now, if the query is release health, the tag values in the results are left unresolved. We need to fix this.
+
     # If we normalized the start/end, return those values in the response so the caller is aware
     # If we normalized the start/end, return those values in the response so the caller is aware
     results = {
     results = {
         **snuba_results,
         **snuba_results,
@@ -555,7 +576,7 @@ def _resolve_formula_metrics(
 
 
 
 
 def _resolve_metrics_query(
 def _resolve_metrics_query(
-    metrics_query: MetricsQuery,
+    metrics_query: MetricsQuery, dataset: str
 ) -> tuple[MetricsQuery, Mapping[str, str | int]]:
 ) -> tuple[MetricsQuery, Mapping[str, str | int]]:
     """
     """
     Returns an updated metrics query with all the indexer resolves complete. Also returns a mapping
     Returns an updated metrics query with all the indexer resolves complete. Also returns a mapping
@@ -572,7 +593,17 @@ def _resolve_metrics_query(
     use_case_id = string_to_use_case_id(use_case_id_str)
     use_case_id = string_to_use_case_id(use_case_id_str)
     metrics_query, mappings = _resolve_query_metrics(metrics_query, use_case_id, org_id)
     metrics_query, mappings = _resolve_query_metrics(metrics_query, use_case_id, org_id)
 
 
-    new_groupby, new_mappings = _resolve_groupby(metrics_query.query.groupby, use_case_id, org_id)
+    # Release health AKA sessions uses a separate Dataset. Change the dataset based on the use case id.
+    # This is necessary here because the product code that uses this isn't aware of which feature is
+    # using it.
+    if metrics_query.scope.use_case_id == UseCaseID.SESSIONS.value:
+        dataset = Dataset.Metrics.value
+    else:
+        dataset = Dataset.PerformanceMetrics.value
+
+    new_groupby, new_mappings = _resolve_groupby(
+        metrics_query.query.groupby, use_case_id, org_id, dataset
+    )
     metrics_query = metrics_query.set_query(metrics_query.query.set_groupby(new_groupby))
     metrics_query = metrics_query.set_query(metrics_query.query.set_groupby(new_groupby))
     mappings.update(new_mappings)
     mappings.update(new_mappings)
 
 
@@ -580,13 +611,17 @@ def _resolve_metrics_query(
         parameters = metrics_query.query.parameters
         parameters = metrics_query.query.parameters
         for i, p in enumerate(parameters):
         for i, p in enumerate(parameters):
             if isinstance(p, Timeseries):
             if isinstance(p, Timeseries):
-                new_groupby, new_mappings = _resolve_groupby(p.groupby, use_case_id, org_id)
+                new_groupby, new_mappings = _resolve_groupby(
+                    p.groupby, use_case_id, org_id, dataset
+                )
                 parameters[i] = p.set_groupby(new_groupby)
                 parameters[i] = p.set_groupby(new_groupby)
                 mappings.update(new_mappings)
                 mappings.update(new_mappings)
 
 
         metrics_query = metrics_query.set_query(metrics_query.query.set_parameters(parameters))
         metrics_query = metrics_query.set_query(metrics_query.query.set_parameters(parameters))
 
 
-    new_filters, new_mappings = _resolve_filters(metrics_query.query.filters, use_case_id, org_id)
+    new_filters, new_mappings = _resolve_filters(
+        metrics_query.query.filters, use_case_id, org_id, dataset
+    )
     metrics_query = metrics_query.set_query(metrics_query.query.set_filters(new_filters))
     metrics_query = metrics_query.set_query(metrics_query.query.set_filters(new_filters))
     mappings.update(new_mappings)
     mappings.update(new_mappings)
 
 
@@ -594,7 +629,9 @@ def _resolve_metrics_query(
         parameters = metrics_query.query.parameters
         parameters = metrics_query.query.parameters
         for i, p in enumerate(parameters):
         for i, p in enumerate(parameters):
             if isinstance(p, Timeseries):
             if isinstance(p, Timeseries):
-                new_filters, new_mappings = _resolve_filters(p.filters, use_case_id, org_id)
+                new_filters, new_mappings = _resolve_filters(
+                    p.filters, use_case_id, org_id, dataset
+                )
                 parameters[i] = p.set_filters(new_filters)
                 parameters[i] = p.set_filters(new_filters)
                 mappings.update(new_mappings)
                 mappings.update(new_mappings)
 
 
@@ -604,7 +641,7 @@ def _resolve_metrics_query(
 
 
 
 
 def _resolve_groupby(
 def _resolve_groupby(
-    groupby: list[Column] | None, use_case_id: UseCaseID, org_id: int
+    groupby: list[Column] | None, use_case_id: UseCaseID, org_id: int, dataset: str
 ) -> tuple[list[Column] | None, Mapping[str, int]]:
 ) -> tuple[list[Column] | None, Mapping[str, int]]:
     """
     """
     Go through the groupby columns and resolve any that need to be resolved.
     Go through the groupby columns and resolve any that need to be resolved.
@@ -619,11 +656,16 @@ def _resolve_groupby(
     for col in groupby:
     for col in groupby:
         resolved = resolve_weak(use_case_id, org_id, col.name)
         resolved = resolve_weak(use_case_id, org_id, col.name)
         if resolved > -1:
         if resolved > -1:
-            # TODO: This currently assumes the use of `tags_raw` but that might not always be correct
-            # It also doesn't take into account mapping indexed tag values back to their original values
-            new_groupby.append(
-                AliasedExpression(exp=replace(col, name=f"tags_raw[{resolved}]"), alias=col.name)
-            )
+            if dataset == Dataset.Metrics.value:
+                new_groupby.append(
+                    AliasedExpression(exp=replace(col, name=f"tags[{resolved}]"), alias=col.name)
+                )
+            else:
+                new_groupby.append(
+                    AliasedExpression(
+                        exp=replace(col, name=f"tags_raw[{resolved}]"), alias=col.name
+                    )
+                )
             mappings[col.name] = resolved
             mappings[col.name] = resolved
         else:
         else:
             new_groupby.append(col)
             new_groupby.append(col)
@@ -632,7 +674,7 @@ def _resolve_groupby(
 
 
 
 
 def _resolve_filters(
 def _resolve_filters(
-    filters: list[Condition | BooleanCondition], use_case_id: UseCaseID, org_id: int
+    filters: list[Condition | BooleanCondition], use_case_id: UseCaseID, org_id: int, dataset: str
 ) -> tuple[list[Condition | BooleanCondition] | None, Mapping[str, int]]:
 ) -> tuple[list[Condition | BooleanCondition] | None, Mapping[str, int]]:
     """
     """
     Go through the columns in the filter and resolve any that can be resolved.
     Go through the columns in the filter and resolve any that can be resolved.
@@ -644,19 +686,43 @@ def _resolve_filters(
 
 
     mappings = {}
     mappings = {}
 
 
-    def resolve_exp(exp: FilterTypes) -> FilterTypes:
-        if isinstance(exp, Column):
+    def resolve_exp(exp: FilterTypes, dataset: str) -> FilterTypes:
+        if dataset == Dataset.Metrics.value and (isinstance(exp, str) or isinstance(exp, list)):
+            if isinstance(exp, str):
+                resolved = resolve_weak(use_case_id, org_id, exp)
+                if resolved > -1:
+                    mappings[exp] = resolved
+                    return resolved
+            elif isinstance(exp, list):
+                resolved_values: list[int] = []
+                for value in exp:
+                    assert isinstance(value, str)
+                    resolved = resolve_weak(use_case_id, org_id, value)
+                    if resolved > -1:
+                        resolved_values.append(resolved)
+                        mappings[value] = resolved
+                    return resolved_values
+            else:
+                raise InvalidParams("Invalid filter tag value type")
+        elif isinstance(exp, Column):
             resolved = resolve_weak(use_case_id, org_id, exp.name)
             resolved = resolve_weak(use_case_id, org_id, exp.name)
             if resolved > -1:
             if resolved > -1:
                 mappings[exp.name] = resolved
                 mappings[exp.name] = resolved
-                return replace(exp, name=f"tags_raw[{resolved}]")
+                if dataset == Dataset.Metrics.value:
+                    return replace(exp, name=f"tags[{resolved}]")
+                else:
+                    return replace(exp, name=f"tags_raw[{resolved}]")
         elif isinstance(exp, CurriedFunction):
         elif isinstance(exp, CurriedFunction):
-            return replace(exp, parameters=[resolve_exp(p) for p in exp.parameters])
+            return replace(exp, parameters=[resolve_exp(p, dataset) for p in exp.parameters])
         elif isinstance(exp, BooleanCondition):
         elif isinstance(exp, BooleanCondition):
-            return replace(exp, conditions=[resolve_exp(c) for c in exp.conditions])
+            return replace(exp, conditions=[resolve_exp(c, dataset) for c in exp.conditions])
         elif isinstance(exp, Condition):
         elif isinstance(exp, Condition):
-            return replace(exp, lhs=resolve_exp(exp.lhs))
+            exp = replace(exp, lhs=resolve_exp(exp.lhs, dataset))
+            # If the dataset is metrics, then we need to resolve the tag values as well
+            if dataset == Dataset.Metrics.value:
+                exp = replace(exp, rhs=resolve_exp(exp.rhs, dataset))
+            return exp
         return exp
         return exp
 
 
-    new_filters = [resolve_exp(exp) for exp in filters]
+    new_filters = [resolve_exp(exp, dataset) for exp in filters]
     return new_filters, mappings
     return new_filters, mappings

+ 6 - 6
tests/sentry/snuba/metrics/test_metrics_query_layer/test_metrics_query_layer.py

@@ -51,7 +51,7 @@ class MetricsQueryLayerTest(BaseMetricsLayerTestCase, TestCase):
             ),
             ),
         )
         )
 
 
-        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query)
+        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query, "generic_metrics")
         expected_metric_id = indexer.resolve(
         expected_metric_id = indexer.resolve(
             UseCaseID.TRANSACTIONS,
             UseCaseID.TRANSACTIONS,
             self.project.organization_id,
             self.project.organization_id,
@@ -82,7 +82,7 @@ class MetricsQueryLayerTest(BaseMetricsLayerTestCase, TestCase):
             ),
             ),
         )
         )
 
 
-        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query)
+        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query, "generic_metrics")
         expected_metric_id = indexer.resolve(
         expected_metric_id = indexer.resolve(
             UseCaseID.TRANSACTIONS,
             UseCaseID.TRANSACTIONS,
             self.project.organization_id,
             self.project.organization_id,
@@ -121,7 +121,7 @@ class MetricsQueryLayerTest(BaseMetricsLayerTestCase, TestCase):
             "transaction",
             "transaction",
         )
         )
 
 
-        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query)
+        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query, "generic_metrics")
         assert resolved_metrics_query.query.metric.public_name == "transaction.duration"
         assert resolved_metrics_query.query.metric.public_name == "transaction.duration"
         assert resolved_metrics_query.query.metric.mri == TransactionMRI.DURATION.value
         assert resolved_metrics_query.query.metric.mri == TransactionMRI.DURATION.value
         assert resolved_metrics_query.query.metric.id == expected_metric_id
         assert resolved_metrics_query.query.metric.id == expected_metric_id
@@ -177,7 +177,7 @@ class MetricsQueryLayerTest(BaseMetricsLayerTestCase, TestCase):
             "status_code",
             "status_code",
         )
         )
 
 
-        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query)
+        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query, "generic_metrics")
         assert (
         assert (
             resolved_metrics_query.query.parameters[0].metric.public_name == "transaction.duration"
             resolved_metrics_query.query.parameters[0].metric.public_name == "transaction.duration"
         )
         )
@@ -242,7 +242,7 @@ class MetricsQueryLayerTest(BaseMetricsLayerTestCase, TestCase):
             "device",
             "device",
         )
         )
 
 
-        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query)
+        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query, "generic_metrics")
         assert resolved_metrics_query.query.metric.id == expected_metric_id
         assert resolved_metrics_query.query.metric.id == expected_metric_id
         assert resolved_metrics_query.query.filters == [
         assert resolved_metrics_query.query.filters == [
             Condition(Column(f"tags_raw[{expected_transaction_id}]"), Op.EQ, "/checkout"),
             Condition(Column(f"tags_raw[{expected_transaction_id}]"), Op.EQ, "/checkout"),
@@ -326,7 +326,7 @@ class MetricsQueryLayerTest(BaseMetricsLayerTestCase, TestCase):
             "status_code",
             "status_code",
         )
         )
 
 
-        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query)
+        resolved_metrics_query, mappings = _resolve_metrics_query(metrics_query, "generic_metrics")
         assert resolved_metrics_query.query.parameters[0].metric.id == expected_metric_id
         assert resolved_metrics_query.query.parameters[0].metric.id == expected_metric_id
         assert resolved_metrics_query.query.parameters[0].filters == [
         assert resolved_metrics_query.query.parameters[0].filters == [
             Condition(Column(f"tags_raw[{expected_transaction_id}]"), Op.EQ, "/checkout"),
             Condition(Column(f"tags_raw[{expected_transaction_id}]"), Op.EQ, "/checkout"),

+ 126 - 7
tests/snuba/test_metrics_layer.py

@@ -48,16 +48,21 @@ class MQLTest(TestCase, BaseMetricsTestCase):
     def setUp(self) -> None:
     def setUp(self) -> None:
         super().setUp()
         super().setUp()
 
 
-        self.metrics: Mapping[str, Literal["counter", "set", "distribution", "gauge"]] = {
+        self.generic_metrics: Mapping[str, Literal["counter", "set", "distribution", "gauge"]] = {
             TransactionMRI.DURATION.value: "distribution",
             TransactionMRI.DURATION.value: "distribution",
             TransactionMRI.USER.value: "set",
             TransactionMRI.USER.value: "set",
             TransactionMRI.COUNT_PER_ROOT_PROJECT.value: "counter",
             TransactionMRI.COUNT_PER_ROOT_PROJECT.value: "counter",
             "g:transactions/test_gauge@none": "gauge",
             "g:transactions/test_gauge@none": "gauge",
         }
         }
+        self.metrics: Mapping[str, Literal["counter", "set", "distribution"]] = {
+            SessionMRI.RAW_DURATION.value: "distribution",
+            SessionMRI.RAW_USER.value: "set",
+            SessionMRI.RAW_SESSION.value: "counter",
+        }
         self.now = datetime.now(tz=timezone.utc).replace(microsecond=0)
         self.now = datetime.now(tz=timezone.utc).replace(microsecond=0)
         self.hour_ago = self.now - timedelta(hours=1)
         self.hour_ago = self.now - timedelta(hours=1)
         self.org_id = self.project.organization_id
         self.org_id = self.project.organization_id
-        for mri, metric_type in self.metrics.items():
+        for mri, metric_type in self.generic_metrics.items():
             assert metric_type in {"counter", "distribution", "set", "gauge"}
             assert metric_type in {"counter", "distribution", "set", "gauge"}
             for i in range(10):
             for i in range(10):
                 value: int | dict[str, int]
                 value: int | dict[str, int]
@@ -85,8 +90,24 @@ class MQLTest(TestCase, BaseMetricsTestCase):
                     value,
                     value,
                     UseCaseID.TRANSACTIONS,
                     UseCaseID.TRANSACTIONS,
                 )
                 )
+        for mri, metric_type in self.metrics.items():
+            assert metric_type in {"counter", "distribution", "set"}
+            for i in range(10):
+                value = i
+                self.store_metric(
+                    self.org_id,
+                    self.project.id,
+                    metric_type,
+                    mri,
+                    {
+                        "release": "release_even" if i % 2 == 0 else "release_odd",
+                    },
+                    self.ts(self.hour_ago + timedelta(minutes=1 * i)),
+                    value,
+                    UseCaseID.SESSIONS,
+                )
 
 
-    def test_basic(self) -> None:
+    def test_basic_generic_metrics(self) -> None:
         query = MetricsQuery(
         query = MetricsQuery(
             query=Timeseries(
             query=Timeseries(
                 metric=Metric(
                 metric=Metric(
@@ -123,7 +144,7 @@ class MQLTest(TestCase, BaseMetricsTestCase):
                 ).isoformat()
                 ).isoformat()
             )
             )
 
 
-    def test_groupby(self) -> None:
+    def test_groupby_generic_metrics(self) -> None:
         query = MetricsQuery(
         query = MetricsQuery(
             query=Timeseries(
             query=Timeseries(
                 metric=Metric(
                 metric=Metric(
@@ -163,7 +184,7 @@ class MQLTest(TestCase, BaseMetricsTestCase):
                 ).isoformat()
                 ).isoformat()
             )
             )
 
 
-    def test_filters(self) -> None:
+    def test_filters_generic_metrics(self) -> None:
         query = MetricsQuery(
         query = MetricsQuery(
             query=Timeseries(
             query=Timeseries(
                 metric=Metric(
                 metric=Metric(
@@ -199,7 +220,7 @@ class MQLTest(TestCase, BaseMetricsTestCase):
         assert rows[0]["aggregate_value"] == [0]
         assert rows[0]["aggregate_value"] == [0]
         assert rows[1]["aggregate_value"] == [6.0]
         assert rows[1]["aggregate_value"] == [6.0]
 
 
-    def test_complex(self) -> None:
+    def test_complex_generic_metrics(self) -> None:
         query = MetricsQuery(
         query = MetricsQuery(
             query=Timeseries(
             query=Timeseries(
                 metric=Metric(
                 metric=Metric(
@@ -423,7 +444,7 @@ class MQLTest(TestCase, BaseMetricsTestCase):
         )
         )
         result = self.run_query(request)
         result = self.run_query(request)
         assert request.dataset == "metrics"
         assert request.dataset == "metrics"
-        assert len(result["data"]) == 0
+        assert len(result["data"]) == 10
 
 
     def test_gauges(self) -> None:
     def test_gauges(self) -> None:
         query = MetricsQuery(
         query = MetricsQuery(
@@ -454,6 +475,104 @@ class MQLTest(TestCase, BaseMetricsTestCase):
         assert len(result["data"]) == 10
         assert len(result["data"]) == 10
         assert result["totals"]["aggregate_value"] == 9.0
         assert result["totals"]["aggregate_value"] == 9.0
 
 
+    def test_groupby_metrics(self) -> None:
+        query = MetricsQuery(
+            query=Timeseries(
+                metric=Metric(
+                    None,
+                    SessionMRI.RAW_DURATION.value,
+                ),
+                aggregate="max",
+                groupby=[Column("release")],
+            ),
+            start=self.hour_ago,
+            end=self.now,
+            rollup=Rollup(interval=60, granularity=60),
+            scope=MetricsScope(
+                org_ids=[self.org_id],
+                project_ids=[self.project.id],
+                use_case_id=UseCaseID.SESSIONS.value,
+            ),
+        )
+
+        request = Request(
+            dataset="metrics",
+            app_id="tests",
+            query=query,
+            tenant_ids={"referrer": "metrics.testing.test", "organization_id": self.org_id},
+        )
+        result = self.run_query(request)
+        assert request.dataset == "metrics"
+        assert len(result["data"]) == 10
+        # TODO: check reverse resolved tags
+
+    def test_filters_metrics(self) -> None:
+        query = MetricsQuery(
+            query=Timeseries(
+                metric=Metric(
+                    None,
+                    SessionMRI.RAW_USER.value,
+                ),
+                aggregate="count",
+                filters=[
+                    Condition(Column("release"), Op.EQ, "release_even"),
+                ],
+            ),
+            start=self.hour_ago,
+            end=self.now,
+            rollup=Rollup(interval=60, granularity=60),
+            scope=MetricsScope(
+                org_ids=[self.org_id],
+                project_ids=[self.project.id],
+                use_case_id=UseCaseID.SESSIONS.value,
+            ),
+        )
+
+        request = Request(
+            dataset="metrics",
+            app_id="tests",
+            query=query,
+            tenant_ids={"referrer": "metrics.testing.test", "organization_id": self.org_id},
+        )
+        result = self.run_query(request)
+        assert request.dataset == "metrics"
+        assert len(result["data"]) == 5
+        # TODO: check reverse resolved tags
+
+    def test_complex_metrics(self) -> None:
+        query = MetricsQuery(
+            query=Timeseries(
+                metric=Metric(
+                    None,
+                    SessionMRI.RAW_SESSION.value,
+                ),
+                aggregate="count",
+                groupby=[Column("release")],
+                filters=[
+                    Condition(Column("release"), Op.EQ, "release_even"),
+                ],
+            ),
+            start=self.hour_ago,
+            end=self.now,
+            rollup=Rollup(interval=60, granularity=60),
+            scope=MetricsScope(
+                org_ids=[self.org_id],
+                project_ids=[self.project.id],
+                use_case_id=UseCaseID.SESSIONS.value,
+            ),
+        )
+
+        request = Request(
+            dataset="metrics",
+            app_id="tests",
+            query=query,
+            tenant_ids={"referrer": "metrics.testing.test", "organization_id": self.org_id},
+        )
+        result = self.run_query(request)
+        assert request.dataset == "metrics"
+        assert len(result["data"]) == 5
+        # TODO: check reverse resolved tags
+
     @pytest.mark.skip(reason="This is not implemented in MQL")
     @pytest.mark.skip(reason="This is not implemented in MQL")
     def test_failure_rate(self) -> None:
     def test_failure_rate(self) -> None:
         query = MetricsQuery(
         query = MetricsQuery(