refactoring.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. """
  2. THIS is not in active development, please check
  3. https://github.com/davidhalter/jedi/issues/667 first before editing.
  4. Introduce some basic refactoring functions to |jedi|. This module is still in a
  5. very early development stage and needs much testing and improvement.
  6. .. warning:: I won't do too much here, but if anyone wants to step in, please
  7. do. Refactoring is none of my priorities
  8. It uses the |jedi| `API <api.html>`_ and supports currently the
  9. following functions (sometimes bug-prone):
  10. - rename
  11. - extract variable
  12. - inline variable
  13. """
  14. import difflib
  15. from parso import python_bytes_to_unicode, split_lines
  16. from jedi.evaluate import helpers
  17. class Refactoring(object):
  18. def __init__(self, change_dct):
  19. """
  20. :param change_dct: dict(old_path=(new_path, old_lines, new_lines))
  21. """
  22. self.change_dct = change_dct
  23. def old_files(self):
  24. dct = {}
  25. for old_path, (new_path, old_l, new_l) in self.change_dct.items():
  26. dct[old_path] = '\n'.join(old_l)
  27. return dct
  28. def new_files(self):
  29. dct = {}
  30. for old_path, (new_path, old_l, new_l) in self.change_dct.items():
  31. dct[new_path] = '\n'.join(new_l)
  32. return dct
  33. def diff(self):
  34. texts = []
  35. for old_path, (new_path, old_l, new_l) in self.change_dct.items():
  36. if old_path:
  37. udiff = difflib.unified_diff(old_l, new_l)
  38. else:
  39. udiff = difflib.unified_diff(old_l, new_l, old_path, new_path)
  40. texts.append('\n'.join(udiff))
  41. return '\n'.join(texts)
  42. def rename(script, new_name):
  43. """ The `args` / `kwargs` params are the same as in `api.Script`.
  44. :param new_name: The new name of the script.
  45. :param script: The source Script object.
  46. :return: list of changed lines/changed files
  47. """
  48. return Refactoring(_rename(script.usages(), new_name))
  49. def _rename(names, replace_str):
  50. """ For both rename and inline. """
  51. order = sorted(names, key=lambda x: (x.module_path, x.line, x.column),
  52. reverse=True)
  53. def process(path, old_lines, new_lines):
  54. if new_lines is not None: # goto next file, save last
  55. dct[path] = path, old_lines, new_lines
  56. dct = {}
  57. current_path = object()
  58. new_lines = old_lines = None
  59. for name in order:
  60. if name.in_builtin_module():
  61. continue
  62. if current_path != name.module_path:
  63. current_path = name.module_path
  64. process(current_path, old_lines, new_lines)
  65. if current_path is not None:
  66. # None means take the source that is a normal param.
  67. with open(current_path) as f:
  68. source = f.read()
  69. new_lines = split_lines(python_bytes_to_unicode(source))
  70. old_lines = new_lines[:]
  71. nr, indent = name.line, name.column
  72. line = new_lines[nr - 1]
  73. new_lines[nr - 1] = line[:indent] + replace_str + \
  74. line[indent + len(name.name):]
  75. process(current_path, old_lines, new_lines)
  76. return dct
  77. def extract(script, new_name):
  78. """ The `args` / `kwargs` params are the same as in `api.Script`.
  79. :param operation: The refactoring operation to execute.
  80. :type operation: str
  81. :type source: str
  82. :return: list of changed lines/changed files
  83. """
  84. new_lines = split_lines(python_bytes_to_unicode(script.source))
  85. old_lines = new_lines[:]
  86. user_stmt = script._parser.user_stmt()
  87. # TODO care for multi-line extracts
  88. dct = {}
  89. if user_stmt:
  90. pos = script._pos
  91. line_index = pos[0] - 1
  92. # Be careful here. 'array_for_pos' does not exist in 'helpers'.
  93. arr, index = helpers.array_for_pos(user_stmt, pos)
  94. if arr is not None:
  95. start_pos = arr[index].start_pos
  96. end_pos = arr[index].end_pos
  97. # take full line if the start line is different from end line
  98. e = end_pos[1] if end_pos[0] == start_pos[0] else None
  99. start_line = new_lines[start_pos[0] - 1]
  100. text = start_line[start_pos[1]:e]
  101. for l in range(start_pos[0], end_pos[0] - 1):
  102. text += '\n' + str(l)
  103. if e is None:
  104. end_line = new_lines[end_pos[0] - 1]
  105. text += '\n' + end_line[:end_pos[1]]
  106. # remove code from new lines
  107. t = text.lstrip()
  108. del_start = start_pos[1] + len(text) - len(t)
  109. text = t.rstrip()
  110. del_end = len(t) - len(text)
  111. if e is None:
  112. new_lines[end_pos[0] - 1] = end_line[end_pos[1] - del_end:]
  113. e = len(start_line)
  114. else:
  115. e = e - del_end
  116. start_line = start_line[:del_start] + new_name + start_line[e:]
  117. new_lines[start_pos[0] - 1] = start_line
  118. new_lines[start_pos[0]:end_pos[0] - 1] = []
  119. # add parentheses in multi-line case
  120. open_brackets = ['(', '[', '{']
  121. close_brackets = [')', ']', '}']
  122. if '\n' in text and not (text[0] in open_brackets and text[-1] ==
  123. close_brackets[open_brackets.index(text[0])]):
  124. text = '(%s)' % text
  125. # add new line before statement
  126. indent = user_stmt.start_pos[1]
  127. new = "%s%s = %s" % (' ' * indent, new_name, text)
  128. new_lines.insert(line_index, new)
  129. dct[script.path] = script.path, old_lines, new_lines
  130. return Refactoring(dct)
  131. def inline(script):
  132. """
  133. :type script: api.Script
  134. """
  135. new_lines = split_lines(python_bytes_to_unicode(script.source))
  136. dct = {}
  137. definitions = script.goto_assignments()
  138. assert len(definitions) == 1
  139. stmt = definitions[0]._definition
  140. usages = script.usages()
  141. inlines = [r for r in usages
  142. if not stmt.start_pos <= (r.line, r.column) <= stmt.end_pos]
  143. inlines = sorted(inlines, key=lambda x: (x.module_path, x.line, x.column),
  144. reverse=True)
  145. expression_list = stmt.expression_list()
  146. # don't allow multi-line refactorings for now.
  147. assert stmt.start_pos[0] == stmt.end_pos[0]
  148. index = stmt.start_pos[0] - 1
  149. line = new_lines[index]
  150. replace_str = line[expression_list[0].start_pos[1]:stmt.end_pos[1] + 1]
  151. replace_str = replace_str.strip()
  152. # tuples need parentheses
  153. if expression_list and isinstance(expression_list[0], pr.Array):
  154. arr = expression_list[0]
  155. if replace_str[0] not in ['(', '[', '{'] and len(arr) > 1:
  156. replace_str = '(%s)' % replace_str
  157. # if it's the only assignment, remove the statement
  158. if len(stmt.get_defined_names()) == 1:
  159. line = line[:stmt.start_pos[1]] + line[stmt.end_pos[1]:]
  160. dct = _rename(inlines, replace_str)
  161. # remove the empty line
  162. new_lines = dct[script.path][2]
  163. if line.strip():
  164. new_lines[index] = line
  165. else:
  166. new_lines.pop(index)
  167. return Refactoring(dct)