UniformityAnalysis.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. //===- ConvergenceUtils.cpp -----------------------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "llvm/Analysis/UniformityAnalysis.h"
  9. #include "llvm/ADT/GenericUniformityImpl.h"
  10. #include "llvm/Analysis/CycleAnalysis.h"
  11. #include "llvm/Analysis/TargetTransformInfo.h"
  12. #include "llvm/IR/Constants.h"
  13. #include "llvm/IR/Dominators.h"
  14. #include "llvm/IR/InstIterator.h"
  15. #include "llvm/IR/Instructions.h"
  16. #include "llvm/InitializePasses.h"
  17. using namespace llvm;
  18. template <>
  19. bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
  20. const Instruction &I) const {
  21. return isDivergent((const Value *)&I);
  22. }
  23. template <>
  24. bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
  25. const Instruction &Instr, bool AllDefsDivergent) {
  26. return markDivergent(&Instr);
  27. }
  28. template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
  29. for (auto &I : instructions(F)) {
  30. if (TTI->isSourceOfDivergence(&I)) {
  31. assert(!I.isTerminator());
  32. markDivergent(I);
  33. } else if (TTI->isAlwaysUniform(&I)) {
  34. addUniformOverride(I);
  35. }
  36. }
  37. for (auto &Arg : F.args()) {
  38. if (TTI->isSourceOfDivergence(&Arg)) {
  39. markDivergent(&Arg);
  40. }
  41. }
  42. }
  43. template <>
  44. void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
  45. const Value *V) {
  46. for (const auto *User : V->users()) {
  47. const auto *UserInstr = dyn_cast<const Instruction>(User);
  48. if (!UserInstr)
  49. continue;
  50. if (isAlwaysUniform(*UserInstr))
  51. continue;
  52. if (markDivergent(*UserInstr)) {
  53. Worklist.push_back(UserInstr);
  54. }
  55. }
  56. }
  57. template <>
  58. void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
  59. const Instruction &Instr) {
  60. assert(!isAlwaysUniform(Instr));
  61. if (Instr.isTerminator())
  62. return;
  63. pushUsers(cast<Value>(&Instr));
  64. }
  65. template <>
  66. bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
  67. const Instruction &I, const Cycle &DefCycle) const {
  68. if (isAlwaysUniform(I))
  69. return false;
  70. for (const Use &U : I.operands()) {
  71. if (auto *I = dyn_cast<Instruction>(&U)) {
  72. if (DefCycle.contains(I->getParent()))
  73. return true;
  74. }
  75. }
  76. return false;
  77. }
  78. // This ensures explicit instantiation of
  79. // GenericUniformityAnalysisImpl::ImplDeleter::operator()
  80. template class llvm::GenericUniformityInfo<SSAContext>;
  81. template struct llvm::GenericUniformityAnalysisImplDeleter<
  82. llvm::GenericUniformityAnalysisImpl<SSAContext>>;
  83. //===----------------------------------------------------------------------===//
  84. // UniformityInfoAnalysis and related pass implementations
  85. //===----------------------------------------------------------------------===//
  86. llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
  87. FunctionAnalysisManager &FAM) {
  88. auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
  89. auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
  90. auto &CI = FAM.getResult<CycleAnalysis>(F);
  91. return UniformityInfo{F, DT, CI, &TTI};
  92. }
  93. AnalysisKey UniformityInfoAnalysis::Key;
  94. UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
  95. : OS(OS) {}
  96. PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
  97. FunctionAnalysisManager &AM) {
  98. OS << "UniformityInfo for function '" << F.getName() << "':\n";
  99. AM.getResult<UniformityInfoAnalysis>(F).print(OS);
  100. return PreservedAnalyses::all();
  101. }
  102. //===----------------------------------------------------------------------===//
  103. // UniformityInfoWrapperPass Implementation
  104. //===----------------------------------------------------------------------===//
  105. char UniformityInfoWrapperPass::ID = 0;
  106. UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {
  107. initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry());
  108. }
  109. INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniforminfo",
  110. "Uniform Info Analysis", true, true)
  111. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  112. INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
  113. INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniforminfo",
  114. "Uniform Info Analysis", true, true)
  115. void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
  116. AU.setPreservesAll();
  117. AU.addRequired<DominatorTreeWrapperPass>();
  118. AU.addRequired<CycleInfoWrapperPass>();
  119. AU.addRequired<TargetTransformInfoWrapperPass>();
  120. }
  121. bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
  122. auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
  123. auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  124. auto &targetTransformInfo =
  125. getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  126. m_function = &F;
  127. m_uniformityInfo =
  128. UniformityInfo{F, domTree, cycleInfo, &targetTransformInfo};
  129. return false;
  130. }
  131. void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
  132. OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
  133. }
  134. void UniformityInfoWrapperPass::releaseMemory() {
  135. m_uniformityInfo = UniformityInfo{};
  136. m_function = nullptr;
  137. }