source.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import ast
  2. import inspect
  3. import textwrap
  4. import tokenize
  5. import types
  6. import warnings
  7. from bisect import bisect_right
  8. from typing import Iterable
  9. from typing import Iterator
  10. from typing import List
  11. from typing import Optional
  12. from typing import overload
  13. from typing import Tuple
  14. from typing import Union
  15. class Source:
  16. """An immutable object holding a source code fragment.
  17. When using Source(...), the source lines are deindented.
  18. """
  19. def __init__(self, obj: object = None) -> None:
  20. if not obj:
  21. self.lines: List[str] = []
  22. elif isinstance(obj, Source):
  23. self.lines = obj.lines
  24. elif isinstance(obj, (tuple, list)):
  25. self.lines = deindent(x.rstrip("\n") for x in obj)
  26. elif isinstance(obj, str):
  27. self.lines = deindent(obj.split("\n"))
  28. else:
  29. try:
  30. rawcode = getrawcode(obj)
  31. src = inspect.getsource(rawcode)
  32. except TypeError:
  33. src = inspect.getsource(obj) # type: ignore[arg-type]
  34. self.lines = deindent(src.split("\n"))
  35. def __eq__(self, other: object) -> bool:
  36. if not isinstance(other, Source):
  37. return NotImplemented
  38. return self.lines == other.lines
  39. # Ignore type because of https://github.com/python/mypy/issues/4266.
  40. __hash__ = None # type: ignore
  41. @overload
  42. def __getitem__(self, key: int) -> str:
  43. ...
  44. @overload
  45. def __getitem__(self, key: slice) -> "Source":
  46. ...
  47. def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]:
  48. if isinstance(key, int):
  49. return self.lines[key]
  50. else:
  51. if key.step not in (None, 1):
  52. raise IndexError("cannot slice a Source with a step")
  53. newsource = Source()
  54. newsource.lines = self.lines[key.start : key.stop]
  55. return newsource
  56. def __iter__(self) -> Iterator[str]:
  57. return iter(self.lines)
  58. def __len__(self) -> int:
  59. return len(self.lines)
  60. def strip(self) -> "Source":
  61. """Return new Source object with trailing and leading blank lines removed."""
  62. start, end = 0, len(self)
  63. while start < end and not self.lines[start].strip():
  64. start += 1
  65. while end > start and not self.lines[end - 1].strip():
  66. end -= 1
  67. source = Source()
  68. source.lines[:] = self.lines[start:end]
  69. return source
  70. def indent(self, indent: str = " " * 4) -> "Source":
  71. """Return a copy of the source object with all lines indented by the
  72. given indent-string."""
  73. newsource = Source()
  74. newsource.lines = [(indent + line) for line in self.lines]
  75. return newsource
  76. def getstatement(self, lineno: int) -> "Source":
  77. """Return Source statement which contains the given linenumber
  78. (counted from 0)."""
  79. start, end = self.getstatementrange(lineno)
  80. return self[start:end]
  81. def getstatementrange(self, lineno: int) -> Tuple[int, int]:
  82. """Return (start, end) tuple which spans the minimal statement region
  83. which containing the given lineno."""
  84. if not (0 <= lineno < len(self)):
  85. raise IndexError("lineno out of range")
  86. ast, start, end = getstatementrange_ast(lineno, self)
  87. return start, end
  88. def deindent(self) -> "Source":
  89. """Return a new Source object deindented."""
  90. newsource = Source()
  91. newsource.lines[:] = deindent(self.lines)
  92. return newsource
  93. def __str__(self) -> str:
  94. return "\n".join(self.lines)
  95. #
  96. # helper functions
  97. #
  98. def findsource(obj) -> Tuple[Optional[Source], int]:
  99. try:
  100. sourcelines, lineno = inspect.findsource(obj)
  101. except Exception:
  102. return None, -1
  103. source = Source()
  104. source.lines = [line.rstrip() for line in sourcelines]
  105. return source, lineno
  106. def getrawcode(obj: object, trycall: bool = True) -> types.CodeType:
  107. """Return code object for given function."""
  108. try:
  109. return obj.__code__ # type: ignore[attr-defined,no-any-return]
  110. except AttributeError:
  111. pass
  112. if trycall:
  113. call = getattr(obj, "__call__", None)
  114. if call and not isinstance(obj, type):
  115. return getrawcode(call, trycall=False)
  116. raise TypeError(f"could not get code object for {obj!r}")
  117. def deindent(lines: Iterable[str]) -> List[str]:
  118. return textwrap.dedent("\n".join(lines)).splitlines()
  119. def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]:
  120. # Flatten all statements and except handlers into one lineno-list.
  121. # AST's line numbers start indexing at 1.
  122. values: List[int] = []
  123. for x in ast.walk(node):
  124. if isinstance(x, (ast.stmt, ast.ExceptHandler)):
  125. # Before Python 3.8, the lineno of a decorated class or function pointed at the decorator.
  126. # Since Python 3.8, the lineno points to the class/def, so need to include the decorators.
  127. if isinstance(x, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
  128. for d in x.decorator_list:
  129. values.append(d.lineno - 1)
  130. values.append(x.lineno - 1)
  131. for name in ("finalbody", "orelse"):
  132. val: Optional[List[ast.stmt]] = getattr(x, name, None)
  133. if val:
  134. # Treat the finally/orelse part as its own statement.
  135. values.append(val[0].lineno - 1 - 1)
  136. values.sort()
  137. insert_index = bisect_right(values, lineno)
  138. start = values[insert_index - 1]
  139. if insert_index >= len(values):
  140. end = None
  141. else:
  142. end = values[insert_index]
  143. return start, end
  144. def getstatementrange_ast(
  145. lineno: int,
  146. source: Source,
  147. assertion: bool = False,
  148. astnode: Optional[ast.AST] = None,
  149. ) -> Tuple[ast.AST, int, int]:
  150. if astnode is None:
  151. content = str(source)
  152. # See #4260:
  153. # Don't produce duplicate warnings when compiling source to find AST.
  154. with warnings.catch_warnings():
  155. warnings.simplefilter("ignore")
  156. astnode = ast.parse(content, "source", "exec")
  157. start, end = get_statement_startend2(lineno, astnode)
  158. # We need to correct the end:
  159. # - ast-parsing strips comments
  160. # - there might be empty lines
  161. # - we might have lesser indented code blocks at the end
  162. if end is None:
  163. end = len(source.lines)
  164. if end > start + 1:
  165. # Make sure we don't span differently indented code blocks
  166. # by using the BlockFinder helper used which inspect.getsource() uses itself.
  167. block_finder = inspect.BlockFinder()
  168. # If we start with an indented line, put blockfinder to "started" mode.
  169. block_finder.started = source.lines[start][0].isspace()
  170. it = ((x + "\n") for x in source.lines[start:end])
  171. try:
  172. for tok in tokenize.generate_tokens(lambda: next(it)):
  173. block_finder.tokeneater(*tok)
  174. except (inspect.EndOfBlock, IndentationError):
  175. end = block_finder.last + start
  176. except Exception:
  177. pass
  178. # The end might still point to a comment or empty line, correct it.
  179. while end:
  180. line = source.lines[end - 1].lstrip()
  181. if line.startswith("#") or not line:
  182. end -= 1
  183. else:
  184. break
  185. return astnode, start, end