Просмотр исходного кода

feat(typing) First pass at narrowing types on django APIs (#53157)

We will be enforcing usage of `using` in transaction.atomic() and
get_connection() soon via monkeypatching. This mypy plugin allows us to
reflect the new type requirements in mypy allowing developers to get
feedback earlier than runtime. 

Alternative to #52969

---------

Co-authored-by: anthony sottile <anthony.sottile@sentry.io>
Mark Story 1 год назад
Родитель
Сommit
11c23edb1f

+ 1 - 1
pyproject.toml

@@ -50,7 +50,7 @@ looponfailroots = ["src", "tests"]
 [tool.mypy]
 python_version = "3.8"
 mypy_path = ["fixtures/stubs-for-mypy"]
-plugins = ["pydantic.mypy", "mypy_django_plugin.main"]
+plugins = ["pydantic.mypy", "mypy_django_plugin.main", "tools.mypy_helpers.plugin"]
 files = ["."]
 exclude = ["^.venv/", "^self-hosted/"]
 

+ 1 - 1
src/sentry/testutils/hybrid_cloud.py

@@ -162,7 +162,7 @@ class SimulatedTransactionWatermarks(threading.local):
         self, using: str | None = None, connection: BaseDatabaseWrapper | None = None
     ) -> int:
         if connection is None:
-            connection = transaction.get_connection(using)
+            connection = transaction.get_connection(using or "default")
         return max(self.get_transaction_depth(connection) - self.state.get(connection.alias, 0), 0)
 
     def connections_above_watermark(self) -> Set[str]:

+ 145 - 0
tests/tools/mypy_helpers/test_plugin.py

@@ -0,0 +1,145 @@
+from __future__ import annotations
+
+import pathlib
+import subprocess
+import sys
+from typing import Callable
+
+import pytest
+
+
+@pytest.fixture
+def call_mypy(tmp_path: pathlib.Path) -> Callable[[str], tuple[int, str]]:
+    cfg = """\
+[tool.mypy]
+plugins = ["tools.mypy_helpers.plugin"]
+"""
+    cfg_path = tmp_path.joinpath("mypy.toml")
+    cfg_path.write_text(cfg)
+
+    def _call_mypy(contents: str) -> tuple[int, str]:
+        ret = subprocess.run(
+            (
+                *(sys.executable, "-m", "mypy"),
+                *("--config", str(cfg_path)),
+                *("-c", contents),
+            ),
+            capture_output=True,
+            encoding="UTF-8",
+        )
+        return ret.returncode, ret.stdout
+
+    return _call_mypy
+
+
+def test_invalid_get_connection_call(call_mypy):
+    code = """
+from django.db.transaction import get_connection
+
+with get_connection() as cursor:
+    cursor.execute("SELECT 1")
+"""
+    expected = """\
+<string>:4: error: Missing positional argument "using" in call to "get_connection"  [call-arg]
+Found 1 error in 1 file (checked 1 source file)
+"""
+    ret, out = call_mypy(code)
+    assert ret
+    assert out == expected
+
+
+def test_ok_get_connection(call_mypy):
+    code = """
+from django.db.transaction import get_connection
+
+with get_connection("default") as cursor:
+    cursor.execute("SELECT 1")
+"""
+    ret, out = call_mypy(code)
+    assert ret == 0
+
+
+def test_invalid_transaction_atomic(call_mypy):
+    code = """
+from django.db import transaction
+
+with transaction.atomic():
+    value = 10 / 2
+"""
+    expected = """\
+<string>:4: error: All overload variants of "atomic" require at least one argument  [call-overload]
+<string>:4: note: Possible overload variants:
+<string>:4: note:     def [_C] atomic(using: _C) -> _C
+<string>:4: note:     def atomic(using: str, savepoint: bool = ..., durable: bool = ...) -> Atomic
+Found 1 error in 1 file (checked 1 source file)
+"""
+    ret, out = call_mypy(code)
+    assert ret
+    assert out == expected
+
+
+def test_ok_transaction_atomic(call_mypy):
+    code = """
+from django.db import transaction
+
+with transaction.atomic("default"):
+    value = 10 / 2
+"""
+    ret, _ = call_mypy(code)
+    assert ret == 0
+
+
+def test_ok_transaction_on_commit(call_mypy):
+    code = """
+from django.db import transaction
+
+def completed():
+    pass
+
+transaction.on_commit(completed, "default")
+"""
+    ret, _ = call_mypy(code)
+    assert ret == 0
+
+
+def test_invalid_transaction_on_commit(call_mypy):
+    code = """
+from django.db import transaction
+
+def completed():
+    pass
+
+transaction.on_commit(completed)
+"""
+    expected = """\
+<string>:7: error: Missing positional argument "using" in call to "on_commit"  [call-arg]
+Found 1 error in 1 file (checked 1 source file)
+"""
+    ret, out = call_mypy(code)
+    assert ret
+    assert out == expected
+
+
+def test_invalid_transaction_set_rollback(call_mypy):
+    code = """
+from django.db import transaction
+
+transaction.set_rollback(True)
+"""
+    expected = """\
+<string>:4: error: Missing positional argument "using" in call to "set_rollback"  [call-arg]
+Found 1 error in 1 file (checked 1 source file)
+"""
+    ret, out = call_mypy(code)
+    assert ret
+    assert out == expected
+
+
+def test_ok_transaction_set_rollback(call_mypy):
+    code = """
+from django.db import transaction
+
+transaction.set_rollback(True, "default")
+"""
+    ret, _ = call_mypy(code)
+    assert ret == 0

+ 78 - 0
tools/mypy_helpers/plugin.py

@@ -0,0 +1,78 @@
+from __future__ import annotations
+
+from typing import Callable
+
+from mypy.nodes import ARG_POS
+from mypy.plugin import FunctionSigContext, Plugin
+from mypy.types import CallableType, FunctionLike, Instance
+
+
+def replace_transaction_atomic_sig_callback(ctx: FunctionSigContext) -> CallableType:
+    signature = ctx.default_signature
+
+    using_arg = signature.argument_by_name("using")
+    if not using_arg:
+        # No using arg in the signature, bail
+        return signature
+
+    # We care about context managers.
+    ret_type = signature.ret_type
+    if not isinstance(ret_type, Instance):
+        return signature
+
+    # Replace the type and remove the default value of using.
+    str_type = ctx.api.named_generic_type("builtins.str", [])
+
+    arg_types = signature.arg_types[1:]
+    arg_kinds = signature.arg_kinds[1:]
+
+    return signature.copy_modified(
+        arg_kinds=[ARG_POS, *arg_kinds],
+        arg_types=[str_type, *arg_types],
+    )
+
+
+def replace_get_connection_sig_callback(ctx: FunctionSigContext) -> CallableType:
+    signature = ctx.default_signature
+    using_arg = signature.argument_by_name("using")
+    if not using_arg:
+        ctx.api.fail("The using parameter is required", ctx.context)
+
+    str_type = ctx.api.named_generic_type("builtins.str", [])
+
+    return signature.copy_modified(arg_kinds=[ARG_POS], arg_types=[str_type])
+
+
+def replace_trailing_using_sig_callback(ctx: FunctionSigContext) -> CallableType:
+    signature = ctx.default_signature
+    using_arg = signature.argument_by_name("using")
+    if not using_arg:
+        ctx.api.fail("The using parameter is required", ctx.context)
+
+    # Update the parameter type to be required and str
+    str_type = ctx.api.named_generic_type("builtins.str", [])
+    arg_kinds = signature.arg_kinds[0:-1]
+    arg_types = signature.arg_types[0:-1]
+
+    return signature.copy_modified(
+        arg_kinds=[*arg_kinds, ARG_POS], arg_types=[*arg_types, str_type]
+    )
+
+
+class SentryMypyPlugin(Plugin):
+    def get_function_signature_hook(
+        self, fullname: str
+    ) -> Callable[[FunctionSigContext], FunctionLike] | None:
+        if fullname == "django.db.transaction.atomic":
+            return replace_transaction_atomic_sig_callback
+        if fullname == "django.db.transaction.get_connection":
+            return replace_get_connection_sig_callback
+        if fullname == "django.db.transaction.on_commit":
+            return replace_trailing_using_sig_callback
+        if fullname == "django.db.transaction.set_rollback":
+            return replace_trailing_using_sig_callback
+        return None
+
+
+def plugin(version: str) -> type[SentryMypyPlugin]:
+    return SentryMypyPlugin