ExpandReductions.cpp 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. //===--- ExpandReductions.cpp - Expand experimental reduction intrinsics --===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass implements IR expansion for reduction intrinsics, allowing targets
  10. // to enable the intrinsics until just before codegen.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/CodeGen/ExpandReductions.h"
  14. #include "llvm/Analysis/TargetTransformInfo.h"
  15. #include "llvm/CodeGen/Passes.h"
  16. #include "llvm/IR/Function.h"
  17. #include "llvm/IR/IRBuilder.h"
  18. #include "llvm/IR/InstIterator.h"
  19. #include "llvm/IR/IntrinsicInst.h"
  20. #include "llvm/IR/Intrinsics.h"
  21. #include "llvm/IR/Module.h"
  22. #include "llvm/InitializePasses.h"
  23. #include "llvm/Pass.h"
  24. #include "llvm/Transforms/Utils/LoopUtils.h"
  25. using namespace llvm;
  26. namespace {
  27. unsigned getOpcode(Intrinsic::ID ID) {
  28. switch (ID) {
  29. case Intrinsic::vector_reduce_fadd:
  30. return Instruction::FAdd;
  31. case Intrinsic::vector_reduce_fmul:
  32. return Instruction::FMul;
  33. case Intrinsic::vector_reduce_add:
  34. return Instruction::Add;
  35. case Intrinsic::vector_reduce_mul:
  36. return Instruction::Mul;
  37. case Intrinsic::vector_reduce_and:
  38. return Instruction::And;
  39. case Intrinsic::vector_reduce_or:
  40. return Instruction::Or;
  41. case Intrinsic::vector_reduce_xor:
  42. return Instruction::Xor;
  43. case Intrinsic::vector_reduce_smax:
  44. case Intrinsic::vector_reduce_smin:
  45. case Intrinsic::vector_reduce_umax:
  46. case Intrinsic::vector_reduce_umin:
  47. return Instruction::ICmp;
  48. case Intrinsic::vector_reduce_fmax:
  49. case Intrinsic::vector_reduce_fmin:
  50. return Instruction::FCmp;
  51. default:
  52. llvm_unreachable("Unexpected ID");
  53. }
  54. }
  55. RecurKind getRK(Intrinsic::ID ID) {
  56. switch (ID) {
  57. case Intrinsic::vector_reduce_smax:
  58. return RecurKind::SMax;
  59. case Intrinsic::vector_reduce_smin:
  60. return RecurKind::SMin;
  61. case Intrinsic::vector_reduce_umax:
  62. return RecurKind::UMax;
  63. case Intrinsic::vector_reduce_umin:
  64. return RecurKind::UMin;
  65. case Intrinsic::vector_reduce_fmax:
  66. return RecurKind::FMax;
  67. case Intrinsic::vector_reduce_fmin:
  68. return RecurKind::FMin;
  69. default:
  70. return RecurKind::None;
  71. }
  72. }
  73. bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
  74. bool Changed = false;
  75. SmallVector<IntrinsicInst *, 4> Worklist;
  76. for (auto &I : instructions(F)) {
  77. if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
  78. switch (II->getIntrinsicID()) {
  79. default: break;
  80. case Intrinsic::vector_reduce_fadd:
  81. case Intrinsic::vector_reduce_fmul:
  82. case Intrinsic::vector_reduce_add:
  83. case Intrinsic::vector_reduce_mul:
  84. case Intrinsic::vector_reduce_and:
  85. case Intrinsic::vector_reduce_or:
  86. case Intrinsic::vector_reduce_xor:
  87. case Intrinsic::vector_reduce_smax:
  88. case Intrinsic::vector_reduce_smin:
  89. case Intrinsic::vector_reduce_umax:
  90. case Intrinsic::vector_reduce_umin:
  91. case Intrinsic::vector_reduce_fmax:
  92. case Intrinsic::vector_reduce_fmin:
  93. if (TTI->shouldExpandReduction(II))
  94. Worklist.push_back(II);
  95. break;
  96. }
  97. }
  98. }
  99. for (auto *II : Worklist) {
  100. FastMathFlags FMF =
  101. isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
  102. Intrinsic::ID ID = II->getIntrinsicID();
  103. RecurKind RK = getRK(ID);
  104. Value *Rdx = nullptr;
  105. IRBuilder<> Builder(II);
  106. IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
  107. Builder.setFastMathFlags(FMF);
  108. switch (ID) {
  109. default: llvm_unreachable("Unexpected intrinsic!");
  110. case Intrinsic::vector_reduce_fadd:
  111. case Intrinsic::vector_reduce_fmul: {
  112. // FMFs must be attached to the call, otherwise it's an ordered reduction
  113. // and it can't be handled by generating a shuffle sequence.
  114. Value *Acc = II->getArgOperand(0);
  115. Value *Vec = II->getArgOperand(1);
  116. if (!FMF.allowReassoc())
  117. Rdx = getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), RK);
  118. else {
  119. if (!isPowerOf2_32(
  120. cast<FixedVectorType>(Vec->getType())->getNumElements()))
  121. continue;
  122. Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
  123. Rdx = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(ID),
  124. Acc, Rdx, "bin.rdx");
  125. }
  126. break;
  127. }
  128. case Intrinsic::vector_reduce_add:
  129. case Intrinsic::vector_reduce_mul:
  130. case Intrinsic::vector_reduce_and:
  131. case Intrinsic::vector_reduce_or:
  132. case Intrinsic::vector_reduce_xor:
  133. case Intrinsic::vector_reduce_smax:
  134. case Intrinsic::vector_reduce_smin:
  135. case Intrinsic::vector_reduce_umax:
  136. case Intrinsic::vector_reduce_umin: {
  137. Value *Vec = II->getArgOperand(0);
  138. if (!isPowerOf2_32(
  139. cast<FixedVectorType>(Vec->getType())->getNumElements()))
  140. continue;
  141. Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
  142. break;
  143. }
  144. case Intrinsic::vector_reduce_fmax:
  145. case Intrinsic::vector_reduce_fmin: {
  146. // We require "nnan" to use a shuffle reduction; "nsz" is implied by the
  147. // semantics of the reduction.
  148. Value *Vec = II->getArgOperand(0);
  149. if (!isPowerOf2_32(
  150. cast<FixedVectorType>(Vec->getType())->getNumElements()) ||
  151. !FMF.noNaNs())
  152. continue;
  153. Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
  154. break;
  155. }
  156. }
  157. II->replaceAllUsesWith(Rdx);
  158. II->eraseFromParent();
  159. Changed = true;
  160. }
  161. return Changed;
  162. }
  163. class ExpandReductions : public FunctionPass {
  164. public:
  165. static char ID;
  166. ExpandReductions() : FunctionPass(ID) {
  167. initializeExpandReductionsPass(*PassRegistry::getPassRegistry());
  168. }
  169. bool runOnFunction(Function &F) override {
  170. const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  171. return expandReductions(F, TTI);
  172. }
  173. void getAnalysisUsage(AnalysisUsage &AU) const override {
  174. AU.addRequired<TargetTransformInfoWrapperPass>();
  175. AU.setPreservesCFG();
  176. }
  177. };
  178. }
  179. char ExpandReductions::ID;
  180. INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions",
  181. "Expand reduction intrinsics", false, false)
  182. INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
  183. INITIALIZE_PASS_END(ExpandReductions, "expand-reductions",
  184. "Expand reduction intrinsics", false, false)
  185. FunctionPass *llvm::createExpandReductionsPass() {
  186. return new ExpandReductions();
  187. }
  188. PreservedAnalyses ExpandReductionsPass::run(Function &F,
  189. FunctionAnalysisManager &AM) {
  190. const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
  191. if (!expandReductions(F, &TTI))
  192. return PreservedAnalyses::all();
  193. PreservedAnalyses PA;
  194. PA.preserveSet<CFGAnalyses>();
  195. return PA;
  196. }