SyncDependenceAnalysis.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. //===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements an algorithm that returns for a divergent branch
  10. // the set of basic blocks whose phi nodes become divergent due to divergent
  11. // control. These are the blocks that are reachable by two disjoint paths from
  12. // the branch or loop exits that have a reaching path that is disjoint from a
  13. // path to the loop latch.
  14. //
  15. // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
  16. // control-induced divergence in phi nodes.
  17. //
  18. //
  19. // -- Reference --
  20. // The algorithm is presented in Section 5 of
  21. //
  22. // An abstract interpretation for SPMD divergence
  23. // on reducible control flow graphs.
  24. // Julian Rosemann, Simon Moll and Sebastian Hack
  25. // POPL '21
  26. //
  27. //
  28. // -- Sync dependence --
  29. // Sync dependence characterizes the control flow aspect of the
  30. // propagation of branch divergence. For example,
  31. //
  32. // %cond = icmp slt i32 %tid, 10
  33. // br i1 %cond, label %then, label %else
  34. // then:
  35. // br label %merge
  36. // else:
  37. // br label %merge
  38. // merge:
  39. // %a = phi i32 [ 0, %then ], [ 1, %else ]
  40. //
  41. // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
  42. // because %tid is not on its use-def chains, %a is sync dependent on %tid
  43. // because the branch "br i1 %cond" depends on %tid and affects which value %a
  44. // is assigned to.
  45. //
  46. //
  47. // -- Reduction to SSA construction --
  48. // There are two disjoint paths from A to X, if a certain variant of SSA
  49. // construction places a phi node in X under the following set-up scheme.
  50. //
  51. // This variant of SSA construction ignores incoming undef values.
  52. // That is paths from the entry without a definition do not result in
  53. // phi nodes.
  54. //
  55. // entry
  56. // / \
  57. // A \
  58. // / \ Y
  59. // B C /
  60. // \ / \ /
  61. // D E
  62. // \ /
  63. // F
  64. //
  65. // Assume that A contains a divergent branch. We are interested
  66. // in the set of all blocks where each block is reachable from A
  67. // via two disjoint paths. This would be the set {D, F} in this
  68. // case.
  69. // To generally reduce this query to SSA construction we introduce
  70. // a virtual variable x and assign to x different values in each
  71. // successor block of A.
  72. //
  73. // entry
  74. // / \
  75. // A \
  76. // / \ Y
  77. // x = 0 x = 1 /
  78. // \ / \ /
  79. // D E
  80. // \ /
  81. // F
  82. //
  83. // Our flavor of SSA construction for x will construct the following
  84. //
  85. // entry
  86. // / \
  87. // A \
  88. // / \ Y
  89. // x0 = 0 x1 = 1 /
  90. // \ / \ /
  91. // x2 = phi E
  92. // \ /
  93. // x3 = phi
  94. //
  95. // The blocks D and F contain phi nodes and are thus each reachable
  96. // by two disjoins paths from A.
  97. //
  98. // -- Remarks --
  99. // * In case of loop exits we need to check the disjoint path criterion for loops.
  100. // To this end, we check whether the definition of x differs between the
  101. // loop exit and the loop header (_after_ SSA construction).
  102. //
  103. // -- Known Limitations & Future Work --
  104. // * The algorithm requires reducible loops because the implementation
  105. // implicitly performs a single iteration of the underlying data flow analysis.
  106. // This was done for pragmatism, simplicity and speed.
  107. //
  108. // Relevant related work for extending the algorithm to irreducible control:
  109. // A simple algorithm for global data flow analysis problems.
  110. // Matthew S. Hecht and Jeffrey D. Ullman.
  111. // SIAM Journal on Computing, 4(4):519–532, December 1975.
  112. //
  113. // * Another reason for requiring reducible loops is that points of
  114. // synchronization in irreducible loops aren't 'obvious' - there is no unique
  115. // header where threads 'should' synchronize when entering or coming back
  116. // around from the latch.
  117. //
  118. //===----------------------------------------------------------------------===//
  119. #include "llvm/Analysis/SyncDependenceAnalysis.h"
  120. #include "llvm/ADT/SmallPtrSet.h"
  121. #include "llvm/Analysis/LoopInfo.h"
  122. #include "llvm/IR/BasicBlock.h"
  123. #include "llvm/IR/CFG.h"
  124. #include "llvm/IR/Dominators.h"
  125. #include "llvm/IR/Function.h"
  126. #include <functional>
  127. #define DEBUG_TYPE "sync-dependence"
  128. // The SDA algorithm operates on a modified CFG - we modify the edges leaving
  129. // loop headers as follows:
  130. //
  131. // * We remove all edges leaving all loop headers.
  132. // * We add additional edges from the loop headers to their exit blocks.
  133. //
  134. // The modification is virtual, that is whenever we visit a loop header we
  135. // pretend it had different successors.
  136. namespace {
  137. using namespace llvm;
  138. // Custom Post-Order Traveral
  139. //
  140. // We cannot use the vanilla (R)PO computation of LLVM because:
  141. // * We (virtually) modify the CFG.
  142. // * We want a loop-compact block enumeration, that is the numbers assigned to
  143. // blocks of a loop form an interval
  144. //
  145. using POCB = std::function<void(const BasicBlock &)>;
  146. using VisitedSet = std::set<const BasicBlock *>;
  147. using BlockStack = std::vector<const BasicBlock *>;
  148. // forward
  149. static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
  150. VisitedSet &Finalized);
  151. // for a nested region (top-level loop or nested loop)
  152. static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop,
  153. POCB CallBack, VisitedSet &Finalized) {
  154. const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr;
  155. while (!Stack.empty()) {
  156. const auto *NextBB = Stack.back();
  157. auto *NestedLoop = LI.getLoopFor(NextBB);
  158. bool IsNestedLoop = NestedLoop != Loop;
  159. // Treat the loop as a node
  160. if (IsNestedLoop) {
  161. SmallVector<BasicBlock *, 3> NestedExits;
  162. NestedLoop->getUniqueExitBlocks(NestedExits);
  163. bool PushedNodes = false;
  164. for (const auto *NestedExitBB : NestedExits) {
  165. if (NestedExitBB == LoopHeader)
  166. continue;
  167. if (Loop && !Loop->contains(NestedExitBB))
  168. continue;
  169. if (Finalized.count(NestedExitBB))
  170. continue;
  171. PushedNodes = true;
  172. Stack.push_back(NestedExitBB);
  173. }
  174. if (!PushedNodes) {
  175. // All loop exits finalized -> finish this node
  176. Stack.pop_back();
  177. computeLoopPO(LI, *NestedLoop, CallBack, Finalized);
  178. }
  179. continue;
  180. }
  181. // DAG-style
  182. bool PushedNodes = false;
  183. for (const auto *SuccBB : successors(NextBB)) {
  184. if (SuccBB == LoopHeader)
  185. continue;
  186. if (Loop && !Loop->contains(SuccBB))
  187. continue;
  188. if (Finalized.count(SuccBB))
  189. continue;
  190. PushedNodes = true;
  191. Stack.push_back(SuccBB);
  192. }
  193. if (!PushedNodes) {
  194. // Never push nodes twice
  195. Stack.pop_back();
  196. if (!Finalized.insert(NextBB).second)
  197. continue;
  198. CallBack(*NextBB);
  199. }
  200. }
  201. }
  202. static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) {
  203. VisitedSet Finalized;
  204. BlockStack Stack;
  205. Stack.reserve(24); // FIXME made-up number
  206. Stack.push_back(&F.getEntryBlock());
  207. computeStackPO(Stack, LI, nullptr, CallBack, Finalized);
  208. }
  209. static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
  210. VisitedSet &Finalized) {
  211. /// Call CallBack on all loop blocks.
  212. std::vector<const BasicBlock *> Stack;
  213. const auto *LoopHeader = Loop.getHeader();
  214. // Visit the header last
  215. Finalized.insert(LoopHeader);
  216. CallBack(*LoopHeader);
  217. // Initialize with immediate successors
  218. for (const auto *BB : successors(LoopHeader)) {
  219. if (!Loop.contains(BB))
  220. continue;
  221. if (BB == LoopHeader)
  222. continue;
  223. Stack.push_back(BB);
  224. }
  225. // Compute PO inside region
  226. computeStackPO(Stack, LI, &Loop, CallBack, Finalized);
  227. }
  228. } // namespace
  229. namespace llvm {
  230. ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc;
  231. SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT,
  232. const PostDominatorTree &PDT,
  233. const LoopInfo &LI)
  234. : DT(DT), PDT(PDT), LI(LI) {
  235. computeTopLevelPO(*DT.getRoot()->getParent(), LI,
  236. [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); });
  237. }
  238. SyncDependenceAnalysis::~SyncDependenceAnalysis() = default;
  239. namespace {
  240. // divergence propagator for reducible CFGs
  241. struct DivergencePropagator {
  242. const ModifiedPO &LoopPOT;
  243. const DominatorTree &DT;
  244. const PostDominatorTree &PDT;
  245. const LoopInfo &LI;
  246. const BasicBlock &DivTermBlock;
  247. // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at
  248. // block B
  249. // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet
  250. // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths
  251. // from X or B is an immediate successor of X (initial value).
  252. using BlockLabelVec = std::vector<const BasicBlock *>;
  253. BlockLabelVec BlockLabels;
  254. // divergent join and loop exit descriptor.
  255. std::unique_ptr<ControlDivergenceDesc> DivDesc;
  256. DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT,
  257. const PostDominatorTree &PDT, const LoopInfo &LI,
  258. const BasicBlock &DivTermBlock)
  259. : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock),
  260. BlockLabels(LoopPOT.size(), nullptr),
  261. DivDesc(new ControlDivergenceDesc) {}
  262. void printDefs(raw_ostream &Out) {
  263. Out << "Propagator::BlockLabels {\n";
  264. for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) {
  265. const auto *Label = BlockLabels[BlockIdx];
  266. Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx
  267. << ") : ";
  268. if (!Label) {
  269. Out << "<null>\n";
  270. } else {
  271. Out << Label->getName() << "\n";
  272. }
  273. }
  274. Out << "}\n";
  275. }
  276. // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
  277. // causes a divergent join.
  278. bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) {
  279. auto SuccIdx = LoopPOT.getIndexOf(SuccBlock);
  280. // unset or same reaching label
  281. const auto *OldLabel = BlockLabels[SuccIdx];
  282. if (!OldLabel || (OldLabel == &PushedLabel)) {
  283. BlockLabels[SuccIdx] = &PushedLabel;
  284. return false;
  285. }
  286. // Update the definition
  287. BlockLabels[SuccIdx] = &SuccBlock;
  288. return true;
  289. }
  290. // visiting a virtual loop exit edge from the loop header --> temporal
  291. // divergence on join
  292. bool visitLoopExitEdge(const BasicBlock &ExitBlock,
  293. const BasicBlock &DefBlock, bool FromParentLoop) {
  294. // Pushing from a non-parent loop cannot cause temporal divergence.
  295. if (!FromParentLoop)
  296. return visitEdge(ExitBlock, DefBlock);
  297. if (!computeJoin(ExitBlock, DefBlock))
  298. return false;
  299. // Identified a divergent loop exit
  300. DivDesc->LoopDivBlocks.insert(&ExitBlock);
  301. LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName()
  302. << "\n");
  303. return true;
  304. }
  305. // process \p SuccBlock with reaching definition \p DefBlock
  306. bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) {
  307. if (!computeJoin(SuccBlock, DefBlock))
  308. return false;
  309. // Divergent, disjoint paths join.
  310. DivDesc->JoinDivBlocks.insert(&SuccBlock);
  311. LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName());
  312. return true;
  313. }
  314. std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() {
  315. assert(DivDesc);
  316. LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName()
  317. << "\n");
  318. const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock);
  319. // Early stopping criterion
  320. int FloorIdx = LoopPOT.size() - 1;
  321. const BasicBlock *FloorLabel = nullptr;
  322. // bootstrap with branch targets
  323. int BlockIdx = 0;
  324. for (const auto *SuccBlock : successors(&DivTermBlock)) {
  325. auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock);
  326. BlockLabels[SuccIdx] = SuccBlock;
  327. // Find the successor with the highest index to start with
  328. BlockIdx = std::max<int>(BlockIdx, SuccIdx);
  329. FloorIdx = std::min<int>(FloorIdx, SuccIdx);
  330. // Identify immediate divergent loop exits
  331. if (!DivBlockLoop)
  332. continue;
  333. const auto *BlockLoop = LI.getLoopFor(SuccBlock);
  334. if (BlockLoop && DivBlockLoop->contains(BlockLoop))
  335. continue;
  336. DivDesc->LoopDivBlocks.insert(SuccBlock);
  337. LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: "
  338. << SuccBlock->getName() << "\n");
  339. }
  340. // propagate definitions at the immediate successors of the node in RPO
  341. for (; BlockIdx >= FloorIdx; --BlockIdx) {
  342. LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs()));
  343. // Any label available here
  344. const auto *Label = BlockLabels[BlockIdx];
  345. if (!Label)
  346. continue;
  347. // Ok. Get the block
  348. const auto *Block = LoopPOT.getBlockAt(BlockIdx);
  349. LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
  350. auto *BlockLoop = LI.getLoopFor(Block);
  351. bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block;
  352. bool CausedJoin = false;
  353. int LoweredFloorIdx = FloorIdx;
  354. if (IsLoopHeader) {
  355. // Disconnect from immediate successors and propagate directly to loop
  356. // exits.
  357. SmallVector<BasicBlock *, 4> BlockLoopExits;
  358. BlockLoop->getExitBlocks(BlockLoopExits);
  359. bool IsParentLoop = BlockLoop->contains(&DivTermBlock);
  360. for (const auto *BlockLoopExit : BlockLoopExits) {
  361. CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop);
  362. LoweredFloorIdx = std::min<int>(LoweredFloorIdx,
  363. LoopPOT.getIndexOf(*BlockLoopExit));
  364. }
  365. } else {
  366. // Acyclic successor case
  367. for (const auto *SuccBlock : successors(Block)) {
  368. CausedJoin |= visitEdge(*SuccBlock, *Label);
  369. LoweredFloorIdx =
  370. std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock));
  371. }
  372. }
  373. // Floor update
  374. if (CausedJoin) {
  375. // 1. Different labels pushed to successors
  376. FloorIdx = LoweredFloorIdx;
  377. } else if (FloorLabel != Label) {
  378. // 2. No join caused BUT we pushed a label that is different than the
  379. // last pushed label
  380. FloorIdx = LoweredFloorIdx;
  381. FloorLabel = Label;
  382. }
  383. }
  384. LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
  385. return std::move(DivDesc);
  386. }
  387. };
  388. } // end anonymous namespace
  389. #ifndef NDEBUG
  390. static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) {
  391. Out << "[";
  392. ListSeparator LS;
  393. for (const auto *BB : Blocks)
  394. Out << LS << BB->getName();
  395. Out << "]";
  396. }
  397. #endif
  398. const ControlDivergenceDesc &
  399. SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) {
  400. // trivial case
  401. if (Term.getNumSuccessors() <= 1) {
  402. return EmptyDivergenceDesc;
  403. }
  404. // already available in cache?
  405. auto ItCached = CachedControlDivDescs.find(&Term);
  406. if (ItCached != CachedControlDivDescs.end())
  407. return *ItCached->second;
  408. // compute all join points
  409. // Special handling of divergent loop exits is not needed for LCSSA
  410. const auto &TermBlock = *Term.getParent();
  411. DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock);
  412. auto DivDesc = Propagator.computeJoinPoints();
  413. LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n";
  414. dbgs() << "JoinDivBlocks: ";
  415. printBlockSet(DivDesc->JoinDivBlocks, dbgs());
  416. dbgs() << "\nLoopDivBlocks: ";
  417. printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";);
  418. auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc));
  419. assert(ItInserted.second);
  420. return *ItInserted.first->second;
  421. }
  422. } // namespace llvm