Просмотр исходного кода

ref(post-process-forwarder): Use RunTaskInThreads strategy from Arroyo (#44636)

Use Arroyo's built in RunTaskInThreads strategy to reduce boilerplate in
the post process forwarder. Also deletes some unused code.
Lyn Nagara 2 лет назад
Родитель
Сommit
c86a9a318e

+ 33 - 82
src/sentry/eventstream/kafka/consumer_strategy.py

@@ -1,32 +1,31 @@
 import logging
-import time
-from collections import deque
-from concurrent.futures import Future, ThreadPoolExecutor
-from typing import Any, Deque, Mapping, Optional, Tuple
+from typing import Any, Mapping, Optional
 
 from arroyo.backends.kafka.consumer import KafkaPayload
-from arroyo.processing.strategies import ProcessingStrategy, ProcessingStrategyFactory
-from arroyo.processing.strategies.abstract import MessageRejected
+from arroyo.processing.strategies import (
+    CommitOffsets,
+    ProcessingStrategy,
+    ProcessingStrategyFactory,
+    RunTaskInThreads,
+)
 from arroyo.types import Commit, Message, Partition
 
 from sentry import options
 from sentry.eventstream.base import GroupStates
 from sentry.eventstream.kafka.postprocessworker import _record_metrics, _sampled_eventstream_timer
-from sentry.eventstream.kafka.postprocessworker import (
-    dispatch_post_process_group_task as _dispatch_post_process_group_task,
-)
 from sentry.eventstream.kafka.protocol import (
     get_task_kwargs_for_message,
     get_task_kwargs_for_message_from_headers,
 )
+from sentry.tasks.post_process import post_process_group
 from sentry.utils import metrics
+from sentry.utils.cache import cache_key_for_event
 
 _DURATION_METRIC = "eventstream.duration"
 
 logger = logging.getLogger(__name__)
 
 
-# For testing. Function will eventually move here when postprocessworker is removed.
 def dispatch_post_process_group_task(
     event_id: str,
     project_id: int,
@@ -40,19 +39,24 @@ def dispatch_post_process_group_task(
     group_states: Optional[GroupStates] = None,
     occurrence_id: Optional[str] = None,
 ) -> None:
-    _dispatch_post_process_group_task(
-        event_id,
-        project_id,
-        group_id,
-        is_new,
-        is_regression,
-        is_new_group_environment,
-        primary_hash,
-        queue,
-        skip_consume,
-        group_states,
-        occurrence_id=occurrence_id,
-    )
+    if skip_consume:
+        logger.info("post_process.skip.raw_event", extra={"event_id": event_id})
+    else:
+        cache_key = cache_key_for_event({"project": project_id, "event_id": event_id})
+
+        post_process_group.apply_async(
+            kwargs={
+                "is_new": is_new,
+                "is_regression": is_regression,
+                "is_new_group_environment": is_new_group_environment,
+                "primary_hash": primary_hash,
+                "cache_key": cache_key,
+                "group_id": group_id,
+                "group_states": group_states,
+                "occurrence_id": occurrence_id,
+            },
+            queue=queue,
+        )
 
 
 def _get_task_kwargs(message: Message[KafkaPayload]) -> Optional[Mapping[str, Any]]:
@@ -81,64 +85,6 @@ def _get_task_kwargs_and_dispatch(message: Message[KafkaPayload]) -> None:
     dispatch_post_process_group_task(**task_kwargs)
 
 
-class DispatchTask(ProcessingStrategy[KafkaPayload]):
-    def __init__(
-        self,
-        concurrency: int,
-        max_pending_futures: int,
-        commit: Commit,
-    ) -> None:
-        self.__executor = ThreadPoolExecutor(max_workers=concurrency)
-        self.__futures: Deque[Tuple[Message[KafkaPayload], Future[None]]] = deque()
-        self.__max_pending_futures = max_pending_futures
-        self.__commit = commit
-        self.__closed = False
-
-    def submit(self, message: Message[KafkaPayload]) -> None:
-        assert not self.__closed
-        # The list of pending futures is too long, tell the stream processor to slow down
-        if len(self.__futures) > self.__max_pending_futures:
-            raise MessageRejected
-
-        self.__futures.append(
-            (message, self.__executor.submit(_get_task_kwargs_and_dispatch, message))
-        )
-
-    def poll(self) -> None:
-        # Remove completed futures in order
-        while self.__futures and self.__futures[0][1].done():
-            message, _ = self.__futures.popleft()
-
-            self.__commit(message.committable)
-
-    def join(self, timeout: Optional[float] = None) -> None:
-        start = time.time()
-
-        # Commit all pending offsets
-        self.__commit({}, force=True)
-
-        while self.__futures:
-            remaining = timeout - (time.time() - start) if timeout is not None else None
-            if remaining is not None and remaining <= 0:
-                logger.warning(f"Timed out with {len(self.__futures)} futures in queue")
-                break
-
-            message, future = self.__futures.popleft()
-
-            future.result(remaining)
-
-            self.__commit(message.committable, force=True)
-
-        self.__executor.shutdown()
-
-    def close(self) -> None:
-        self.__closed = True
-
-    def terminate(self) -> None:
-        self.__closed = True
-        self.__executor.shutdown()
-
-
 class PostProcessForwarderStrategyFactory(ProcessingStrategyFactory[KafkaPayload]):
     def __init__(
         self,
@@ -153,4 +99,9 @@ class PostProcessForwarderStrategyFactory(ProcessingStrategyFactory[KafkaPayload
         commit: Commit,
         partitions: Mapping[Partition, int],
     ) -> ProcessingStrategy[KafkaPayload]:
-        return DispatchTask(self.__concurrency, self.__max_pending_futures, commit)
+        return RunTaskInThreads(
+            _get_task_kwargs_and_dispatch,
+            self.__concurrency,
+            self.__max_pending_futures,
+            CommitOffsets(commit),
+        )

+ 0 - 45
src/sentry/eventstream/kafka/postprocessworker.py

@@ -10,14 +10,11 @@ from threading import Lock
 from typing import Any, Generator, Mapping, MutableMapping, Optional, Tuple
 
 from sentry import options
-from sentry.eventstream.base import GroupStates
 from sentry.eventstream.kafka.protocol import (
     get_task_kwargs_for_message,
     get_task_kwargs_for_message_from_headers,
 )
-from sentry.tasks.post_process import post_process_group
 from sentry.utils import metrics
-from sentry.utils.cache import cache_key_for_event
 
 logger = logging.getLogger(__name__)
 
@@ -87,45 +84,3 @@ def _record_metrics(partition: int, task_kwargs: Mapping[str, Any]) -> None:
                 tags={"partition": partition, "type": event_type},
                 sample_rate=1,
             )
-
-
-def dispatch_post_process_group_task(
-    event_id: str,
-    project_id: int,
-    group_id: Optional[int],
-    is_new: bool,
-    is_regression: Optional[bool],
-    is_new_group_environment: bool,
-    primary_hash: Optional[str],
-    queue: str,
-    skip_consume: bool = False,
-    group_states: Optional[GroupStates] = None,
-    occurrence_id: Optional[str] = None,
-) -> None:
-    if skip_consume:
-        logger.info("post_process.skip.raw_event", extra={"event_id": event_id})
-    else:
-        cache_key = cache_key_for_event({"project": project_id, "event_id": event_id})
-
-        post_process_group.apply_async(
-            kwargs={
-                "is_new": is_new,
-                "is_regression": is_regression,
-                "is_new_group_environment": is_new_group_environment,
-                "primary_hash": primary_hash,
-                "cache_key": cache_key,
-                "group_id": group_id,
-                "group_states": group_states,
-                "occurrence_id": occurrence_id,
-            },
-            queue=queue,
-        )
-
-
-def _get_task_kwargs_and_dispatch(message: Message) -> None:
-    task_kwargs = _get_task_kwargs(message)
-    if not task_kwargs:
-        return None
-
-    _record_metrics(message.partition(), task_kwargs)
-    dispatch_post_process_group_task(**task_kwargs)