Browse Source

fix(taskbroker): Add at_most_once support (#81048)

If a task is marked as at_most_once, then check the cache to see if the
task has already been seen
before. If it has, assume the task has already been executed and
continue. Otherwise store the task
ID and execute the task.

Depends on https://github.com/getsentry/sentry-protos/pull/66

---------

Co-authored-by: getsantry[bot] <66042841+getsantry[bot]@users.noreply.github.com>
Evan Hicks 3 months ago
parent
commit
4618b45444

+ 1 - 1
requirements-base.txt

@@ -68,7 +68,7 @@ rfc3986-validator>=0.1.1
 sentry-arroyo>=2.16.5
 sentry-kafka-schemas>=0.1.120
 sentry-ophio==1.0.0
-sentry-protos>=0.1.34
+sentry-protos>=0.1.37
 sentry-redis-tools>=0.1.7
 sentry-relay>=0.9.3
 sentry-sdk[http2]>=2.18.0

+ 2 - 2
requirements-dev-frozen.txt

@@ -146,7 +146,7 @@ pyflakes==3.2.0
 pyjwt==2.4.0
 pymemcache==4.0.0
 pysocks==1.7.1
-pytest==8.1.2
+pytest==8.3.3
 pytest-cov==4.0.0
 pytest-django==4.9.0
 pytest-fail-slow==0.3.0
@@ -188,7 +188,7 @@ sentry-forked-django-stubs==5.1.1.post1
 sentry-forked-djangorestframework-stubs==3.15.1.post2
 sentry-kafka-schemas==0.1.120
 sentry-ophio==1.0.0
-sentry-protos==0.1.34
+sentry-protos==0.1.37
 sentry-redis-tools==0.1.7
 sentry-relay==0.9.3
 sentry-sdk==2.18.0

+ 1 - 1
requirements-dev.txt

@@ -10,7 +10,7 @@ time-machine>=2.16.0
 honcho>=2
 openapi-core>=0.18.2
 openapi-pydantic>=0.4.0
-pytest>=8.1
+pytest>=8.3
 pytest-cov>=4.0.0
 pytest-django>=4.9.0
 pytest-fail-slow>=0.3.0

+ 1 - 1
requirements-frozen.txt

@@ -127,7 +127,7 @@ s3transfer==0.10.0
 sentry-arroyo==2.16.5
 sentry-kafka-schemas==0.1.120
 sentry-ophio==1.0.0
-sentry-protos==0.1.34
+sentry-protos==0.1.37
 sentry-redis-tools==0.1.7
 sentry-relay==0.9.3
 sentry-sdk==2.18.0

+ 2 - 0
src/sentry/taskworker/registry.py

@@ -71,6 +71,7 @@ class TaskNamespace:
         retry: Retry | None = None,
         expires: int | datetime.timedelta | None = None,
         processing_deadline_duration: int | datetime.timedelta | None = None,
+        at_most_once: bool = False,
     ) -> Callable[[Callable[P, R]], Task[P, R]]:
         """register a task, used as a decorator"""
 
@@ -84,6 +85,7 @@ class TaskNamespace:
                 processing_deadline_duration=(
                     processing_deadline_duration or self.default_processing_deadline_duration
                 ),
+                at_most_once=at_most_once,
             )
             # TODO(taskworker) tasks should be registered into the registry
             # so that we can ensure task names are globally unique

+ 6 - 1
src/sentry/taskworker/task.py

@@ -32,6 +32,7 @@ class Task(Generic[P, R]):
         retry: Retry | None = None,
         expires: int | datetime.timedelta | None = None,
         processing_deadline_duration: int | datetime.timedelta | None = None,
+        at_most_once: bool = False,
     ):
         # TODO(taskworker) Implement task execution deadlines
         self.name = name
@@ -42,6 +43,7 @@ class Task(Generic[P, R]):
         self._processing_deadline_duration = (
             processing_deadline_duration or DEFAULT_PROCESSING_DEADLINE
         )
+        self.at_most_once = at_most_once
         update_wrapper(self, func)
 
     @property
@@ -92,8 +94,11 @@ class Task(Generic[P, R]):
                 attempts=0,
                 kind="sentry.taskworker.retry.Retry",
                 discard_after_attempt=1,
+                at_most_once=self.at_most_once,
             )
-        return retry.initial_state()
+        retry_state = retry.initial_state()
+        retry_state.at_most_once = self.at_most_once
+        return retry_state
 
     def should_retry(self, state: RetryState, exc: Exception) -> bool:
         # No retry policy means no retries.

+ 28 - 6
src/sentry/taskworker/worker.py

@@ -11,6 +11,7 @@ from uuid import uuid4
 import grpc
 import orjson
 from django.conf import settings
+from django.core.cache import cache
 from sentry_protos.sentry.v1.taskworker_pb2 import (
     TASK_ACTIVATION_STATUS_COMPLETE,
     TASK_ACTIVATION_STATUS_FAILURE,
@@ -20,6 +21,7 @@ from sentry_protos.sentry.v1.taskworker_pb2 import (
 
 from sentry.taskworker.client import TaskworkerClient
 from sentry.taskworker.registry import taskregistry
+from sentry.taskworker.task import Task
 from sentry.utils import metrics
 
 logger = logging.getLogger("sentry.taskworker.worker")
@@ -35,6 +37,14 @@ def _process_activation(
     taskregistry.get(namespace).get(task_name)(*args, **kwargs)
 
 
+AT_MOST_ONCE_TIMEOUT = 60 * 60 * 24  # 1 day
+
+
+def get_at_most_once_key(namespace: str, taskname: str, task_id: str) -> str:
+    # tw:amo -> taskworker:at_most_once
+    return f"tw:amo:{namespace}:{taskname}:{task_id}"
+
+
 class TaskWorker:
     """
     A TaskWorker fetches tasks from a taskworker RPC host and handles executing task activations.
@@ -128,13 +138,13 @@ class TaskWorker:
         metrics.incr("taskworker.worker.get_task.success")
         return activation
 
-    def _known_task(self, activation: TaskActivation) -> bool:
+    def _get_known_task(self, activation: TaskActivation) -> Task[Any, Any] | None:
         if not taskregistry.contains(activation.namespace):
             logger.error(
                 "taskworker.invalid_namespace",
                 extra={"namespace": activation.namespace, "taskname": activation.taskname},
             )
-            return False
+            return None
 
         namespace = taskregistry.get(activation.namespace)
         if not namespace.contains(activation.taskname):
@@ -142,12 +152,13 @@ class TaskWorker:
                 "taskworker.invalid_taskname",
                 extra={"namespace": activation.namespace, "taskname": activation.taskname},
             )
-            return False
-        return True
+            return None
+        return namespace.get(activation.taskname)
 
     def process_task(self, activation: TaskActivation) -> TaskActivation | None:
         assert self._pool
-        if not self._known_task(activation):
+        task = self._get_known_task(activation)
+        if not task:
             metrics.incr(
                 "taskworker.worker.unknown_task",
                 tags={"namespace": activation.namespace, "taskname": activation.taskname},
@@ -158,7 +169,18 @@ class TaskWorker:
                 status=TASK_ACTIVATION_STATUS_FAILURE,
             )
 
-        # TODO(taskworker): Add at_most_once checks
+        if task.at_most_once:
+            key = get_at_most_once_key(activation.namespace, activation.taskname, activation.id)
+            if cache.add(key, "1", timeout=AT_MOST_ONCE_TIMEOUT):  # The key didn't exist
+                metrics.incr(
+                    "taskworker.task.at_most_once.executed", tags={"task": activation.taskname}
+                )
+            else:
+                metrics.incr(
+                    "taskworker.worker.at_most_once.skipped", tags={"task": activation.taskname}
+                )
+                return None
+
         processing_timeout = activation.processing_deadline_duration
         namespace = taskregistry.get(activation.namespace)
         next_state = TASK_ACTIVATION_STATUS_FAILURE

+ 31 - 0
tests/sentry/taskworker/test_task.py

@@ -110,6 +110,22 @@ def test_create_activation(task_namespace: TaskNamespace) -> None:
         processing_deadline_duration=30,
     )
 
+    at_most_once_task = Task(
+        name="test.at_most_once",
+        func=do_things,
+        namespace=task_namespace,
+        at_most_once=True,
+    )
+
+    retry = Retry(times=3, times_exceeded=LastAction.Deadletter)
+    retry_at_most_once_task = Task(
+        name="test.with_retry_at_most_once",
+        func=do_things,
+        namespace=task_namespace,
+        retry=retry,
+        at_most_once=True,
+    )
+
     # No retries will be made as there is no retry policy on the task or namespace.
     activation = no_retry_task.create_activation()
     assert activation.taskname == "test.no_retry"
@@ -136,6 +152,21 @@ def test_create_activation(task_namespace: TaskNamespace) -> None:
     assert activation.expires == 300
     assert activation.processing_deadline_duration == 30
 
+    activation = at_most_once_task.create_activation()
+    assert activation.taskname == "test.at_most_once"
+    assert activation.namespace == task_namespace.name
+    assert activation.retry_state
+    assert activation.retry_state.at_most_once is True
+
+    activation = retry_at_most_once_task.create_activation()
+    assert activation.taskname == "test.with_retry_at_most_once"
+    assert activation.namespace == task_namespace.name
+    assert activation.retry_state
+    assert activation.retry_state.at_most_once is True
+    assert activation.retry_state.attempts == 0
+    assert activation.retry_state.discard_after_attempt == 0
+    assert activation.retry_state.deadletter_after_attempt == 3
+
 
 def test_create_activation_parameters(task_namespace: TaskNamespace) -> None:
     @task_namespace.register(name="test.parameters")

+ 32 - 3
tests/sentry/taskworker/test_worker.py

@@ -34,6 +34,11 @@ def fail_task():
     raise ValueError("nope")
 
 
+@test_namespace.register(name="test.at_most_once", at_most_once=True)
+def at_most_once_task():
+    pass
+
+
 SIMPLE_TASK = TaskActivation(
     id="111",
     taskname="test.simple_task",
@@ -66,6 +71,14 @@ UNDEFINED_TASK = TaskActivation(
     processing_deadline_duration=1,
 )
 
+AT_MOST_ONCE_TASK = TaskActivation(
+    id="555",
+    taskname="test.at_most_once",
+    namespace="tests",
+    parameters='{"args": [], "kwargs": {}}',
+    processing_deadline_duration=1,
+)
+
 
 @override_settings(TASKWORKER_IMPORTS=("tests.sentry.taskworker.test_worker",))
 class TestTaskWorker(TestCase):
@@ -118,16 +131,32 @@ class TestTaskWorker(TestCase):
 
     def test_process_task_failure(self) -> None:
         taskworker = TaskWorker(rpc_host="127.0.0.1:50051", max_task_count=100)
-        with mock.patch.object(taskworker.client, "update_task") as mock_update_task:
-            mock_update_task.return_value = SIMPLE_TASK
+        with mock.patch.object(taskworker.client, "update_task") as mock_update:
+            mock_update.return_value = SIMPLE_TASK
             result = taskworker.process_task(FAIL_TASK)
 
-            mock_update_task.assert_called_with(
+            mock_update.assert_called_with(
                 task_id=FAIL_TASK.id, status=TASK_ACTIVATION_STATUS_FAILURE
             )
             assert result
             assert result.id == SIMPLE_TASK.id
 
+    def test_process_task_at_most_once(self) -> None:
+        taskworker = TaskWorker(rpc_host="127.0.0.1:50051", max_task_count=100)
+        with mock.patch.object(taskworker.client, "update_task") as mock_update:
+            mock_update.return_value = SIMPLE_TASK
+            result = taskworker.process_task(AT_MOST_ONCE_TASK)
+
+            mock_update.assert_called_with(
+                task_id=AT_MOST_ONCE_TASK.id, status=TASK_ACTIVATION_STATUS_COMPLETE
+            )
+        assert taskworker.process_task(AT_MOST_ONCE_TASK) is None
+        assert result
+        assert result.id == SIMPLE_TASK.id
+
+        result = taskworker.process_task(AT_MOST_ONCE_TASK)
+        assert result is None
+
     def test_start_max_task_count(self) -> None:
         taskworker = TaskWorker(rpc_host="127.0.0.1:50051", max_task_count=1)
         with mock.patch.object(taskworker, "client") as mock_client: