DivergenceAnalysis.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. //===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements a general divergence analysis for loop vectorization
  10. // and GPU programs. It determines which branches and values in a loop or GPU
  11. // program are divergent. It can help branch optimizations such as jump
  12. // threading and loop unswitching to make better decisions.
  13. //
  14. // GPU programs typically use the SIMD execution model, where multiple threads
  15. // in the same execution group have to execute in lock-step. Therefore, if the
  16. // code contains divergent branches (i.e., threads in a group do not agree on
  17. // which path of the branch to take), the group of threads has to execute all
  18. // the paths from that branch with different subsets of threads enabled until
  19. // they re-converge.
  20. //
  21. // Due to this execution model, some optimizations such as jump
  22. // threading and loop unswitching can interfere with thread re-convergence.
  23. // Therefore, an analysis that computes which branches in a GPU program are
  24. // divergent can help the compiler to selectively run these optimizations.
  25. //
  26. // This implementation is derived from the Vectorization Analysis of the
  27. // Region Vectorizer (RV). The analysis is based on the approach described in
  28. //
  29. // An abstract interpretation for SPMD divergence
  30. // on reducible control flow graphs.
  31. // Julian Rosemann, Simon Moll and Sebastian Hack
  32. // POPL '21
  33. //
  34. // This implementation is generic in the sense that it does
  35. // not itself identify original sources of divergence.
  36. // Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and
  37. // (DivergenceAnalysis) for functions, identify the sources of divergence
  38. // (e.g., special variables that hold the thread ID or the iteration variable).
  39. //
  40. // The generic implementation propagates divergence to variables that are data
  41. // or sync dependent on a source of divergence.
  42. //
  43. // While data dependency is a well-known concept, the notion of sync dependency
  44. // is worth more explanation. Sync dependence characterizes the control flow
  45. // aspect of the propagation of branch divergence. For example,
  46. //
  47. // %cond = icmp slt i32 %tid, 10
  48. // br i1 %cond, label %then, label %else
  49. // then:
  50. // br label %merge
  51. // else:
  52. // br label %merge
  53. // merge:
  54. // %a = phi i32 [ 0, %then ], [ 1, %else ]
  55. //
  56. // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
  57. // because %tid is not on its use-def chains, %a is sync dependent on %tid
  58. // because the branch "br i1 %cond" depends on %tid and affects which value %a
  59. // is assigned to.
  60. //
  61. // The sync dependence detection (which branch induces divergence in which join
  62. // points) is implemented in the SyncDependenceAnalysis.
  63. //
  64. // The current implementation has the following limitations:
  65. // 1. intra-procedural. It conservatively considers the arguments of a
  66. // non-kernel-entry function and the return value of a function call as
  67. // divergent.
  68. // 2. memory as black box. It conservatively considers values loaded from
  69. // generic or local address as divergent. This can be improved by leveraging
  70. // pointer analysis and/or by modelling non-escaping memory objects in SSA
  71. // as done in RV.
  72. //
  73. //===----------------------------------------------------------------------===//
  74. #include "llvm/Analysis/DivergenceAnalysis.h"
  75. #include "llvm/ADT/PostOrderIterator.h"
  76. #include "llvm/Analysis/CFG.h"
  77. #include "llvm/Analysis/LoopInfo.h"
  78. #include "llvm/Analysis/PostDominators.h"
  79. #include "llvm/Analysis/TargetTransformInfo.h"
  80. #include "llvm/IR/Dominators.h"
  81. #include "llvm/IR/InstIterator.h"
  82. #include "llvm/IR/Instructions.h"
  83. #include "llvm/IR/Value.h"
  84. #include "llvm/Support/Debug.h"
  85. #include "llvm/Support/raw_ostream.h"
  86. using namespace llvm;
  87. #define DEBUG_TYPE "divergence"
  88. DivergenceAnalysisImpl::DivergenceAnalysisImpl(
  89. const Function &F, const Loop *RegionLoop, const DominatorTree &DT,
  90. const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm)
  91. : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA),
  92. IsLCSSAForm(IsLCSSAForm) {}
  93. bool DivergenceAnalysisImpl::markDivergent(const Value &DivVal) {
  94. if (isAlwaysUniform(DivVal))
  95. return false;
  96. assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal));
  97. assert(!isAlwaysUniform(DivVal) && "cannot be a divergent");
  98. return DivergentValues.insert(&DivVal).second;
  99. }
  100. void DivergenceAnalysisImpl::addUniformOverride(const Value &UniVal) {
  101. UniformOverrides.insert(&UniVal);
  102. }
  103. bool DivergenceAnalysisImpl::isTemporalDivergent(
  104. const BasicBlock &ObservingBlock, const Value &Val) const {
  105. const auto *Inst = dyn_cast<const Instruction>(&Val);
  106. if (!Inst)
  107. return false;
  108. // check whether any divergent loop carrying Val terminates before control
  109. // proceeds to ObservingBlock
  110. for (const auto *Loop = LI.getLoopFor(Inst->getParent());
  111. Loop != RegionLoop && !Loop->contains(&ObservingBlock);
  112. Loop = Loop->getParentLoop()) {
  113. if (DivergentLoops.contains(Loop))
  114. return true;
  115. }
  116. return false;
  117. }
  118. bool DivergenceAnalysisImpl::inRegion(const Instruction &I) const {
  119. return I.getParent() && inRegion(*I.getParent());
  120. }
  121. bool DivergenceAnalysisImpl::inRegion(const BasicBlock &BB) const {
  122. return RegionLoop ? RegionLoop->contains(&BB) : (BB.getParent() == &F);
  123. }
  124. void DivergenceAnalysisImpl::pushUsers(const Value &V) {
  125. const auto *I = dyn_cast<const Instruction>(&V);
  126. if (I && I->isTerminator()) {
  127. analyzeControlDivergence(*I);
  128. return;
  129. }
  130. for (const auto *User : V.users()) {
  131. const auto *UserInst = dyn_cast<const Instruction>(User);
  132. if (!UserInst)
  133. continue;
  134. // only compute divergent inside loop
  135. if (!inRegion(*UserInst))
  136. continue;
  137. // All users of divergent values are immediate divergent
  138. if (markDivergent(*UserInst))
  139. Worklist.push_back(UserInst);
  140. }
  141. }
  142. static const Instruction *getIfCarriedInstruction(const Use &U,
  143. const Loop &DivLoop) {
  144. const auto *I = dyn_cast<const Instruction>(&U);
  145. if (!I)
  146. return nullptr;
  147. if (!DivLoop.contains(I))
  148. return nullptr;
  149. return I;
  150. }
  151. void DivergenceAnalysisImpl::analyzeTemporalDivergence(
  152. const Instruction &I, const Loop &OuterDivLoop) {
  153. if (isAlwaysUniform(I))
  154. return;
  155. if (isDivergent(I))
  156. return;
  157. LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n");
  158. assert((isa<PHINode>(I) || !IsLCSSAForm) &&
  159. "In LCSSA form all users of loop-exiting defs are Phi nodes.");
  160. for (const Use &Op : I.operands()) {
  161. const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop);
  162. if (!OpInst)
  163. continue;
  164. if (markDivergent(I))
  165. pushUsers(I);
  166. return;
  167. }
  168. }
  169. // marks all users of loop-carried values of the loop headed by LoopHeader as
  170. // divergent
  171. void DivergenceAnalysisImpl::analyzeLoopExitDivergence(
  172. const BasicBlock &DivExit, const Loop &OuterDivLoop) {
  173. // All users are in immediate exit blocks
  174. if (IsLCSSAForm) {
  175. for (const auto &Phi : DivExit.phis()) {
  176. analyzeTemporalDivergence(Phi, OuterDivLoop);
  177. }
  178. return;
  179. }
  180. // For non-LCSSA we have to follow all live out edges wherever they may lead.
  181. const BasicBlock &LoopHeader = *OuterDivLoop.getHeader();
  182. SmallVector<const BasicBlock *, 8> TaintStack;
  183. TaintStack.push_back(&DivExit);
  184. // Otherwise potential users of loop-carried values could be anywhere in the
  185. // dominance region of DivLoop (including its fringes for phi nodes)
  186. DenseSet<const BasicBlock *> Visited;
  187. Visited.insert(&DivExit);
  188. do {
  189. auto *UserBlock = TaintStack.pop_back_val();
  190. // don't spread divergence beyond the region
  191. if (!inRegion(*UserBlock))
  192. continue;
  193. assert(!OuterDivLoop.contains(UserBlock) &&
  194. "irreducible control flow detected");
  195. // phi nodes at the fringes of the dominance region
  196. if (!DT.dominates(&LoopHeader, UserBlock)) {
  197. // all PHI nodes of UserBlock become divergent
  198. for (const auto &Phi : UserBlock->phis()) {
  199. analyzeTemporalDivergence(Phi, OuterDivLoop);
  200. }
  201. continue;
  202. }
  203. // Taint outside users of values carried by OuterDivLoop.
  204. for (const auto &I : *UserBlock) {
  205. analyzeTemporalDivergence(I, OuterDivLoop);
  206. }
  207. // visit all blocks in the dominance region
  208. for (const auto *SuccBlock : successors(UserBlock)) {
  209. if (!Visited.insert(SuccBlock).second) {
  210. continue;
  211. }
  212. TaintStack.push_back(SuccBlock);
  213. }
  214. } while (!TaintStack.empty());
  215. }
  216. void DivergenceAnalysisImpl::propagateLoopExitDivergence(
  217. const BasicBlock &DivExit, const Loop &InnerDivLoop) {
  218. LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n");
  219. // Find outer-most loop that does not contain \p DivExit
  220. const Loop *DivLoop = &InnerDivLoop;
  221. const Loop *OuterDivLoop = DivLoop;
  222. const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit);
  223. const unsigned LoopExitDepth =
  224. ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0;
  225. while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) {
  226. DivergentLoops.insert(DivLoop); // all crossed loops are divergent
  227. OuterDivLoop = DivLoop;
  228. DivLoop = DivLoop->getParentLoop();
  229. }
  230. LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName()
  231. << "\n");
  232. analyzeLoopExitDivergence(DivExit, *OuterDivLoop);
  233. }
  234. // this is a divergent join point - mark all phi nodes as divergent and push
  235. // them onto the stack.
  236. void DivergenceAnalysisImpl::taintAndPushPhiNodes(const BasicBlock &JoinBlock) {
  237. LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName()
  238. << "\n");
  239. // ignore divergence outside the region
  240. if (!inRegion(JoinBlock)) {
  241. return;
  242. }
  243. // push non-divergent phi nodes in JoinBlock to the worklist
  244. for (const auto &Phi : JoinBlock.phis()) {
  245. if (isDivergent(Phi))
  246. continue;
  247. // FIXME Theoretically ,the 'undef' value could be replaced by any other
  248. // value causing spurious divergence.
  249. if (Phi.hasConstantOrUndefValue())
  250. continue;
  251. if (markDivergent(Phi))
  252. Worklist.push_back(&Phi);
  253. }
  254. }
  255. void DivergenceAnalysisImpl::analyzeControlDivergence(const Instruction &Term) {
  256. LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName()
  257. << "\n");
  258. // Don't propagate divergence from unreachable blocks.
  259. if (!DT.isReachableFromEntry(Term.getParent()))
  260. return;
  261. const auto *BranchLoop = LI.getLoopFor(Term.getParent());
  262. const auto &DivDesc = SDA.getJoinBlocks(Term);
  263. // Iterate over all blocks now reachable by a disjoint path join
  264. for (const auto *JoinBlock : DivDesc.JoinDivBlocks) {
  265. taintAndPushPhiNodes(*JoinBlock);
  266. }
  267. assert(DivDesc.LoopDivBlocks.empty() || BranchLoop);
  268. for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) {
  269. propagateLoopExitDivergence(*DivExitBlock, *BranchLoop);
  270. }
  271. }
  272. void DivergenceAnalysisImpl::compute() {
  273. // Initialize worklist.
  274. auto DivValuesCopy = DivergentValues;
  275. for (const auto *DivVal : DivValuesCopy) {
  276. assert(isDivergent(*DivVal) && "Worklist invariant violated!");
  277. pushUsers(*DivVal);
  278. }
  279. // All values on the Worklist are divergent.
  280. // Their users may not have been updated yed.
  281. while (!Worklist.empty()) {
  282. const Instruction &I = *Worklist.back();
  283. Worklist.pop_back();
  284. // propagate value divergence to users
  285. assert(isDivergent(I) && "Worklist invariant violated!");
  286. pushUsers(I);
  287. }
  288. }
  289. bool DivergenceAnalysisImpl::isAlwaysUniform(const Value &V) const {
  290. return UniformOverrides.contains(&V);
  291. }
  292. bool DivergenceAnalysisImpl::isDivergent(const Value &V) const {
  293. return DivergentValues.contains(&V);
  294. }
  295. bool DivergenceAnalysisImpl::isDivergentUse(const Use &U) const {
  296. Value &V = *U.get();
  297. Instruction &I = *cast<Instruction>(U.getUser());
  298. return isDivergent(V) || isTemporalDivergent(*I.getParent(), V);
  299. }
  300. DivergenceInfo::DivergenceInfo(Function &F, const DominatorTree &DT,
  301. const PostDominatorTree &PDT, const LoopInfo &LI,
  302. const TargetTransformInfo &TTI,
  303. bool KnownReducible)
  304. : F(F) {
  305. if (!KnownReducible) {
  306. using RPOTraversal = ReversePostOrderTraversal<const Function *>;
  307. RPOTraversal FuncRPOT(&F);
  308. if (containsIrreducibleCFG<const BasicBlock *, const RPOTraversal,
  309. const LoopInfo>(FuncRPOT, LI)) {
  310. ContainsIrreducible = true;
  311. return;
  312. }
  313. }
  314. SDA = std::make_unique<SyncDependenceAnalysis>(DT, PDT, LI);
  315. DA = std::make_unique<DivergenceAnalysisImpl>(F, nullptr, DT, LI, *SDA,
  316. /* LCSSA */ false);
  317. for (auto &I : instructions(F)) {
  318. if (TTI.isSourceOfDivergence(&I)) {
  319. DA->markDivergent(I);
  320. } else if (TTI.isAlwaysUniform(&I)) {
  321. DA->addUniformOverride(I);
  322. }
  323. }
  324. for (auto &Arg : F.args()) {
  325. if (TTI.isSourceOfDivergence(&Arg)) {
  326. DA->markDivergent(Arg);
  327. }
  328. }
  329. DA->compute();
  330. }
  331. AnalysisKey DivergenceAnalysis::Key;
  332. DivergenceAnalysis::Result
  333. DivergenceAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
  334. auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
  335. auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
  336. auto &LI = AM.getResult<LoopAnalysis>(F);
  337. auto &TTI = AM.getResult<TargetIRAnalysis>(F);
  338. return DivergenceInfo(F, DT, PDT, LI, TTI, /* KnownReducible = */ false);
  339. }
  340. PreservedAnalyses
  341. DivergenceAnalysisPrinterPass::run(Function &F, FunctionAnalysisManager &FAM) {
  342. auto &DI = FAM.getResult<DivergenceAnalysis>(F);
  343. OS << "'Divergence Analysis' for function '" << F.getName() << "':\n";
  344. if (DI.hasDivergence()) {
  345. for (auto &Arg : F.args()) {
  346. OS << (DI.isDivergent(Arg) ? "DIVERGENT: " : " ");
  347. OS << Arg << "\n";
  348. }
  349. for (const BasicBlock &BB : F) {
  350. OS << "\n " << BB.getName() << ":\n";
  351. for (const auto &I : BB.instructionsWithoutDebug()) {
  352. OS << (DI.isDivergent(I) ? "DIVERGENT: " : " ");
  353. OS << I << "\n";
  354. }
  355. }
  356. }
  357. return PreservedAnalyses::all();
  358. }