DivergenceAnalysis.h 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- llvm/Analysis/DivergenceAnalysis.h - Divergence Analysis -*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. // \file
  15. // The divergence analysis determines which instructions and branches are
  16. // divergent given a set of divergent source instructions.
  17. //
  18. //===----------------------------------------------------------------------===//
  19. #ifndef LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
  20. #define LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
  21. #include "llvm/ADT/DenseSet.h"
  22. #include "llvm/Analysis/SyncDependenceAnalysis.h"
  23. #include "llvm/IR/Function.h"
  24. #include "llvm/Pass.h"
  25. #include <vector>
  26. namespace llvm {
  27. class Module;
  28. class Value;
  29. class Instruction;
  30. class Loop;
  31. class raw_ostream;
  32. class TargetTransformInfo;
  33. /// \brief Generic divergence analysis for reducible CFGs.
  34. ///
  35. /// This analysis propagates divergence in a data-parallel context from sources
  36. /// of divergence to all users. It requires reducible CFGs. All assignments
  37. /// should be in SSA form.
  38. class DivergenceAnalysis {
  39. public:
  40. /// \brief This instance will analyze the whole function \p F or the loop \p
  41. /// RegionLoop.
  42. ///
  43. /// \param RegionLoop if non-null the analysis is restricted to \p RegionLoop.
  44. /// Otherwise the whole function is analyzed.
  45. /// \param IsLCSSAForm whether the analysis may assume that the IR in the
  46. /// region in in LCSSA form.
  47. DivergenceAnalysis(const Function &F, const Loop *RegionLoop,
  48. const DominatorTree &DT, const LoopInfo &LI,
  49. SyncDependenceAnalysis &SDA, bool IsLCSSAForm);
  50. /// \brief The loop that defines the analyzed region (if any).
  51. const Loop *getRegionLoop() const { return RegionLoop; }
  52. const Function &getFunction() const { return F; }
  53. /// \brief Whether \p BB is part of the region.
  54. bool inRegion(const BasicBlock &BB) const;
  55. /// \brief Whether \p I is part of the region.
  56. bool inRegion(const Instruction &I) const;
  57. /// \brief Mark \p UniVal as a value that is always uniform.
  58. void addUniformOverride(const Value &UniVal);
  59. /// \brief Mark \p DivVal as a value that is always divergent. Will not do so
  60. /// if `isAlwaysUniform(DivVal)`.
  61. /// \returns Whether the tracked divergence state of \p DivVal changed.
  62. bool markDivergent(const Value &DivVal);
  63. /// \brief Propagate divergence to all instructions in the region.
  64. /// Divergence is seeded by calls to \p markDivergent.
  65. void compute();
  66. /// \brief Whether any value was marked or analyzed to be divergent.
  67. bool hasDetectedDivergence() const { return !DivergentValues.empty(); }
  68. /// \brief Whether \p Val will always return a uniform value regardless of its
  69. /// operands
  70. bool isAlwaysUniform(const Value &Val) const;
  71. /// \brief Whether \p Val is divergent at its definition.
  72. bool isDivergent(const Value &Val) const;
  73. /// \brief Whether \p U is divergent. Uses of a uniform value can be
  74. /// divergent.
  75. bool isDivergentUse(const Use &U) const;
  76. void print(raw_ostream &OS, const Module *) const;
  77. private:
  78. /// \brief Mark \p Term as divergent and push all Instructions that become
  79. /// divergent as a result on the worklist.
  80. void analyzeControlDivergence(const Instruction &Term);
  81. /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on
  82. /// the worklist.
  83. void taintAndPushPhiNodes(const BasicBlock &JoinBlock);
  84. /// \brief Identify all Instructions that become divergent because \p DivExit
  85. /// is a divergent loop exit of \p DivLoop. Mark those instructions as
  86. /// divergent and push them on the worklist.
  87. void propagateLoopExitDivergence(const BasicBlock &DivExit,
  88. const Loop &DivLoop);
  89. /// \brief Internal implementation function for propagateLoopExitDivergence.
  90. void analyzeLoopExitDivergence(const BasicBlock &DivExit,
  91. const Loop &OuterDivLoop);
  92. /// \brief Mark all instruction as divergent that use a value defined in \p
  93. /// OuterDivLoop. Push their users on the worklist.
  94. void analyzeTemporalDivergence(const Instruction &I,
  95. const Loop &OuterDivLoop);
  96. /// \brief Push all users of \p Val (in the region) to the worklist.
  97. void pushUsers(const Value &I);
  98. /// \brief Whether \p Val is divergent when read in \p ObservingBlock.
  99. bool isTemporalDivergent(const BasicBlock &ObservingBlock,
  100. const Value &Val) const;
  101. /// \brief Whether \p Block is join divergent
  102. ///
  103. /// (see markBlockJoinDivergent).
  104. bool isJoinDivergent(const BasicBlock &Block) const {
  105. return DivergentJoinBlocks.contains(&Block);
  106. }
  107. private:
  108. const Function &F;
  109. // If regionLoop != nullptr, analysis is only performed within \p RegionLoop.
  110. // Otherwise, analyze the whole function
  111. const Loop *RegionLoop;
  112. const DominatorTree &DT;
  113. const LoopInfo &LI;
  114. // Recognized divergent loops
  115. DenseSet<const Loop *> DivergentLoops;
  116. // The SDA links divergent branches to divergent control-flow joins.
  117. SyncDependenceAnalysis &SDA;
  118. // Use simplified code path for LCSSA form.
  119. bool IsLCSSAForm;
  120. // Set of known-uniform values.
  121. DenseSet<const Value *> UniformOverrides;
  122. // Blocks with joining divergent control from different predecessors.
  123. DenseSet<const BasicBlock *> DivergentJoinBlocks; // FIXME Deprecated
  124. // Detected/marked divergent values.
  125. DenseSet<const Value *> DivergentValues;
  126. // Internal worklist for divergence propagation.
  127. std::vector<const Instruction *> Worklist;
  128. };
  129. /// \brief Divergence analysis frontend for GPU kernels.
  130. class GPUDivergenceAnalysis {
  131. SyncDependenceAnalysis SDA;
  132. DivergenceAnalysis DA;
  133. public:
  134. /// Runs the divergence analysis on @F, a GPU kernel
  135. GPUDivergenceAnalysis(Function &F, const DominatorTree &DT,
  136. const PostDominatorTree &PDT, const LoopInfo &LI,
  137. const TargetTransformInfo &TTI);
  138. /// Whether any divergence was detected.
  139. bool hasDivergence() const { return DA.hasDetectedDivergence(); }
  140. /// The GPU kernel this analysis result is for
  141. const Function &getFunction() const { return DA.getFunction(); }
  142. /// Whether \p V is divergent at its definition.
  143. bool isDivergent(const Value &V) const;
  144. /// Whether \p U is divergent. Uses of a uniform value can be divergent.
  145. bool isDivergentUse(const Use &U) const;
  146. /// Whether \p V is uniform/non-divergent.
  147. bool isUniform(const Value &V) const { return !isDivergent(V); }
  148. /// Whether \p U is uniform/non-divergent. Uses of a uniform value can be
  149. /// divergent.
  150. bool isUniformUse(const Use &U) const { return !isDivergentUse(U); }
  151. /// Print all divergent values in the kernel.
  152. void print(raw_ostream &OS, const Module *) const;
  153. };
  154. } // namespace llvm
  155. #endif // LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
  156. #ifdef __GNUC__
  157. #pragma GCC diagnostic pop
  158. #endif