plugin.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from __future__ import annotations
  2. from typing import Callable
  3. from mypy.nodes import ARG_POS, TypeInfo
  4. from mypy.plugin import FunctionSigContext, MethodSigContext, Plugin
  5. from mypy.types import AnyType, CallableType, FunctionLike, Instance, TypeOfAny
  6. def _make_using_required_str(ctx: FunctionSigContext) -> CallableType:
  7. sig = ctx.default_signature
  8. using_arg = sig.argument_by_name("using")
  9. if using_arg is None or using_arg.pos is None:
  10. ctx.api.fail("The using parameter is required", ctx.context)
  11. return sig
  12. for kind in sig.arg_kinds[: using_arg.pos]:
  13. if kind != ARG_POS:
  14. ctx.api.fail("Expected using to be the first optional", ctx.context)
  15. return sig
  16. str_type = ctx.api.named_generic_type("builtins.str", [])
  17. arg_kinds = [*sig.arg_kinds[: using_arg.pos], ARG_POS, *sig.arg_kinds[using_arg.pos + 1 :]]
  18. arg_types = [*sig.arg_types[: using_arg.pos], str_type, *sig.arg_types[using_arg.pos + 1 :]]
  19. return sig.copy_modified(arg_kinds=arg_kinds, arg_types=arg_types)
  20. def replace_transaction_atomic_sig_callback(ctx: FunctionSigContext) -> CallableType:
  21. sig = ctx.default_signature
  22. if not sig.argument_by_name("using"):
  23. # No using arg in the signature, bail
  24. return sig
  25. # We care about context managers.
  26. if not isinstance(sig.ret_type, Instance):
  27. return sig
  28. return _make_using_required_str(ctx)
  29. def _choice_field_choices_sequence(ctx: FunctionSigContext) -> CallableType:
  30. sig = ctx.default_signature
  31. assert sig.arg_names[0] == "choices", sig
  32. any_type = AnyType(TypeOfAny.explicit)
  33. sequence_any = ctx.api.named_generic_type("typing.Sequence", [any_type])
  34. return sig.copy_modified(arg_types=[sequence_any, *sig.arg_types[1:]])
  35. _FUNCTION_SIGNATURE_HOOKS = {
  36. "django.db.transaction.atomic": replace_transaction_atomic_sig_callback,
  37. "django.db.transaction.get_connection": _make_using_required_str,
  38. "django.db.transaction.on_commit": _make_using_required_str,
  39. "django.db.transaction.set_rollback": _make_using_required_str,
  40. "rest_framework.fields.ChoiceField": _choice_field_choices_sequence,
  41. "rest_framework.fields.MultipleChoiceField": _choice_field_choices_sequence,
  42. }
  43. def field_descriptor_no_overloads(ctx: MethodSigContext) -> FunctionLike:
  44. # ignore the class / non-model instance descriptor overloads
  45. signature = ctx.default_signature
  46. # replace `def __get__(self, inst: Model, owner: Any) -> _GT:`
  47. # with `def __get__(self, inst: Any, owner: Any) -> _GT:`
  48. if str(signature.arg_types[0]) == "django.db.models.base.Model":
  49. return signature.copy_modified(arg_types=[signature.arg_types[1]] * 2)
  50. else:
  51. return signature
  52. class SentryMypyPlugin(Plugin):
  53. def get_function_signature_hook(
  54. self, fullname: str
  55. ) -> Callable[[FunctionSigContext], FunctionLike] | None:
  56. return _FUNCTION_SIGNATURE_HOOKS.get(fullname)
  57. def get_method_signature_hook(
  58. self, fullname: str
  59. ) -> Callable[[MethodSigContext], FunctionLike] | None:
  60. if fullname == "django.db.models.fields.Field":
  61. return field_descriptor_no_overloads
  62. clsname, _, methodname = fullname.rpartition(".")
  63. if methodname != "__get__":
  64. return None
  65. clsinfo = self.lookup_fully_qualified(clsname)
  66. if clsinfo is None or not isinstance(clsinfo.node, TypeInfo):
  67. return None
  68. fieldinfo = self.lookup_fully_qualified("django.db.models.fields.Field")
  69. if fieldinfo is None:
  70. return None
  71. if fieldinfo.node in clsinfo.node.mro:
  72. return field_descriptor_no_overloads
  73. else:
  74. return None
  75. def plugin(version: str) -> type[SentryMypyPlugin]:
  76. return SentryMypyPlugin