source.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. # -*- coding: utf-8 -*-
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5. import ast
  6. import inspect
  7. import linecache
  8. import sys
  9. import textwrap
  10. import tokenize
  11. import warnings
  12. from ast import PyCF_ONLY_AST as _AST_FLAG
  13. from bisect import bisect_right
  14. import py
  15. import six
  16. class Source(object):
  17. """ an immutable object holding a source code fragment,
  18. possibly deindenting it.
  19. """
  20. _compilecounter = 0
  21. def __init__(self, *parts, **kwargs):
  22. self.lines = lines = []
  23. de = kwargs.get("deindent", True)
  24. for part in parts:
  25. if not part:
  26. partlines = []
  27. elif isinstance(part, Source):
  28. partlines = part.lines
  29. elif isinstance(part, (tuple, list)):
  30. partlines = [x.rstrip("\n") for x in part]
  31. elif isinstance(part, six.string_types):
  32. partlines = part.split("\n")
  33. else:
  34. partlines = getsource(part, deindent=de).lines
  35. if de:
  36. partlines = deindent(partlines)
  37. lines.extend(partlines)
  38. def __eq__(self, other):
  39. try:
  40. return self.lines == other.lines
  41. except AttributeError:
  42. if isinstance(other, str):
  43. return str(self) == other
  44. return False
  45. __hash__ = None
  46. def __getitem__(self, key):
  47. if isinstance(key, int):
  48. return self.lines[key]
  49. else:
  50. if key.step not in (None, 1):
  51. raise IndexError("cannot slice a Source with a step")
  52. newsource = Source()
  53. newsource.lines = self.lines[key.start : key.stop]
  54. return newsource
  55. def __len__(self):
  56. return len(self.lines)
  57. def strip(self):
  58. """ return new source object with trailing
  59. and leading blank lines removed.
  60. """
  61. start, end = 0, len(self)
  62. while start < end and not self.lines[start].strip():
  63. start += 1
  64. while end > start and not self.lines[end - 1].strip():
  65. end -= 1
  66. source = Source()
  67. source.lines[:] = self.lines[start:end]
  68. return source
  69. def putaround(self, before="", after="", indent=" " * 4):
  70. """ return a copy of the source object with
  71. 'before' and 'after' wrapped around it.
  72. """
  73. before = Source(before)
  74. after = Source(after)
  75. newsource = Source()
  76. lines = [(indent + line) for line in self.lines]
  77. newsource.lines = before.lines + lines + after.lines
  78. return newsource
  79. def indent(self, indent=" " * 4):
  80. """ return a copy of the source object with
  81. all lines indented by the given indent-string.
  82. """
  83. newsource = Source()
  84. newsource.lines = [(indent + line) for line in self.lines]
  85. return newsource
  86. def getstatement(self, lineno):
  87. """ return Source statement which contains the
  88. given linenumber (counted from 0).
  89. """
  90. start, end = self.getstatementrange(lineno)
  91. return self[start:end]
  92. def getstatementrange(self, lineno):
  93. """ return (start, end) tuple which spans the minimal
  94. statement region which containing the given lineno.
  95. """
  96. if not (0 <= lineno < len(self)):
  97. raise IndexError("lineno out of range")
  98. ast, start, end = getstatementrange_ast(lineno, self)
  99. return start, end
  100. def deindent(self):
  101. """return a new source object deindented."""
  102. newsource = Source()
  103. newsource.lines[:] = deindent(self.lines)
  104. return newsource
  105. def isparseable(self, deindent=True):
  106. """ return True if source is parseable, heuristically
  107. deindenting it by default.
  108. """
  109. if deindent:
  110. source = str(self.deindent())
  111. else:
  112. source = str(self)
  113. try:
  114. ast.parse(source)
  115. except (SyntaxError, ValueError, TypeError):
  116. return False
  117. else:
  118. return True
  119. def __str__(self):
  120. return "\n".join(self.lines)
  121. def compile(
  122. self, filename=None, mode="exec", flag=0, dont_inherit=0, _genframe=None
  123. ):
  124. """ return compiled code object. if filename is None
  125. invent an artificial filename which displays
  126. the source/line position of the caller frame.
  127. """
  128. if not filename or py.path.local(filename).check(file=0):
  129. if _genframe is None:
  130. _genframe = sys._getframe(1) # the caller
  131. fn, lineno = _genframe.f_code.co_filename, _genframe.f_lineno
  132. base = "<%d-codegen " % self._compilecounter
  133. self.__class__._compilecounter += 1
  134. if not filename:
  135. filename = base + "%s:%d>" % (fn, lineno)
  136. else:
  137. filename = base + "%r %s:%d>" % (filename, fn, lineno)
  138. source = "\n".join(self.lines) + "\n"
  139. try:
  140. co = compile(source, filename, mode, flag)
  141. except SyntaxError:
  142. ex = sys.exc_info()[1]
  143. # re-represent syntax errors from parsing python strings
  144. msglines = self.lines[: ex.lineno]
  145. if ex.offset:
  146. msglines.append(" " * ex.offset + "^")
  147. msglines.append("(code was compiled probably from here: %s)" % filename)
  148. newex = SyntaxError("\n".join(msglines))
  149. newex.offset = ex.offset
  150. newex.lineno = ex.lineno
  151. newex.text = ex.text
  152. raise newex
  153. else:
  154. if flag & _AST_FLAG:
  155. return co
  156. lines = [(x + "\n") for x in self.lines]
  157. linecache.cache[filename] = (1, None, lines, filename)
  158. return co
  159. #
  160. # public API shortcut functions
  161. #
  162. def compile_(source, filename=None, mode="exec", flags=0, dont_inherit=0):
  163. """ compile the given source to a raw code object,
  164. and maintain an internal cache which allows later
  165. retrieval of the source code for the code object
  166. and any recursively created code objects.
  167. """
  168. if isinstance(source, ast.AST):
  169. # XXX should Source support having AST?
  170. return compile(source, filename, mode, flags, dont_inherit)
  171. _genframe = sys._getframe(1) # the caller
  172. s = Source(source)
  173. co = s.compile(filename, mode, flags, _genframe=_genframe)
  174. return co
  175. def getfslineno(obj):
  176. """ Return source location (path, lineno) for the given object.
  177. If the source cannot be determined return ("", -1).
  178. The line number is 0-based.
  179. """
  180. from .code import Code
  181. try:
  182. code = Code(obj)
  183. except TypeError:
  184. try:
  185. fn = inspect.getsourcefile(obj) or inspect.getfile(obj)
  186. except TypeError:
  187. return "", -1
  188. fspath = fn and py.path.local(fn) or None
  189. lineno = -1
  190. if fspath:
  191. try:
  192. _, lineno = findsource(obj)
  193. except IOError:
  194. pass
  195. else:
  196. fspath = code.path
  197. lineno = code.firstlineno
  198. assert isinstance(lineno, int)
  199. return fspath, lineno
  200. #
  201. # helper functions
  202. #
  203. def findsource(obj):
  204. try:
  205. sourcelines, lineno = inspect.findsource(obj)
  206. except Exception:
  207. return None, -1
  208. source = Source()
  209. source.lines = [line.rstrip() for line in sourcelines]
  210. return source, lineno
  211. def getsource(obj, **kwargs):
  212. from .code import getrawcode
  213. obj = getrawcode(obj)
  214. try:
  215. strsrc = inspect.getsource(obj)
  216. except IndentationError:
  217. strsrc = '"Buggy python version consider upgrading, cannot get source"'
  218. assert isinstance(strsrc, str)
  219. return Source(strsrc, **kwargs)
  220. def deindent(lines):
  221. return textwrap.dedent("\n".join(lines)).splitlines()
  222. def get_statement_startend2(lineno, node):
  223. import ast
  224. # flatten all statements and except handlers into one lineno-list
  225. # AST's line numbers start indexing at 1
  226. values = []
  227. for x in ast.walk(node):
  228. if isinstance(x, (ast.stmt, ast.ExceptHandler)):
  229. values.append(x.lineno - 1)
  230. for name in ("finalbody", "orelse"):
  231. val = getattr(x, name, None)
  232. if val:
  233. # treat the finally/orelse part as its own statement
  234. values.append(val[0].lineno - 1 - 1)
  235. values.sort()
  236. insert_index = bisect_right(values, lineno)
  237. start = values[insert_index - 1]
  238. if insert_index >= len(values):
  239. end = None
  240. else:
  241. end = values[insert_index]
  242. return start, end
  243. def getstatementrange_ast(lineno, source, assertion=False, astnode=None):
  244. if astnode is None:
  245. content = str(source)
  246. # See #4260:
  247. # don't produce duplicate warnings when compiling source to find ast
  248. with warnings.catch_warnings():
  249. warnings.simplefilter("ignore")
  250. astnode = compile(content, "source", "exec", _AST_FLAG)
  251. start, end = get_statement_startend2(lineno, astnode)
  252. # we need to correct the end:
  253. # - ast-parsing strips comments
  254. # - there might be empty lines
  255. # - we might have lesser indented code blocks at the end
  256. if end is None:
  257. end = len(source.lines)
  258. if end > start + 1:
  259. # make sure we don't span differently indented code blocks
  260. # by using the BlockFinder helper used which inspect.getsource() uses itself
  261. block_finder = inspect.BlockFinder()
  262. # if we start with an indented line, put blockfinder to "started" mode
  263. block_finder.started = source.lines[start][0].isspace()
  264. it = ((x + "\n") for x in source.lines[start:end])
  265. try:
  266. for tok in tokenize.generate_tokens(lambda: next(it)):
  267. block_finder.tokeneater(*tok)
  268. except (inspect.EndOfBlock, IndentationError):
  269. end = block_finder.last + start
  270. except Exception:
  271. pass
  272. # the end might still point to a comment or empty line, correct it
  273. while end:
  274. line = source.lines[end - 1].lstrip()
  275. if line.startswith("#") or not line:
  276. end -= 1
  277. else:
  278. break
  279. return astnode, start, end