Browse Source

fix(kafka): Unified way to handle signals, fix missing ones (#42942)

Markus Unterwaditzer 2 years ago
parent
commit
62c970dc63

+ 1 - 9
src/sentry/profiles/consumers/__init__.py

@@ -1,6 +1,5 @@
 from __future__ import annotations
 
-import signal
 from typing import Any, MutableMapping
 
 from arroyo import Topic
@@ -25,20 +24,13 @@ def get_profiles_process_consumer(
     topic = force_topic or topic
     consumer_config = get_config(topic, group_id, auto_offset_reset, force_cluster)
     consumer = KafkaConsumer(consumer_config)
-    processor = StreamProcessor(
+    return StreamProcessor(
         consumer=consumer,
         topic=Topic(topic),
         processor_factory=ProcessProfileStrategyFactory(),
         commit_policy=ONCE_PER_SECOND,
     )
 
-    def handler(signum: int, frame: Any) -> None:
-        processor.signal_shutdown()
-
-    signal.signal(signal.SIGINT, handler)
-    signal.signal(signal.SIGTERM, handler)
-    return processor
-
 
 def get_config(
     topic: str, group_id: str, auto_offset_reset: str, force_cluster: str | None

+ 1 - 10
src/sentry/region_to_control/consumer.py

@@ -1,4 +1,3 @@
-import signal
 from typing import Any, Mapping, Optional
 
 import sentry_sdk
@@ -38,21 +37,13 @@ def get_region_to_control_consumer(
         )
     )
 
-    processor = StreamProcessor(
+    return StreamProcessor(
         consumer=consumer,
         topic=Topic(settings.KAFKA_REGION_TO_CONTROL),
         processor_factory=RegionToControlStrategyFactory(),
         commit_policy=ONCE_PER_SECOND,
     )
 
-    def handler(*args: Any) -> None:
-        processor.signal_shutdown()
-
-    signal.signal(signal.SIGINT, handler)
-    signal.signal(signal.SIGTERM, handler)
-
-    return processor
-
 
 class ProcessRegionToControlMessage(ProcessingStrategy[KafkaPayload]):
     def __init__(self, next_step: ProcessingStrategy[KafkaPayload]) -> None:

+ 1 - 9
src/sentry/replays/consumers/__init__.py

@@ -1,6 +1,5 @@
 from __future__ import annotations
 
-import signal
 from typing import Any, MutableMapping
 
 from arroyo import Topic
@@ -26,20 +25,13 @@ def get_replays_recordings_consumer(
     topic = force_topic or topic
     consumer_config = get_config(topic, group_id, auto_offset_reset, force_cluster)
     consumer = KafkaConsumer(consumer_config)
-    processor = StreamProcessor(
+    return StreamProcessor(
         consumer=consumer,
         topic=Topic(topic),
         processor_factory=ProcessReplayRecordingStrategyFactory(),
         commit_policy=ONCE_PER_SECOND,
     )
 
-    def handler(signum: int, frame: Any) -> None:
-        processor.signal_shutdown()
-
-    signal.signal(signal.SIGINT, handler)
-    signal.signal(signal.SIGTERM, handler)
-    return processor
-
 
 def get_config(
     topic: str, group_id: str, auto_offset_reset: str, force_cluster: str | None

+ 17 - 28
src/sentry/runner/commands/run.py

@@ -1,6 +1,5 @@
 from __future__ import annotations
 
-import signal
 import sys
 from concurrent.futures import ThreadPoolExecutor
 from multiprocessing import cpu_count
@@ -11,6 +10,7 @@ from sentry.bgtasks.api import managed_bgtasks
 from sentry.ingest.types import ConsumerType
 from sentry.runner.decorators import configuration, log_options
 from sentry.sentry_metrics.consumers.indexer.slicing_router import get_slicing_router
+from sentry.utils.kafka import run_processor_with_signals
 
 DEFAULT_BLOCK_SIZE = int(32 * 1e6)
 
@@ -371,6 +371,7 @@ def post_process_forwarder(**options):
     from sentry.eventstream.base import ForwarderNotRequired
 
     try:
+        # TODO(markus): convert to use run_processor_with_signals -- can't yet because there's a custom shutdown handler
         eventstream.run_post_process_forwarder(
             entity=options["entity"],
             consumer_group=options["consumer_group"],
@@ -437,13 +438,7 @@ def query_subscription_consumer(**options):
         force_offset_reset=options["force_offset_reset"],
     )
 
-    def handler(signum, frame):
-        subscriber.shutdown()
-
-    signal.signal(signal.SIGINT, handler)
-    signal.signal(signal.SIGTERM, handler)
-
-    subscriber.run()
+    run_processor_with_signals(subscriber)
 
 
 def batching_kafka_options(group, max_batch_size=None, max_batch_time_ms=1000):
@@ -561,7 +556,8 @@ def ingest_consumer(consumer_types, all_consumer_types, **options):
     with metrics.global_tags(
         ingest_consumer_types=",".join(sorted(consumer_types)), _all_threads=True
     ):
-        get_ingest_consumer(consumer_types=consumer_types, executor=executor, **options).run()
+        consumer = get_ingest_consumer(consumer_types=consumer_types, executor=executor, **options)
+        run_processor_with_signals(consumer)
 
 
 @run.command("occurrences-ingest-consumer")
@@ -575,7 +571,8 @@ def occurrences_ingest_consumer():
     consumer_type = settings.KAFKA_INGEST_OCCURRENCES
 
     with metrics.global_tags(ingest_consumer_types=consumer_type, _all_threads=True):
-        get_occurrences_ingest_consumer(consumer_type).run()
+        consumer = get_occurrences_ingest_consumer(consumer_type)
+        run_processor_with_signals(consumer)
 
 
 @run.command("region-to-control-consumer")
@@ -598,8 +595,10 @@ def region_to_control_consumer(region_name, **kafka_options):
     from sentry.region_to_control.consumer import get_region_to_control_consumer
     from sentry.utils import metrics
 
+    consumer = get_region_to_control_consumer(**kafka_options)
+
     with metrics.global_tags(region_name=region_name):
-        get_region_to_control_consumer(**kafka_options).run()
+        run_processor_with_signals(consumer)
 
 
 @run.command("ingest-metrics-parallel-consumer")
@@ -637,14 +636,8 @@ def metrics_parallel_consumer(**options):
         indexer_profile=ingest_config, slicing_router=slicing_router, **options
     )
 
-    def handler(signum, frame):
-        streamer.signal_shutdown()
-
-    signal.signal(signal.SIGINT, handler)
-    signal.signal(signal.SIGTERM, handler)
-
     initialize_global_consumer_state(ingest_config)
-    streamer.run()
+    run_processor_with_signals(streamer)
 
 
 @run.command("billing-metrics-consumer")
@@ -655,7 +648,7 @@ def metrics_billing_consumer(**options):
     from sentry.ingest.billing_metrics_consumer import get_metrics_billing_consumer
 
     consumer = get_metrics_billing_consumer(**options)
-    consumer.run()
+    run_processor_with_signals(consumer)
 
 
 @run.command("ingest-profiles")
@@ -666,7 +659,8 @@ def metrics_billing_consumer(**options):
 def profiles_consumer(**options):
     from sentry.profiles.consumers import get_profiles_process_consumer
 
-    get_profiles_process_consumer(**options).run()
+    consumer = get_profiles_process_consumer(**options)
+    run_processor_with_signals(consumer)
 
 
 @run.command("ingest-replay-recordings")
@@ -679,7 +673,8 @@ def profiles_consumer(**options):
 def replays_recordings_consumer(**options):
     from sentry.replays.consumers import get_replays_recordings_consumer
 
-    get_replays_recordings_consumer(**options).run()
+    consumer = get_replays_recordings_consumer(**options)
+    run_processor_with_signals(consumer)
 
 
 @run.command("indexer-last-seen-updater")
@@ -702,11 +697,5 @@ def last_seen_updater(**options):
 
     consumer = get_last_seen_updater(ingest_config=ingest_config, **options)
 
-    def handler(signum, frame):
-        consumer.signal_shutdown()
-
-    signal.signal(signal.SIGINT, handler)
-    signal.signal(signal.SIGTERM, handler)
-
     with global_tags(_all_threads=True, pipeline=ingest_config.internal_metrics_tag):
-        consumer.run()
+        run_processor_with_signals(consumer)

+ 1 - 1
src/sentry/snuba/query_subscription_consumer.py

@@ -237,7 +237,7 @@ class QuerySubscriptionConsumer:
 
         self._reset_batch()
 
-    def shutdown(self) -> None:
+    def signal_shutdown(self) -> None:
         self.__shutdown_requested = True
 
     def handle_message(self, message: Message) -> None:

+ 8 - 3
src/sentry/utils/kafka.py

@@ -3,13 +3,14 @@ import signal
 
 from django.conf import settings
 
-from sentry.utils import metrics
 from sentry.utils.batching_kafka_consumer import BatchingKafkaConsumer
 
 logger = logging.getLogger(__name__)
 
 
 def create_batching_kafka_consumer(topic_names, worker, **options):
+    from sentry.utils import metrics
+
     # In some cases we want to override the configuration stored in settings from the command line
     force_topic = options.pop("force_topic", None)
     force_cluster = options.pop("force_cluster", None)
@@ -43,10 +44,14 @@ def create_batching_kafka_consumer(topic_names, worker, **options):
         **options,
     )
 
+    return consumer
+
+
+def run_processor_with_signals(processor):
     def handler(signum, frame):
-        consumer.signal_shutdown()
+        processor.signal_shutdown()
 
     signal.signal(signal.SIGINT, handler)
     signal.signal(signal.SIGTERM, handler)
 
-    return consumer
+    processor.run()

+ 1 - 1
tests/snuba/incidents/test_tasks.py

@@ -129,7 +129,7 @@ class HandleSnubaQueryUpdateTest(TestCase):
             # We want to just exit after the callback so that we can see the result of
             # processing.
             original_callback(*args, **kwargs)
-            consumer.shutdown()
+            consumer.signal_shutdown()
 
         subscriber_registry[INCIDENTS_SNUBA_SUBSCRIPTION_TYPE] = shutdown_callback
 

+ 4 - 4
tests/snuba/snuba/test_query_subscription_consumer.py

@@ -126,7 +126,7 @@ class QuerySubscriptionConsumerTest(TestCase, SnubaTestCase):
         producer.flush()
 
         consumer = QuerySubscriptionConsumer("hi", topic=self.topic, commit_batch_size=1)
-        mock_callback = Mock(side_effect=lambda *a, **k: consumer.shutdown())
+        mock_callback = Mock(side_effect=lambda *a, **k: consumer.signal_shutdown())
         register_subscriber(self.registration_key)(mock_callback)
         sub = self.create_subscription()
         consumer.run()
@@ -151,7 +151,7 @@ class QuerySubscriptionConsumerTest(TestCase, SnubaTestCase):
         producer.flush()
 
         consumer = QuerySubscriptionConsumer("hi", topic=self.topic, commit_batch_size=1)
-        mock_callback = Mock(side_effect=lambda *a, **k: consumer.shutdown())
+        mock_callback = Mock(side_effect=lambda *a, **k: consumer.signal_shutdown())
         register_subscriber(self.registration_key)(mock_callback)
         sub = self.create_subscription()
         consumer.run()
@@ -184,7 +184,7 @@ class QuerySubscriptionConsumerTest(TestCase, SnubaTestCase):
 
         def mock_callback(*args, **kwargs):
             if mock.call_count >= len(expected_calls):
-                consumer.shutdown()
+                consumer.signal_shutdown()
 
         mock = Mock(side_effect=mock_callback)
 
@@ -218,7 +218,7 @@ class QuerySubscriptionConsumerTest(TestCase, SnubaTestCase):
 
         def mock_callback(*args, **kwargs):
             time.sleep(0.1)
-            consumer.shutdown()
+            consumer.signal_shutdown()
 
         mock = Mock(side_effect=mock_callback)