Просмотр исходного кода

ref(post-process-forwarder): Use RunTaskInThreads strategy from Arroyo (#44636)

Use Arroyo's built in RunTaskInThreads strategy to reduce boilerplate in
the post process forwarder. Also deletes some unused code.
Lyn Nagara 2 лет назад
Родитель
Сommit
c86a9a318e

+ 33 - 82
src/sentry/eventstream/kafka/consumer_strategy.py

@@ -1,32 +1,31 @@
 import logging
 import logging
-import time
-from collections import deque
-from concurrent.futures import Future, ThreadPoolExecutor
-from typing import Any, Deque, Mapping, Optional, Tuple
+from typing import Any, Mapping, Optional
 
 
 from arroyo.backends.kafka.consumer import KafkaPayload
 from arroyo.backends.kafka.consumer import KafkaPayload
-from arroyo.processing.strategies import ProcessingStrategy, ProcessingStrategyFactory
-from arroyo.processing.strategies.abstract import MessageRejected
+from arroyo.processing.strategies import (
+    CommitOffsets,
+    ProcessingStrategy,
+    ProcessingStrategyFactory,
+    RunTaskInThreads,
+)
 from arroyo.types import Commit, Message, Partition
 from arroyo.types import Commit, Message, Partition
 
 
 from sentry import options
 from sentry import options
 from sentry.eventstream.base import GroupStates
 from sentry.eventstream.base import GroupStates
 from sentry.eventstream.kafka.postprocessworker import _record_metrics, _sampled_eventstream_timer
 from sentry.eventstream.kafka.postprocessworker import _record_metrics, _sampled_eventstream_timer
-from sentry.eventstream.kafka.postprocessworker import (
-    dispatch_post_process_group_task as _dispatch_post_process_group_task,
-)
 from sentry.eventstream.kafka.protocol import (
 from sentry.eventstream.kafka.protocol import (
     get_task_kwargs_for_message,
     get_task_kwargs_for_message,
     get_task_kwargs_for_message_from_headers,
     get_task_kwargs_for_message_from_headers,
 )
 )
+from sentry.tasks.post_process import post_process_group
 from sentry.utils import metrics
 from sentry.utils import metrics
+from sentry.utils.cache import cache_key_for_event
 
 
 _DURATION_METRIC = "eventstream.duration"
 _DURATION_METRIC = "eventstream.duration"
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
-# For testing. Function will eventually move here when postprocessworker is removed.
 def dispatch_post_process_group_task(
 def dispatch_post_process_group_task(
     event_id: str,
     event_id: str,
     project_id: int,
     project_id: int,
@@ -40,19 +39,24 @@ def dispatch_post_process_group_task(
     group_states: Optional[GroupStates] = None,
     group_states: Optional[GroupStates] = None,
     occurrence_id: Optional[str] = None,
     occurrence_id: Optional[str] = None,
 ) -> None:
 ) -> None:
-    _dispatch_post_process_group_task(
-        event_id,
-        project_id,
-        group_id,
-        is_new,
-        is_regression,
-        is_new_group_environment,
-        primary_hash,
-        queue,
-        skip_consume,
-        group_states,
-        occurrence_id=occurrence_id,
-    )
+    if skip_consume:
+        logger.info("post_process.skip.raw_event", extra={"event_id": event_id})
+    else:
+        cache_key = cache_key_for_event({"project": project_id, "event_id": event_id})
+
+        post_process_group.apply_async(
+            kwargs={
+                "is_new": is_new,
+                "is_regression": is_regression,
+                "is_new_group_environment": is_new_group_environment,
+                "primary_hash": primary_hash,
+                "cache_key": cache_key,
+                "group_id": group_id,
+                "group_states": group_states,
+                "occurrence_id": occurrence_id,
+            },
+            queue=queue,
+        )
 
 
 
 
 def _get_task_kwargs(message: Message[KafkaPayload]) -> Optional[Mapping[str, Any]]:
 def _get_task_kwargs(message: Message[KafkaPayload]) -> Optional[Mapping[str, Any]]:
@@ -81,64 +85,6 @@ def _get_task_kwargs_and_dispatch(message: Message[KafkaPayload]) -> None:
     dispatch_post_process_group_task(**task_kwargs)
     dispatch_post_process_group_task(**task_kwargs)
 
 
 
 
-class DispatchTask(ProcessingStrategy[KafkaPayload]):
-    def __init__(
-        self,
-        concurrency: int,
-        max_pending_futures: int,
-        commit: Commit,
-    ) -> None:
-        self.__executor = ThreadPoolExecutor(max_workers=concurrency)
-        self.__futures: Deque[Tuple[Message[KafkaPayload], Future[None]]] = deque()
-        self.__max_pending_futures = max_pending_futures
-        self.__commit = commit
-        self.__closed = False
-
-    def submit(self, message: Message[KafkaPayload]) -> None:
-        assert not self.__closed
-        # The list of pending futures is too long, tell the stream processor to slow down
-        if len(self.__futures) > self.__max_pending_futures:
-            raise MessageRejected
-
-        self.__futures.append(
-            (message, self.__executor.submit(_get_task_kwargs_and_dispatch, message))
-        )
-
-    def poll(self) -> None:
-        # Remove completed futures in order
-        while self.__futures and self.__futures[0][1].done():
-            message, _ = self.__futures.popleft()
-
-            self.__commit(message.committable)
-
-    def join(self, timeout: Optional[float] = None) -> None:
-        start = time.time()
-
-        # Commit all pending offsets
-        self.__commit({}, force=True)
-
-        while self.__futures:
-            remaining = timeout - (time.time() - start) if timeout is not None else None
-            if remaining is not None and remaining <= 0:
-                logger.warning(f"Timed out with {len(self.__futures)} futures in queue")
-                break
-
-            message, future = self.__futures.popleft()
-
-            future.result(remaining)
-
-            self.__commit(message.committable, force=True)
-
-        self.__executor.shutdown()
-
-    def close(self) -> None:
-        self.__closed = True
-
-    def terminate(self) -> None:
-        self.__closed = True
-        self.__executor.shutdown()
-
-
 class PostProcessForwarderStrategyFactory(ProcessingStrategyFactory[KafkaPayload]):
 class PostProcessForwarderStrategyFactory(ProcessingStrategyFactory[KafkaPayload]):
     def __init__(
     def __init__(
         self,
         self,
@@ -153,4 +99,9 @@ class PostProcessForwarderStrategyFactory(ProcessingStrategyFactory[KafkaPayload
         commit: Commit,
         commit: Commit,
         partitions: Mapping[Partition, int],
         partitions: Mapping[Partition, int],
     ) -> ProcessingStrategy[KafkaPayload]:
     ) -> ProcessingStrategy[KafkaPayload]:
-        return DispatchTask(self.__concurrency, self.__max_pending_futures, commit)
+        return RunTaskInThreads(
+            _get_task_kwargs_and_dispatch,
+            self.__concurrency,
+            self.__max_pending_futures,
+            CommitOffsets(commit),
+        )

+ 0 - 45
src/sentry/eventstream/kafka/postprocessworker.py

@@ -10,14 +10,11 @@ from threading import Lock
 from typing import Any, Generator, Mapping, MutableMapping, Optional, Tuple
 from typing import Any, Generator, Mapping, MutableMapping, Optional, Tuple
 
 
 from sentry import options
 from sentry import options
-from sentry.eventstream.base import GroupStates
 from sentry.eventstream.kafka.protocol import (
 from sentry.eventstream.kafka.protocol import (
     get_task_kwargs_for_message,
     get_task_kwargs_for_message,
     get_task_kwargs_for_message_from_headers,
     get_task_kwargs_for_message_from_headers,
 )
 )
-from sentry.tasks.post_process import post_process_group
 from sentry.utils import metrics
 from sentry.utils import metrics
-from sentry.utils.cache import cache_key_for_event
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -87,45 +84,3 @@ def _record_metrics(partition: int, task_kwargs: Mapping[str, Any]) -> None:
                 tags={"partition": partition, "type": event_type},
                 tags={"partition": partition, "type": event_type},
                 sample_rate=1,
                 sample_rate=1,
             )
             )
-
-
-def dispatch_post_process_group_task(
-    event_id: str,
-    project_id: int,
-    group_id: Optional[int],
-    is_new: bool,
-    is_regression: Optional[bool],
-    is_new_group_environment: bool,
-    primary_hash: Optional[str],
-    queue: str,
-    skip_consume: bool = False,
-    group_states: Optional[GroupStates] = None,
-    occurrence_id: Optional[str] = None,
-) -> None:
-    if skip_consume:
-        logger.info("post_process.skip.raw_event", extra={"event_id": event_id})
-    else:
-        cache_key = cache_key_for_event({"project": project_id, "event_id": event_id})
-
-        post_process_group.apply_async(
-            kwargs={
-                "is_new": is_new,
-                "is_regression": is_regression,
-                "is_new_group_environment": is_new_group_environment,
-                "primary_hash": primary_hash,
-                "cache_key": cache_key,
-                "group_id": group_id,
-                "group_states": group_states,
-                "occurrence_id": occurrence_id,
-            },
-            queue=queue,
-        )
-
-
-def _get_task_kwargs_and_dispatch(message: Message) -> None:
-    task_kwargs = _get_task_kwargs(message)
-    if not task_kwargs:
-        return None
-
-    _record_metrics(message.partition(), task_kwargs)
-    dispatch_post_process_group_task(**task_kwargs)