DivergenceAnalysis.h 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- llvm/Analysis/DivergenceAnalysis.h - Divergence Analysis -*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. // \file
  15. // The divergence analysis determines which instructions and branches are
  16. // divergent given a set of divergent source instructions.
  17. //
  18. //===----------------------------------------------------------------------===//
  19. #ifndef LLVM_ANALYSIS_DIVERGENCEANALYSIS_H
  20. #define LLVM_ANALYSIS_DIVERGENCEANALYSIS_H
  21. #include "llvm/ADT/DenseSet.h"
  22. #include "llvm/Analysis/SyncDependenceAnalysis.h"
  23. #include "llvm/IR/PassManager.h"
  24. #include <vector>
  25. namespace llvm {
  26. class Function;
  27. class Instruction;
  28. class Loop;
  29. class raw_ostream;
  30. class TargetTransformInfo;
  31. class Value;
  32. /// \brief Generic divergence analysis for reducible CFGs.
  33. ///
  34. /// This analysis propagates divergence in a data-parallel context from sources
  35. /// of divergence to all users. It requires reducible CFGs. All assignments
  36. /// should be in SSA form.
  37. class DivergenceAnalysisImpl {
  38. public:
  39. /// \brief This instance will analyze the whole function \p F or the loop \p
  40. /// RegionLoop.
  41. ///
  42. /// \param RegionLoop if non-null the analysis is restricted to \p RegionLoop.
  43. /// Otherwise the whole function is analyzed.
  44. /// \param IsLCSSAForm whether the analysis may assume that the IR in the
  45. /// region in LCSSA form.
  46. DivergenceAnalysisImpl(const Function &F, const Loop *RegionLoop,
  47. const DominatorTree &DT, const LoopInfo &LI,
  48. SyncDependenceAnalysis &SDA, bool IsLCSSAForm);
  49. /// \brief The loop that defines the analyzed region (if any).
  50. const Loop *getRegionLoop() const { return RegionLoop; }
  51. const Function &getFunction() const { return F; }
  52. /// \brief Whether \p BB is part of the region.
  53. bool inRegion(const BasicBlock &BB) const;
  54. /// \brief Whether \p I is part of the region.
  55. bool inRegion(const Instruction &I) const;
  56. /// \brief Mark \p UniVal as a value that is always uniform.
  57. void addUniformOverride(const Value &UniVal);
  58. /// \brief Mark \p DivVal as a value that is always divergent. Will not do so
  59. /// if `isAlwaysUniform(DivVal)`.
  60. /// \returns Whether the tracked divergence state of \p DivVal changed.
  61. bool markDivergent(const Value &DivVal);
  62. /// \brief Propagate divergence to all instructions in the region.
  63. /// Divergence is seeded by calls to \p markDivergent.
  64. void compute();
  65. /// \brief Whether any value was marked or analyzed to be divergent.
  66. bool hasDetectedDivergence() const { return !DivergentValues.empty(); }
  67. /// \brief Whether \p Val will always return a uniform value regardless of its
  68. /// operands
  69. bool isAlwaysUniform(const Value &Val) const;
  70. /// \brief Whether \p Val is divergent at its definition.
  71. bool isDivergent(const Value &Val) const;
  72. /// \brief Whether \p U is divergent. Uses of a uniform value can be
  73. /// divergent.
  74. bool isDivergentUse(const Use &U) const;
  75. private:
  76. /// \brief Mark \p Term as divergent and push all Instructions that become
  77. /// divergent as a result on the worklist.
  78. void analyzeControlDivergence(const Instruction &Term);
  79. /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on
  80. /// the worklist.
  81. void taintAndPushPhiNodes(const BasicBlock &JoinBlock);
  82. /// \brief Identify all Instructions that become divergent because \p DivExit
  83. /// is a divergent loop exit of \p DivLoop. Mark those instructions as
  84. /// divergent and push them on the worklist.
  85. void propagateLoopExitDivergence(const BasicBlock &DivExit,
  86. const Loop &DivLoop);
  87. /// \brief Internal implementation function for propagateLoopExitDivergence.
  88. void analyzeLoopExitDivergence(const BasicBlock &DivExit,
  89. const Loop &OuterDivLoop);
  90. /// \brief Mark all instruction as divergent that use a value defined in \p
  91. /// OuterDivLoop. Push their users on the worklist.
  92. void analyzeTemporalDivergence(const Instruction &I,
  93. const Loop &OuterDivLoop);
  94. /// \brief Push all users of \p Val (in the region) to the worklist.
  95. void pushUsers(const Value &I);
  96. /// \brief Whether \p Val is divergent when read in \p ObservingBlock.
  97. bool isTemporalDivergent(const BasicBlock &ObservingBlock,
  98. const Value &Val) const;
  99. private:
  100. const Function &F;
  101. // If regionLoop != nullptr, analysis is only performed within \p RegionLoop.
  102. // Otherwise, analyze the whole function
  103. const Loop *RegionLoop;
  104. const DominatorTree &DT;
  105. const LoopInfo &LI;
  106. // Recognized divergent loops
  107. DenseSet<const Loop *> DivergentLoops;
  108. // The SDA links divergent branches to divergent control-flow joins.
  109. SyncDependenceAnalysis &SDA;
  110. // Use simplified code path for LCSSA form.
  111. bool IsLCSSAForm;
  112. // Set of known-uniform values.
  113. DenseSet<const Value *> UniformOverrides;
  114. // Detected/marked divergent values.
  115. DenseSet<const Value *> DivergentValues;
  116. // Internal worklist for divergence propagation.
  117. std::vector<const Instruction *> Worklist;
  118. };
  119. class DivergenceInfo {
  120. Function &F;
  121. // If the function contains an irreducible region the divergence
  122. // analysis can run indefinitely. We set ContainsIrreducible and no
  123. // analysis is actually performed on the function. All values in
  124. // this function are conservatively reported as divergent instead.
  125. bool ContainsIrreducible = false;
  126. std::unique_ptr<SyncDependenceAnalysis> SDA;
  127. std::unique_ptr<DivergenceAnalysisImpl> DA;
  128. public:
  129. DivergenceInfo(Function &F, const DominatorTree &DT,
  130. const PostDominatorTree &PDT, const LoopInfo &LI,
  131. const TargetTransformInfo &TTI, bool KnownReducible);
  132. /// Whether any divergence was detected.
  133. bool hasDivergence() const {
  134. return ContainsIrreducible || DA->hasDetectedDivergence();
  135. }
  136. /// The GPU kernel this analysis result is for
  137. const Function &getFunction() const { return F; }
  138. /// Whether \p V is divergent at its definition.
  139. bool isDivergent(const Value &V) const {
  140. return ContainsIrreducible || DA->isDivergent(V);
  141. }
  142. /// Whether \p U is divergent. Uses of a uniform value can be divergent.
  143. bool isDivergentUse(const Use &U) const {
  144. return ContainsIrreducible || DA->isDivergentUse(U);
  145. }
  146. /// Whether \p V is uniform/non-divergent.
  147. bool isUniform(const Value &V) const { return !isDivergent(V); }
  148. /// Whether \p U is uniform/non-divergent. Uses of a uniform value can be
  149. /// divergent.
  150. bool isUniformUse(const Use &U) const { return !isDivergentUse(U); }
  151. };
  152. /// \brief Divergence analysis frontend for GPU kernels.
  153. class DivergenceAnalysis : public AnalysisInfoMixin<DivergenceAnalysis> {
  154. friend AnalysisInfoMixin<DivergenceAnalysis>;
  155. static AnalysisKey Key;
  156. public:
  157. using Result = DivergenceInfo;
  158. /// Runs the divergence analysis on @F, a GPU kernel
  159. Result run(Function &F, FunctionAnalysisManager &AM);
  160. };
  161. /// Printer pass to dump divergence analysis results.
  162. struct DivergenceAnalysisPrinterPass
  163. : public PassInfoMixin<DivergenceAnalysisPrinterPass> {
  164. DivergenceAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {}
  165. PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM);
  166. private:
  167. raw_ostream &OS;
  168. }; // class DivergenceAnalysisPrinterPass
  169. } // namespace llvm
  170. #endif // LLVM_ANALYSIS_DIVERGENCEANALYSIS_H
  171. #ifdef __GNUC__
  172. #pragma GCC diagnostic pop
  173. #endif