plugin.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from __future__ import annotations
  2. from typing import Callable
  3. from mypy.nodes import ARG_POS, TypeInfo
  4. from mypy.plugin import FunctionSigContext, MethodSigContext, Plugin
  5. from mypy.types import CallableType, FunctionLike, Instance
  6. def _make_using_required_str(ctx: FunctionSigContext) -> CallableType:
  7. sig = ctx.default_signature
  8. using_arg = sig.argument_by_name("using")
  9. if using_arg is None or using_arg.pos is None:
  10. ctx.api.fail("The using parameter is required", ctx.context)
  11. return sig
  12. for kind in sig.arg_kinds[: using_arg.pos]:
  13. if kind != ARG_POS:
  14. ctx.api.fail("Expected using to be the first optional", ctx.context)
  15. return sig
  16. str_type = ctx.api.named_generic_type("builtins.str", [])
  17. arg_kinds = [*sig.arg_kinds[: using_arg.pos], ARG_POS, *sig.arg_kinds[using_arg.pos + 1 :]]
  18. arg_types = [*sig.arg_types[: using_arg.pos], str_type, *sig.arg_types[using_arg.pos + 1 :]]
  19. return sig.copy_modified(arg_kinds=arg_kinds, arg_types=arg_types)
  20. def replace_transaction_atomic_sig_callback(ctx: FunctionSigContext) -> CallableType:
  21. sig = ctx.default_signature
  22. if not sig.argument_by_name("using"):
  23. # No using arg in the signature, bail
  24. return sig
  25. # We care about context managers.
  26. if not isinstance(sig.ret_type, Instance):
  27. return sig
  28. return _make_using_required_str(ctx)
  29. _FUNCTION_SIGNATURE_HOOKS = {
  30. "django.db.transaction.atomic": replace_transaction_atomic_sig_callback,
  31. "django.db.transaction.get_connection": _make_using_required_str,
  32. "django.db.transaction.on_commit": _make_using_required_str,
  33. "django.db.transaction.set_rollback": _make_using_required_str,
  34. }
  35. def field_descriptor_no_overloads(ctx: MethodSigContext) -> FunctionLike:
  36. # ignore the class / non-model instance descriptor overloads
  37. signature = ctx.default_signature
  38. # replace `def __get__(self, inst: Model, owner: Any) -> _GT:`
  39. # with `def __get__(self, inst: Any, owner: Any) -> _GT:`
  40. if str(signature.arg_types[0]) == "django.db.models.base.Model":
  41. return signature.copy_modified(arg_types=[signature.arg_types[1]] * 2)
  42. else:
  43. return signature
  44. class SentryMypyPlugin(Plugin):
  45. def get_function_signature_hook(
  46. self, fullname: str
  47. ) -> Callable[[FunctionSigContext], FunctionLike] | None:
  48. return _FUNCTION_SIGNATURE_HOOKS.get(fullname)
  49. def get_method_signature_hook(
  50. self, fullname: str
  51. ) -> Callable[[MethodSigContext], FunctionLike] | None:
  52. if fullname == "django.db.models.fields.Field":
  53. return field_descriptor_no_overloads
  54. clsname, _, methodname = fullname.rpartition(".")
  55. if methodname != "__get__":
  56. return None
  57. clsinfo = self.lookup_fully_qualified(clsname)
  58. if clsinfo is None or not isinstance(clsinfo.node, TypeInfo):
  59. return None
  60. fieldinfo = self.lookup_fully_qualified("django.db.models.fields.Field")
  61. if fieldinfo is None:
  62. return None
  63. if fieldinfo.node in clsinfo.node.mro:
  64. return field_descriptor_no_overloads
  65. else:
  66. return None
  67. def plugin(version: str) -> type[SentryMypyPlugin]:
  68. return SentryMypyPlugin