123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- from __future__ import annotations
- from typing import Callable
- from mypy.nodes import ARG_POS
- from mypy.plugin import FunctionSigContext, Plugin
- from mypy.types import CallableType, FunctionLike, Instance
- def replace_transaction_atomic_sig_callback(ctx: FunctionSigContext) -> CallableType:
- signature = ctx.default_signature
- using_arg = signature.argument_by_name("using")
- if not using_arg:
- # No using arg in the signature, bail
- return signature
- # We care about context managers.
- ret_type = signature.ret_type
- if not isinstance(ret_type, Instance):
- return signature
- # Replace the type and remove the default value of using.
- str_type = ctx.api.named_generic_type("builtins.str", [])
- arg_types = signature.arg_types[1:]
- arg_kinds = signature.arg_kinds[1:]
- return signature.copy_modified(
- arg_kinds=[ARG_POS, *arg_kinds],
- arg_types=[str_type, *arg_types],
- )
- def replace_get_connection_sig_callback(ctx: FunctionSigContext) -> CallableType:
- signature = ctx.default_signature
- using_arg = signature.argument_by_name("using")
- if not using_arg:
- ctx.api.fail("The using parameter is required", ctx.context)
- str_type = ctx.api.named_generic_type("builtins.str", [])
- return signature.copy_modified(arg_kinds=[ARG_POS], arg_types=[str_type])
- def replace_trailing_using_sig_callback(ctx: FunctionSigContext) -> CallableType:
- signature = ctx.default_signature
- using_arg = signature.argument_by_name("using")
- if not using_arg:
- ctx.api.fail("The using parameter is required", ctx.context)
- # Update the parameter type to be required and str
- str_type = ctx.api.named_generic_type("builtins.str", [])
- arg_kinds = signature.arg_kinds[0:-1]
- arg_types = signature.arg_types[0:-1]
- return signature.copy_modified(
- arg_kinds=[*arg_kinds, ARG_POS], arg_types=[*arg_types, str_type]
- )
- class SentryMypyPlugin(Plugin):
- def get_function_signature_hook(
- self, fullname: str
- ) -> Callable[[FunctionSigContext], FunctionLike] | None:
- if fullname == "django.db.transaction.atomic":
- return replace_transaction_atomic_sig_callback
- if fullname == "django.db.transaction.get_connection":
- return replace_get_connection_sig_callback
- if fullname == "django.db.transaction.on_commit":
- return replace_trailing_using_sig_callback
- if fullname == "django.db.transaction.set_rollback":
- return replace_trailing_using_sig_callback
- return None
- def plugin(version: str) -> type[SentryMypyPlugin]:
- return SentryMypyPlugin
|