test_plugin.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. from __future__ import annotations
  2. import os.path
  3. import subprocess
  4. import sys
  5. import tempfile
  6. def call_mypy(src: str, *, plugins: list[str] | None = None) -> tuple[int, str]:
  7. if plugins is None:
  8. plugins = ["tools.mypy_helpers.plugin"]
  9. with tempfile.TemporaryDirectory() as tmpdir:
  10. cfg = os.path.join(tmpdir, "mypy.toml")
  11. with open(cfg, "w") as f:
  12. f.write(f"[tool.mypy]\nplugins = {plugins!r}\n")
  13. ret = subprocess.run(
  14. (
  15. *(sys.executable, "-m", "mypy"),
  16. *("--config", cfg),
  17. *("-c", src),
  18. ),
  19. capture_output=True,
  20. encoding="UTF-8",
  21. )
  22. return ret.returncode, ret.stdout
  23. def test_invalid_get_connection_call():
  24. code = """
  25. from django.db.transaction import get_connection
  26. with get_connection() as cursor:
  27. cursor.execute("SELECT 1")
  28. """
  29. expected = """\
  30. <string>:4: error: Missing positional argument "using" in call to "get_connection" [call-arg]
  31. Found 1 error in 1 file (checked 1 source file)
  32. """
  33. ret, out = call_mypy(code)
  34. assert ret
  35. assert out == expected
  36. def test_ok_get_connection():
  37. code = """
  38. from django.db.transaction import get_connection
  39. with get_connection("default") as cursor:
  40. cursor.execute("SELECT 1")
  41. """
  42. ret, out = call_mypy(code)
  43. assert ret == 0
  44. def test_invalid_transaction_atomic():
  45. code = """
  46. from django.db import transaction
  47. with transaction.atomic():
  48. value = 10 / 2
  49. """
  50. expected = """\
  51. <string>:4: error: All overload variants of "atomic" require at least one argument [call-overload]
  52. <string>:4: note: Possible overload variants:
  53. <string>:4: note: def [_C] atomic(using: _C) -> _C
  54. <string>:4: note: def atomic(using: str, savepoint: bool = ..., durable: bool = ...) -> Atomic
  55. Found 1 error in 1 file (checked 1 source file)
  56. """
  57. ret, out = call_mypy(code)
  58. assert ret
  59. assert out == expected
  60. def test_ok_transaction_atomic():
  61. code = """
  62. from django.db import transaction
  63. with transaction.atomic("default"):
  64. value = 10 / 2
  65. """
  66. ret, _ = call_mypy(code)
  67. assert ret == 0
  68. def test_ok_transaction_on_commit():
  69. code = """
  70. from django.db import transaction
  71. def completed():
  72. pass
  73. transaction.on_commit(completed, "default")
  74. """
  75. ret, _ = call_mypy(code)
  76. assert ret == 0
  77. def test_invalid_transaction_on_commit():
  78. code = """
  79. from django.db import transaction
  80. def completed():
  81. pass
  82. transaction.on_commit(completed)
  83. """
  84. expected = """\
  85. <string>:7: error: Missing positional argument "using" in call to "on_commit" [call-arg]
  86. Found 1 error in 1 file (checked 1 source file)
  87. """
  88. ret, out = call_mypy(code)
  89. assert ret
  90. assert out == expected
  91. def test_invalid_transaction_set_rollback():
  92. code = """
  93. from django.db import transaction
  94. transaction.set_rollback(True)
  95. """
  96. expected = """\
  97. <string>:4: error: Missing positional argument "using" in call to "set_rollback" [call-arg]
  98. Found 1 error in 1 file (checked 1 source file)
  99. """
  100. ret, out = call_mypy(code)
  101. assert ret
  102. assert out == expected
  103. def test_ok_transaction_set_rollback():
  104. code = """
  105. from django.db import transaction
  106. transaction.set_rollback(True, "default")
  107. """
  108. ret, _ = call_mypy(code)
  109. assert ret == 0
  110. def test_field_descriptor_hack():
  111. code = """\
  112. from __future__ import annotations
  113. from django.db import models
  114. class M1(models.Model):
  115. f: models.Field[int, int] = models.IntegerField()
  116. class C:
  117. f: int
  118. def f(inst: C | M1 | M2) -> int:
  119. return inst.f
  120. # should also work with field subclasses
  121. class F(models.Field[int, int]):
  122. pass
  123. class M2(models.Model):
  124. f = F()
  125. def g(inst: C | M2) -> int:
  126. return inst.f
  127. """
  128. # should be an error with default plugins
  129. # mypy may fix this at some point hopefully: python/mypy#5570
  130. ret, out = call_mypy(code, plugins=[])
  131. assert ret
  132. assert (
  133. out
  134. == """\
  135. <string>:12: error: Incompatible return value type (got "Union[int, Field[int, int]]", expected "int") [return-value]
  136. <string>:22: error: Incompatible return value type (got "Union[int, F]", expected "int") [return-value]
  137. Found 2 errors in 1 file (checked 1 source file)
  138. """
  139. )
  140. # should be fixed with our special plugin
  141. ret, _ = call_mypy(code)
  142. assert ret == 0