RefactoringCallbacks.cpp 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. //===--- RefactoringCallbacks.cpp - Structural query framework ------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. //
  10. //===----------------------------------------------------------------------===//
  11. #include "clang/Tooling/RefactoringCallbacks.h"
  12. #include "clang/ASTMatchers/ASTMatchFinder.h"
  13. #include "clang/Basic/SourceLocation.h"
  14. #include "clang/Lex/Lexer.h"
  15. using llvm::StringError;
  16. using llvm::make_error;
  17. namespace clang {
  18. namespace tooling {
  19. RefactoringCallback::RefactoringCallback() {}
  20. tooling::Replacements &RefactoringCallback::getReplacements() {
  21. return Replace;
  22. }
  23. ASTMatchRefactorer::ASTMatchRefactorer(
  24. std::map<std::string, Replacements> &FileToReplaces)
  25. : FileToReplaces(FileToReplaces) {}
  26. void ASTMatchRefactorer::addDynamicMatcher(
  27. const ast_matchers::internal::DynTypedMatcher &Matcher,
  28. RefactoringCallback *Callback) {
  29. MatchFinder.addDynamicMatcher(Matcher, Callback);
  30. Callbacks.push_back(Callback);
  31. }
  32. class RefactoringASTConsumer : public ASTConsumer {
  33. public:
  34. explicit RefactoringASTConsumer(ASTMatchRefactorer &Refactoring)
  35. : Refactoring(Refactoring) {}
  36. void HandleTranslationUnit(ASTContext &Context) override {
  37. // The ASTMatchRefactorer is re-used between translation units.
  38. // Clear the matchers so that each Replacement is only emitted once.
  39. for (const auto &Callback : Refactoring.Callbacks) {
  40. Callback->getReplacements().clear();
  41. }
  42. Refactoring.MatchFinder.matchAST(Context);
  43. for (const auto &Callback : Refactoring.Callbacks) {
  44. for (const auto &Replacement : Callback->getReplacements()) {
  45. llvm::Error Err =
  46. Refactoring.FileToReplaces[std::string(Replacement.getFilePath())]
  47. .add(Replacement);
  48. if (Err) {
  49. llvm::errs() << "Skipping replacement " << Replacement.toString()
  50. << " due to this error:\n"
  51. << toString(std::move(Err)) << "\n";
  52. }
  53. }
  54. }
  55. }
  56. private:
  57. ASTMatchRefactorer &Refactoring;
  58. };
  59. std::unique_ptr<ASTConsumer> ASTMatchRefactorer::newASTConsumer() {
  60. return std::make_unique<RefactoringASTConsumer>(*this);
  61. }
  62. static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From,
  63. StringRef Text) {
  64. return tooling::Replacement(
  65. Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text);
  66. }
  67. static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From,
  68. const Stmt &To) {
  69. return replaceStmtWithText(
  70. Sources, From,
  71. Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()),
  72. Sources, LangOptions()));
  73. }
  74. ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText)
  75. : FromId(std::string(FromId)), ToText(std::string(ToText)) {}
  76. void ReplaceStmtWithText::run(
  77. const ast_matchers::MatchFinder::MatchResult &Result) {
  78. if (const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId)) {
  79. auto Err = Replace.add(tooling::Replacement(
  80. *Result.SourceManager,
  81. CharSourceRange::getTokenRange(FromMatch->getSourceRange()), ToText));
  82. // FIXME: better error handling. For now, just print error message in the
  83. // release version.
  84. if (Err) {
  85. llvm::errs() << llvm::toString(std::move(Err)) << "\n";
  86. assert(false);
  87. }
  88. }
  89. }
  90. ReplaceStmtWithStmt::ReplaceStmtWithStmt(StringRef FromId, StringRef ToId)
  91. : FromId(std::string(FromId)), ToId(std::string(ToId)) {}
  92. void ReplaceStmtWithStmt::run(
  93. const ast_matchers::MatchFinder::MatchResult &Result) {
  94. const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId);
  95. const Stmt *ToMatch = Result.Nodes.getNodeAs<Stmt>(ToId);
  96. if (FromMatch && ToMatch) {
  97. auto Err = Replace.add(
  98. replaceStmtWithStmt(*Result.SourceManager, *FromMatch, *ToMatch));
  99. // FIXME: better error handling. For now, just print error message in the
  100. // release version.
  101. if (Err) {
  102. llvm::errs() << llvm::toString(std::move(Err)) << "\n";
  103. assert(false);
  104. }
  105. }
  106. }
  107. ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id,
  108. bool PickTrueBranch)
  109. : Id(std::string(Id)), PickTrueBranch(PickTrueBranch) {}
  110. void ReplaceIfStmtWithItsBody::run(
  111. const ast_matchers::MatchFinder::MatchResult &Result) {
  112. if (const IfStmt *Node = Result.Nodes.getNodeAs<IfStmt>(Id)) {
  113. const Stmt *Body = PickTrueBranch ? Node->getThen() : Node->getElse();
  114. if (Body) {
  115. auto Err =
  116. Replace.add(replaceStmtWithStmt(*Result.SourceManager, *Node, *Body));
  117. // FIXME: better error handling. For now, just print error message in the
  118. // release version.
  119. if (Err) {
  120. llvm::errs() << llvm::toString(std::move(Err)) << "\n";
  121. assert(false);
  122. }
  123. } else if (!PickTrueBranch) {
  124. // If we want to use the 'else'-branch, but it doesn't exist, delete
  125. // the whole 'if'.
  126. auto Err =
  127. Replace.add(replaceStmtWithText(*Result.SourceManager, *Node, ""));
  128. // FIXME: better error handling. For now, just print error message in the
  129. // release version.
  130. if (Err) {
  131. llvm::errs() << llvm::toString(std::move(Err)) << "\n";
  132. assert(false);
  133. }
  134. }
  135. }
  136. }
  137. ReplaceNodeWithTemplate::ReplaceNodeWithTemplate(
  138. llvm::StringRef FromId, std::vector<TemplateElement> Template)
  139. : FromId(std::string(FromId)), Template(std::move(Template)) {}
  140. llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>>
  141. ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) {
  142. std::vector<TemplateElement> ParsedTemplate;
  143. for (size_t Index = 0; Index < ToTemplate.size();) {
  144. if (ToTemplate[Index] == '$') {
  145. if (ToTemplate.substr(Index, 2) == "$$") {
  146. Index += 2;
  147. ParsedTemplate.push_back(
  148. TemplateElement{TemplateElement::Literal, "$"});
  149. } else if (ToTemplate.substr(Index, 2) == "${") {
  150. size_t EndOfIdentifier = ToTemplate.find("}", Index);
  151. if (EndOfIdentifier == std::string::npos) {
  152. return make_error<StringError>(
  153. "Unterminated ${...} in replacement template near " +
  154. ToTemplate.substr(Index),
  155. llvm::inconvertibleErrorCode());
  156. }
  157. std::string SourceNodeName = std::string(
  158. ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2));
  159. ParsedTemplate.push_back(
  160. TemplateElement{TemplateElement::Identifier, SourceNodeName});
  161. Index = EndOfIdentifier + 1;
  162. } else {
  163. return make_error<StringError>(
  164. "Invalid $ in replacement template near " +
  165. ToTemplate.substr(Index),
  166. llvm::inconvertibleErrorCode());
  167. }
  168. } else {
  169. size_t NextIndex = ToTemplate.find('$', Index + 1);
  170. ParsedTemplate.push_back(TemplateElement{
  171. TemplateElement::Literal,
  172. std::string(ToTemplate.substr(Index, NextIndex - Index))});
  173. Index = NextIndex;
  174. }
  175. }
  176. return std::unique_ptr<ReplaceNodeWithTemplate>(
  177. new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate)));
  178. }
  179. void ReplaceNodeWithTemplate::run(
  180. const ast_matchers::MatchFinder::MatchResult &Result) {
  181. const auto &NodeMap = Result.Nodes.getMap();
  182. std::string ToText;
  183. for (const auto &Element : Template) {
  184. switch (Element.Type) {
  185. case TemplateElement::Literal:
  186. ToText += Element.Value;
  187. break;
  188. case TemplateElement::Identifier: {
  189. auto NodeIter = NodeMap.find(Element.Value);
  190. if (NodeIter == NodeMap.end()) {
  191. llvm::errs() << "Node " << Element.Value
  192. << " used in replacement template not bound in Matcher \n";
  193. llvm::report_fatal_error("Unbound node in replacement template.");
  194. }
  195. CharSourceRange Source =
  196. CharSourceRange::getTokenRange(NodeIter->second.getSourceRange());
  197. ToText += Lexer::getSourceText(Source, *Result.SourceManager,
  198. Result.Context->getLangOpts());
  199. break;
  200. }
  201. }
  202. }
  203. if (NodeMap.count(FromId) == 0) {
  204. llvm::errs() << "Node to be replaced " << FromId
  205. << " not bound in query.\n";
  206. llvm::report_fatal_error("FromId node not bound in MatchResult");
  207. }
  208. auto Replacement =
  209. tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText,
  210. Result.Context->getLangOpts());
  211. llvm::Error Err = Replace.add(Replacement);
  212. if (Err) {
  213. llvm::errs() << "Query and replace failed in " << Replacement.getFilePath()
  214. << "! " << llvm::toString(std::move(Err)) << "\n";
  215. llvm::report_fatal_error("Replacement failed");
  216. }
  217. }
  218. } // end namespace tooling
  219. } // end namespace clang