TransProtectedScope.cpp 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. //===--- TransProtectedScope.cpp - Transformations to ARC mode ------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Adds brackets in case statements that "contain" initialization of retaining
  10. // variable, thus emitting the "switch case is in protected scope" error.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "Internals.h"
  14. #include "Transforms.h"
  15. #include "clang/AST/ASTContext.h"
  16. #include "clang/Basic/SourceManager.h"
  17. #include "clang/Sema/SemaDiagnostic.h"
  18. using namespace clang;
  19. using namespace arcmt;
  20. using namespace trans;
  21. namespace {
  22. class LocalRefsCollector : public RecursiveASTVisitor<LocalRefsCollector> {
  23. SmallVectorImpl<DeclRefExpr *> &Refs;
  24. public:
  25. LocalRefsCollector(SmallVectorImpl<DeclRefExpr *> &refs)
  26. : Refs(refs) { }
  27. bool VisitDeclRefExpr(DeclRefExpr *E) {
  28. if (ValueDecl *D = E->getDecl())
  29. if (D->getDeclContext()->getRedeclContext()->isFunctionOrMethod())
  30. Refs.push_back(E);
  31. return true;
  32. }
  33. };
  34. struct CaseInfo {
  35. SwitchCase *SC;
  36. SourceRange Range;
  37. enum {
  38. St_Unchecked,
  39. St_CannotFix,
  40. St_Fixed
  41. } State;
  42. CaseInfo() : SC(nullptr), State(St_Unchecked) {}
  43. CaseInfo(SwitchCase *S, SourceRange Range)
  44. : SC(S), Range(Range), State(St_Unchecked) {}
  45. };
  46. class CaseCollector : public RecursiveASTVisitor<CaseCollector> {
  47. ParentMap &PMap;
  48. SmallVectorImpl<CaseInfo> &Cases;
  49. public:
  50. CaseCollector(ParentMap &PMap, SmallVectorImpl<CaseInfo> &Cases)
  51. : PMap(PMap), Cases(Cases) { }
  52. bool VisitSwitchStmt(SwitchStmt *S) {
  53. SwitchCase *Curr = S->getSwitchCaseList();
  54. if (!Curr)
  55. return true;
  56. Stmt *Parent = getCaseParent(Curr);
  57. Curr = Curr->getNextSwitchCase();
  58. // Make sure all case statements are in the same scope.
  59. while (Curr) {
  60. if (getCaseParent(Curr) != Parent)
  61. return true;
  62. Curr = Curr->getNextSwitchCase();
  63. }
  64. SourceLocation NextLoc = S->getEndLoc();
  65. Curr = S->getSwitchCaseList();
  66. // We iterate over case statements in reverse source-order.
  67. while (Curr) {
  68. Cases.push_back(
  69. CaseInfo(Curr, SourceRange(Curr->getBeginLoc(), NextLoc)));
  70. NextLoc = Curr->getBeginLoc();
  71. Curr = Curr->getNextSwitchCase();
  72. }
  73. return true;
  74. }
  75. Stmt *getCaseParent(SwitchCase *S) {
  76. Stmt *Parent = PMap.getParent(S);
  77. while (Parent && (isa<SwitchCase>(Parent) || isa<LabelStmt>(Parent)))
  78. Parent = PMap.getParent(Parent);
  79. return Parent;
  80. }
  81. };
  82. class ProtectedScopeFixer {
  83. MigrationPass &Pass;
  84. SourceManager &SM;
  85. SmallVector<CaseInfo, 16> Cases;
  86. SmallVector<DeclRefExpr *, 16> LocalRefs;
  87. public:
  88. ProtectedScopeFixer(BodyContext &BodyCtx)
  89. : Pass(BodyCtx.getMigrationContext().Pass),
  90. SM(Pass.Ctx.getSourceManager()) {
  91. CaseCollector(BodyCtx.getParentMap(), Cases)
  92. .TraverseStmt(BodyCtx.getTopStmt());
  93. LocalRefsCollector(LocalRefs).TraverseStmt(BodyCtx.getTopStmt());
  94. SourceRange BodyRange = BodyCtx.getTopStmt()->getSourceRange();
  95. const CapturedDiagList &DiagList = Pass.getDiags();
  96. // Copy the diagnostics so we don't have to worry about invaliding iterators
  97. // from the diagnostic list.
  98. SmallVector<StoredDiagnostic, 16> StoredDiags;
  99. StoredDiags.append(DiagList.begin(), DiagList.end());
  100. SmallVectorImpl<StoredDiagnostic>::iterator
  101. I = StoredDiags.begin(), E = StoredDiags.end();
  102. while (I != E) {
  103. if (I->getID() == diag::err_switch_into_protected_scope &&
  104. isInRange(I->getLocation(), BodyRange)) {
  105. handleProtectedScopeError(I, E);
  106. continue;
  107. }
  108. ++I;
  109. }
  110. }
  111. void handleProtectedScopeError(
  112. SmallVectorImpl<StoredDiagnostic>::iterator &DiagI,
  113. SmallVectorImpl<StoredDiagnostic>::iterator DiagE){
  114. Transaction Trans(Pass.TA);
  115. assert(DiagI->getID() == diag::err_switch_into_protected_scope);
  116. SourceLocation ErrLoc = DiagI->getLocation();
  117. bool handledAllNotes = true;
  118. ++DiagI;
  119. for (; DiagI != DiagE && DiagI->getLevel() == DiagnosticsEngine::Note;
  120. ++DiagI) {
  121. if (!handleProtectedNote(*DiagI))
  122. handledAllNotes = false;
  123. }
  124. if (handledAllNotes)
  125. Pass.TA.clearDiagnostic(diag::err_switch_into_protected_scope, ErrLoc);
  126. }
  127. bool handleProtectedNote(const StoredDiagnostic &Diag) {
  128. assert(Diag.getLevel() == DiagnosticsEngine::Note);
  129. for (unsigned i = 0; i != Cases.size(); i++) {
  130. CaseInfo &info = Cases[i];
  131. if (isInRange(Diag.getLocation(), info.Range)) {
  132. if (info.State == CaseInfo::St_Unchecked)
  133. tryFixing(info);
  134. assert(info.State != CaseInfo::St_Unchecked);
  135. if (info.State == CaseInfo::St_Fixed) {
  136. Pass.TA.clearDiagnostic(Diag.getID(), Diag.getLocation());
  137. return true;
  138. }
  139. return false;
  140. }
  141. }
  142. return false;
  143. }
  144. void tryFixing(CaseInfo &info) {
  145. assert(info.State == CaseInfo::St_Unchecked);
  146. if (hasVarReferencedOutside(info)) {
  147. info.State = CaseInfo::St_CannotFix;
  148. return;
  149. }
  150. Pass.TA.insertAfterToken(info.SC->getColonLoc(), " {");
  151. Pass.TA.insert(info.Range.getEnd(), "}\n");
  152. info.State = CaseInfo::St_Fixed;
  153. }
  154. bool hasVarReferencedOutside(CaseInfo &info) {
  155. for (unsigned i = 0, e = LocalRefs.size(); i != e; ++i) {
  156. DeclRefExpr *DRE = LocalRefs[i];
  157. if (isInRange(DRE->getDecl()->getLocation(), info.Range) &&
  158. !isInRange(DRE->getLocation(), info.Range))
  159. return true;
  160. }
  161. return false;
  162. }
  163. bool isInRange(SourceLocation Loc, SourceRange R) {
  164. if (Loc.isInvalid())
  165. return false;
  166. return !SM.isBeforeInTranslationUnit(Loc, R.getBegin()) &&
  167. SM.isBeforeInTranslationUnit(Loc, R.getEnd());
  168. }
  169. };
  170. } // anonymous namespace
  171. void ProtectedScopeTraverser::traverseBody(BodyContext &BodyCtx) {
  172. ProtectedScopeFixer Fix(BodyCtx);
  173. }