GenericUniformityImpl.h 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- GenericUniformAnalysis.cpp --------------------*- C++ -*------------===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. // This template implementation resides in a separate file so that it
  15. // does not get injected into every .cpp file that includes the
  16. // generic header.
  17. //
  18. // DO NOT INCLUDE THIS FILE WHEN MERELY USING UNIFORMITYINFO.
  19. //
  20. // This file should only be included by files that implement a
  21. // specialization of the relvant templates. Currently these are:
  22. // - UniformityAnalysis.cpp
  23. //
  24. // Note: The DEBUG_TYPE macro should be defined before using this
  25. // file so that any use of LLVM_DEBUG is associated with the
  26. // including file rather than this file.
  27. //
  28. //===----------------------------------------------------------------------===//
  29. ///
  30. /// \file
  31. /// \brief Implementation of uniformity analysis.
  32. ///
  33. /// The algorithm is a fixed point iteration that starts with the assumption
  34. /// that all control flow and all values are uniform. Starting from sources of
  35. /// divergence (whose discovery must be implemented by a CFG- or even
  36. /// target-specific derived class), divergence of values is propagated from
  37. /// definition to uses in a straight-forward way. The main complexity lies in
  38. /// the propagation of the impact of divergent control flow on the divergence of
  39. /// values (sync dependencies).
  40. ///
  41. //===----------------------------------------------------------------------===//
  42. #ifndef LLVM_ADT_GENERICUNIFORMITYIMPL_H
  43. #define LLVM_ADT_GENERICUNIFORMITYIMPL_H
  44. #include "llvm/ADT/GenericUniformityInfo.h"
  45. #include "llvm/ADT/SmallPtrSet.h"
  46. #include "llvm/ADT/SparseBitVector.h"
  47. #include "llvm/ADT/StringExtras.h"
  48. #include "llvm/Support/raw_ostream.h"
  49. #include <set>
  50. #define DEBUG_TYPE "uniformity"
  51. using namespace llvm;
  52. namespace llvm {
  53. template <typename Range> auto unique(Range &&R) {
  54. return std::unique(adl_begin(R), adl_end(R));
  55. }
  56. /// Construct a specially modified post-order traversal of cycles.
  57. ///
  58. /// The ModifiedPO is contructed using a virtually modified CFG as follows:
  59. ///
  60. /// 1. The successors of pre-entry nodes (predecessors of an cycle
  61. /// entry that are outside the cycle) are replaced by the
  62. /// successors of the successors of the header.
  63. /// 2. Successors of the cycle header are replaced by the exit blocks
  64. /// of the cycle.
  65. ///
  66. /// Effectively, we produce a depth-first numbering with the following
  67. /// properties:
  68. ///
  69. /// 1. Nodes after a cycle are numbered earlier than the cycle header.
  70. /// 2. The header is numbered earlier than the nodes in the cycle.
  71. /// 3. The numbering of the nodes within the cycle forms an interval
  72. /// starting with the header.
  73. ///
  74. /// Effectively, the virtual modification arranges the nodes in a
  75. /// cycle as a DAG with the header as the sole leaf, and successors of
  76. /// the header as the roots. A reverse traversal of this numbering has
  77. /// the following invariant on the unmodified original CFG:
  78. ///
  79. /// Each node is visited after all its predecessors, except if that
  80. /// predecessor is the cycle header.
  81. ///
  82. template <typename ContextT> class ModifiedPostOrder {
  83. public:
  84. using BlockT = typename ContextT::BlockT;
  85. using FunctionT = typename ContextT::FunctionT;
  86. using DominatorTreeT = typename ContextT::DominatorTreeT;
  87. using CycleInfoT = GenericCycleInfo<ContextT>;
  88. using CycleT = typename CycleInfoT::CycleT;
  89. using const_iterator = typename std::vector<BlockT *>::const_iterator;
  90. ModifiedPostOrder(const ContextT &C) : Context(C) {}
  91. bool empty() const { return m_order.empty(); }
  92. size_t size() const { return m_order.size(); }
  93. void clear() { m_order.clear(); }
  94. void compute(const CycleInfoT &CI);
  95. unsigned count(BlockT *BB) const { return POIndex.count(BB); }
  96. const BlockT *operator[](size_t idx) const { return m_order[idx]; }
  97. void appendBlock(const BlockT &BB, bool isReducibleCycleHeader = false) {
  98. POIndex[&BB] = m_order.size();
  99. m_order.push_back(&BB);
  100. LLVM_DEBUG(dbgs() << "ModifiedPO(" << POIndex[&BB]
  101. << "): " << Context.print(&BB) << "\n");
  102. if (isReducibleCycleHeader)
  103. ReducibleCycleHeaders.insert(&BB);
  104. }
  105. unsigned getIndex(const BlockT *BB) const {
  106. assert(POIndex.count(BB));
  107. return POIndex.lookup(BB);
  108. }
  109. bool isReducibleCycleHeader(const BlockT *BB) const {
  110. return ReducibleCycleHeaders.contains(BB);
  111. }
  112. private:
  113. SmallVector<const BlockT *> m_order;
  114. DenseMap<const BlockT *, unsigned> POIndex;
  115. SmallPtrSet<const BlockT *, 32> ReducibleCycleHeaders;
  116. const ContextT &Context;
  117. void computeCyclePO(const CycleInfoT &CI, const CycleT *Cycle,
  118. SmallPtrSetImpl<BlockT *> &Finalized);
  119. void computeStackPO(SmallVectorImpl<BlockT *> &Stack, const CycleInfoT &CI,
  120. const CycleT *Cycle,
  121. SmallPtrSetImpl<BlockT *> &Finalized);
  122. };
  123. template <typename> class DivergencePropagator;
  124. /// \class GenericSyncDependenceAnalysis
  125. ///
  126. /// \brief Locate join blocks for disjoint paths starting at a divergent branch.
  127. ///
  128. /// An analysis per divergent branch that returns the set of basic
  129. /// blocks whose phi nodes become divergent due to divergent control.
  130. /// These are the blocks that are reachable by two disjoint paths from
  131. /// the branch, or cycle exits reachable along a path that is disjoint
  132. /// from a path to the cycle latch.
  133. // --- Above line is not a doxygen comment; intentionally left blank ---
  134. //
  135. // Originally implemented in SyncDependenceAnalysis.cpp for DivergenceAnalysis.
  136. //
  137. // The SyncDependenceAnalysis is used in the UniformityAnalysis to model
  138. // control-induced divergence in phi nodes.
  139. //
  140. // -- Reference --
  141. // The algorithm is an extension of Section 5 of
  142. //
  143. // An abstract interpretation for SPMD divergence
  144. // on reducible control flow graphs.
  145. // Julian Rosemann, Simon Moll and Sebastian Hack
  146. // POPL '21
  147. //
  148. //
  149. // -- Sync dependence --
  150. // Sync dependence characterizes the control flow aspect of the
  151. // propagation of branch divergence. For example,
  152. //
  153. // %cond = icmp slt i32 %tid, 10
  154. // br i1 %cond, label %then, label %else
  155. // then:
  156. // br label %merge
  157. // else:
  158. // br label %merge
  159. // merge:
  160. // %a = phi i32 [ 0, %then ], [ 1, %else ]
  161. //
  162. // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
  163. // because %tid is not on its use-def chains, %a is sync dependent on %tid
  164. // because the branch "br i1 %cond" depends on %tid and affects which value %a
  165. // is assigned to.
  166. //
  167. //
  168. // -- Reduction to SSA construction --
  169. // There are two disjoint paths from A to X, if a certain variant of SSA
  170. // construction places a phi node in X under the following set-up scheme.
  171. //
  172. // This variant of SSA construction ignores incoming undef values.
  173. // That is paths from the entry without a definition do not result in
  174. // phi nodes.
  175. //
  176. // entry
  177. // / \
  178. // A \
  179. // / \ Y
  180. // B C /
  181. // \ / \ /
  182. // D E
  183. // \ /
  184. // F
  185. //
  186. // Assume that A contains a divergent branch. We are interested
  187. // in the set of all blocks where each block is reachable from A
  188. // via two disjoint paths. This would be the set {D, F} in this
  189. // case.
  190. // To generally reduce this query to SSA construction we introduce
  191. // a virtual variable x and assign to x different values in each
  192. // successor block of A.
  193. //
  194. // entry
  195. // / \
  196. // A \
  197. // / \ Y
  198. // x = 0 x = 1 /
  199. // \ / \ /
  200. // D E
  201. // \ /
  202. // F
  203. //
  204. // Our flavor of SSA construction for x will construct the following
  205. //
  206. // entry
  207. // / \
  208. // A \
  209. // / \ Y
  210. // x0 = 0 x1 = 1 /
  211. // \ / \ /
  212. // x2 = phi E
  213. // \ /
  214. // x3 = phi
  215. //
  216. // The blocks D and F contain phi nodes and are thus each reachable
  217. // by two disjoins paths from A.
  218. //
  219. // -- Remarks --
  220. // * In case of cycle exits we need to check for temporal divergence.
  221. // To this end, we check whether the definition of x differs between the
  222. // cycle exit and the cycle header (_after_ SSA construction).
  223. //
  224. // * In the presence of irreducible control flow, the fixed point is
  225. // reached only after multiple iterations. This is because labels
  226. // reaching the header of a cycle must be repropagated through the
  227. // cycle. This is true even in a reducible cycle, since the labels
  228. // may have been produced by a nested irreducible cycle.
  229. //
  230. // * Note that SyncDependenceAnalysis is not concerned with the points
  231. // of convergence in an irreducible cycle. It's only purpose is to
  232. // identify join blocks. The "diverged entry" criterion is
  233. // separately applied on join blocks to determine if an entire
  234. // irreducible cycle is assumed to be divergent.
  235. //
  236. // * Relevant related work:
  237. // A simple algorithm for global data flow analysis problems.
  238. // Matthew S. Hecht and Jeffrey D. Ullman.
  239. // SIAM Journal on Computing, 4(4):519–532, December 1975.
  240. //
  241. template <typename ContextT> class GenericSyncDependenceAnalysis {
  242. public:
  243. using BlockT = typename ContextT::BlockT;
  244. using DominatorTreeT = typename ContextT::DominatorTreeT;
  245. using FunctionT = typename ContextT::FunctionT;
  246. using ValueRefT = typename ContextT::ValueRefT;
  247. using InstructionT = typename ContextT::InstructionT;
  248. using CycleInfoT = GenericCycleInfo<ContextT>;
  249. using CycleT = typename CycleInfoT::CycleT;
  250. using ConstBlockSet = SmallPtrSet<const BlockT *, 4>;
  251. using ModifiedPO = ModifiedPostOrder<ContextT>;
  252. // * if BlockLabels[B] == C then C is the dominating definition at
  253. // block B
  254. // * if BlockLabels[B] == nullptr then we haven't seen B yet
  255. // * if BlockLabels[B] == B then:
  256. // - B is a join point of disjoint paths from X, or,
  257. // - B is an immediate successor of X (initial value), or,
  258. // - B is X
  259. using BlockLabelMap = DenseMap<const BlockT *, const BlockT *>;
  260. /// Information discovered by the sync dependence analysis for each
  261. /// divergent branch.
  262. struct DivergenceDescriptor {
  263. // Join points of diverged paths.
  264. ConstBlockSet JoinDivBlocks;
  265. // Divergent cycle exits
  266. ConstBlockSet CycleDivBlocks;
  267. // Labels assigned to blocks on diverged paths.
  268. BlockLabelMap BlockLabels;
  269. };
  270. using DivergencePropagatorT = DivergencePropagator<ContextT>;
  271. GenericSyncDependenceAnalysis(const ContextT &Context,
  272. const DominatorTreeT &DT, const CycleInfoT &CI);
  273. /// \brief Computes divergent join points and cycle exits caused by branch
  274. /// divergence in \p Term.
  275. ///
  276. /// This returns a pair of sets:
  277. /// * The set of blocks which are reachable by disjoint paths from
  278. /// \p Term.
  279. /// * The set also contains cycle exits if there two disjoint paths:
  280. /// one from \p Term to the cycle exit and another from \p Term to
  281. /// the cycle header.
  282. const DivergenceDescriptor &getJoinBlocks(const BlockT *DivTermBlock);
  283. private:
  284. static DivergenceDescriptor EmptyDivergenceDesc;
  285. ModifiedPO CyclePO;
  286. const DominatorTreeT &DT;
  287. const CycleInfoT &CI;
  288. DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>>
  289. CachedControlDivDescs;
  290. };
  291. /// \brief Analysis that identifies uniform values in a data-parallel
  292. /// execution.
  293. ///
  294. /// This analysis propagates divergence in a data-parallel context
  295. /// from sources of divergence to all users. It can be instantiated
  296. /// for an IR that provides a suitable SSAContext.
  297. template <typename ContextT> class GenericUniformityAnalysisImpl {
  298. public:
  299. using BlockT = typename ContextT::BlockT;
  300. using FunctionT = typename ContextT::FunctionT;
  301. using ValueRefT = typename ContextT::ValueRefT;
  302. using ConstValueRefT = typename ContextT::ConstValueRefT;
  303. using InstructionT = typename ContextT::InstructionT;
  304. using DominatorTreeT = typename ContextT::DominatorTreeT;
  305. using CycleInfoT = GenericCycleInfo<ContextT>;
  306. using CycleT = typename CycleInfoT::CycleT;
  307. using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>;
  308. using DivergenceDescriptorT =
  309. typename SyncDependenceAnalysisT::DivergenceDescriptor;
  310. using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap;
  311. GenericUniformityAnalysisImpl(const FunctionT &F, const DominatorTreeT &DT,
  312. const CycleInfoT &CI,
  313. const TargetTransformInfo *TTI)
  314. : Context(CI.getSSAContext()), F(F), CI(CI), TTI(TTI), DT(DT),
  315. SDA(Context, DT, CI) {}
  316. void initialize();
  317. const FunctionT &getFunction() const { return F; }
  318. /// \brief Mark \p UniVal as a value that is always uniform.
  319. void addUniformOverride(const InstructionT &Instr);
  320. /// \brief Mark \p DivVal as a value that is always divergent.
  321. /// \returns Whether the tracked divergence state of \p DivVal changed.
  322. bool markDivergent(const InstructionT &I);
  323. bool markDivergent(ConstValueRefT DivVal);
  324. bool markDefsDivergent(const InstructionT &Instr,
  325. bool AllDefsDivergent = true);
  326. /// \brief Propagate divergence to all instructions in the region.
  327. /// Divergence is seeded by calls to \p markDivergent.
  328. void compute();
  329. /// \brief Whether any value was marked or analyzed to be divergent.
  330. bool hasDivergence() const { return !DivergentValues.empty(); }
  331. /// \brief Whether \p Val will always return a uniform value regardless of its
  332. /// operands
  333. bool isAlwaysUniform(const InstructionT &Instr) const;
  334. bool hasDivergentDefs(const InstructionT &I) const;
  335. bool isDivergent(const InstructionT &I) const {
  336. if (I.isTerminator()) {
  337. return DivergentTermBlocks.contains(I.getParent());
  338. }
  339. return hasDivergentDefs(I);
  340. };
  341. /// \brief Whether \p Val is divergent at its definition.
  342. bool isDivergent(ConstValueRefT V) const { return DivergentValues.count(V); }
  343. bool hasDivergentTerminator(const BlockT &B) const {
  344. return DivergentTermBlocks.contains(&B);
  345. }
  346. void print(raw_ostream &out) const;
  347. protected:
  348. /// \brief Value/block pair representing a single phi input.
  349. struct PhiInput {
  350. ConstValueRefT value;
  351. BlockT *predBlock;
  352. PhiInput(ConstValueRefT value, BlockT *predBlock)
  353. : value(value), predBlock(predBlock) {}
  354. };
  355. const ContextT &Context;
  356. const FunctionT &F;
  357. const CycleInfoT &CI;
  358. const TargetTransformInfo *TTI = nullptr;
  359. // Detected/marked divergent values.
  360. std::set<ConstValueRefT> DivergentValues;
  361. SmallPtrSet<const BlockT *, 32> DivergentTermBlocks;
  362. // Internal worklist for divergence propagation.
  363. std::vector<const InstructionT *> Worklist;
  364. /// \brief Mark \p Term as divergent and push all Instructions that become
  365. /// divergent as a result on the worklist.
  366. void analyzeControlDivergence(const InstructionT &Term);
  367. private:
  368. const DominatorTreeT &DT;
  369. // Recognized cycles with divergent exits.
  370. SmallPtrSet<const CycleT *, 16> DivergentExitCycles;
  371. // Cycles assumed to be divergent.
  372. //
  373. // We don't use a set here because every insertion needs an explicit
  374. // traversal of all existing members.
  375. SmallVector<const CycleT *> AssumedDivergent;
  376. // The SDA links divergent branches to divergent control-flow joins.
  377. SyncDependenceAnalysisT SDA;
  378. // Set of known-uniform values.
  379. SmallPtrSet<const InstructionT *, 32> UniformOverrides;
  380. /// \brief Mark all nodes in \p JoinBlock as divergent and push them on
  381. /// the worklist.
  382. void taintAndPushAllDefs(const BlockT &JoinBlock);
  383. /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on
  384. /// the worklist.
  385. void taintAndPushPhiNodes(const BlockT &JoinBlock);
  386. /// \brief Identify all Instructions that become divergent because \p DivExit
  387. /// is a divergent cycle exit of \p DivCycle. Mark those instructions as
  388. /// divergent and push them on the worklist.
  389. void propagateCycleExitDivergence(const BlockT &DivExit,
  390. const CycleT &DivCycle);
  391. /// \brief Internal implementation function for propagateCycleExitDivergence.
  392. void analyzeCycleExitDivergence(const CycleT &OuterDivCycle);
  393. /// \brief Mark all instruction as divergent that use a value defined in \p
  394. /// OuterDivCycle. Push their users on the worklist.
  395. void analyzeTemporalDivergence(const InstructionT &I,
  396. const CycleT &OuterDivCycle);
  397. /// \brief Push all users of \p Val (in the region) to the worklist.
  398. void pushUsers(const InstructionT &I);
  399. void pushUsers(ConstValueRefT V);
  400. bool usesValueFromCycle(const InstructionT &I, const CycleT &DefCycle) const;
  401. /// \brief Whether \p Val is divergent when read in \p ObservingBlock.
  402. bool isTemporalDivergent(const BlockT &ObservingBlock,
  403. ConstValueRefT Val) const;
  404. };
  405. template <typename ImplT>
  406. void GenericUniformityAnalysisImplDeleter<ImplT>::operator()(ImplT *Impl) {
  407. delete Impl;
  408. }
  409. /// Compute divergence starting with a divergent branch.
  410. template <typename ContextT> class DivergencePropagator {
  411. public:
  412. using BlockT = typename ContextT::BlockT;
  413. using DominatorTreeT = typename ContextT::DominatorTreeT;
  414. using FunctionT = typename ContextT::FunctionT;
  415. using ValueRefT = typename ContextT::ValueRefT;
  416. using CycleInfoT = GenericCycleInfo<ContextT>;
  417. using CycleT = typename CycleInfoT::CycleT;
  418. using ModifiedPO = ModifiedPostOrder<ContextT>;
  419. using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>;
  420. using DivergenceDescriptorT =
  421. typename SyncDependenceAnalysisT::DivergenceDescriptor;
  422. using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap;
  423. const ModifiedPO &CyclePOT;
  424. const DominatorTreeT &DT;
  425. const CycleInfoT &CI;
  426. const BlockT &DivTermBlock;
  427. const ContextT &Context;
  428. // Track blocks that receive a new label. Every time we relabel a
  429. // cycle header, we another pass over the modified post-order in
  430. // order to propagate the header label. The bit vector also allows
  431. // us to skip labels that have not changed.
  432. SparseBitVector<> FreshLabels;
  433. // divergent join and cycle exit descriptor.
  434. std::unique_ptr<DivergenceDescriptorT> DivDesc;
  435. BlockLabelMapT &BlockLabels;
  436. DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT,
  437. const CycleInfoT &CI, const BlockT &DivTermBlock)
  438. : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock),
  439. Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT),
  440. BlockLabels(DivDesc->BlockLabels) {}
  441. void printDefs(raw_ostream &Out) {
  442. Out << "Propagator::BlockLabels {\n";
  443. for (int BlockIdx = (int)CyclePOT.size() - 1; BlockIdx >= 0; --BlockIdx) {
  444. const auto *Block = CyclePOT[BlockIdx];
  445. const auto *Label = BlockLabels[Block];
  446. Out << Context.print(Block) << "(" << BlockIdx << ") : ";
  447. if (!Label) {
  448. Out << "<null>\n";
  449. } else {
  450. Out << Context.print(Label) << "\n";
  451. }
  452. }
  453. Out << "}\n";
  454. }
  455. // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
  456. // causes a divergent join.
  457. bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) {
  458. const auto *OldLabel = BlockLabels[&SuccBlock];
  459. LLVM_DEBUG(dbgs() << "labeling " << Context.print(&SuccBlock) << ":\n"
  460. << "\tpushed label: " << Context.print(&PushedLabel)
  461. << "\n"
  462. << "\told label: " << Context.print(OldLabel) << "\n");
  463. // Early exit if there is no change in the label.
  464. if (OldLabel == &PushedLabel)
  465. return false;
  466. if (OldLabel != &SuccBlock) {
  467. auto SuccIdx = CyclePOT.getIndex(&SuccBlock);
  468. // Assigning a new label, mark this in FreshLabels.
  469. LLVM_DEBUG(dbgs() << "\tfresh label: " << SuccIdx << "\n");
  470. FreshLabels.set(SuccIdx);
  471. }
  472. // This is not a join if the succ was previously unlabeled.
  473. if (!OldLabel) {
  474. LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&PushedLabel)
  475. << "\n");
  476. BlockLabels[&SuccBlock] = &PushedLabel;
  477. return false;
  478. }
  479. // This is a new join. Label the join block as itself, and not as
  480. // the pushed label.
  481. LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&SuccBlock) << "\n");
  482. BlockLabels[&SuccBlock] = &SuccBlock;
  483. return true;
  484. }
  485. // visiting a virtual cycle exit edge from the cycle header --> temporal
  486. // divergence on join
  487. bool visitCycleExitEdge(const BlockT &ExitBlock, const BlockT &Label) {
  488. if (!computeJoin(ExitBlock, Label))
  489. return false;
  490. // Identified a divergent cycle exit
  491. DivDesc->CycleDivBlocks.insert(&ExitBlock);
  492. LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(&ExitBlock)
  493. << "\n");
  494. return true;
  495. }
  496. // process \p SuccBlock with reaching definition \p Label
  497. bool visitEdge(const BlockT &SuccBlock, const BlockT &Label) {
  498. if (!computeJoin(SuccBlock, Label))
  499. return false;
  500. // Divergent, disjoint paths join.
  501. DivDesc->JoinDivBlocks.insert(&SuccBlock);
  502. LLVM_DEBUG(dbgs() << "\tDivergent join: " << Context.print(&SuccBlock)
  503. << "\n");
  504. return true;
  505. }
  506. std::unique_ptr<DivergenceDescriptorT> computeJoinPoints() {
  507. assert(DivDesc);
  508. LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
  509. << Context.print(&DivTermBlock) << "\n");
  510. // Early stopping criterion
  511. int FloorIdx = CyclePOT.size() - 1;
  512. const BlockT *FloorLabel = nullptr;
  513. int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
  514. // Bootstrap with branch targets
  515. auto const *DivTermCycle = CI.getCycle(&DivTermBlock);
  516. for (const auto *SuccBlock : successors(&DivTermBlock)) {
  517. if (DivTermCycle && !DivTermCycle->contains(SuccBlock)) {
  518. // If DivTerm exits the cycle immediately, computeJoin() might
  519. // not reach SuccBlock with a different label. We need to
  520. // check for this exit now.
  521. DivDesc->CycleDivBlocks.insert(SuccBlock);
  522. LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
  523. << Context.print(SuccBlock) << "\n");
  524. }
  525. auto SuccIdx = CyclePOT.getIndex(SuccBlock);
  526. visitEdge(*SuccBlock, *SuccBlock);
  527. FloorIdx = std::min<int>(FloorIdx, SuccIdx);
  528. }
  529. while (true) {
  530. auto BlockIdx = FreshLabels.find_last();
  531. if (BlockIdx == -1 || BlockIdx < FloorIdx)
  532. break;
  533. LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));
  534. FreshLabels.reset(BlockIdx);
  535. if (BlockIdx == DivTermIdx) {
  536. LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n");
  537. continue;
  538. }
  539. const auto *Block = CyclePOT[BlockIdx];
  540. LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index "
  541. << BlockIdx << "\n");
  542. const auto *Label = BlockLabels[Block];
  543. assert(Label);
  544. bool CausedJoin = false;
  545. int LoweredFloorIdx = FloorIdx;
  546. // If the current block is the header of a reducible cycle that
  547. // contains the divergent branch, then the label should be
  548. // propagated to the cycle exits. Such a header is the "last
  549. // possible join" of any disjoint paths within this cycle. This
  550. // prevents detection of spurious joins at the entries of any
  551. // irreducible child cycles.
  552. //
  553. // This conclusion about the header is true for any choice of DFS:
  554. //
  555. // If some DFS has a reducible cycle C with header H, then for
  556. // any other DFS, H is the header of a cycle C' that is a
  557. // superset of C. For a divergent branch inside the subgraph
  558. // C, any join node inside C is either H, or some node
  559. // encountered without passing through H.
  560. //
  561. auto getReducibleParent = [&](const BlockT *Block) -> const CycleT * {
  562. if (!CyclePOT.isReducibleCycleHeader(Block))
  563. return nullptr;
  564. const auto *BlockCycle = CI.getCycle(Block);
  565. if (BlockCycle->contains(&DivTermBlock))
  566. return BlockCycle;
  567. return nullptr;
  568. };
  569. if (const auto *BlockCycle = getReducibleParent(Block)) {
  570. SmallVector<BlockT *, 4> BlockCycleExits;
  571. BlockCycle->getExitBlocks(BlockCycleExits);
  572. for (auto *BlockCycleExit : BlockCycleExits) {
  573. CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
  574. LoweredFloorIdx =
  575. std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
  576. }
  577. } else {
  578. for (const auto *SuccBlock : successors(Block)) {
  579. CausedJoin |= visitEdge(*SuccBlock, *Label);
  580. LoweredFloorIdx =
  581. std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
  582. }
  583. }
  584. // Floor update
  585. if (CausedJoin) {
  586. // 1. Different labels pushed to successors
  587. FloorIdx = LoweredFloorIdx;
  588. } else if (FloorLabel != Label) {
  589. // 2. No join caused BUT we pushed a label that is different than the
  590. // last pushed label
  591. FloorIdx = LoweredFloorIdx;
  592. FloorLabel = Label;
  593. }
  594. }
  595. LLVM_DEBUG(dbgs() << "Final labeling:\n"; printDefs(dbgs()));
  596. // Check every cycle containing DivTermBlock for exit divergence.
  597. // A cycle has exit divergence if the label of an exit block does
  598. // not match the label of its header.
  599. for (const auto *Cycle = CI.getCycle(&DivTermBlock); Cycle;
  600. Cycle = Cycle->getParentCycle()) {
  601. if (Cycle->isReducible()) {
  602. // The exit divergence of a reducible cycle is recorded while
  603. // propagating labels.
  604. continue;
  605. }
  606. SmallVector<BlockT *> Exits;
  607. Cycle->getExitBlocks(Exits);
  608. auto *Header = Cycle->getHeader();
  609. auto *HeaderLabel = BlockLabels[Header];
  610. for (const auto *Exit : Exits) {
  611. if (BlockLabels[Exit] != HeaderLabel) {
  612. // Identified a divergent cycle exit
  613. DivDesc->CycleDivBlocks.insert(Exit);
  614. LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(Exit)
  615. << "\n");
  616. }
  617. }
  618. }
  619. return std::move(DivDesc);
  620. }
  621. };
  622. template <typename ContextT>
  623. typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor
  624. llvm::GenericSyncDependenceAnalysis<ContextT>::EmptyDivergenceDesc;
  625. template <typename ContextT>
  626. llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis(
  627. const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
  628. : CyclePO(Context), DT(DT), CI(CI) {
  629. CyclePO.compute(CI);
  630. }
  631. template <typename ContextT>
  632. auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
  633. const BlockT *DivTermBlock) -> const DivergenceDescriptor & {
  634. // trivial case
  635. if (succ_size(DivTermBlock) <= 1) {
  636. return EmptyDivergenceDesc;
  637. }
  638. // already available in cache?
  639. auto ItCached = CachedControlDivDescs.find(DivTermBlock);
  640. if (ItCached != CachedControlDivDescs.end())
  641. return *ItCached->second;
  642. // compute all join points
  643. DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock);
  644. auto DivDesc = Propagator.computeJoinPoints();
  645. auto printBlockSet = [&](ConstBlockSet &Blocks) {
  646. return Printable([&](raw_ostream &Out) {
  647. Out << "[";
  648. ListSeparator LS;
  649. for (const auto *BB : Blocks) {
  650. Out << LS << CI.getSSAContext().print(BB);
  651. }
  652. Out << "]\n";
  653. });
  654. };
  655. LLVM_DEBUG(
  656. dbgs() << "\nResult (" << CI.getSSAContext().print(DivTermBlock)
  657. << "):\n JoinDivBlocks: " << printBlockSet(DivDesc->JoinDivBlocks)
  658. << " CycleDivBlocks: " << printBlockSet(DivDesc->CycleDivBlocks)
  659. << "\n");
  660. (void)printBlockSet;
  661. auto ItInserted =
  662. CachedControlDivDescs.try_emplace(DivTermBlock, std::move(DivDesc));
  663. assert(ItInserted.second);
  664. return *ItInserted.first->second;
  665. }
  666. template <typename ContextT>
  667. bool GenericUniformityAnalysisImpl<ContextT>::markDivergent(
  668. const InstructionT &I) {
  669. if (I.isTerminator()) {
  670. if (DivergentTermBlocks.insert(I.getParent()).second) {
  671. LLVM_DEBUG(dbgs() << "marked divergent term block: "
  672. << Context.print(I.getParent()) << "\n");
  673. return true;
  674. }
  675. return false;
  676. }
  677. return markDefsDivergent(I);
  678. }
  679. template <typename ContextT>
  680. bool GenericUniformityAnalysisImpl<ContextT>::markDivergent(
  681. ConstValueRefT Val) {
  682. if (DivergentValues.insert(Val).second) {
  683. LLVM_DEBUG(dbgs() << "marked divergent: " << Context.print(Val) << "\n");
  684. return true;
  685. }
  686. return false;
  687. }
  688. template <typename ContextT>
  689. void GenericUniformityAnalysisImpl<ContextT>::addUniformOverride(
  690. const InstructionT &Instr) {
  691. UniformOverrides.insert(&Instr);
  692. }
  693. template <typename ContextT>
  694. void GenericUniformityAnalysisImpl<ContextT>::analyzeTemporalDivergence(
  695. const InstructionT &I, const CycleT &OuterDivCycle) {
  696. if (isDivergent(I))
  697. return;
  698. LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << Context.print(&I)
  699. << "\n");
  700. if (!usesValueFromCycle(I, OuterDivCycle))
  701. return;
  702. if (isAlwaysUniform(I))
  703. return;
  704. if (markDivergent(I))
  705. Worklist.push_back(&I);
  706. }
  707. // Mark all external users of values defined inside \param
  708. // OuterDivCycle as divergent.
  709. //
  710. // This follows all live out edges wherever they may lead. Potential
  711. // users of values defined inside DivCycle could be anywhere in the
  712. // dominance region of DivCycle (including its fringes for phi nodes).
  713. // A cycle C dominates a block B iff every path from the entry block
  714. // to B must pass through a block contained in C. If C is a reducible
  715. // cycle (or natural loop), C dominates B iff the header of C
  716. // dominates B. But in general, we iteratively examine cycle cycle
  717. // exits and their successors.
  718. template <typename ContextT>
  719. void GenericUniformityAnalysisImpl<ContextT>::analyzeCycleExitDivergence(
  720. const CycleT &OuterDivCycle) {
  721. // Set of blocks that are dominated by the cycle, i.e., each is only
  722. // reachable from paths that pass through the cycle.
  723. SmallPtrSet<BlockT *, 16> DomRegion;
  724. // The boundary of DomRegion, formed by blocks that are not
  725. // dominated by the cycle.
  726. SmallVector<BlockT *> DomFrontier;
  727. OuterDivCycle.getExitBlocks(DomFrontier);
  728. // Returns true if BB is dominated by the cycle.
  729. auto isInDomRegion = [&](BlockT *BB) {
  730. for (auto *P : predecessors(BB)) {
  731. if (OuterDivCycle.contains(P))
  732. continue;
  733. if (DomRegion.count(P))
  734. continue;
  735. return false;
  736. }
  737. return true;
  738. };
  739. // Keep advancing the frontier along successor edges, while
  740. // promoting blocks to DomRegion.
  741. while (true) {
  742. bool Promoted = false;
  743. SmallVector<BlockT *> Temp;
  744. for (auto *W : DomFrontier) {
  745. if (!isInDomRegion(W)) {
  746. Temp.push_back(W);
  747. continue;
  748. }
  749. DomRegion.insert(W);
  750. Promoted = true;
  751. for (auto *Succ : successors(W)) {
  752. if (DomRegion.contains(Succ))
  753. continue;
  754. Temp.push_back(Succ);
  755. }
  756. }
  757. if (!Promoted)
  758. break;
  759. DomFrontier = Temp;
  760. }
  761. // At DomFrontier, only the PHI nodes are affected by temporal
  762. // divergence.
  763. for (const auto *UserBlock : DomFrontier) {
  764. LLVM_DEBUG(dbgs() << "Analyze phis after cycle exit: "
  765. << Context.print(UserBlock) << "\n");
  766. for (const auto &Phi : UserBlock->phis()) {
  767. LLVM_DEBUG(dbgs() << " " << Context.print(&Phi) << "\n");
  768. analyzeTemporalDivergence(Phi, OuterDivCycle);
  769. }
  770. }
  771. // All instructions inside the dominance region are affected by
  772. // temporal divergence.
  773. for (const auto *UserBlock : DomRegion) {
  774. LLVM_DEBUG(dbgs() << "Analyze non-phi users after cycle exit: "
  775. << Context.print(UserBlock) << "\n");
  776. for (const auto &I : *UserBlock) {
  777. LLVM_DEBUG(dbgs() << " " << Context.print(&I) << "\n");
  778. analyzeTemporalDivergence(I, OuterDivCycle);
  779. }
  780. }
  781. }
  782. template <typename ContextT>
  783. void GenericUniformityAnalysisImpl<ContextT>::propagateCycleExitDivergence(
  784. const BlockT &DivExit, const CycleT &InnerDivCycle) {
  785. LLVM_DEBUG(dbgs() << "\tpropCycleExitDiv " << Context.print(&DivExit)
  786. << "\n");
  787. auto *DivCycle = &InnerDivCycle;
  788. auto *OuterDivCycle = DivCycle;
  789. auto *ExitLevelCycle = CI.getCycle(&DivExit);
  790. const unsigned CycleExitDepth =
  791. ExitLevelCycle ? ExitLevelCycle->getDepth() : 0;
  792. // Find outer-most cycle that does not contain \p DivExit
  793. while (DivCycle && DivCycle->getDepth() > CycleExitDepth) {
  794. LLVM_DEBUG(dbgs() << " Found exiting cycle: "
  795. << Context.print(DivCycle->getHeader()) << "\n");
  796. OuterDivCycle = DivCycle;
  797. DivCycle = DivCycle->getParentCycle();
  798. }
  799. LLVM_DEBUG(dbgs() << "\tOuter-most exiting cycle: "
  800. << Context.print(OuterDivCycle->getHeader()) << "\n");
  801. if (!DivergentExitCycles.insert(OuterDivCycle).second)
  802. return;
  803. // Exit divergence does not matter if the cycle itself is assumed to
  804. // be divergent.
  805. for (const auto *C : AssumedDivergent) {
  806. if (C->contains(OuterDivCycle))
  807. return;
  808. }
  809. analyzeCycleExitDivergence(*OuterDivCycle);
  810. }
  811. template <typename ContextT>
  812. void GenericUniformityAnalysisImpl<ContextT>::taintAndPushAllDefs(
  813. const BlockT &BB) {
  814. LLVM_DEBUG(dbgs() << "taintAndPushAllDefs " << Context.print(&BB) << "\n");
  815. for (const auto &I : instrs(BB)) {
  816. // Terminators do not produce values; they are divergent only if
  817. // the condition is divergent. That is handled when the divergent
  818. // condition is placed in the worklist.
  819. if (I.isTerminator())
  820. break;
  821. // Mark this as divergent. We don't check if the instruction is
  822. // always uniform. In a cycle where the thread convergence is not
  823. // statically known, the instruction is not statically converged,
  824. // and its outputs cannot be statically uniform.
  825. if (markDivergent(I))
  826. Worklist.push_back(&I);
  827. }
  828. }
  829. /// Mark divergent phi nodes in a join block
  830. template <typename ContextT>
  831. void GenericUniformityAnalysisImpl<ContextT>::taintAndPushPhiNodes(
  832. const BlockT &JoinBlock) {
  833. LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << Context.print(&JoinBlock)
  834. << "\n");
  835. for (const auto &Phi : JoinBlock.phis()) {
  836. if (ContextT::isConstantValuePhi(Phi))
  837. continue;
  838. if (markDivergent(Phi))
  839. Worklist.push_back(&Phi);
  840. }
  841. }
  842. /// Add \p Candidate to \p Cycles if it is not already contained in \p Cycles.
  843. ///
  844. /// \return true iff \p Candidate was added to \p Cycles.
  845. template <typename CycleT>
  846. static bool insertIfNotContained(SmallVector<CycleT *> &Cycles,
  847. CycleT *Candidate) {
  848. if (llvm::any_of(Cycles,
  849. [Candidate](CycleT *C) { return C->contains(Candidate); }))
  850. return false;
  851. Cycles.push_back(Candidate);
  852. return true;
  853. }
  854. /// Return the outermost cycle made divergent by branch outside it.
  855. ///
  856. /// If two paths that diverged outside an irreducible cycle join
  857. /// inside that cycle, then that whole cycle is assumed to be
  858. /// divergent. This does not apply if the cycle is reducible.
  859. template <typename CycleT, typename BlockT>
  860. static const CycleT *getExtDivCycle(const CycleT *Cycle,
  861. const BlockT *DivTermBlock,
  862. const BlockT *JoinBlock) {
  863. assert(Cycle);
  864. assert(Cycle->contains(JoinBlock));
  865. if (Cycle->contains(DivTermBlock))
  866. return nullptr;
  867. if (Cycle->isReducible()) {
  868. assert(Cycle->getHeader() == JoinBlock);
  869. return nullptr;
  870. }
  871. const auto *Parent = Cycle->getParentCycle();
  872. while (Parent && !Parent->contains(DivTermBlock)) {
  873. // If the join is inside a child, then the parent must be
  874. // irreducible. The only join in a reducible cyle is its own
  875. // header.
  876. assert(!Parent->isReducible());
  877. Cycle = Parent;
  878. Parent = Cycle->getParentCycle();
  879. }
  880. LLVM_DEBUG(dbgs() << "cycle made divergent by external branch\n");
  881. return Cycle;
  882. }
  883. /// Return the outermost cycle made divergent by branch inside it.
  884. ///
  885. /// This checks the "diverged entry" criterion defined in the
  886. /// docs/ConvergenceAnalysis.html.
  887. template <typename ContextT, typename CycleT, typename BlockT,
  888. typename DominatorTreeT>
  889. static const CycleT *
  890. getIntDivCycle(const CycleT *Cycle, const BlockT *DivTermBlock,
  891. const BlockT *JoinBlock, const DominatorTreeT &DT,
  892. ContextT &Context) {
  893. LLVM_DEBUG(dbgs() << "examine join " << Context.print(JoinBlock)
  894. << "for internal branch " << Context.print(DivTermBlock)
  895. << "\n");
  896. if (DT.properlyDominates(DivTermBlock, JoinBlock))
  897. return nullptr;
  898. // Find the smallest common cycle, if one exists.
  899. assert(Cycle && Cycle->contains(JoinBlock));
  900. while (Cycle && !Cycle->contains(DivTermBlock)) {
  901. Cycle = Cycle->getParentCycle();
  902. }
  903. if (!Cycle || Cycle->isReducible())
  904. return nullptr;
  905. if (DT.properlyDominates(Cycle->getHeader(), JoinBlock))
  906. return nullptr;
  907. LLVM_DEBUG(dbgs() << " header " << Context.print(Cycle->getHeader())
  908. << " does not dominate join\n");
  909. const auto *Parent = Cycle->getParentCycle();
  910. while (Parent && !DT.properlyDominates(Parent->getHeader(), JoinBlock)) {
  911. LLVM_DEBUG(dbgs() << " header " << Context.print(Parent->getHeader())
  912. << " does not dominate join\n");
  913. Cycle = Parent;
  914. Parent = Parent->getParentCycle();
  915. }
  916. LLVM_DEBUG(dbgs() << " cycle made divergent by internal branch\n");
  917. return Cycle;
  918. }
  919. template <typename ContextT, typename CycleT, typename BlockT,
  920. typename DominatorTreeT>
  921. static const CycleT *
  922. getOutermostDivergentCycle(const CycleT *Cycle, const BlockT *DivTermBlock,
  923. const BlockT *JoinBlock, const DominatorTreeT &DT,
  924. ContextT &Context) {
  925. if (!Cycle)
  926. return nullptr;
  927. // First try to expand Cycle to the largest that contains JoinBlock
  928. // but not DivTermBlock.
  929. const auto *Ext = getExtDivCycle(Cycle, DivTermBlock, JoinBlock);
  930. // Continue expanding to the largest cycle that contains both.
  931. const auto *Int = getIntDivCycle(Cycle, DivTermBlock, JoinBlock, DT, Context);
  932. if (Int)
  933. return Int;
  934. return Ext;
  935. }
  936. template <typename ContextT>
  937. void GenericUniformityAnalysisImpl<ContextT>::analyzeControlDivergence(
  938. const InstructionT &Term) {
  939. const auto *DivTermBlock = Term.getParent();
  940. DivergentTermBlocks.insert(DivTermBlock);
  941. LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Context.print(DivTermBlock)
  942. << "\n");
  943. // Don't propagate divergence from unreachable blocks.
  944. if (!DT.isReachableFromEntry(DivTermBlock))
  945. return;
  946. const auto &DivDesc = SDA.getJoinBlocks(DivTermBlock);
  947. SmallVector<const CycleT *> DivCycles;
  948. // Iterate over all blocks now reachable by a disjoint path join
  949. for (const auto *JoinBlock : DivDesc.JoinDivBlocks) {
  950. const auto *Cycle = CI.getCycle(JoinBlock);
  951. LLVM_DEBUG(dbgs() << "visiting join block " << Context.print(JoinBlock)
  952. << "\n");
  953. if (const auto *Outermost = getOutermostDivergentCycle(
  954. Cycle, DivTermBlock, JoinBlock, DT, Context)) {
  955. LLVM_DEBUG(dbgs() << "found divergent cycle\n");
  956. DivCycles.push_back(Outermost);
  957. continue;
  958. }
  959. taintAndPushPhiNodes(*JoinBlock);
  960. }
  961. // Sort by order of decreasing depth. This allows later cycles to be skipped
  962. // because they are already contained in earlier ones.
  963. llvm::sort(DivCycles, [](const CycleT *A, const CycleT *B) {
  964. return A->getDepth() > B->getDepth();
  965. });
  966. // Cycles that are assumed divergent due to the diverged entry
  967. // criterion potentially contain temporal divergence depending on
  968. // the DFS chosen. Conservatively, all values produced in such a
  969. // cycle are assumed divergent. "Cycle invariant" values may be
  970. // assumed uniform, but that requires further analysis.
  971. for (auto *C : DivCycles) {
  972. if (!insertIfNotContained(AssumedDivergent, C))
  973. continue;
  974. LLVM_DEBUG(dbgs() << "process divergent cycle\n");
  975. for (const BlockT *BB : C->blocks()) {
  976. taintAndPushAllDefs(*BB);
  977. }
  978. }
  979. const auto *BranchCycle = CI.getCycle(DivTermBlock);
  980. assert(DivDesc.CycleDivBlocks.empty() || BranchCycle);
  981. for (const auto *DivExitBlock : DivDesc.CycleDivBlocks) {
  982. propagateCycleExitDivergence(*DivExitBlock, *BranchCycle);
  983. }
  984. }
  985. template <typename ContextT>
  986. void GenericUniformityAnalysisImpl<ContextT>::compute() {
  987. // Initialize worklist.
  988. auto DivValuesCopy = DivergentValues;
  989. for (const auto DivVal : DivValuesCopy) {
  990. assert(isDivergent(DivVal) && "Worklist invariant violated!");
  991. pushUsers(DivVal);
  992. }
  993. // All values on the Worklist are divergent.
  994. // Their users may not have been updated yet.
  995. while (!Worklist.empty()) {
  996. const InstructionT *I = Worklist.back();
  997. Worklist.pop_back();
  998. LLVM_DEBUG(dbgs() << "worklist pop: " << Context.print(I) << "\n");
  999. if (I->isTerminator()) {
  1000. analyzeControlDivergence(*I);
  1001. continue;
  1002. }
  1003. // propagate value divergence to users
  1004. assert(isDivergent(*I) && "Worklist invariant violated!");
  1005. pushUsers(*I);
  1006. }
  1007. }
  1008. template <typename ContextT>
  1009. bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
  1010. const InstructionT &Instr) const {
  1011. return UniformOverrides.contains(&Instr);
  1012. }
  1013. template <typename ContextT>
  1014. GenericUniformityInfo<ContextT>::GenericUniformityInfo(
  1015. FunctionT &Func, const DominatorTreeT &DT, const CycleInfoT &CI,
  1016. const TargetTransformInfo *TTI)
  1017. : F(&Func) {
  1018. DA.reset(new ImplT{Func, DT, CI, TTI});
  1019. DA->initialize();
  1020. DA->compute();
  1021. }
  1022. template <typename ContextT>
  1023. void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const {
  1024. bool haveDivergentArgs = false;
  1025. if (DivergentValues.empty()) {
  1026. assert(DivergentTermBlocks.empty());
  1027. assert(DivergentExitCycles.empty());
  1028. OS << "ALL VALUES UNIFORM\n";
  1029. return;
  1030. }
  1031. for (const auto &entry : DivergentValues) {
  1032. const BlockT *parent = Context.getDefBlock(entry);
  1033. if (!parent) {
  1034. if (!haveDivergentArgs) {
  1035. OS << "DIVERGENT ARGUMENTS:\n";
  1036. haveDivergentArgs = true;
  1037. }
  1038. OS << " DIVERGENT: " << Context.print(entry) << '\n';
  1039. }
  1040. }
  1041. if (!AssumedDivergent.empty()) {
  1042. OS << "CYCLES ASSSUMED DIVERGENT:\n";
  1043. for (const CycleT *cycle : AssumedDivergent) {
  1044. OS << " " << cycle->print(Context) << '\n';
  1045. }
  1046. }
  1047. if (!DivergentExitCycles.empty()) {
  1048. OS << "CYCLES WITH DIVERGENT EXIT:\n";
  1049. for (const CycleT *cycle : DivergentExitCycles) {
  1050. OS << " " << cycle->print(Context) << '\n';
  1051. }
  1052. }
  1053. for (auto &block : F) {
  1054. OS << "\nBLOCK " << Context.print(&block) << '\n';
  1055. OS << "DEFINITIONS\n";
  1056. SmallVector<ConstValueRefT, 16> defs;
  1057. Context.appendBlockDefs(defs, block);
  1058. for (auto value : defs) {
  1059. if (isDivergent(value))
  1060. OS << " DIVERGENT: ";
  1061. else
  1062. OS << " ";
  1063. OS << Context.print(value) << '\n';
  1064. }
  1065. OS << "TERMINATORS\n";
  1066. SmallVector<const InstructionT *, 8> terms;
  1067. Context.appendBlockTerms(terms, block);
  1068. bool divergentTerminators = hasDivergentTerminator(block);
  1069. for (auto *T : terms) {
  1070. if (divergentTerminators)
  1071. OS << " DIVERGENT: ";
  1072. else
  1073. OS << " ";
  1074. OS << Context.print(T) << '\n';
  1075. }
  1076. OS << "END BLOCK\n";
  1077. }
  1078. }
  1079. template <typename ContextT>
  1080. bool GenericUniformityInfo<ContextT>::hasDivergence() const {
  1081. return DA->hasDivergence();
  1082. }
  1083. /// Whether \p V is divergent at its definition.
  1084. template <typename ContextT>
  1085. bool GenericUniformityInfo<ContextT>::isDivergent(ConstValueRefT V) const {
  1086. return DA->isDivergent(V);
  1087. }
  1088. template <typename ContextT>
  1089. bool GenericUniformityInfo<ContextT>::hasDivergentTerminator(const BlockT &B) {
  1090. return DA->hasDivergentTerminator(B);
  1091. }
  1092. /// \brief T helper function for printing.
  1093. template <typename ContextT>
  1094. void GenericUniformityInfo<ContextT>::print(raw_ostream &out) const {
  1095. DA->print(out);
  1096. }
  1097. template <typename ContextT>
  1098. void llvm::ModifiedPostOrder<ContextT>::computeStackPO(
  1099. SmallVectorImpl<BlockT *> &Stack, const CycleInfoT &CI, const CycleT *Cycle,
  1100. SmallPtrSetImpl<BlockT *> &Finalized) {
  1101. LLVM_DEBUG(dbgs() << "inside computeStackPO\n");
  1102. while (!Stack.empty()) {
  1103. auto *NextBB = Stack.back();
  1104. if (Finalized.count(NextBB)) {
  1105. Stack.pop_back();
  1106. continue;
  1107. }
  1108. LLVM_DEBUG(dbgs() << " visiting " << CI.getSSAContext().print(NextBB)
  1109. << "\n");
  1110. auto *NestedCycle = CI.getCycle(NextBB);
  1111. if (Cycle != NestedCycle && (!Cycle || Cycle->contains(NestedCycle))) {
  1112. LLVM_DEBUG(dbgs() << " found a cycle\n");
  1113. while (NestedCycle->getParentCycle() != Cycle)
  1114. NestedCycle = NestedCycle->getParentCycle();
  1115. SmallVector<BlockT *, 3> NestedExits;
  1116. NestedCycle->getExitBlocks(NestedExits);
  1117. bool PushedNodes = false;
  1118. for (auto *NestedExitBB : NestedExits) {
  1119. LLVM_DEBUG(dbgs() << " examine exit: "
  1120. << CI.getSSAContext().print(NestedExitBB) << "\n");
  1121. if (Cycle && !Cycle->contains(NestedExitBB))
  1122. continue;
  1123. if (Finalized.count(NestedExitBB))
  1124. continue;
  1125. PushedNodes = true;
  1126. Stack.push_back(NestedExitBB);
  1127. LLVM_DEBUG(dbgs() << " pushed exit: "
  1128. << CI.getSSAContext().print(NestedExitBB) << "\n");
  1129. }
  1130. if (!PushedNodes) {
  1131. // All loop exits finalized -> finish this node
  1132. Stack.pop_back();
  1133. computeCyclePO(CI, NestedCycle, Finalized);
  1134. }
  1135. continue;
  1136. }
  1137. LLVM_DEBUG(dbgs() << " no nested cycle, going into DAG\n");
  1138. // DAG-style
  1139. bool PushedNodes = false;
  1140. for (auto *SuccBB : successors(NextBB)) {
  1141. LLVM_DEBUG(dbgs() << " examine succ: "
  1142. << CI.getSSAContext().print(SuccBB) << "\n");
  1143. if (Cycle && !Cycle->contains(SuccBB))
  1144. continue;
  1145. if (Finalized.count(SuccBB))
  1146. continue;
  1147. PushedNodes = true;
  1148. Stack.push_back(SuccBB);
  1149. LLVM_DEBUG(dbgs() << " pushed succ: " << CI.getSSAContext().print(SuccBB)
  1150. << "\n");
  1151. }
  1152. if (!PushedNodes) {
  1153. // Never push nodes twice
  1154. LLVM_DEBUG(dbgs() << " finishing node: "
  1155. << CI.getSSAContext().print(NextBB) << "\n");
  1156. Stack.pop_back();
  1157. Finalized.insert(NextBB);
  1158. appendBlock(*NextBB);
  1159. }
  1160. }
  1161. LLVM_DEBUG(dbgs() << "exited computeStackPO\n");
  1162. }
  1163. template <typename ContextT>
  1164. void ModifiedPostOrder<ContextT>::computeCyclePO(
  1165. const CycleInfoT &CI, const CycleT *Cycle,
  1166. SmallPtrSetImpl<BlockT *> &Finalized) {
  1167. LLVM_DEBUG(dbgs() << "inside computeCyclePO\n");
  1168. SmallVector<BlockT *> Stack;
  1169. auto *CycleHeader = Cycle->getHeader();
  1170. LLVM_DEBUG(dbgs() << " noted header: "
  1171. << CI.getSSAContext().print(CycleHeader) << "\n");
  1172. assert(!Finalized.count(CycleHeader));
  1173. Finalized.insert(CycleHeader);
  1174. // Visit the header last
  1175. LLVM_DEBUG(dbgs() << " finishing header: "
  1176. << CI.getSSAContext().print(CycleHeader) << "\n");
  1177. appendBlock(*CycleHeader, Cycle->isReducible());
  1178. // Initialize with immediate successors
  1179. for (auto *BB : successors(CycleHeader)) {
  1180. LLVM_DEBUG(dbgs() << " examine succ: " << CI.getSSAContext().print(BB)
  1181. << "\n");
  1182. if (!Cycle->contains(BB))
  1183. continue;
  1184. if (BB == CycleHeader)
  1185. continue;
  1186. if (!Finalized.count(BB)) {
  1187. LLVM_DEBUG(dbgs() << " pushed succ: " << CI.getSSAContext().print(BB)
  1188. << "\n");
  1189. Stack.push_back(BB);
  1190. }
  1191. }
  1192. // Compute PO inside region
  1193. computeStackPO(Stack, CI, Cycle, Finalized);
  1194. LLVM_DEBUG(dbgs() << "exited computeCyclePO\n");
  1195. }
  1196. /// \brief Generically compute the modified post order.
  1197. template <typename ContextT>
  1198. void llvm::ModifiedPostOrder<ContextT>::compute(const CycleInfoT &CI) {
  1199. SmallPtrSet<BlockT *, 32> Finalized;
  1200. SmallVector<BlockT *> Stack;
  1201. auto *F = CI.getFunction();
  1202. Stack.reserve(24); // FIXME made-up number
  1203. Stack.push_back(GraphTraits<FunctionT *>::getEntryNode(F));
  1204. computeStackPO(Stack, CI, nullptr, Finalized);
  1205. }
  1206. } // namespace llvm
  1207. #undef DEBUG_TYPE
  1208. #endif // LLVM_ADT_GENERICUNIFORMITYIMPL_H
  1209. #ifdef __GNUC__
  1210. #pragma GCC diagnostic pop
  1211. #endif