MachineUniformityAnalysis.cpp 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. //===- MachineUniformityAnalysis.cpp --------------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "llvm/CodeGen/MachineUniformityAnalysis.h"
  9. #include "llvm/ADT/GenericUniformityImpl.h"
  10. #include "llvm/CodeGen/MachineCycleAnalysis.h"
  11. #include "llvm/CodeGen/MachineDominators.h"
  12. #include "llvm/CodeGen/MachineRegisterInfo.h"
  13. #include "llvm/CodeGen/MachineSSAContext.h"
  14. #include "llvm/CodeGen/TargetInstrInfo.h"
  15. #include "llvm/InitializePasses.h"
  16. using namespace llvm;
  17. template <>
  18. bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
  19. const MachineInstr &I) const {
  20. for (auto &op : I.operands()) {
  21. if (!op.isReg() || !op.isDef())
  22. continue;
  23. if (isDivergent(op.getReg()))
  24. return true;
  25. }
  26. return false;
  27. }
  28. template <>
  29. bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
  30. const MachineInstr &Instr, bool AllDefsDivergent) {
  31. bool insertedDivergent = false;
  32. const auto &MRI = F.getRegInfo();
  33. const auto &TRI = *MRI.getTargetRegisterInfo();
  34. for (auto &op : Instr.operands()) {
  35. if (!op.isReg() || !op.isDef())
  36. continue;
  37. if (!op.getReg().isVirtual())
  38. continue;
  39. assert(!op.getSubReg());
  40. if (!AllDefsDivergent) {
  41. auto *RC = MRI.getRegClassOrNull(op.getReg());
  42. if (RC && !TRI.isDivergentRegClass(RC))
  43. continue;
  44. }
  45. insertedDivergent |= markDivergent(op.getReg());
  46. }
  47. return insertedDivergent;
  48. }
  49. template <>
  50. void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
  51. const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
  52. for (const MachineBasicBlock &block : F) {
  53. for (const MachineInstr &instr : block) {
  54. auto uniformity = InstrInfo.getInstructionUniformity(instr);
  55. if (uniformity == InstructionUniformity::AlwaysUniform) {
  56. addUniformOverride(instr);
  57. continue;
  58. }
  59. if (uniformity == InstructionUniformity::NeverUniform) {
  60. markDefsDivergent(instr, /* AllDefsDivergent = */ false);
  61. }
  62. }
  63. }
  64. }
  65. template <>
  66. void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
  67. Register Reg) {
  68. const auto &RegInfo = F.getRegInfo();
  69. for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
  70. if (isAlwaysUniform(UserInstr))
  71. continue;
  72. if (markDivergent(UserInstr))
  73. Worklist.push_back(&UserInstr);
  74. }
  75. }
  76. template <>
  77. void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
  78. const MachineInstr &Instr) {
  79. assert(!isAlwaysUniform(Instr));
  80. if (Instr.isTerminator())
  81. return;
  82. for (const MachineOperand &op : Instr.operands()) {
  83. if (op.isReg() && op.isDef() && op.getReg().isVirtual())
  84. pushUsers(op.getReg());
  85. }
  86. }
  87. template <>
  88. bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
  89. const MachineInstr &I, const MachineCycle &DefCycle) const {
  90. assert(!isAlwaysUniform(I));
  91. for (auto &Op : I.operands()) {
  92. if (!Op.isReg() || !Op.readsReg())
  93. continue;
  94. auto Reg = Op.getReg();
  95. assert(Reg.isVirtual());
  96. auto *Def = F.getRegInfo().getVRegDef(Reg);
  97. if (DefCycle.contains(Def->getParent()))
  98. return true;
  99. }
  100. return false;
  101. }
  102. // This ensures explicit instantiation of
  103. // GenericUniformityAnalysisImpl::ImplDeleter::operator()
  104. template class llvm::GenericUniformityInfo<MachineSSAContext>;
  105. template struct llvm::GenericUniformityAnalysisImplDeleter<
  106. llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>;
  107. MachineUniformityInfo
  108. llvm::computeMachineUniformityInfo(MachineFunction &F,
  109. const MachineCycleInfo &cycleInfo,
  110. const MachineDomTree &domTree) {
  111. assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
  112. return MachineUniformityInfo(F, domTree, cycleInfo);
  113. }
  114. namespace {
  115. /// Legacy analysis pass which computes a \ref MachineUniformityInfo.
  116. class MachineUniformityAnalysisPass : public MachineFunctionPass {
  117. MachineUniformityInfo UI;
  118. public:
  119. static char ID;
  120. MachineUniformityAnalysisPass();
  121. MachineUniformityInfo &getUniformityInfo() { return UI; }
  122. const MachineUniformityInfo &getUniformityInfo() const { return UI; }
  123. bool runOnMachineFunction(MachineFunction &F) override;
  124. void getAnalysisUsage(AnalysisUsage &AU) const override;
  125. void print(raw_ostream &OS, const Module *M = nullptr) const override;
  126. // TODO: verify analysis
  127. };
  128. class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
  129. public:
  130. static char ID;
  131. MachineUniformityInfoPrinterPass();
  132. bool runOnMachineFunction(MachineFunction &F) override;
  133. void getAnalysisUsage(AnalysisUsage &AU) const override;
  134. };
  135. } // namespace
  136. char MachineUniformityAnalysisPass::ID = 0;
  137. MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
  138. : MachineFunctionPass(ID) {
  139. initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry());
  140. }
  141. INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
  142. "Machine Uniformity Info Analysis", true, true)
  143. INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
  144. INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
  145. INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
  146. "Machine Uniformity Info Analysis", true, true)
  147. void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
  148. AU.setPreservesAll();
  149. AU.addRequired<MachineCycleInfoWrapperPass>();
  150. AU.addRequired<MachineDominatorTree>();
  151. MachineFunctionPass::getAnalysisUsage(AU);
  152. }
  153. bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
  154. auto &DomTree = getAnalysis<MachineDominatorTree>().getBase();
  155. auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
  156. UI = computeMachineUniformityInfo(MF, CI, DomTree);
  157. return false;
  158. }
  159. void MachineUniformityAnalysisPass::print(raw_ostream &OS,
  160. const Module *) const {
  161. OS << "MachineUniformityInfo for function: " << UI.getFunction().getName()
  162. << "\n";
  163. UI.print(OS);
  164. }
  165. char MachineUniformityInfoPrinterPass::ID = 0;
  166. MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
  167. : MachineFunctionPass(ID) {
  168. initializeMachineUniformityInfoPrinterPassPass(
  169. *PassRegistry::getPassRegistry());
  170. }
  171. INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
  172. "print-machine-uniformity",
  173. "Print Machine Uniformity Info Analysis", true, true)
  174. INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
  175. INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
  176. "print-machine-uniformity",
  177. "Print Machine Uniformity Info Analysis", true, true)
  178. void MachineUniformityInfoPrinterPass::getAnalysisUsage(
  179. AnalysisUsage &AU) const {
  180. AU.setPreservesAll();
  181. AU.addRequired<MachineUniformityAnalysisPass>();
  182. MachineFunctionPass::getAnalysisUsage(AU);
  183. }
  184. bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
  185. MachineFunction &F) {
  186. auto &UI = getAnalysis<MachineUniformityAnalysisPass>();
  187. UI.print(errs());
  188. return false;
  189. }