Browse Source

fix(typing): Add types for codeowners (#71601)

Snigdha Sharma 9 months ago
parent
commit
4d70405e5c
2 changed files with 13 additions and 8 deletions
  1. 0 1
      pyproject.toml
  2. 13 7
      src/sentry/tasks/base.py

+ 0 - 1
pyproject.toml

@@ -431,7 +431,6 @@ module = [
     "sentry.tagstore.types",
     "sentry.tagstore.types",
     "sentry.tasks.auth",
     "sentry.tasks.auth",
     "sentry.tasks.base",
     "sentry.tasks.base",
-    "sentry.tasks.codeowners.update_code_owners_schema",
     "sentry.tasks.integrations",
     "sentry.tasks.integrations",
     "sentry.tasks.integrations.migrate_issues",
     "sentry.tasks.integrations.migrate_issues",
     "sentry.tasks.integrations.migrate_repo",
     "sentry.tasks.integrations.migrate_repo",

+ 13 - 7
src/sentry/tasks/base.py

@@ -5,15 +5,18 @@ from collections.abc import Callable, Iterable
 from contextlib import contextmanager
 from contextlib import contextmanager
 from datetime import datetime
 from datetime import datetime
 from functools import wraps
 from functools import wraps
-from typing import Any
+from typing import Any, TypeVar
 
 
 from celery import current_task
 from celery import current_task
+from django.db.models import Model
 
 
 from sentry.celery import app
 from sentry.celery import app
 from sentry.silo.base import SiloLimit, SiloMode
 from sentry.silo.base import SiloLimit, SiloMode
 from sentry.utils import metrics
 from sentry.utils import metrics
 from sentry.utils.sdk import capture_exception, configure_scope
 from sentry.utils.sdk import capture_exception, configure_scope
 
 
+ModelT = TypeVar("ModelT", bound=Model)
+
 
 
 class TaskSiloLimit(SiloLimit):
 class TaskSiloLimit(SiloLimit):
     """
     """
@@ -65,12 +68,14 @@ def track_memory_usage(metric, **kwargs):
         metrics.distribution(metric, get_rss_usage() - before, unit="byte", **kwargs)
         metrics.distribution(metric, get_rss_usage() - before, unit="byte", **kwargs)
 
 
 
 
-def load_model_from_db(cls, instance_or_id, allow_cache=True):
+def load_model_from_db(
+    tp: type[ModelT], instance_or_id: ModelT | int, allow_cache: bool = True
+) -> ModelT:
     """Utility function to allow a task to transition to passing ids rather than model instances."""
     """Utility function to allow a task to transition to passing ids rather than model instances."""
     if isinstance(instance_or_id, int):
     if isinstance(instance_or_id, int):
-        if hasattr(cls.objects, "get_from_cache") and allow_cache:
-            return cls.objects.get_from_cache(pk=instance_or_id)
-        return cls.objects.get(pk=instance_or_id)
+        if hasattr(tp.objects, "get_from_cache") and allow_cache:
+            return tp.objects.get_from_cache(pk=instance_or_id)
+        return tp.objects.get(pk=instance_or_id)
     return instance_or_id
     return instance_or_id
 
 
 
 
@@ -111,8 +116,9 @@ def instrumented_task(name, stat_suffix=None, silo_mode=None, record_timing=Fals
                 scope.set_tag("task_name", name)
                 scope.set_tag("task_name", name)
                 scope.set_tag("transaction_id", transaction_id)
                 scope.set_tag("transaction_id", transaction_id)
 
 
-            with metrics.timer(key, instance=instance), track_memory_usage(
-                "jobs.memory_change", instance=instance
+            with (
+                metrics.timer(key, instance=instance),
+                track_memory_usage("jobs.memory_change", instance=instance),
             ):
             ):
                 result = func(*args, **kwargs)
                 result = func(*args, **kwargs)