DivergenceAnalysis.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. //===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements a general divergence analysis for loop vectorization
  10. // and GPU programs. It determines which branches and values in a loop or GPU
  11. // program are divergent. It can help branch optimizations such as jump
  12. // threading and loop unswitching to make better decisions.
  13. //
  14. // GPU programs typically use the SIMD execution model, where multiple threads
  15. // in the same execution group have to execute in lock-step. Therefore, if the
  16. // code contains divergent branches (i.e., threads in a group do not agree on
  17. // which path of the branch to take), the group of threads has to execute all
  18. // the paths from that branch with different subsets of threads enabled until
  19. // they re-converge.
  20. //
  21. // Due to this execution model, some optimizations such as jump
  22. // threading and loop unswitching can interfere with thread re-convergence.
  23. // Therefore, an analysis that computes which branches in a GPU program are
  24. // divergent can help the compiler to selectively run these optimizations.
  25. //
  26. // This implementation is derived from the Vectorization Analysis of the
  27. // Region Vectorizer (RV). The analysis is based on the approach described in
  28. //
  29. // An abstract interpretation for SPMD divergence
  30. // on reducible control flow graphs.
  31. // Julian Rosemann, Simon Moll and Sebastian Hack
  32. // POPL '21
  33. //
  34. // This implementation is generic in the sense that it does
  35. // not itself identify original sources of divergence.
  36. // Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and
  37. // (DivergenceAnalysis) for functions, identify the sources of divergence
  38. // (e.g., special variables that hold the thread ID or the iteration variable).
  39. //
  40. // The generic implementation propagates divergence to variables that are data
  41. // or sync dependent on a source of divergence.
  42. //
  43. // While data dependency is a well-known concept, the notion of sync dependency
  44. // is worth more explanation. Sync dependence characterizes the control flow
  45. // aspect of the propagation of branch divergence. For example,
  46. //
  47. // %cond = icmp slt i32 %tid, 10
  48. // br i1 %cond, label %then, label %else
  49. // then:
  50. // br label %merge
  51. // else:
  52. // br label %merge
  53. // merge:
  54. // %a = phi i32 [ 0, %then ], [ 1, %else ]
  55. //
  56. // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
  57. // because %tid is not on its use-def chains, %a is sync dependent on %tid
  58. // because the branch "br i1 %cond" depends on %tid and affects which value %a
  59. // is assigned to.
  60. //
  61. // The sync dependence detection (which branch induces divergence in which join
  62. // points) is implemented in the SyncDependenceAnalysis.
  63. //
  64. // The current implementation has the following limitations:
  65. // 1. intra-procedural. It conservatively considers the arguments of a
  66. // non-kernel-entry function and the return value of a function call as
  67. // divergent.
  68. // 2. memory as black box. It conservatively considers values loaded from
  69. // generic or local address as divergent. This can be improved by leveraging
  70. // pointer analysis and/or by modelling non-escaping memory objects in SSA
  71. // as done in RV.
  72. //
  73. //===----------------------------------------------------------------------===//
  74. #include "llvm/Analysis/DivergenceAnalysis.h"
  75. #include "llvm/Analysis/CFG.h"
  76. #include "llvm/Analysis/LoopInfo.h"
  77. #include "llvm/Analysis/Passes.h"
  78. #include "llvm/Analysis/PostDominators.h"
  79. #include "llvm/Analysis/TargetTransformInfo.h"
  80. #include "llvm/IR/Dominators.h"
  81. #include "llvm/IR/InstIterator.h"
  82. #include "llvm/IR/Instructions.h"
  83. #include "llvm/IR/IntrinsicInst.h"
  84. #include "llvm/IR/Value.h"
  85. #include "llvm/Support/Debug.h"
  86. #include "llvm/Support/raw_ostream.h"
  87. using namespace llvm;
  88. #define DEBUG_TYPE "divergence"
  89. DivergenceAnalysisImpl::DivergenceAnalysisImpl(
  90. const Function &F, const Loop *RegionLoop, const DominatorTree &DT,
  91. const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm)
  92. : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA),
  93. IsLCSSAForm(IsLCSSAForm) {}
  94. bool DivergenceAnalysisImpl::markDivergent(const Value &DivVal) {
  95. if (isAlwaysUniform(DivVal))
  96. return false;
  97. assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal));
  98. assert(!isAlwaysUniform(DivVal) && "cannot be a divergent");
  99. return DivergentValues.insert(&DivVal).second;
  100. }
  101. void DivergenceAnalysisImpl::addUniformOverride(const Value &UniVal) {
  102. UniformOverrides.insert(&UniVal);
  103. }
  104. bool DivergenceAnalysisImpl::isTemporalDivergent(
  105. const BasicBlock &ObservingBlock, const Value &Val) const {
  106. const auto *Inst = dyn_cast<const Instruction>(&Val);
  107. if (!Inst)
  108. return false;
  109. // check whether any divergent loop carrying Val terminates before control
  110. // proceeds to ObservingBlock
  111. for (const auto *Loop = LI.getLoopFor(Inst->getParent());
  112. Loop != RegionLoop && !Loop->contains(&ObservingBlock);
  113. Loop = Loop->getParentLoop()) {
  114. if (DivergentLoops.contains(Loop))
  115. return true;
  116. }
  117. return false;
  118. }
  119. bool DivergenceAnalysisImpl::inRegion(const Instruction &I) const {
  120. return I.getParent() && inRegion(*I.getParent());
  121. }
  122. bool DivergenceAnalysisImpl::inRegion(const BasicBlock &BB) const {
  123. return RegionLoop ? RegionLoop->contains(&BB) : (BB.getParent() == &F);
  124. }
  125. void DivergenceAnalysisImpl::pushUsers(const Value &V) {
  126. const auto *I = dyn_cast<const Instruction>(&V);
  127. if (I && I->isTerminator()) {
  128. analyzeControlDivergence(*I);
  129. return;
  130. }
  131. for (const auto *User : V.users()) {
  132. const auto *UserInst = dyn_cast<const Instruction>(User);
  133. if (!UserInst)
  134. continue;
  135. // only compute divergent inside loop
  136. if (!inRegion(*UserInst))
  137. continue;
  138. // All users of divergent values are immediate divergent
  139. if (markDivergent(*UserInst))
  140. Worklist.push_back(UserInst);
  141. }
  142. }
  143. static const Instruction *getIfCarriedInstruction(const Use &U,
  144. const Loop &DivLoop) {
  145. const auto *I = dyn_cast<const Instruction>(&U);
  146. if (!I)
  147. return nullptr;
  148. if (!DivLoop.contains(I))
  149. return nullptr;
  150. return I;
  151. }
  152. void DivergenceAnalysisImpl::analyzeTemporalDivergence(
  153. const Instruction &I, const Loop &OuterDivLoop) {
  154. if (isAlwaysUniform(I))
  155. return;
  156. if (isDivergent(I))
  157. return;
  158. LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n");
  159. assert((isa<PHINode>(I) || !IsLCSSAForm) &&
  160. "In LCSSA form all users of loop-exiting defs are Phi nodes.");
  161. for (const Use &Op : I.operands()) {
  162. const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop);
  163. if (!OpInst)
  164. continue;
  165. if (markDivergent(I))
  166. pushUsers(I);
  167. return;
  168. }
  169. }
  170. // marks all users of loop-carried values of the loop headed by LoopHeader as
  171. // divergent
  172. void DivergenceAnalysisImpl::analyzeLoopExitDivergence(
  173. const BasicBlock &DivExit, const Loop &OuterDivLoop) {
  174. // All users are in immediate exit blocks
  175. if (IsLCSSAForm) {
  176. for (const auto &Phi : DivExit.phis()) {
  177. analyzeTemporalDivergence(Phi, OuterDivLoop);
  178. }
  179. return;
  180. }
  181. // For non-LCSSA we have to follow all live out edges wherever they may lead.
  182. const BasicBlock &LoopHeader = *OuterDivLoop.getHeader();
  183. SmallVector<const BasicBlock *, 8> TaintStack;
  184. TaintStack.push_back(&DivExit);
  185. // Otherwise potential users of loop-carried values could be anywhere in the
  186. // dominance region of DivLoop (including its fringes for phi nodes)
  187. DenseSet<const BasicBlock *> Visited;
  188. Visited.insert(&DivExit);
  189. do {
  190. auto *UserBlock = TaintStack.pop_back_val();
  191. // don't spread divergence beyond the region
  192. if (!inRegion(*UserBlock))
  193. continue;
  194. assert(!OuterDivLoop.contains(UserBlock) &&
  195. "irreducible control flow detected");
  196. // phi nodes at the fringes of the dominance region
  197. if (!DT.dominates(&LoopHeader, UserBlock)) {
  198. // all PHI nodes of UserBlock become divergent
  199. for (auto &Phi : UserBlock->phis()) {
  200. analyzeTemporalDivergence(Phi, OuterDivLoop);
  201. }
  202. continue;
  203. }
  204. // Taint outside users of values carried by OuterDivLoop.
  205. for (auto &I : *UserBlock) {
  206. analyzeTemporalDivergence(I, OuterDivLoop);
  207. }
  208. // visit all blocks in the dominance region
  209. for (auto *SuccBlock : successors(UserBlock)) {
  210. if (!Visited.insert(SuccBlock).second) {
  211. continue;
  212. }
  213. TaintStack.push_back(SuccBlock);
  214. }
  215. } while (!TaintStack.empty());
  216. }
  217. void DivergenceAnalysisImpl::propagateLoopExitDivergence(
  218. const BasicBlock &DivExit, const Loop &InnerDivLoop) {
  219. LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n");
  220. // Find outer-most loop that does not contain \p DivExit
  221. const Loop *DivLoop = &InnerDivLoop;
  222. const Loop *OuterDivLoop = DivLoop;
  223. const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit);
  224. const unsigned LoopExitDepth =
  225. ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0;
  226. while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) {
  227. DivergentLoops.insert(DivLoop); // all crossed loops are divergent
  228. OuterDivLoop = DivLoop;
  229. DivLoop = DivLoop->getParentLoop();
  230. }
  231. LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName()
  232. << "\n");
  233. analyzeLoopExitDivergence(DivExit, *OuterDivLoop);
  234. }
  235. // this is a divergent join point - mark all phi nodes as divergent and push
  236. // them onto the stack.
  237. void DivergenceAnalysisImpl::taintAndPushPhiNodes(const BasicBlock &JoinBlock) {
  238. LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName()
  239. << "\n");
  240. // ignore divergence outside the region
  241. if (!inRegion(JoinBlock)) {
  242. return;
  243. }
  244. // push non-divergent phi nodes in JoinBlock to the worklist
  245. for (const auto &Phi : JoinBlock.phis()) {
  246. if (isDivergent(Phi))
  247. continue;
  248. // FIXME Theoretically ,the 'undef' value could be replaced by any other
  249. // value causing spurious divergence.
  250. if (Phi.hasConstantOrUndefValue())
  251. continue;
  252. if (markDivergent(Phi))
  253. Worklist.push_back(&Phi);
  254. }
  255. }
  256. void DivergenceAnalysisImpl::analyzeControlDivergence(const Instruction &Term) {
  257. LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName()
  258. << "\n");
  259. // Don't propagate divergence from unreachable blocks.
  260. if (!DT.isReachableFromEntry(Term.getParent()))
  261. return;
  262. const auto *BranchLoop = LI.getLoopFor(Term.getParent());
  263. const auto &DivDesc = SDA.getJoinBlocks(Term);
  264. // Iterate over all blocks now reachable by a disjoint path join
  265. for (const auto *JoinBlock : DivDesc.JoinDivBlocks) {
  266. taintAndPushPhiNodes(*JoinBlock);
  267. }
  268. assert(DivDesc.LoopDivBlocks.empty() || BranchLoop);
  269. for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) {
  270. propagateLoopExitDivergence(*DivExitBlock, *BranchLoop);
  271. }
  272. }
  273. void DivergenceAnalysisImpl::compute() {
  274. // Initialize worklist.
  275. auto DivValuesCopy = DivergentValues;
  276. for (const auto *DivVal : DivValuesCopy) {
  277. assert(isDivergent(*DivVal) && "Worklist invariant violated!");
  278. pushUsers(*DivVal);
  279. }
  280. // All values on the Worklist are divergent.
  281. // Their users may not have been updated yed.
  282. while (!Worklist.empty()) {
  283. const Instruction &I = *Worklist.back();
  284. Worklist.pop_back();
  285. // propagate value divergence to users
  286. assert(isDivergent(I) && "Worklist invariant violated!");
  287. pushUsers(I);
  288. }
  289. }
  290. bool DivergenceAnalysisImpl::isAlwaysUniform(const Value &V) const {
  291. return UniformOverrides.contains(&V);
  292. }
  293. bool DivergenceAnalysisImpl::isDivergent(const Value &V) const {
  294. return DivergentValues.contains(&V);
  295. }
  296. bool DivergenceAnalysisImpl::isDivergentUse(const Use &U) const {
  297. Value &V = *U.get();
  298. Instruction &I = *cast<Instruction>(U.getUser());
  299. return isDivergent(V) || isTemporalDivergent(*I.getParent(), V);
  300. }
  301. DivergenceInfo::DivergenceInfo(Function &F, const DominatorTree &DT,
  302. const PostDominatorTree &PDT, const LoopInfo &LI,
  303. const TargetTransformInfo &TTI,
  304. bool KnownReducible)
  305. : F(F) {
  306. if (!KnownReducible) {
  307. using RPOTraversal = ReversePostOrderTraversal<const Function *>;
  308. RPOTraversal FuncRPOT(&F);
  309. if (containsIrreducibleCFG<const BasicBlock *, const RPOTraversal,
  310. const LoopInfo>(FuncRPOT, LI)) {
  311. ContainsIrreducible = true;
  312. return;
  313. }
  314. }
  315. SDA = std::make_unique<SyncDependenceAnalysis>(DT, PDT, LI);
  316. DA = std::make_unique<DivergenceAnalysisImpl>(F, nullptr, DT, LI, *SDA,
  317. /* LCSSA */ false);
  318. for (auto &I : instructions(F)) {
  319. if (TTI.isSourceOfDivergence(&I)) {
  320. DA->markDivergent(I);
  321. } else if (TTI.isAlwaysUniform(&I)) {
  322. DA->addUniformOverride(I);
  323. }
  324. }
  325. for (auto &Arg : F.args()) {
  326. if (TTI.isSourceOfDivergence(&Arg)) {
  327. DA->markDivergent(Arg);
  328. }
  329. }
  330. DA->compute();
  331. }
  332. AnalysisKey DivergenceAnalysis::Key;
  333. DivergenceAnalysis::Result
  334. DivergenceAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
  335. auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
  336. auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
  337. auto &LI = AM.getResult<LoopAnalysis>(F);
  338. auto &TTI = AM.getResult<TargetIRAnalysis>(F);
  339. return DivergenceInfo(F, DT, PDT, LI, TTI, /* KnownReducible = */ false);
  340. }
  341. PreservedAnalyses
  342. DivergenceAnalysisPrinterPass::run(Function &F, FunctionAnalysisManager &FAM) {
  343. auto &DI = FAM.getResult<DivergenceAnalysis>(F);
  344. OS << "'Divergence Analysis' for function '" << F.getName() << "':\n";
  345. if (DI.hasDivergence()) {
  346. for (auto &Arg : F.args()) {
  347. OS << (DI.isDivergent(Arg) ? "DIVERGENT: " : " ");
  348. OS << Arg << "\n";
  349. }
  350. for (const BasicBlock &BB : F) {
  351. OS << "\n " << BB.getName() << ":\n";
  352. for (auto &I : BB.instructionsWithoutDebug()) {
  353. OS << (DI.isDivergent(I) ? "DIVERGENT: " : " ");
  354. OS << I << "\n";
  355. }
  356. }
  357. }
  358. return PreservedAnalyses::all();
  359. }