UnrollLoopsCheck.cpp 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. //===--- UnrollLoopsCheck.cpp - clang-tidy --------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "UnrollLoopsCheck.h"
  9. #include "clang/AST/APValue.h"
  10. #include "clang/AST/ASTContext.h"
  11. #include "clang/AST/ASTTypeTraits.h"
  12. #include "clang/AST/OperationKinds.h"
  13. #include "clang/AST/ParentMapContext.h"
  14. #include "clang/ASTMatchers/ASTMatchFinder.h"
  15. #include <math.h>
  16. using namespace clang::ast_matchers;
  17. namespace clang::tidy::altera {
  18. UnrollLoopsCheck::UnrollLoopsCheck(StringRef Name, ClangTidyContext *Context)
  19. : ClangTidyCheck(Name, Context),
  20. MaxLoopIterations(Options.get("MaxLoopIterations", 100U)) {}
  21. void UnrollLoopsCheck::registerMatchers(MatchFinder *Finder) {
  22. const auto HasLoopBound = hasDescendant(
  23. varDecl(allOf(matchesName("__end*"),
  24. hasDescendant(integerLiteral().bind("cxx_loop_bound")))));
  25. const auto CXXForRangeLoop =
  26. cxxForRangeStmt(anyOf(HasLoopBound, unless(HasLoopBound)));
  27. const auto AnyLoop = anyOf(forStmt(), whileStmt(), doStmt(), CXXForRangeLoop);
  28. Finder->addMatcher(
  29. stmt(allOf(AnyLoop, unless(hasDescendant(stmt(AnyLoop))))).bind("loop"),
  30. this);
  31. }
  32. void UnrollLoopsCheck::check(const MatchFinder::MatchResult &Result) {
  33. const auto *Loop = Result.Nodes.getNodeAs<Stmt>("loop");
  34. const auto *CXXLoopBound =
  35. Result.Nodes.getNodeAs<IntegerLiteral>("cxx_loop_bound");
  36. const ASTContext *Context = Result.Context;
  37. switch (unrollType(Loop, Result.Context)) {
  38. case NotUnrolled:
  39. diag(Loop->getBeginLoc(),
  40. "kernel performance could be improved by unrolling this loop with a "
  41. "'#pragma unroll' directive");
  42. break;
  43. case PartiallyUnrolled:
  44. // Loop already partially unrolled, do nothing.
  45. break;
  46. case FullyUnrolled:
  47. if (hasKnownBounds(Loop, CXXLoopBound, Context)) {
  48. if (hasLargeNumIterations(Loop, CXXLoopBound, Context)) {
  49. diag(Loop->getBeginLoc(),
  50. "loop likely has a large number of iterations and thus "
  51. "cannot be fully unrolled; to partially unroll this loop, use "
  52. "the '#pragma unroll <num>' directive");
  53. return;
  54. }
  55. return;
  56. }
  57. if (isa<WhileStmt, DoStmt>(Loop)) {
  58. diag(Loop->getBeginLoc(),
  59. "full unrolling requested, but loop bounds may not be known; to "
  60. "partially unroll this loop, use the '#pragma unroll <num>' "
  61. "directive",
  62. DiagnosticIDs::Note);
  63. break;
  64. }
  65. diag(Loop->getBeginLoc(),
  66. "full unrolling requested, but loop bounds are not known; to "
  67. "partially unroll this loop, use the '#pragma unroll <num>' "
  68. "directive");
  69. break;
  70. }
  71. }
  72. enum UnrollLoopsCheck::UnrollType
  73. UnrollLoopsCheck::unrollType(const Stmt *Statement, ASTContext *Context) {
  74. const DynTypedNodeList Parents = Context->getParents<Stmt>(*Statement);
  75. for (const DynTypedNode &Parent : Parents) {
  76. const auto *ParentStmt = Parent.get<AttributedStmt>();
  77. if (!ParentStmt)
  78. continue;
  79. for (const Attr *Attribute : ParentStmt->getAttrs()) {
  80. const auto *LoopHint = dyn_cast<LoopHintAttr>(Attribute);
  81. if (!LoopHint)
  82. continue;
  83. switch (LoopHint->getState()) {
  84. case LoopHintAttr::Numeric:
  85. return PartiallyUnrolled;
  86. case LoopHintAttr::Disable:
  87. return NotUnrolled;
  88. case LoopHintAttr::Full:
  89. return FullyUnrolled;
  90. case LoopHintAttr::Enable:
  91. return FullyUnrolled;
  92. case LoopHintAttr::AssumeSafety:
  93. return NotUnrolled;
  94. case LoopHintAttr::FixedWidth:
  95. return NotUnrolled;
  96. case LoopHintAttr::ScalableWidth:
  97. return NotUnrolled;
  98. }
  99. }
  100. }
  101. return NotUnrolled;
  102. }
  103. bool UnrollLoopsCheck::hasKnownBounds(const Stmt *Statement,
  104. const IntegerLiteral *CXXLoopBound,
  105. const ASTContext *Context) {
  106. if (isa<CXXForRangeStmt>(Statement))
  107. return CXXLoopBound != nullptr;
  108. // Too many possibilities in a while statement, so always recommend partial
  109. // unrolling for these.
  110. if (isa<WhileStmt, DoStmt>(Statement))
  111. return false;
  112. // The last loop type is a for loop.
  113. const auto *ForLoop = cast<ForStmt>(Statement);
  114. const Stmt *Initializer = ForLoop->getInit();
  115. const Expr *Conditional = ForLoop->getCond();
  116. const Expr *Increment = ForLoop->getInc();
  117. if (!Initializer || !Conditional || !Increment)
  118. return false;
  119. // If the loop variable value isn't known, loop bounds are unknown.
  120. if (const auto *InitDeclStatement = dyn_cast<DeclStmt>(Initializer)) {
  121. if (const auto *VariableDecl =
  122. dyn_cast<VarDecl>(InitDeclStatement->getSingleDecl())) {
  123. APValue *Evaluation = VariableDecl->evaluateValue();
  124. if (!Evaluation || !Evaluation->hasValue())
  125. return false;
  126. }
  127. }
  128. // If increment is unary and not one of ++ and --, loop bounds are unknown.
  129. if (const auto *Op = dyn_cast<UnaryOperator>(Increment))
  130. if (!Op->isIncrementDecrementOp())
  131. return false;
  132. if (const auto *BinaryOp = dyn_cast<BinaryOperator>(Conditional)) {
  133. const Expr *LHS = BinaryOp->getLHS();
  134. const Expr *RHS = BinaryOp->getRHS();
  135. // If both sides are value dependent or constant, loop bounds are unknown.
  136. return LHS->isEvaluatable(*Context) != RHS->isEvaluatable(*Context);
  137. }
  138. return false; // If it's not a binary operator, loop bounds are unknown.
  139. }
  140. const Expr *UnrollLoopsCheck::getCondExpr(const Stmt *Statement) {
  141. if (const auto *ForLoop = dyn_cast<ForStmt>(Statement))
  142. return ForLoop->getCond();
  143. if (const auto *WhileLoop = dyn_cast<WhileStmt>(Statement))
  144. return WhileLoop->getCond();
  145. if (const auto *DoWhileLoop = dyn_cast<DoStmt>(Statement))
  146. return DoWhileLoop->getCond();
  147. if (const auto *CXXRangeLoop = dyn_cast<CXXForRangeStmt>(Statement))
  148. return CXXRangeLoop->getCond();
  149. llvm_unreachable("Unknown loop");
  150. }
  151. bool UnrollLoopsCheck::hasLargeNumIterations(const Stmt *Statement,
  152. const IntegerLiteral *CXXLoopBound,
  153. const ASTContext *Context) {
  154. // Because hasKnownBounds is called before this, if this is true, then
  155. // CXXLoopBound is also matched.
  156. if (isa<CXXForRangeStmt>(Statement)) {
  157. assert(CXXLoopBound && "CXX ranged for loop has no loop bound");
  158. return exprHasLargeNumIterations(CXXLoopBound, Context);
  159. }
  160. const auto *ForLoop = cast<ForStmt>(Statement);
  161. const Stmt *Initializer = ForLoop->getInit();
  162. const Expr *Conditional = ForLoop->getCond();
  163. const Expr *Increment = ForLoop->getInc();
  164. int InitValue;
  165. // If the loop variable value isn't known, we can't know the loop bounds.
  166. if (const auto *InitDeclStatement = dyn_cast<DeclStmt>(Initializer)) {
  167. if (const auto *VariableDecl =
  168. dyn_cast<VarDecl>(InitDeclStatement->getSingleDecl())) {
  169. APValue *Evaluation = VariableDecl->evaluateValue();
  170. if (!Evaluation || !Evaluation->isInt())
  171. return true;
  172. InitValue = Evaluation->getInt().getExtValue();
  173. }
  174. }
  175. int EndValue;
  176. const auto *BinaryOp = cast<BinaryOperator>(Conditional);
  177. if (!extractValue(EndValue, BinaryOp, Context))
  178. return true;
  179. double Iterations;
  180. // If increment is unary and not one of ++, --, we can't know the loop bounds.
  181. if (const auto *Op = dyn_cast<UnaryOperator>(Increment)) {
  182. if (Op->isIncrementOp())
  183. Iterations = EndValue - InitValue;
  184. else if (Op->isDecrementOp())
  185. Iterations = InitValue - EndValue;
  186. else
  187. llvm_unreachable("Unary operator neither increment nor decrement");
  188. }
  189. // If increment is binary and not one of +, -, *, /, we can't know the loop
  190. // bounds.
  191. if (const auto *Op = dyn_cast<BinaryOperator>(Increment)) {
  192. int ConstantValue;
  193. if (!extractValue(ConstantValue, Op, Context))
  194. return true;
  195. switch (Op->getOpcode()) {
  196. case (BO_AddAssign):
  197. Iterations = ceil(float(EndValue - InitValue) / ConstantValue);
  198. break;
  199. case (BO_SubAssign):
  200. Iterations = ceil(float(InitValue - EndValue) / ConstantValue);
  201. break;
  202. case (BO_MulAssign):
  203. Iterations = 1 + (log(EndValue) - log(InitValue)) / log(ConstantValue);
  204. break;
  205. case (BO_DivAssign):
  206. Iterations = 1 + (log(InitValue) - log(EndValue)) / log(ConstantValue);
  207. break;
  208. default:
  209. // All other operators are not handled; assume large bounds.
  210. return true;
  211. }
  212. }
  213. return Iterations > MaxLoopIterations;
  214. }
  215. bool UnrollLoopsCheck::extractValue(int &Value, const BinaryOperator *Op,
  216. const ASTContext *Context) {
  217. const Expr *LHS = Op->getLHS();
  218. const Expr *RHS = Op->getRHS();
  219. Expr::EvalResult Result;
  220. if (LHS->isEvaluatable(*Context))
  221. LHS->EvaluateAsRValue(Result, *Context);
  222. else if (RHS->isEvaluatable(*Context))
  223. RHS->EvaluateAsRValue(Result, *Context);
  224. else
  225. return false; // Cannot evaluate either side.
  226. if (!Result.Val.isInt())
  227. return false; // Cannot check number of iterations, return false to be
  228. // safe.
  229. Value = Result.Val.getInt().getExtValue();
  230. return true;
  231. }
  232. bool UnrollLoopsCheck::exprHasLargeNumIterations(const Expr *Expression,
  233. const ASTContext *Context) {
  234. Expr::EvalResult Result;
  235. if (Expression->EvaluateAsRValue(Result, *Context)) {
  236. if (!Result.Val.isInt())
  237. return false; // Cannot check number of iterations, return false to be
  238. // safe.
  239. // The following assumes values go from 0 to Val in increments of 1.
  240. return Result.Val.getInt() > MaxLoopIterations;
  241. }
  242. // Cannot evaluate Expression as an r-value, so cannot check number of
  243. // iterations.
  244. return false;
  245. }
  246. void UnrollLoopsCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
  247. Options.store(Opts, "MaxLoopIterations", MaxLoopIterations);
  248. }
  249. } // namespace clang::tidy::altera