plugin.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from __future__ import annotations
  2. from typing import Callable
  3. from mypy.nodes import ARG_POS
  4. from mypy.plugin import FunctionSigContext, Plugin
  5. from mypy.types import CallableType, FunctionLike, Instance
  6. def replace_transaction_atomic_sig_callback(ctx: FunctionSigContext) -> CallableType:
  7. signature = ctx.default_signature
  8. using_arg = signature.argument_by_name("using")
  9. if not using_arg:
  10. # No using arg in the signature, bail
  11. return signature
  12. # We care about context managers.
  13. ret_type = signature.ret_type
  14. if not isinstance(ret_type, Instance):
  15. return signature
  16. # Replace the type and remove the default value of using.
  17. str_type = ctx.api.named_generic_type("builtins.str", [])
  18. arg_types = signature.arg_types[1:]
  19. arg_kinds = signature.arg_kinds[1:]
  20. return signature.copy_modified(
  21. arg_kinds=[ARG_POS, *arg_kinds],
  22. arg_types=[str_type, *arg_types],
  23. )
  24. def replace_get_connection_sig_callback(ctx: FunctionSigContext) -> CallableType:
  25. signature = ctx.default_signature
  26. using_arg = signature.argument_by_name("using")
  27. if not using_arg:
  28. ctx.api.fail("The using parameter is required", ctx.context)
  29. str_type = ctx.api.named_generic_type("builtins.str", [])
  30. return signature.copy_modified(arg_kinds=[ARG_POS], arg_types=[str_type])
  31. def replace_trailing_using_sig_callback(ctx: FunctionSigContext) -> CallableType:
  32. signature = ctx.default_signature
  33. using_arg = signature.argument_by_name("using")
  34. if not using_arg:
  35. ctx.api.fail("The using parameter is required", ctx.context)
  36. # Update the parameter type to be required and str
  37. str_type = ctx.api.named_generic_type("builtins.str", [])
  38. arg_kinds = signature.arg_kinds[0:-1]
  39. arg_types = signature.arg_types[0:-1]
  40. return signature.copy_modified(
  41. arg_kinds=[*arg_kinds, ARG_POS], arg_types=[*arg_types, str_type]
  42. )
  43. class SentryMypyPlugin(Plugin):
  44. def get_function_signature_hook(
  45. self, fullname: str
  46. ) -> Callable[[FunctionSigContext], FunctionLike] | None:
  47. if fullname == "django.db.transaction.atomic":
  48. return replace_transaction_atomic_sig_callback
  49. if fullname == "django.db.transaction.get_connection":
  50. return replace_get_connection_sig_callback
  51. if fullname == "django.db.transaction.on_commit":
  52. return replace_trailing_using_sig_callback
  53. if fullname == "django.db.transaction.set_rollback":
  54. return replace_trailing_using_sig_callback
  55. return None
  56. def plugin(version: str) -> type[SentryMypyPlugin]:
  57. return SentryMypyPlugin