test_plugin.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. from __future__ import annotations
  2. import os.path
  3. import subprocess
  4. import sys
  5. import tempfile
  6. import pytest
  7. def call_mypy(src: str, *, plugins: list[str] | None = None) -> tuple[int, str]:
  8. if plugins is None:
  9. plugins = ["tools.mypy_helpers.plugin"]
  10. with tempfile.TemporaryDirectory() as tmpdir:
  11. cfg = os.path.join(tmpdir, "mypy.toml")
  12. with open(cfg, "w") as f:
  13. f.write(f"[tool.mypy]\nplugins = {plugins!r}\n")
  14. ret = subprocess.run(
  15. (
  16. *(sys.executable, "-m", "mypy"),
  17. *("--config", cfg),
  18. *("-c", src),
  19. ),
  20. capture_output=True,
  21. encoding="UTF-8",
  22. )
  23. assert not ret.stderr
  24. return ret.returncode, ret.stdout
  25. def test_invalid_get_connection_call():
  26. code = """
  27. from django.db.transaction import get_connection
  28. with get_connection() as cursor:
  29. cursor.execute("SELECT 1")
  30. """
  31. expected = """\
  32. <string>:4: error: Missing positional argument "using" in call to "get_connection" [call-arg]
  33. Found 1 error in 1 file (checked 1 source file)
  34. """
  35. ret, out = call_mypy(code)
  36. assert ret
  37. assert out == expected
  38. def test_ok_get_connection():
  39. code = """
  40. from django.db.transaction import get_connection
  41. with get_connection("default") as cursor:
  42. cursor.execute("SELECT 1")
  43. """
  44. ret, out = call_mypy(code)
  45. assert ret == 0
  46. def test_invalid_transaction_atomic():
  47. code = """
  48. from django.db import transaction
  49. with transaction.atomic():
  50. value = 10 / 2
  51. """
  52. expected = """\
  53. <string>:4: error: All overload variants of "atomic" require at least one argument [call-overload]
  54. <string>:4: note: Possible overload variants:
  55. <string>:4: note: def [_C] atomic(using: _C) -> _C
  56. <string>:4: note: def atomic(using: str, savepoint: bool = ..., durable: bool = ...) -> Atomic
  57. Found 1 error in 1 file (checked 1 source file)
  58. """
  59. ret, out = call_mypy(code)
  60. assert ret
  61. assert out == expected
  62. def test_ok_transaction_atomic():
  63. code = """
  64. from django.db import transaction
  65. with transaction.atomic("default"):
  66. value = 10 / 2
  67. """
  68. ret, _ = call_mypy(code)
  69. assert ret == 0
  70. def test_ok_transaction_on_commit():
  71. code = """
  72. from django.db import transaction
  73. def completed():
  74. pass
  75. transaction.on_commit(completed, "default")
  76. """
  77. ret, _ = call_mypy(code)
  78. assert ret == 0
  79. def test_invalid_transaction_on_commit():
  80. code = """
  81. from django.db import transaction
  82. def completed():
  83. pass
  84. transaction.on_commit(completed)
  85. """
  86. expected = """\
  87. <string>:7: error: Missing positional argument "using" in call to "on_commit" [call-arg]
  88. Found 1 error in 1 file (checked 1 source file)
  89. """
  90. ret, out = call_mypy(code)
  91. assert ret
  92. assert out == expected
  93. def test_invalid_transaction_set_rollback():
  94. code = """
  95. from django.db import transaction
  96. transaction.set_rollback(True)
  97. """
  98. expected = """\
  99. <string>:4: error: Missing positional argument "using" in call to "set_rollback" [call-arg]
  100. Found 1 error in 1 file (checked 1 source file)
  101. """
  102. ret, out = call_mypy(code)
  103. assert ret
  104. assert out == expected
  105. def test_ok_transaction_set_rollback():
  106. code = """
  107. from django.db import transaction
  108. transaction.set_rollback(True, "default")
  109. """
  110. ret, _ = call_mypy(code)
  111. assert ret == 0
  112. def test_field_descriptor_hack():
  113. code = """\
  114. from __future__ import annotations
  115. from django.db import models
  116. class M1(models.Model):
  117. f: models.Field[int, int] = models.IntegerField()
  118. class C:
  119. f: int
  120. def f(inst: C | M1 | M2) -> int:
  121. return inst.f
  122. # should also work with field subclasses
  123. class F(models.Field[int, int]):
  124. pass
  125. class M2(models.Model):
  126. f = F()
  127. def g(inst: C | M2) -> int:
  128. return inst.f
  129. """
  130. # should be an error with default plugins
  131. # mypy may fix this at some point hopefully: python/mypy#5570
  132. ret, out = call_mypy(code, plugins=[])
  133. assert ret
  134. assert (
  135. out
  136. == """\
  137. <string>:12: error: Incompatible return value type (got "Union[int, Field[int, int]]", expected "int") [return-value]
  138. <string>:22: error: Incompatible return value type (got "Union[int, F]", expected "int") [return-value]
  139. Found 2 errors in 1 file (checked 1 source file)
  140. """
  141. )
  142. # should be fixed with our special plugin
  143. ret, _ = call_mypy(code)
  144. assert ret == 0
  145. def test_rest_framework_serializers_require_sequence():
  146. code = """\
  147. from __future__ import annotations
  148. from rest_framework import serializers
  149. SOME_FSET = frozenset(('a', 'b', 'c'))
  150. SOME_SET = {'a', 'b', 'c'}
  151. SOME_TUPLE = ('a', 'b', 'c')
  152. SOME_LIST = ['a', 'b', 'c']
  153. # ok
  154. serializers.ChoiceField(choices=SOME_TUPLE)
  155. serializers.ChoiceField(choices=SOME_LIST)
  156. serializers.MultipleChoiceField(choices=SOME_TUPLE)
  157. serializers.MultipleChoiceField(choices=SOME_LIST)
  158. # not ok
  159. serializers.ChoiceField(choices=SOME_SET)
  160. serializers.ChoiceField(choices=SOME_FSET)
  161. serializers.MultipleChoiceField(choices=SOME_SET)
  162. serializers.MultipleChoiceField(choices=SOME_FSET)
  163. """
  164. expected = """\
  165. <string>:16: error: Argument "choices" to "ChoiceField" has incompatible type "Set[str]"; expected "Sequence[Any]" [arg-type]
  166. <string>:17: error: Argument "choices" to "ChoiceField" has incompatible type "FrozenSet[str]"; expected "Sequence[Any]" [arg-type]
  167. <string>:18: error: Argument "choices" to "MultipleChoiceField" has incompatible type "Set[str]"; expected "Sequence[Any]" [arg-type]
  168. <string>:19: error: Argument "choices" to "MultipleChoiceField" has incompatible type "FrozenSet[str]"; expected "Sequence[Any]" [arg-type]
  169. Found 4 errors in 1 file (checked 1 source file)
  170. """
  171. # should be ok without plugins
  172. ret, _ = call_mypy(code, plugins=[])
  173. assert ret == 0
  174. # should be an error with plugins
  175. ret, out = call_mypy(code)
  176. assert ret
  177. assert out == expected
  178. @pytest.mark.parametrize(
  179. "attr",
  180. (
  181. pytest.param("access", id="access from sentry.api.base"),
  182. pytest.param("csp_nonce", id="csp_nonce from csp.middleware"),
  183. pytest.param("is_sudo", id="is_sudo from sudo.middleware"),
  184. pytest.param("subdomain", id="subdomain from sentry.middleware.subdomain"),
  185. ),
  186. )
  187. def test_added_http_request_attribute(attr: str) -> None:
  188. src = f"""\
  189. from django.http.request import HttpRequest
  190. x: HttpRequest
  191. x.{attr}
  192. """
  193. ret, out = call_mypy(src, plugins=[])
  194. assert ret
  195. ret, out = call_mypy(src)
  196. assert ret == 0, (ret, out)