Browse Source

feat(replays): Implement recording consumer using arroyo primitives (#42186)

Should be safer and easier to maintain than the custom consumer we
currently have. Inspired by the recent change which added threading to
the single message type. This pull should prevent those errors in the
future by simplifying the interface.

- Flattens the consumer module and the companion test module.
- Remove custom strategy and uses arroyo primitives.
- Keeps consumer behavior in the consumer module and moves ingest
related behavior into the usecase.
- No behavioral changes.

closes: https://github.com/getsentry/replay-backend/issues/238
Colton Allen 2 years ago
parent
commit
35c4912c49

+ 3 - 3
src/sentry/replays/consumers/__init__.py

@@ -6,11 +6,11 @@ from typing import Any, MutableMapping
 from arroyo import Topic
 from arroyo.backends.kafka.configuration import build_kafka_consumer_configuration
 from arroyo.backends.kafka.consumer import KafkaConsumer, KafkaPayload
-from arroyo.commit import IMMEDIATE
+from arroyo.commit import ONCE_PER_SECOND
 from arroyo.processing.processor import StreamProcessor
 from django.conf import settings
 
-from sentry.replays.consumers.recording.factory import ProcessReplayRecordingStrategyFactory
+from sentry.replays.consumers.recording import ProcessReplayRecordingStrategyFactory
 from sentry.utils import kafka_config
 
 
@@ -30,7 +30,7 @@ def get_replays_recordings_consumer(
         consumer=consumer,
         topic=Topic(topic),
         processor_factory=ProcessReplayRecordingStrategyFactory(),
-        commit_policy=IMMEDIATE,
+        commit_policy=ONCE_PER_SECOND,
     )
 
     def handler(signum: int, frame: Any) -> None:

+ 99 - 0
src/sentry/replays/consumers/recording.py

@@ -0,0 +1,99 @@
+import dataclasses
+import logging
+import random
+from typing import Any, Dict, Mapping, cast
+
+import msgpack
+import sentry_sdk
+from arroyo.backends.kafka.consumer import KafkaPayload
+from arroyo.processing.strategies import RunTaskInThreads, TransformStep
+from arroyo.processing.strategies.abstract import ProcessingStrategyFactory
+from arroyo.processing.strategies.commit import CommitOffsets
+from arroyo.processing.strategies.filter import FilterStep
+from arroyo.types import Commit, Message, Partition
+from django.conf import settings
+from sentry_sdk.tracing import Transaction
+
+from sentry.replays.lib.consumer import LogExceptionStep
+from sentry.replays.usecases.ingest import (
+    RecordingMessage,
+    RecordingSegmentChunkMessage,
+    RecordingSegmentMessage,
+    ingest_chunk,
+    ingest_recording_chunked,
+    ingest_recording_not_chunked,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclasses.dataclass
+class MessageContext:
+    message: Dict[str, Any]
+    transaction: Transaction
+
+
+class ProcessReplayRecordingStrategyFactory(ProcessingStrategyFactory[KafkaPayload]):
+    """
+    This consumer processes replay recordings, which are compressed payloads split up into
+    chunks.
+    """
+
+    def create_with_partitions(
+        self,
+        commit: Commit,
+        partitions: Mapping[Partition, int],
+    ) -> Any:
+        return LogExceptionStep(
+            message="Invalid recording specified.",
+            logger=logger,
+            next_step=TransformStep(
+                function=move_chunks_to_cache_or_skip,
+                next_step=FilterStep(
+                    function=is_capstone_message,
+                    next_step=RunTaskInThreads(
+                        processing_function=move_replay_to_permanent_storage,
+                        concurrency=16,
+                        max_pending_futures=32,
+                        next_step=CommitOffsets(commit),
+                    ),
+                ),
+            ),
+        )
+
+
+def move_chunks_to_cache_or_skip(message: Message[KafkaPayload]) -> MessageContext:
+    """Move chunk messages to cache or skip."""
+    transaction = sentry_sdk.start_transaction(
+        name="replays.consumer.process_recording",
+        op="replays.consumer",
+        sampled=random.random()
+        < getattr(settings, "SENTRY_REPLAY_RECORDINGS_CONSUMER_APM_SAMPLING", 0),
+    )
+
+    message_dict = msgpack.unpackb(message.payload.value)
+
+    if message_dict["type"] == "replay_recording_chunk":
+        ingest_chunk(cast(RecordingSegmentChunkMessage, message_dict), transaction)
+
+    return MessageContext(message_dict, transaction)
+
+
+def is_capstone_message(message: Message[MessageContext]) -> Any:
+    """Return "True" if the message is a capstone and can be processed in parallel."""
+    message_type = message.payload.message["type"]
+    return message_type == "replay_recording_not_chunked" or message_type == "replay_recording"
+
+
+def move_replay_to_permanent_storage(message: Message[MessageContext]) -> Any:
+    """Move the replay payload to permanent storage."""
+    context: MessageContext = message.payload
+    message_dict = context.message
+    message_type = message_dict["type"]
+
+    if message_type == "replay_recording_not_chunked":
+        ingest_recording_not_chunked(cast(RecordingMessage, message_dict), context.transaction)
+    elif message_type == "replay_recording":
+        ingest_recording_chunked(cast(RecordingSegmentMessage, message_dict), context.transaction)
+    else:
+        raise ValueError(f"Invalid replays recording message type specified: {message_type}")

+ 0 - 0
src/sentry/replays/consumers/recording/__init__.py


+ 0 - 24
src/sentry/replays/consumers/recording/factory.py

@@ -1,24 +0,0 @@
-import logging
-from typing import Mapping
-
-from arroyo.backends.kafka.consumer import KafkaPayload
-from arroyo.processing.strategies.abstract import ProcessingStrategy, ProcessingStrategyFactory
-from arroyo.types import Commit, Partition
-
-from sentry.replays.consumers.recording.process_recording import ProcessRecordingSegmentStrategy
-
-logger = logging.getLogger(__name__)
-
-
-class ProcessReplayRecordingStrategyFactory(ProcessingStrategyFactory[KafkaPayload]):
-    """
-    This consumer processes replay recordings, which are compressed payloads split up into
-    chunks.
-    """
-
-    def create_with_partitions(
-        self,
-        commit: Commit,
-        partitions: Mapping[Partition, int],
-    ) -> ProcessingStrategy[KafkaPayload]:
-        return ProcessRecordingSegmentStrategy(commit)

+ 0 - 193
src/sentry/replays/consumers/recording/process_recording.py

@@ -1,193 +0,0 @@
-from __future__ import annotations
-
-import concurrent.futures
-import logging
-import random
-import time
-from collections import deque
-from concurrent.futures import Future
-from typing import Deque, MutableMapping, NamedTuple, Optional, cast
-
-import msgpack
-import sentry_sdk
-from arroyo import Partition
-from arroyo.backends.kafka.consumer import KafkaPayload
-from arroyo.processing.strategies import MessageRejected
-from arroyo.processing.strategies.abstract import ProcessingStrategy
-from arroyo.types import Commit, Message
-from django.conf import settings
-
-from sentry.replays.cache import RecordingSegmentParts
-from sentry.replays.usecases.ingest import (
-    RecordingMessage,
-    RecordingSegmentChunkMessage,
-    RecordingSegmentMessage,
-    ingest_chunk,
-    ingest_recording_chunked,
-    ingest_recording_not_chunked,
-)
-from sentry.utils import metrics
-
-logger = logging.getLogger("sentry.replays")
-
-CACHE_TIMEOUT = 3600
-COMMIT_FREQUENCY_SEC = 1
-
-
-class MissingRecordingSegmentHeaders(ValueError):
-    pass
-
-
-class ReplayRecordingMessageFuture(NamedTuple):
-    """
-    Map a submitted message to a Future returned by the Producer.
-    This is useful for being able to commit the latest offset back
-    to the original consumer.
-    """
-
-    message: Message[KafkaPayload]
-    future: Future[None]
-
-
-class ProcessRecordingSegmentStrategy(ProcessingStrategy[KafkaPayload]):
-    def __init__(
-        self,
-        commit: Commit,
-    ) -> None:
-        self.__closed = False
-        self.__futures: Deque[ReplayRecordingMessageFuture] = deque()
-        self.__threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=16)
-        self.__commit = commit
-        self.__commit_data: MutableMapping[Partition, int] = {}
-        self.__last_committed: float = 0
-        self.__max_pending_futures = 32
-
-    @metrics.wraps("replays.process_recording.submit")
-    def submit(self, message: Message[KafkaPayload]) -> None:
-        assert not self.__closed
-
-        if len(self.__futures) > self.__max_pending_futures:
-            raise MessageRejected
-
-        current_transaction = sentry_sdk.start_transaction(
-            name="replays.consumer.process_recording",
-            op="replays.consumer",
-            sampled=random.random()
-            < getattr(settings, "SENTRY_REPLAY_RECORDINGS_CONSUMER_APM_SAMPLING", 0),
-        )
-
-        try:
-
-            with current_transaction.start_child(op="msg_unpack"):
-                message_dict = msgpack.unpackb(message.payload.value)
-
-            if message_dict["type"] == "replay_recording_chunk":
-                if type(message_dict["payload"]) is str:
-                    # if the payload is uncompressed, we need to encode it as bytes
-                    # as msgpack will decode it as a utf-8 python string
-                    message_dict["payload"] = message_dict["payload"].encode("utf-8")
-
-                ingest_chunk(cast(RecordingSegmentChunkMessage, message_dict), current_transaction)
-            elif message_dict["type"] == "replay_recording":
-                cache_prefix = replay_recording_segment_cache_id(
-                    project_id=message_dict["project_id"],
-                    replay_id=message_dict["replay_id"],
-                    segment_id=message_dict["replay_recording"]["id"],
-                )
-                parts = RecordingSegmentParts(
-                    prefix=cache_prefix, num_parts=message_dict["replay_recording"]["chunks"]
-                )
-
-                self.__futures.append(
-                    ReplayRecordingMessageFuture(
-                        message,
-                        self.__threadpool.submit(
-                            ingest_recording_chunked,
-                            message_dict=cast(RecordingSegmentMessage, message_dict),
-                            parts=parts,
-                            transaction=current_transaction,
-                        ),
-                    )
-                )
-            elif message_dict["type"] == "replay_recording_not_chunked":
-                self.__futures.append(
-                    ReplayRecordingMessageFuture(
-                        message,
-                        self.__threadpool.submit(
-                            ingest_recording_not_chunked,
-                            message_dict=cast(RecordingMessage, message_dict),
-                            transaction=current_transaction,
-                        ),
-                    )
-                )
-        except Exception:
-            # avoid crash looping on bad messsages for now
-            logger.exception(
-                "Failed to process replay recording message",
-                extra={"committable": message.committable},
-            )
-            current_transaction.finish()
-
-    def close(self) -> None:
-        self.__closed = True
-        self.__threadpool.shutdown(wait=False)
-
-    def terminate(self) -> None:
-        self.close()
-
-    def join(self, timeout: Optional[float] = None) -> None:
-        start = time.time()
-
-        # Immediately commit all the offsets we have popped from the queue.
-        self.__throttled_commit(force=True)
-
-        # Any remaining items in the queue are flushed until the process is terminated.
-        while self.__futures:
-            remaining = timeout - (time.time() - start) if timeout is not None else None
-            if remaining is not None and remaining <= 0:
-                logger.warning(f"Timed out with {len(self.__futures)} futures in queue")
-                break
-
-            # Pop the future from the queue.  If it succeeds great but if not it will be discarded
-            # on the next loop iteration without commit.  An error will be logged.
-            message, future = self.__futures.popleft()
-
-            try:
-                future.result(remaining)
-                self.__commit(message.committable)
-            except Exception:
-                logger.exception(
-                    "Async future failed in replays recording-segment consumer.",
-                    extra={"committable": message.committable},
-                )
-
-    def poll(self) -> None:
-        while self.__futures:
-            message, future = self.__futures[0]
-            if not future.done():
-                break
-
-            if future.exception():
-                logger.error(
-                    "Async future failed in replays recording-segment consumer.",
-                    exc_info=future.exception(),
-                    extra={"committable": message.committable},
-                )
-
-            self.__futures.popleft()
-            self.__commit_data.update(message.committable)
-
-        self.__throttled_commit()
-
-    def __throttled_commit(self, force: bool = False) -> None:
-        now = time.time()
-
-        if (now - self.__last_committed) >= COMMIT_FREQUENCY_SEC or force is True:
-            if self.__commit_data:
-                self.__commit(self.__commit_data)
-                self.__last_committed = now
-                self.__commit_data = {}
-
-
-def replay_recording_segment_cache_id(project_id: int, replay_id: str, segment_id: str) -> str:
-    return f"{project_id}:{replay_id}:{segment_id}"

+ 45 - 0
src/sentry/replays/lib/consumer.py

@@ -0,0 +1,45 @@
+import logging
+from typing import Optional
+
+from arroyo.processing.strategies.abstract import ProcessingStrategy
+from arroyo.types import Message, TPayload
+
+
+class LogExceptionStep(ProcessingStrategy[TPayload]):
+    def __init__(
+        self,
+        message: str,
+        logger: logging.Logger,
+        next_step: ProcessingStrategy[TPayload],
+    ) -> None:
+        self.__exception_message = message
+        self.__next_step = next_step
+        self.__closed = False
+        self.__logger = logger
+
+    def submit(self, message: Message[TPayload]) -> None:
+        assert not self.__closed
+
+        try:
+            self.__next_step.submit(message)
+        except Exception:
+            self.__logger.exception(self.__exception_message)
+
+    def poll(self) -> None:
+        try:
+            self.__next_step.poll()
+        except Exception:
+            self.__logger.exception(self.__exception_message)
+
+    def close(self) -> None:
+        self.__closed = True
+
+    def terminate(self) -> None:
+        self.__closed = True
+
+        self.__logger.debug("Terminating %r...", self.__next_step)
+        self.__next_step.terminate()
+
+    def join(self, timeout: Optional[float] = None) -> None:
+        self.__next_step.close()
+        self.__next_step.join(timeout)

+ 15 - 4
src/sentry/replays/usecases/ingest.py

@@ -4,7 +4,7 @@ import dataclasses
 import logging
 from datetime import datetime, timezone
 from io import BytesIO
-from typing import TypedDict
+from typing import TypedDict, Union
 
 import sentry_sdk
 from django.conf import settings
@@ -40,7 +40,7 @@ class RecordingSegmentChunkMessage(TypedDict):
     replay_id: str  # the uuid of the encompassing replay event
     project_id: int
     chunk_index: int  # each segment is split into chunks to fit into kafka
-    payload: bytes
+    payload: Union[bytes, str]
 
 
 class RecordingSegmentMessage(TypedDict):
@@ -78,7 +78,6 @@ class RecordingIngestMessage:
 @metrics.wraps("replays.usecases.ingest.ingest_recording_chunked")
 def ingest_recording_chunked(
     message_dict: RecordingSegmentMessage,
-    parts: RecordingSegmentParts,
     transaction: Transaction,
 ) -> None:
     """Ingest chunked recording messages."""
@@ -86,6 +85,15 @@ def ingest_recording_chunked(
         op="replays.usecases.ingest.ingest_recording_chunked",
         description="ingest_recording_chunked",
     ):
+        cache_prefix = replay_recording_segment_cache_id(
+            project_id=message_dict["project_id"],
+            replay_id=message_dict["replay_id"],
+            segment_id=message_dict["replay_recording"]["id"],
+        )
+        parts = RecordingSegmentParts(
+            prefix=cache_prefix, num_parts=message_dict["replay_recording"]["chunks"]
+        )
+
         try:
             recording_segment_with_headers = collate_segment_chunks(parts)
         except ValueError:
@@ -224,8 +232,11 @@ def ingest_chunk(message_dict: RecordingSegmentChunkMessage, transaction: Transa
             segment_id=message_dict["id"],
         )
 
+        payload = message_dict["payload"]
+        payload = payload.encode("utf-8") if isinstance(payload, str) else payload
+
         part = RecordingSegmentCache(cache_prefix)
-        part[message_dict["chunk_index"]] = message_dict["payload"]
+        part[message_dict["chunk_index"]] = payload
 
     transaction.finish()
 

+ 0 - 0
tests/sentry/replays/consumers/recording_consumer/__init__.py


+ 16 - 6
tests/sentry/replays/consumers/recording_consumer/test_consumer.py → tests/sentry/replays/consumers/test_recording.py

@@ -10,7 +10,7 @@ from arroyo.backends.kafka import KafkaPayload
 from arroyo.types import BrokerValue, Message, Partition, Topic
 
 from sentry.models import File, OnboardingTask, OnboardingTaskStatus
-from sentry.replays.consumers.recording.factory import ProcessReplayRecordingStrategyFactory
+from sentry.replays.consumers.recording import ProcessReplayRecordingStrategyFactory
 from sentry.replays.models import ReplayRecordingSegment
 from sentry.testutils import TransactionTestCase
 
@@ -27,7 +27,9 @@ class TestRecordingsConsumerEndToEnd(TransactionTestCase):
     @patch("sentry.models.OrganizationOnboardingTask.objects.record")
     @patch("sentry.analytics.record")
     def test_basic_flow_compressed(self, mock_record, mock_onboarding_task):
-        processing_strategy = self.processing_factory().create_with_partitions(lambda x: None, None)
+        processing_strategy = self.processing_factory().create_with_partitions(
+            lambda x, force=False: None, None
+        )
         segment_id = 0
         consumer_messages = [
             {
@@ -106,7 +108,9 @@ class TestRecordingsConsumerEndToEnd(TransactionTestCase):
         )
 
     def test_basic_flow_uncompressed(self):
-        processing_strategy = self.processing_factory().create_with_partitions(lambda x: None, None)
+        processing_strategy = self.processing_factory().create_with_partitions(
+            lambda x, force=False: None, None
+        )
         segment_id = 0
         consumer_messages = [
             {
@@ -168,7 +172,9 @@ class TestRecordingsConsumerEndToEnd(TransactionTestCase):
         assert self.project.flags.has_replays
 
     def test_duplicate_segment_flow(self):
-        processing_strategy = self.processing_factory().create_with_partitions(lambda x: None, None)
+        processing_strategy = self.processing_factory().create_with_partitions(
+            lambda x, force=False: None, None
+        )
         segment_id = 0
         consumer_messages = [
             {
@@ -223,7 +229,9 @@ class TestRecordingsConsumerEndToEnd(TransactionTestCase):
     @patch("sentry.models.OrganizationOnboardingTask.objects.record")
     @patch("sentry.analytics.record")
     def test_nonchunked_basic_flow_compressed(self, mock_record, mock_onboarding_task):
-        processing_strategy = self.processing_factory().create_with_partitions(lambda x: None, None)
+        processing_strategy = self.processing_factory().create_with_partitions(
+            lambda x, force=False: None, None
+        )
         segment_id = 0
 
         processing_strategy.submit(
@@ -281,7 +289,9 @@ class TestRecordingsConsumerEndToEnd(TransactionTestCase):
     @patch("sentry.models.OrganizationOnboardingTask.objects.record")
     @patch("sentry.analytics.record")
     def test_nonchunked_basic_flow(self, mock_record, mock_onboarding_task):
-        processing_strategy = self.processing_factory().create_with_partitions(lambda x: None, None)
+        processing_strategy = self.processing_factory().create_with_partitions(
+            lambda x, force=False: None, None
+        )
         segment_id = 0
 
         processing_strategy.submit(