Browse Source

Revert "fix(metrics): Revert dlq changes" (#59121)

Reverts getsentry/sentry#59073
John 1 year ago
parent
commit
c26e2de7a3

+ 12 - 2
bin/send_metrics.py

@@ -149,7 +149,13 @@ def produce_msgs(messages, is_generic, host, dryrun):
     show_default=True,
     help="Specify which org id(s) to send",
 )
-def main(use_cases, rand_str, host, dryrun, org_id):
+@click.option(
+    "--num-bad-msg",
+    default=0,
+    show_default=True,
+    help="Number of additional badly formatted metric messages to send",
+)
+def main(use_cases, rand_str, host, dryrun, org_id, num_bad_msg):
     if UseCaseID.SESSIONS.value in use_cases and len(use_cases) > 1:
         click.secho(
             "ERROR: UseCaseID.SESSIONS is in use_cases and there are more than 1 use cases",
@@ -158,9 +164,10 @@ def main(use_cases, rand_str, host, dryrun, org_id):
         )
         exit(1)
 
+    rand_str = rand_str or "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
+
     is_generic = UseCaseID.SESSIONS.value not in use_cases
 
-    rand_str = rand_str or "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
     messages = list(
         itertools.chain.from_iterable(
             (
@@ -172,6 +179,9 @@ def main(use_cases, rand_str, host, dryrun, org_id):
             for org in org_id
         )
     )
+
+    messages.extend([{"BAD_VALUE": rand_str, "idx": i} for i in range(num_bad_msg)])
+
     random.shuffle(messages)
 
     produce_msgs(messages, is_generic, host, dryrun)

+ 139 - 131
src/sentry/sentry_metrics/consumers/indexer/batch.py

@@ -1,6 +1,6 @@
 import logging
 import random
-from collections import defaultdict
+from collections import defaultdict, deque
 from typing import (
     Any,
     Callable,
@@ -8,7 +8,6 @@ from typing import (
     Mapping,
     MutableMapping,
     MutableSequence,
-    NamedTuple,
     Optional,
     Sequence,
     Set,
@@ -23,14 +22,17 @@ from arroyo.backends.kafka import KafkaPayload
 from arroyo.types import BrokerValue, Message
 from django.conf import settings
 from sentry_kafka_schemas.codecs import Codec, ValidationError
-from sentry_kafka_schemas.schema_types.ingest_metrics_v1 import IngestMetric
 from sentry_kafka_schemas.schema_types.snuba_generic_metrics_v1 import GenericMetric
 from sentry_kafka_schemas.schema_types.snuba_metrics_v1 import Metric
 
 from sentry import options
 from sentry.sentry_metrics.aggregation_option_registry import get_aggregation_option
 from sentry.sentry_metrics.configuration import MAX_INDEXED_COLUMN_LENGTH
-from sentry.sentry_metrics.consumers.indexer.common import IndexerOutputMessageBatch, MessageBatch
+from sentry.sentry_metrics.consumers.indexer.common import (
+    BrokerMeta,
+    IndexerOutputMessageBatch,
+    MessageBatch,
+)
 from sentry.sentry_metrics.consumers.indexer.parsed_message import ParsedMessage
 from sentry.sentry_metrics.consumers.indexer.routing_producer import RoutingPayload
 from sentry.sentry_metrics.indexer.base import Metadata
@@ -49,11 +51,6 @@ OrgId = int
 Headers = MutableSequence[Tuple[str, bytes]]
 
 
-class PartitionIdxOffset(NamedTuple):
-    partition_idx: int
-    offset: int
-
-
 def valid_metric_name(name: Optional[str]) -> bool:
     if name is None:
         return False
@@ -87,105 +84,161 @@ class IndexerBatch:
         self.__message_size_sum: MutableMapping[UseCaseID, int] = defaultdict(int)
         self.__message_size_max: MutableMapping[UseCaseID, int] = defaultdict(int)
 
-        self._extract_messages()
+        # Invalid messages and filtered messages are both skipped during processing
+        # (reconstruct_messages), but we want to put the invalid messages into the
+        # DLQ while discarding the filtered messages
+        self.invalid_msg_meta: Set[BrokerMeta] = set()
+        self.filtered_msg_meta: Set[BrokerMeta] = set()
+        self.parsed_payloads_by_meta: MutableMapping[BrokerMeta, ParsedMessage] = {}
 
-    def _extract_namespace(self, headers: Headers) -> Optional[str]:
-        for string, endcoded in headers:
-            if string == "namespace":
-                return endcoded.decode("utf-8")
-        metrics.incr("sentry-metrics.indexer.killswitch.no-namespace-in-header")
-        return None
+        self._extract_messages()
 
     @metrics.wraps("process_messages.extract_messages")
     def _extract_messages(self) -> None:
-        self.skipped_offsets: Set[PartitionIdxOffset] = set()
-        self.parsed_payloads_by_offset: MutableMapping[PartitionIdxOffset, ParsedMessage] = {}
-
-        disabled_msgs_cnt: MutableMapping[str, int] = defaultdict(int)
+        """
+        For each messages:
+        1. Check the header to see if the use case ID is disabled
+        2. Parse the raw bytes into ParsedMessage (_extract_message)
+        3. Semantically validate the content of ParsedMessage (_validate_message)
+
+        We track the offset and partition of the message that are filtered or
+        invalid so later we can:
+        - Produce the invalid messages to DLQ
+        - Skip those filtered/invalid message from the indexing phase
+        (extract_strings and reconstruct_messages)
+        """
+        skipped_msgs_cnt: MutableMapping[str, int] = defaultdict(int)
 
         for msg in self.outer_message.payload:
+
             assert isinstance(msg.value, BrokerValue)
-            partition_offset = PartitionIdxOffset(msg.value.partition.index, msg.value.offset)
+            broker_meta = BrokerMeta(msg.value.partition, msg.value.offset)
 
             if (namespace := self._extract_namespace(msg.payload.headers)) in options.get(
                 "sentry-metrics.indexer.disabled-namespaces"
             ):
                 assert namespace
-                self.skipped_offsets.add(partition_offset)
-                disabled_msgs_cnt[namespace] += 1
+                skipped_msgs_cnt[namespace] += 1
+                self.filtered_msg_meta.add(broker_meta)
                 continue
-            try:
-                parsed_payload: ParsedMessage = json.loads(
-                    msg.payload.value.decode("utf-8"), use_rapid_json=True
-                )
-            except rapidjson.JSONDecodeError:
-                self.skipped_offsets.add(partition_offset)
-                logger.error(
-                    "process_messages.invalid_json",
-                    extra={"payload_value": str(msg.payload.value)},
-                    exc_info=True,
-                )
-                continue
-            try:
-                if self.__input_codec:
-                    self.__input_codec.validate(parsed_payload)
-            except ValidationError:
-                if settings.SENTRY_METRICS_INDEXER_RAISE_VALIDATION_ERRORS:
-                    raise
-
-                # For now while this is still experimental, those errors are
-                # not supposed to be fatal.
-                logger.warning(
-                    "process_messages.invalid_schema",
-                    extra={"payload_value": str(msg.payload.value)},
-                    exc_info=True,
-                )
 
             try:
-                parsed_payload["use_case_id"] = use_case_id = extract_use_case_id(
-                    parsed_payload["name"]
-                )
-            except ValidationError:
-                self.skipped_offsets.add(partition_offset)
+                parsed_payload = self._extract_message(msg)
+                self._validate_message(parsed_payload)
+                self.parsed_payloads_by_meta[broker_meta] = parsed_payload
+            except Exception as e:
+                self.invalid_msg_meta.add(broker_meta)
                 logger.error(
-                    "process_messages.invalid_metric_resource_identifier",
+                    e,
                     extra={"payload_value": str(msg.payload.value)},
                     exc_info=True,
                 )
-                continue
 
-            self.__message_count[use_case_id] += 1
-            self.__message_size_max[use_case_id] = max(
-                len(msg.payload.value), self.__message_size_max[use_case_id]
-            )
-            self.__message_size_sum[use_case_id] += len(msg.payload.value)
-
-            # Ensure that the parsed_payload can be cast back to to
-            # IngestMetric. If there are any schema changes, this check would
-            # fail and ParsedMessage needs to be adjusted to be a superset of
-            # IngestMetric again.
-            _: IngestMetric = parsed_payload
-
-            self.parsed_payloads_by_offset[partition_offset] = parsed_payload
-
-        for namespace, cnt in disabled_msgs_cnt.items():
+        for namespace, cnt in skipped_msgs_cnt.items():
             metrics.incr(
                 "process_messages.namespace_disabled",
                 amount=cnt,
                 tags={"namespace": namespace},
             )
 
+    def _extract_message(
+        self,
+        msg: Message[KafkaPayload],
+    ) -> ParsedMessage:
+        assert isinstance(msg.value, BrokerValue)
+        try:
+            parsed_payload: ParsedMessage = json.loads(
+                msg.payload.value.decode("utf-8"), use_rapid_json=True
+            )
+        except rapidjson.JSONDecodeError:
+            logger.error(
+                "process_messages.invalid_json",
+                extra={"payload_value": str(msg.payload.value)},
+                exc_info=True,
+            )
+            raise
+        try:
+            if self.__input_codec:
+                self.__input_codec.validate(parsed_payload)
+        except ValidationError:
+            if settings.SENTRY_METRICS_INDEXER_RAISE_VALIDATION_ERRORS:
+                raise
+            logger.warning(
+                "process_messages.invalid_schema",
+                extra={"payload_value": str(msg.payload.value)},
+                exc_info=True,
+            )
+        parsed_payload["use_case_id"] = use_case_id = extract_use_case_id(parsed_payload["name"])
+
+        self.__message_count[use_case_id] += 1
+        self.__message_size_max[use_case_id] = max(
+            len(msg.payload.value), self.__message_size_max[use_case_id]
+        )
+        self.__message_size_sum[use_case_id] += len(msg.payload.value)
+
+        return parsed_payload
+
+    def _extract_namespace(self, headers: Headers) -> Optional[str]:
+        for string, endcoded in headers:
+            if string == "namespace":
+                return endcoded.decode("utf-8")
+        metrics.incr("sentry-metrics.indexer.killswitch.no-namespace-in-header")
+        return None
+
+    def _validate_message(self, parsed_payload: ParsedMessage) -> None:
+        metric_name = parsed_payload["name"]
+        metric_type = parsed_payload["type"]
+        use_case_id = parsed_payload["use_case_id"]
+        org_id = parsed_payload["org_id"]
+        tags = parsed_payload.get("tags", {})
+
+        if not valid_metric_name(metric_name):
+            logger.error(
+                "process_messages.invalid_metric_name",
+                extra={
+                    "use_case_id": use_case_id,
+                    "org_id": org_id,
+                    "metric_name": metric_name,
+                },
+            )
+            raise ValueError(f"Invalid metric name: {metric_name}")
+
+        if metric_type not in ACCEPTED_METRIC_TYPES:
+            logger.error(
+                "process_messages.invalid_metric_type",
+                extra={
+                    "use_case_id": use_case_id,
+                    "org_id": org_id,
+                    "metric_type": metric_type,
+                },
+            )
+            raise ValueError(f"Invalid metric type: {metric_type}")
+
+        if self.tags_validator(tags) is False:
+            # sentry doesn't seem to actually capture nested logger.error extra args
+            sentry_sdk.set_extra("all_metric_tags", tags)
+            logger.error(
+                "process_messages.invalid_tags",
+                extra={
+                    "use_case_id": use_case_id,
+                    "org_id": org_id,
+                    "metric_name": metric_name,
+                    "tags": tags,
+                },
+            )
+            raise ValueError(f"Invalid metric tags: {tags}")
+
     @metrics.wraps("process_messages.filter_messages")
-    def filter_messages(self, keys_to_remove: Sequence[PartitionIdxOffset]) -> None:
+    def filter_messages(self, keys_to_remove: Sequence[BrokerMeta]) -> None:
         # XXX: it is useful to be able to get a sample of organization ids that are affected by rate limits, but this is really slow.
-        for offset in keys_to_remove:
+        for broker_meta in keys_to_remove:
             if _should_sample_debug_log():
                 sentry_sdk.set_tag(
                     "sentry_metrics.organization_id",
-                    self.parsed_payloads_by_offset[offset]["org_id"],
+                    self.parsed_payloads_by_meta[broker_meta]["org_id"],
                 )
                 sentry_sdk.set_tag(
-                    "sentry_metrics.metric_name", self.parsed_payloads_by_offset[offset]["name"]
+                    "sentry_metrics.metric_name", self.parsed_payloads_by_meta[broker_meta]["name"]
                 )
                 logger.error(
                     "process_messages.dropped_message",
@@ -194,7 +247,7 @@ class IndexerBatch:
                     },
                 )
 
-        self.skipped_offsets.update(keys_to_remove)
+        self.filtered_msg_meta.update(keys_to_remove)
 
     @metrics.wraps("process_messages.extract_strings")
     def extract_strings(self) -> Mapping[UseCaseID, Mapping[OrgId, Set[str]]]:
@@ -202,62 +255,15 @@ class IndexerBatch:
             lambda: defaultdict(set)
         )
 
-        for partition_offset, message in self.parsed_payloads_by_offset.items():
-            if partition_offset in self.skipped_offsets:
+        for broker_meta, message in self.parsed_payloads_by_meta.items():
+            if broker_meta in self.invalid_msg_meta or broker_meta in self.filtered_msg_meta:
                 continue
 
-            partition_idx, offset = partition_offset
-
             metric_name = message["name"]
-            metric_type = message["type"]
             use_case_id = message["use_case_id"]
             org_id = message["org_id"]
             tags = message.get("tags", {})
 
-            if not valid_metric_name(metric_name):
-                logger.error(
-                    "process_messages.invalid_metric_name",
-                    extra={
-                        "use_case_id": use_case_id,
-                        "org_id": org_id,
-                        "metric_name": metric_name,
-                        "partition": partition_idx,
-                        "offset": offset,
-                    },
-                )
-                self.skipped_offsets.add(partition_offset)
-                continue
-
-            if metric_type not in ACCEPTED_METRIC_TYPES:
-                logger.error(
-                    "process_messages.invalid_metric_type",
-                    extra={
-                        "use_case_id": use_case_id,
-                        "org_id": org_id,
-                        "metric_type": metric_type,
-                        "offset": offset,
-                    },
-                )
-                self.skipped_offsets.add(partition_offset)
-                continue
-
-            if self.tags_validator(tags) is False:
-                # sentry doesn't seem to actually capture nested logger.error extra args
-                sentry_sdk.set_extra("all_metric_tags", tags)
-                logger.error(
-                    "process_messages.invalid_tags",
-                    extra={
-                        "use_case_id": use_case_id,
-                        "org_id": org_id,
-                        "metric_name": metric_name,
-                        "tags": tags,
-                        "partition": partition_idx,
-                        "offset": offset,
-                    },
-                )
-                self.skipped_offsets.add(partition_offset)
-                continue
-
             strings_in_message = {
                 metric_name,
                 *tags.keys(),
@@ -290,12 +296,10 @@ class IndexerBatch:
             used_tags: Set[str] = set()
             output_message_meta: Dict[str, Dict[str, str]] = defaultdict(dict)
             assert isinstance(message.value, BrokerValue)
-            partition_offset = PartitionIdxOffset(
-                message.value.partition.index, message.value.offset
-            )
-            if partition_offset in self.skipped_offsets:
+            broker_meta = BrokerMeta(message.value.partition, message.value.offset)
+            if broker_meta in self.invalid_msg_meta or broker_meta in self.filtered_msg_meta:
                 continue
-            old_payload_value = self.parsed_payloads_by_offset.pop(partition_offset)
+            old_payload_value = self.parsed_payloads_by_meta.pop(broker_meta)
 
             metric_name = old_payload_value["name"]
             org_id = old_payload_value["org_id"]
@@ -490,4 +494,8 @@ class IndexerBatch:
                 self.__message_size_max[use_case_id],
                 tags={"use_case_id": use_case_id.value},
             )
-        return IndexerOutputMessageBatch(new_messages, cogs_usage)
+        return IndexerOutputMessageBatch(
+            new_messages,
+            deque(sorted(self.invalid_msg_meta)),
+            cogs_usage,
+        )

+ 20 - 2
src/sentry/sentry_metrics/consumers/indexer/common.py

@@ -1,8 +1,19 @@
 import logging
 import time
 from dataclasses import dataclass
-from typing import Any, List, Mapping, MutableMapping, MutableSequence, Optional, Union
-
+from typing import (
+    Any,
+    Deque,
+    List,
+    Mapping,
+    MutableMapping,
+    MutableSequence,
+    NamedTuple,
+    Optional,
+    Union,
+)
+
+from arroyo import Partition
 from arroyo.backends.kafka import KafkaPayload
 from arroyo.backends.kafka.configuration import build_kafka_consumer_configuration
 from arroyo.processing.strategies import MessageRejected
@@ -14,6 +25,12 @@ from sentry.sentry_metrics.consumers.indexer.routing_producer import RoutingPayl
 from sentry.sentry_metrics.use_case_id_registry import UseCaseID
 from sentry.utils import kafka_config, metrics
 
+
+class BrokerMeta(NamedTuple):
+    partition: Partition
+    offset: int
+
+
 MessageBatch = List[Message[KafkaPayload]]
 
 logger = logging.getLogger(__name__)
@@ -25,6 +42,7 @@ DEFAULT_QUEUED_MIN_MESSAGES = 100000
 @dataclass(frozen=True)
 class IndexerOutputMessageBatch:
     data: MutableSequence[Message[Union[RoutingPayload, KafkaPayload]]]
+    invalid_msg_meta: Deque[BrokerMeta]
     cogs_data: Mapping[UseCaseID, int]
 
 

+ 10 - 3
src/sentry/sentry_metrics/consumers/indexer/multiprocess.py

@@ -1,12 +1,12 @@
 import logging
 import time
 from functools import partial
-from typing import Any, Mapping, MutableMapping, Optional
+from typing import Any, Mapping, MutableMapping, Optional, Union
 
 from arroyo.backends.abstract import Producer as AbstractProducer
 from arroyo.backends.kafka import KafkaPayload
 from arroyo.processing.strategies import ProcessingStrategy as ProcessingStep
-from arroyo.types import Commit, Message, Partition
+from arroyo.types import Commit, FilteredPayload, Message, Partition
 from confluent_kafka import Producer
 
 from sentry.utils import kafka_config, metrics
@@ -68,7 +68,14 @@ class SimpleProduceStep(ProcessingStep[KafkaPayload]):
             self.__commit_function(self.__produced_message_offsets)
             self.__produced_message_offsets = {}
 
-    def submit(self, message: Message[KafkaPayload]) -> None:
+    def submit(self, message: Message[Union[KafkaPayload, FilteredPayload]]) -> None:
+        if isinstance(message.payload, FilteredPayload):
+            # FilteredPayload will not be commited, this may cause the the indexer to consume
+            # and produce invalid message to the DLQ twice if the last messages it consume
+            # are invalid and is then shutdown. But it will never produce valid messages
+            # twice to snuba
+            # TODO: Use the arroyo producer which handles FilteredPayload elegantly
+            return
         self.__producer.produce(
             topic=self.__producer_topic,
             key=None,

+ 18 - 4
src/sentry/sentry_metrics/consumers/indexer/parallel.py

@@ -2,10 +2,12 @@ from __future__ import annotations
 
 import functools
 import logging
-from typing import Any, Mapping, Optional, Union, cast
+from collections import deque
+from typing import Any, Deque, Mapping, NamedTuple, Optional, Union, cast
 
 from arroyo.backends.kafka import KafkaConsumer, KafkaPayload
 from arroyo.commit import ONCE_PER_SECOND
+from arroyo.dlq import InvalidMessage
 from arroyo.processing import StreamProcessor
 from arroyo.processing.strategies import ProcessingStrategy
 from arroyo.processing.strategies import ProcessingStrategy as ProcessingStep
@@ -39,19 +41,31 @@ logger = logging.getLogger(__name__)
 class Unbatcher(ProcessingStep[Union[FilteredPayload, IndexerOutputMessageBatch]]):
     def __init__(
         self,
-        next_step: ProcessingStep[Union[KafkaPayload, RoutingPayload]],
+        next_step: ProcessingStep[Union[FilteredPayload, KafkaPayload, RoutingPayload]],
     ) -> None:
         self.__next_step = next_step
         self.__closed = False
+        self._invalid_msg_meta: Deque[NamedTuple] = deque()
 
     def poll(self) -> None:
+        if self._invalid_msg_meta:
+            partition, offset = self._invalid_msg_meta.popleft()
+            raise InvalidMessage(partition, offset)
+
         self.__next_step.poll()
 
     def submit(self, message: Message[Union[FilteredPayload, IndexerOutputMessageBatch]]) -> None:
         assert not self.__closed
 
-        # FilteredPayloads are not handled in the indexer
-        for transformed_message in cast(IndexerOutputMessageBatch, message.payload).data:
+        if isinstance(message.payload, FilteredPayload):
+            self.__next_step.submit(cast(Message[KafkaPayload], message))
+            return
+
+        self._invalid_msg_meta.extend(message.payload.invalid_msg_meta)
+
+        _ = message.payload.cogs_data
+
+        for transformed_message in message.payload.data:
             self.__next_step.submit(transformed_message)
 
     def close(self) -> None:

+ 18 - 10
src/sentry/sentry_metrics/consumers/indexer/processing.py

@@ -76,22 +76,30 @@ class MessageProcessor:
         outer_message: Message[MessageBatch],
     ) -> IndexerOutputMessageBatch:
         """
-        We have an outer_message Message() whose payload is a batch of Message() objects.
-
+        We have an outer_message which contains a collection of Message() objects.
+        Each of them represents a single message/metric on kafka.
             Message(
-                partition=...,
-                offset=...
-                timestamp=...
                 payload=[Message(...), Message(...), etc]
             )
 
         The inner messages payloads are KafkaPayload's that have:
+            * kafka meta data (partition/offsets)
             * key
             * headers
             * value
 
         The value of the message is what we need to parse and then translate
         using the indexer.
+
+        We create an IndexerBatch object to:
+
+        1. Parse and validate the inner messages from a sequence of bytes into
+           Python objects (initalization)
+        2. Filter messages (filter_messages)
+        3. Create a collection of all the strings that needs to to be indexed
+        (extract_strings)
+        4. Take a mapping of string -> int (indexed strings), and replace all of
+           the messages strings into ints
         """
         should_index_tag_values = self._config.should_index_tag_values
         is_output_sliced = self._config.is_output_sliced or False
@@ -104,14 +112,14 @@ class MessageProcessor:
             tags_validator=self.__get_tags_validator(),
         )
 
-        sdk.set_measurement("indexer_batch.payloads.len", len(batch.parsed_payloads_by_offset))
+        sdk.set_measurement("indexer_batch.payloads.len", len(batch.parsed_payloads_by_meta))
 
         with metrics.timer("metrics_consumer.check_cardinality_limits"), sentry_sdk.start_span(
             op="check_cardinality_limits"
         ):
             cardinality_limiter = cardinality_limiter_factory.get_ratelimiter(self._config)
             cardinality_limiter_state = cardinality_limiter.check_cardinality_limits(
-                self._config.use_case_id, batch.parsed_payloads_by_offset
+                self._config.use_case_id, batch.parsed_payloads_by_meta
             )
 
         sdk.set_measurement(
@@ -129,9 +137,9 @@ class MessageProcessor:
         mapping = record_result.get_mapped_results()
         bulk_record_meta = record_result.get_fetch_metadata()
 
-        new_messages = batch.reconstruct_messages(mapping, bulk_record_meta)
+        results = batch.reconstruct_messages(mapping, bulk_record_meta)
 
-        sdk.set_measurement("new_messages.len", len(new_messages.data))
+        sdk.set_measurement("new_messages.len", len(results.data))
 
         with metrics.timer("metrics_consumer.apply_cardinality_limits"), sentry_sdk.start_span(
             op="apply_cardinality_limits"
@@ -139,4 +147,4 @@ class MessageProcessor:
             # TODO: move to separate thread
             cardinality_limiter.apply_cardinality_limits(cardinality_limiter_state)
 
-        return new_messages
+        return results

+ 8 - 8
src/sentry/sentry_metrics/indexer/limiters/cardinality.py

@@ -14,7 +14,7 @@ from sentry.ratelimits.cardinality import (
     Timestamp,
 )
 from sentry.sentry_metrics.configuration import MetricsIngestConfiguration, UseCaseKey
-from sentry.sentry_metrics.consumers.indexer.batch import PartitionIdxOffset
+from sentry.sentry_metrics.consumers.indexer.common import BrokerMeta
 from sentry.sentry_metrics.use_case_id_registry import (
     USE_CASE_ID_CARDINALITY_LIMIT_QUOTA_OPTIONS,
     UseCaseID,
@@ -32,7 +32,7 @@ class CardinalityLimiterState:
     _metric_path_key: UseCaseKey
     _grants: Optional[Sequence[GrantedQuota]]
     _timestamp: Optional[Timestamp]
-    keys_to_remove: Sequence[PartitionIdxOffset]
+    keys_to_remove: Sequence[BrokerMeta]
 
 
 def _build_quota_key(use_case_id: UseCaseID, org_id: OrgId) -> str:
@@ -80,10 +80,10 @@ class TimeseriesCardinalityLimiter:
         self.backend: CardinalityLimiter = rate_limiter
 
     def check_cardinality_limits(
-        self, metric_path_key: UseCaseKey, messages: Mapping[PartitionIdxOffset, InboundMessage]
+        self, metric_path_key: UseCaseKey, messages: Mapping[BrokerMeta, InboundMessage]
     ) -> CardinalityLimiterState:
         request_hashes = defaultdict(set)
-        hash_to_offset: Mapping[str, Dict[int, PartitionIdxOffset]] = defaultdict(dict)
+        hash_to_meta: Mapping[str, Dict[int, BrokerMeta]] = defaultdict(dict)
         prefix_to_quota = {}
 
         # this works by applying one cardinality limiter rollout option
@@ -109,7 +109,7 @@ class TimeseriesCardinalityLimiter:
                 16,
             )
             prefix = _build_quota_key(message["use_case_id"], org_id)
-            hash_to_offset[prefix][message_hash] = key
+            hash_to_meta[prefix][message_hash] = key
             request_hashes[prefix].add(message_hash)
             configured_quota = _construct_quotas(message["use_case_id"])
 
@@ -144,10 +144,10 @@ class TimeseriesCardinalityLimiter:
 
         timestamp, grants = self.backend.check_within_quotas(requested_quotas)
 
-        keys_to_remove = hash_to_offset
-        # make sure that hash_to_offset is no longer used, as the underlying
+        keys_to_remove = hash_to_meta
+        # make sure that hash_to_broker_meta is no longer used, as the underlying
         # dict will be mutated
-        del hash_to_offset
+        del hash_to_meta
 
         for grant in grants:
             for hash in grant.granted_unit_hashes:

+ 35 - 31
tests/sentry/sentry_metrics/limiters/test_cardinality_limiter.py

@@ -2,6 +2,7 @@ import time
 from typing import Optional, Sequence, Tuple
 
 import pytest
+from arroyo import Partition, Topic
 
 from sentry.ratelimits.cardinality import (
     CardinalityLimiter,
@@ -10,7 +11,7 @@ from sentry.ratelimits.cardinality import (
     Timestamp,
 )
 from sentry.sentry_metrics.configuration import UseCaseKey
-from sentry.sentry_metrics.consumers.indexer.batch import PartitionIdxOffset
+from sentry.sentry_metrics.consumers.indexer.common import BrokerMeta
 from sentry.sentry_metrics.indexer.limiters.cardinality import (
     TimeseriesCardinalityLimiter,
     _build_quota_key,
@@ -98,13 +99,13 @@ def test_reject_all():
         result = limiter.check_cardinality_limits(
             UseCaseKey.PERFORMANCE,
             {
-                PartitionIdxOffset(0, 0): {
+                BrokerMeta(Partition(Topic("topic"), 0), 0): {
                     "org_id": 1,
                     "name": "foo",
                     "tags": {},
                     "use_case_id": UseCaseID.TRANSACTIONS,
                 },
-                PartitionIdxOffset(0, 1): {
+                BrokerMeta(Partition(Topic("topic"), 0), 1): {
                     "org_id": 1,
                     "name": "bar",
                     "tags": {},
@@ -113,7 +114,10 @@ def test_reject_all():
             },
         )
 
-        assert result.keys_to_remove == [PartitionIdxOffset(0, 0), PartitionIdxOffset(0, 1)]
+        assert result.keys_to_remove == [
+            BrokerMeta(Partition(Topic("topic"), 0), 0),
+            BrokerMeta(Partition(Topic("topic"), 0), 1),
+        ]
 
 
 def test_reject_all_with_default():
@@ -147,19 +151,19 @@ def test_reject_all_with_default():
         result = limiter.check_cardinality_limits(
             UseCaseKey.PERFORMANCE,
             {
-                PartitionIdxOffset(0, 0): {
+                BrokerMeta(Partition(Topic("topic"), 0), 0): {
                     "org_id": 1,
                     "name": "foo",
                     "tags": {},
                     "use_case_id": UseCaseID.TRANSACTIONS,
                 },
-                PartitionIdxOffset(0, 1): {
+                BrokerMeta(Partition(Topic("topic"), 0), 1): {
                     "org_id": 1,
                     "name": "bar",
                     "tags": {},
                     "use_case_id": UseCaseID.SPANS,
                 },
-                PartitionIdxOffset(0, 2): {
+                BrokerMeta(Partition(Topic("topic"), 0), 2): {
                     "org_id": 1,
                     "name": "boo",
                     "tags": {},
@@ -169,9 +173,9 @@ def test_reject_all_with_default():
         )
 
         assert result.keys_to_remove == [
-            PartitionIdxOffset(0, 0),
-            PartitionIdxOffset(0, 1),
-            PartitionIdxOffset(0, 2),
+            BrokerMeta(Partition(Topic("topic"), 0), 0),
+            BrokerMeta(Partition(Topic("topic"), 0), 1),
+            BrokerMeta(Partition(Topic("topic"), 0), 2),
         ]
 
 
@@ -200,19 +204,19 @@ def test_reject_partial():
         result = limiter.check_cardinality_limits(
             UseCaseKey.PERFORMANCE,
             {
-                PartitionIdxOffset(0, 0): {
+                BrokerMeta(Partition(Topic("topic"), 0), 0): {
                     "org_id": 1,
                     "name": "foo",
                     "tags": {},
                     "use_case_id": UseCaseID.TRANSACTIONS,
                 },
-                PartitionIdxOffset(0, 1): {
+                BrokerMeta(Partition(Topic("topic"), 0), 1): {
                     "org_id": 1,
                     "name": "bar",
                     "tags": {},
                     "use_case_id": UseCaseID.TRANSACTIONS,
                 },
-                PartitionIdxOffset(0, 2): {
+                BrokerMeta(Partition(Topic("topic"), 0), 2): {
                     "org_id": 1,
                     "name": "baz",
                     "tags": {},
@@ -221,7 +225,7 @@ def test_reject_partial():
             },
         )
 
-        assert result.keys_to_remove == [PartitionIdxOffset(0, 2)]
+        assert result.keys_to_remove == [BrokerMeta(Partition(Topic("topic"), 0), 2)]
 
 
 def test_reject_partial_again():
@@ -249,31 +253,31 @@ def test_reject_partial_again():
         result = limiter.check_cardinality_limits(
             UseCaseKey.PERFORMANCE,
             {
-                PartitionIdxOffset(0, 0): {
+                BrokerMeta(Partition(Topic("topic"), 0), 0): {
                     "org_id": 1,
                     "name": "foo",
                     "tags": {},
                     "use_case_id": UseCaseID.TRANSACTIONS,
                 },
-                PartitionIdxOffset(0, 1): {
+                BrokerMeta(Partition(Topic("topic"), 0), 1): {
                     "org_id": 1,
                     "name": "bar",
                     "tags": {},
                     "use_case_id": UseCaseID.TRANSACTIONS,
                 },
-                PartitionIdxOffset(0, 2): {
+                BrokerMeta(Partition(Topic("topic"), 0), 2): {
                     "org_id": 1,
                     "name": "baz",
                     "tags": {},
                     "use_case_id": UseCaseID.SPANS,
                 },
-                PartitionIdxOffset(0, 3): {
+                BrokerMeta(Partition(Topic("topic"), 0), 3): {
                     "org_id": 1,
                     "name": "boo",
                     "tags": {},
                     "use_case_id": UseCaseID.CUSTOM,
                 },
-                PartitionIdxOffset(0, 4): {
+                BrokerMeta(Partition(Topic("topic"), 0), 4): {
                     "org_id": 1,
                     "name": "bye",
                     "tags": {},
@@ -282,7 +286,7 @@ def test_reject_partial_again():
             },
         )
 
-        assert result.keys_to_remove == [PartitionIdxOffset(0, 3)]
+        assert result.keys_to_remove == [BrokerMeta(Partition(Topic("topic"), 0), 3)]
 
 
 def test_accept_all():
@@ -314,31 +318,31 @@ def test_accept_all():
         result = limiter.check_cardinality_limits(
             UseCaseKey.PERFORMANCE,
             {
-                PartitionIdxOffset(0, 0): {
+                BrokerMeta(Partition(Topic("topic"), 0), 0): {
                     "org_id": 1,
                     "name": "foo",
                     "tags": {},
                     "use_case_id": UseCaseID.TRANSACTIONS,
                 },
-                PartitionIdxOffset(0, 1): {
+                BrokerMeta(Partition(Topic("topic"), 0), 1): {
                     "org_id": 1,
                     "name": "bar",
                     "tags": {},
                     "use_case_id": UseCaseID.TRANSACTIONS,
                 },
-                PartitionIdxOffset(0, 2): {
+                BrokerMeta(Partition(Topic("topic"), 0), 2): {
                     "org_id": 1,
                     "name": "baz",
                     "tags": {},
                     "use_case_id": UseCaseID.SPANS,
                 },
-                PartitionIdxOffset(0, 3): {
+                BrokerMeta(Partition(Topic("topic"), 0), 3): {
                     "org_id": 1,
                     "name": "bazz",
                     "tags": {},
                     "use_case_id": UseCaseID.ESCALATING_ISSUES,
                 },
-                PartitionIdxOffset(0, 4): {
+                BrokerMeta(Partition(Topic("topic"), 0), 4): {
                     "org_id": 1,
                     "name": "bye",
                     "tags": {},
@@ -378,19 +382,19 @@ def test_sample_rate_zero(set_sentry_option):
         result = limiter.check_cardinality_limits(
             UseCaseKey.PERFORMANCE,
             {
-                PartitionIdxOffset(0, 0): {
+                BrokerMeta(Partition(Topic("topic"), 0), 0): {
                     "org_id": 1,
                     "name": "foo",
                     "tags": {},
                     "use_case_id": UseCaseID.TRANSACTIONS,
                 },
-                PartitionIdxOffset(0, 1): {
+                BrokerMeta(Partition(Topic("topic"), 0), 1): {
                     "org_id": 1,
                     "name": "bar",
                     "tags": {},
                     "use_case_id": UseCaseID.SPANS,
                 },
-                PartitionIdxOffset(0, 2): {
+                BrokerMeta(Partition(Topic("topic"), 0), 2): {
                     "org_id": 1,
                     "name": "baz",
                     "tags": {},
@@ -433,13 +437,13 @@ def test_sample_rate_half(set_sentry_option):
         result = limiter.check_cardinality_limits(
             UseCaseKey.PERFORMANCE,
             {
-                PartitionIdxOffset(0, 0): {
+                BrokerMeta(Partition(Topic("topic"), 0), 0): {
                     "org_id": 1,
                     "name": "foo",
                     "tags": {},
                     "use_case_id": UseCaseID.TRANSACTIONS,
                 },
-                PartitionIdxOffset(0, 1): {
+                BrokerMeta(Partition(Topic("topic"), 0), 1): {
                     "org_id": 99,
                     "name": "bar",
                     "tags": {},
@@ -450,4 +454,4 @@ def test_sample_rate_half(set_sentry_option):
 
         # We are sampling org_id=1 into cardinality limiting. Because our quota is
         # zero, only that org's metrics are dropped.
-        assert result.keys_to_remove == [PartitionIdxOffset(0, 0)]
+        assert result.keys_to_remove == [BrokerMeta(Partition(Topic("topic"), 0), 0)]

+ 19 - 4
tests/sentry/sentry_metrics/test_batch.py

@@ -10,7 +10,8 @@ from arroyo.backends.kafka import KafkaPayload
 from arroyo.types import BrokerValue, Message, Partition, Topic, Value
 
 from sentry.sentry_metrics.aggregation_option_registry import AggregationOption
-from sentry.sentry_metrics.consumers.indexer.batch import IndexerBatch, PartitionIdxOffset
+from sentry.sentry_metrics.consumers.indexer.batch import IndexerBatch
+from sentry.sentry_metrics.consumers.indexer.common import BrokerMeta
 from sentry.sentry_metrics.consumers.indexer.tags_validator import ReleaseHealthTagsValidator
 from sentry.sentry_metrics.indexer.base import FetchType, FetchTypeExt, Metadata
 from sentry.sentry_metrics.use_case_id_registry import UseCaseID
@@ -249,6 +250,7 @@ def test_extract_strings_with_rollout(should_index_tag_values, expected):
     )
 
     assert batch.extract_strings() == expected
+    assert not batch.invalid_msg_meta
 
 
 @pytest.mark.django_db
@@ -412,6 +414,7 @@ def test_extract_strings_with_single_use_case_ids_blocked():
             }
         }
     }
+    assert not batch.invalid_msg_meta
 
 
 @override_options({"sentry-metrics.indexer.disabled-namespaces": ["spans", "escalating_issues"]})
@@ -484,6 +487,7 @@ def test_extract_strings_with_multiple_use_case_ids_blocked():
             }
         },
     }
+    assert not batch.invalid_msg_meta
 
 
 @pytest.mark.django_db
@@ -586,6 +590,7 @@ def test_extract_strings_with_invalid_mri():
             }
         },
     }
+    assert batch.invalid_msg_meta == {BrokerMeta(Partition(Topic("topic"), 0), 0)}
 
 
 @pytest.mark.django_db
@@ -676,6 +681,7 @@ def test_extract_strings_with_multiple_use_case_ids_and_org_ids():
             }
         },
     }
+    assert not batch.invalid_msg_meta
 
 
 @pytest.mark.django_db
@@ -728,6 +734,7 @@ def test_resolved_with_aggregation_options(caplog, settings):
             }
         }
     )
+    assert not batch.invalid_msg_meta
 
     caplog.set_level(logging.ERROR)
     snuba_payloads = batch.reconstruct_messages(
@@ -868,6 +875,7 @@ def test_all_resolved(caplog, settings):
             }
         }
     )
+    assert not batch.invalid_msg_meta
 
     caplog.set_level(logging.ERROR)
     snuba_payloads = batch.reconstruct_messages(
@@ -1178,6 +1186,7 @@ def test_all_resolved_retention_days_honored(caplog, settings):
             }
         }
     )
+    assert not batch.invalid_msg_meta
 
     caplog.set_level(logging.ERROR)
     snuba_payloads = batch.reconstruct_messages(
@@ -1330,6 +1339,7 @@ def test_batch_resolve_with_values_not_indexed(caplog, settings):
             }
         }
     )
+    assert not batch.invalid_msg_meta
 
     caplog.set_level(logging.ERROR)
     snuba_payloads = batch.reconstruct_messages(
@@ -1474,6 +1484,7 @@ def test_metric_id_rate_limited(caplog, settings):
             }
         }
     )
+    assert not batch.invalid_msg_meta
 
     caplog.set_level(logging.ERROR)
     snuba_payloads = batch.reconstruct_messages(
@@ -1588,6 +1599,7 @@ def test_tag_key_rate_limited(caplog, settings):
             }
         }
     )
+    assert not batch.invalid_msg_meta
 
     caplog.set_level(logging.ERROR)
     snuba_payloads = batch.reconstruct_messages(
@@ -1679,6 +1691,7 @@ def test_tag_value_rate_limited(caplog, settings):
             }
         }
     )
+    assert not batch.invalid_msg_meta
 
     caplog.set_level(logging.ERROR)
     snuba_payloads = batch.reconstruct_messages(
@@ -1821,6 +1834,7 @@ def test_one_org_limited(caplog, settings):
             }
         }
     )
+    assert not batch.invalid_msg_meta
 
     caplog.set_level(logging.ERROR)
     snuba_payloads = batch.reconstruct_messages(
@@ -1932,12 +1946,12 @@ def test_cardinality_limiter(caplog, settings):
         input_codec=_INGEST_CODEC,
         tags_validator=ReleaseHealthTagsValidator().is_allowed,
     )
-    keys_to_remove = list(batch.parsed_payloads_by_offset)[:2]
+    keys_to_remove = list(batch.parsed_payloads_by_meta)[:2]
     # the messages come in a certain order, and Python dictionaries preserve
     # their insertion order. So we can hardcode offsets here.
     assert keys_to_remove == [
-        PartitionIdxOffset(partition_idx=0, offset=0),
-        PartitionIdxOffset(partition_idx=0, offset=1),
+        BrokerMeta(partition=Partition(Topic("topic"), 0), offset=0),
+        BrokerMeta(partition=Partition(Topic("topic"), 0), offset=1),
     ]
     batch.filter_messages(keys_to_remove)
     assert batch.extract_strings() == {
@@ -1953,6 +1967,7 @@ def test_cardinality_limiter(caplog, settings):
             },
         }
     }
+    assert not batch.invalid_msg_meta
 
     snuba_payloads = batch.reconstruct_messages(
         {

+ 23 - 9
tests/sentry/sentry_metrics/test_gen_metrics_multiprocess_steps.py

@@ -4,6 +4,7 @@ import logging
 import pickle
 import re
 import time
+from collections import deque
 from copy import deepcopy
 from datetime import datetime, timezone
 from typing import Any, Dict, List, MutableMapping, Sequence, Union
@@ -18,7 +19,12 @@ from sentry.ratelimits.cardinality import CardinalityLimiter
 from sentry.sentry_metrics.aggregation_option_registry import get_aggregation_option
 from sentry.sentry_metrics.configuration import IndexerStorage, UseCaseKey, get_ingest_config
 from sentry.sentry_metrics.consumers.indexer.batch import valid_metric_name
-from sentry.sentry_metrics.consumers.indexer.common import BatchMessages, MetricsBatchBuilder
+from sentry.sentry_metrics.consumers.indexer.common import (
+    BatchMessages,
+    BrokerMeta,
+    IndexerOutputMessageBatch,
+    MetricsBatchBuilder,
+)
 from sentry.sentry_metrics.consumers.indexer.processing import MessageProcessor
 from sentry.sentry_metrics.indexer.limiters.cardinality import (
     TimeseriesCardinalityLimiter,
@@ -65,10 +71,10 @@ def compare_messages_ignoring_mapping_metadata(actual: Message, expected: Messag
 
 
 def compare_message_batches_ignoring_metadata(
-    actual: Sequence[Message], expected: Sequence[Message]
+    actual: IndexerOutputMessageBatch, expected: Sequence[Message]
 ) -> None:
-    assert len(actual) == len(expected)
-    for (a, e) in zip(actual, expected):
+    assert len(actual.data) == len(expected)
+    for (a, e) in zip(actual.data, expected):
         compare_messages_ignoring_mapping_metadata(a, e)
 
 
@@ -339,7 +345,7 @@ def test_process_messages() -> None:
 
     outer_message = Message(Value(message_batch, last.committable))
 
-    new_batch = MESSAGE_PROCESSOR.process_messages(outer_message=outer_message).data
+    new_batch = MESSAGE_PROCESSOR.process_messages(outer_message=outer_message)
     expected_new_batch = []
     for i, m in enumerate(message_batch):
         assert isinstance(m.value, BrokerValue)
@@ -361,6 +367,7 @@ def test_process_messages() -> None:
         )
 
     compare_message_batches_ignoring_metadata(new_batch, expected_new_batch)
+    assert not new_batch.invalid_msg_meta
 
 
 @pytest.mark.django_db
@@ -386,7 +393,10 @@ def test_process_messages_default_card_rollout(set_sentry_option) -> None:
         "sentry-metrics.cardinality-limiter.orgs-rollout-rate",
         1.0,
     ):
-        MESSAGE_PROCESSOR.process_messages(outer_message=outer_message)
+        new_batch = MESSAGE_PROCESSOR.process_messages(outer_message=outer_message)
+
+    assert len(new_batch.data) == len(message_batch)
+    assert not new_batch.invalid_msg_meta
 
 
 invalid_payloads = [
@@ -478,7 +488,7 @@ def test_process_messages_invalid_messages(
     outer_message = Message(Value(message_batch, last.committable))
 
     with caplog.at_level(logging.ERROR):
-        new_batch = MESSAGE_PROCESSOR.process_messages(outer_message=outer_message).data
+        new_batch = MESSAGE_PROCESSOR.process_messages(outer_message=outer_message)
 
     # we expect just the valid counter_payload msg to be left
     expected_msg = message_batch[0]
@@ -496,6 +506,7 @@ def test_process_messages_invalid_messages(
     ]
     compare_message_batches_ignoring_metadata(new_batch, expected_new_batch)
     assert error_text in caplog.text
+    assert new_batch.invalid_msg_meta == deque([BrokerMeta(Partition(Topic("topic"), 0), 1)])
 
 
 @pytest.mark.django_db
@@ -543,7 +554,7 @@ def test_process_messages_rate_limited(caplog, settings) -> None:
     raw_simple_string_indexer._strings[UseCaseID(rgx.group(2))][1]["rate_limited_test"] = None
 
     with caplog.at_level(logging.ERROR):
-        new_batch = message_processor.process_messages(outer_message=outer_message).data
+        new_batch = message_processor.process_messages(outer_message=outer_message)
 
     # we expect just the counter_payload msg to be left, as that one didn't
     # cause/depend on string writes that have been rate limited
@@ -567,6 +578,7 @@ def test_process_messages_rate_limited(caplog, settings) -> None:
     ]
     compare_message_batches_ignoring_metadata(new_batch, expected_new_batch)
     assert "dropped_message" in caplog.text
+    assert not new_batch.invalid_msg_meta
 
 
 @pytest.mark.django_db
@@ -615,10 +627,12 @@ def test_process_messages_cardinality_limited(
         outer_message = Message(Value(message_batch, last.committable))
 
         with caplog.at_level(logging.ERROR):
-            new_batch = MESSAGE_PROCESSOR.process_messages(outer_message=outer_message).data
+            new_batch = MESSAGE_PROCESSOR.process_messages(outer_message=outer_message)
 
         compare_message_batches_ignoring_metadata(new_batch, [])
 
+    assert not new_batch.invalid_msg_meta
+
 
 def test_valid_metric_name() -> None:
     assert valid_metric_name("") is True

Some files were not shown because too many files changed in this diff