plugin.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from __future__ import annotations
  2. from typing import Callable
  3. from mypy.nodes import ARG_POS, TypeInfo
  4. from mypy.plugin import ClassDefContext, FunctionSigContext, MethodSigContext, Plugin
  5. from mypy.plugins.common import add_attribute_to_class
  6. from mypy.types import AnyType, CallableType, FunctionLike, Instance, NoneType, TypeOfAny, UnionType
  7. def _make_using_required_str(ctx: FunctionSigContext) -> CallableType:
  8. sig = ctx.default_signature
  9. using_arg = sig.argument_by_name("using")
  10. if using_arg is None or using_arg.pos is None:
  11. ctx.api.fail("The using parameter is required", ctx.context)
  12. return sig
  13. for kind in sig.arg_kinds[: using_arg.pos]:
  14. if kind != ARG_POS:
  15. ctx.api.fail("Expected using to be the first optional", ctx.context)
  16. return sig
  17. str_type = ctx.api.named_generic_type("builtins.str", [])
  18. arg_kinds = [*sig.arg_kinds[: using_arg.pos], ARG_POS, *sig.arg_kinds[using_arg.pos + 1 :]]
  19. arg_types = [*sig.arg_types[: using_arg.pos], str_type, *sig.arg_types[using_arg.pos + 1 :]]
  20. return sig.copy_modified(arg_kinds=arg_kinds, arg_types=arg_types)
  21. def replace_transaction_atomic_sig_callback(ctx: FunctionSigContext) -> CallableType:
  22. sig = ctx.default_signature
  23. if not sig.argument_by_name("using"):
  24. # No using arg in the signature, bail
  25. return sig
  26. # We care about context managers.
  27. if not isinstance(sig.ret_type, Instance):
  28. return sig
  29. return _make_using_required_str(ctx)
  30. def _choice_field_choices_sequence(ctx: FunctionSigContext) -> CallableType:
  31. sig = ctx.default_signature
  32. assert sig.arg_names[0] == "choices", sig
  33. any_type = AnyType(TypeOfAny.explicit)
  34. sequence_any = ctx.api.named_generic_type("typing.Sequence", [any_type])
  35. return sig.copy_modified(arg_types=[sequence_any, *sig.arg_types[1:]])
  36. _FUNCTION_SIGNATURE_HOOKS = {
  37. "django.db.transaction.atomic": replace_transaction_atomic_sig_callback,
  38. "django.db.transaction.get_connection": _make_using_required_str,
  39. "django.db.transaction.on_commit": _make_using_required_str,
  40. "django.db.transaction.set_rollback": _make_using_required_str,
  41. "rest_framework.fields.ChoiceField": _choice_field_choices_sequence,
  42. "rest_framework.fields.MultipleChoiceField": _choice_field_choices_sequence,
  43. }
  44. def field_descriptor_no_overloads(ctx: MethodSigContext) -> FunctionLike:
  45. # ignore the class / non-model instance descriptor overloads
  46. signature = ctx.default_signature
  47. # replace `def __get__(self, inst: Model, owner: Any) -> _GT:`
  48. # with `def __get__(self, inst: Any, owner: Any) -> _GT:`
  49. if str(signature.arg_types[0]) == "django.db.models.base.Model":
  50. return signature.copy_modified(arg_types=[signature.arg_types[1]] * 2)
  51. else:
  52. return signature
  53. def _adjust_http_request_members(ctx: ClassDefContext) -> None:
  54. if ctx.cls.name == "HttpRequest":
  55. # added by sentry.api.base and sentry.web.frontend.base
  56. # TODO: idk why I can't use the real type here :/
  57. add_attribute_to_class(ctx.api, ctx.cls, "access", AnyType(TypeOfAny.explicit))
  58. # added by sentry.middleware.auth
  59. # TODO: figure out how to get the real types here
  60. add_attribute_to_class(ctx.api, ctx.cls, "auth", AnyType(TypeOfAny.explicit))
  61. # added by csp.middleware.CSPMiddleware
  62. add_attribute_to_class(ctx.api, ctx.cls, "csp_nonce", ctx.api.named_type("builtins.str"))
  63. # added by sudo.middleware.SudoMiddleware
  64. # this is slightly better than a method returning bool for overriding
  65. returns_bool = CallableType(
  66. arg_types=[],
  67. arg_kinds=[],
  68. arg_names=[],
  69. ret_type=ctx.api.named_type("builtins.bool"),
  70. fallback=ctx.api.named_type("builtins.function"),
  71. name="is_sudo",
  72. )
  73. add_attribute_to_class(ctx.api, ctx.cls, "is_sudo", returns_bool)
  74. # added by sentry.middleware.subdomain
  75. subdomain_tp = UnionType([NoneType(), ctx.api.named_type("builtins.str")])
  76. add_attribute_to_class(ctx.api, ctx.cls, "subdomain", subdomain_tp)
  77. # added by sentry.middleware.superuser
  78. # TODO: figure out how to get the real types here
  79. add_attribute_to_class(ctx.api, ctx.cls, "superuser", AnyType(TypeOfAny.explicit))
  80. class SentryMypyPlugin(Plugin):
  81. def get_function_signature_hook(
  82. self, fullname: str
  83. ) -> Callable[[FunctionSigContext], FunctionLike] | None:
  84. return _FUNCTION_SIGNATURE_HOOKS.get(fullname)
  85. def get_method_signature_hook(
  86. self, fullname: str
  87. ) -> Callable[[MethodSigContext], FunctionLike] | None:
  88. if fullname == "django.db.models.fields.Field":
  89. return field_descriptor_no_overloads
  90. clsname, _, methodname = fullname.rpartition(".")
  91. if methodname != "__get__":
  92. return None
  93. clsinfo = self.lookup_fully_qualified(clsname)
  94. if clsinfo is None or not isinstance(clsinfo.node, TypeInfo):
  95. return None
  96. fieldinfo = self.lookup_fully_qualified("django.db.models.fields.Field")
  97. if fieldinfo is None:
  98. return None
  99. if fieldinfo.node in clsinfo.node.mro:
  100. return field_descriptor_no_overloads
  101. else:
  102. return None
  103. def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
  104. # XXX: this is a hack -- I don't know if there's a better callback to modify a class
  105. if fullname == "io.BytesIO":
  106. return _adjust_http_request_members
  107. else:
  108. return None
  109. def plugin(version: str) -> type[SentryMypyPlugin]:
  110. return SentryMypyPlugin