ExpandReductions.cpp 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. //===--- ExpandReductions.cpp - Expand experimental reduction intrinsics --===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass implements IR expansion for reduction intrinsics, allowing targets
  10. // to enable the intrinsics until just before codegen.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/CodeGen/ExpandReductions.h"
  14. #include "llvm/Analysis/TargetTransformInfo.h"
  15. #include "llvm/CodeGen/Passes.h"
  16. #include "llvm/IR/IRBuilder.h"
  17. #include "llvm/IR/InstIterator.h"
  18. #include "llvm/IR/IntrinsicInst.h"
  19. #include "llvm/IR/Intrinsics.h"
  20. #include "llvm/InitializePasses.h"
  21. #include "llvm/Pass.h"
  22. #include "llvm/Transforms/Utils/LoopUtils.h"
  23. using namespace llvm;
  24. namespace {
  25. unsigned getOpcode(Intrinsic::ID ID) {
  26. switch (ID) {
  27. case Intrinsic::vector_reduce_fadd:
  28. return Instruction::FAdd;
  29. case Intrinsic::vector_reduce_fmul:
  30. return Instruction::FMul;
  31. case Intrinsic::vector_reduce_add:
  32. return Instruction::Add;
  33. case Intrinsic::vector_reduce_mul:
  34. return Instruction::Mul;
  35. case Intrinsic::vector_reduce_and:
  36. return Instruction::And;
  37. case Intrinsic::vector_reduce_or:
  38. return Instruction::Or;
  39. case Intrinsic::vector_reduce_xor:
  40. return Instruction::Xor;
  41. case Intrinsic::vector_reduce_smax:
  42. case Intrinsic::vector_reduce_smin:
  43. case Intrinsic::vector_reduce_umax:
  44. case Intrinsic::vector_reduce_umin:
  45. return Instruction::ICmp;
  46. case Intrinsic::vector_reduce_fmax:
  47. case Intrinsic::vector_reduce_fmin:
  48. return Instruction::FCmp;
  49. default:
  50. llvm_unreachable("Unexpected ID");
  51. }
  52. }
  53. RecurKind getRK(Intrinsic::ID ID) {
  54. switch (ID) {
  55. case Intrinsic::vector_reduce_smax:
  56. return RecurKind::SMax;
  57. case Intrinsic::vector_reduce_smin:
  58. return RecurKind::SMin;
  59. case Intrinsic::vector_reduce_umax:
  60. return RecurKind::UMax;
  61. case Intrinsic::vector_reduce_umin:
  62. return RecurKind::UMin;
  63. case Intrinsic::vector_reduce_fmax:
  64. return RecurKind::FMax;
  65. case Intrinsic::vector_reduce_fmin:
  66. return RecurKind::FMin;
  67. default:
  68. return RecurKind::None;
  69. }
  70. }
  71. bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
  72. bool Changed = false;
  73. SmallVector<IntrinsicInst *, 4> Worklist;
  74. for (auto &I : instructions(F)) {
  75. if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
  76. switch (II->getIntrinsicID()) {
  77. default: break;
  78. case Intrinsic::vector_reduce_fadd:
  79. case Intrinsic::vector_reduce_fmul:
  80. case Intrinsic::vector_reduce_add:
  81. case Intrinsic::vector_reduce_mul:
  82. case Intrinsic::vector_reduce_and:
  83. case Intrinsic::vector_reduce_or:
  84. case Intrinsic::vector_reduce_xor:
  85. case Intrinsic::vector_reduce_smax:
  86. case Intrinsic::vector_reduce_smin:
  87. case Intrinsic::vector_reduce_umax:
  88. case Intrinsic::vector_reduce_umin:
  89. case Intrinsic::vector_reduce_fmax:
  90. case Intrinsic::vector_reduce_fmin:
  91. if (TTI->shouldExpandReduction(II))
  92. Worklist.push_back(II);
  93. break;
  94. }
  95. }
  96. }
  97. for (auto *II : Worklist) {
  98. FastMathFlags FMF =
  99. isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
  100. Intrinsic::ID ID = II->getIntrinsicID();
  101. RecurKind RK = getRK(ID);
  102. Value *Rdx = nullptr;
  103. IRBuilder<> Builder(II);
  104. IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
  105. Builder.setFastMathFlags(FMF);
  106. switch (ID) {
  107. default: llvm_unreachable("Unexpected intrinsic!");
  108. case Intrinsic::vector_reduce_fadd:
  109. case Intrinsic::vector_reduce_fmul: {
  110. // FMFs must be attached to the call, otherwise it's an ordered reduction
  111. // and it can't be handled by generating a shuffle sequence.
  112. Value *Acc = II->getArgOperand(0);
  113. Value *Vec = II->getArgOperand(1);
  114. if (!FMF.allowReassoc())
  115. Rdx = getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), RK);
  116. else {
  117. if (!isPowerOf2_32(
  118. cast<FixedVectorType>(Vec->getType())->getNumElements()))
  119. continue;
  120. Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
  121. Rdx = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(ID),
  122. Acc, Rdx, "bin.rdx");
  123. }
  124. break;
  125. }
  126. case Intrinsic::vector_reduce_add:
  127. case Intrinsic::vector_reduce_mul:
  128. case Intrinsic::vector_reduce_and:
  129. case Intrinsic::vector_reduce_or:
  130. case Intrinsic::vector_reduce_xor:
  131. case Intrinsic::vector_reduce_smax:
  132. case Intrinsic::vector_reduce_smin:
  133. case Intrinsic::vector_reduce_umax:
  134. case Intrinsic::vector_reduce_umin: {
  135. Value *Vec = II->getArgOperand(0);
  136. if (!isPowerOf2_32(
  137. cast<FixedVectorType>(Vec->getType())->getNumElements()))
  138. continue;
  139. Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
  140. break;
  141. }
  142. case Intrinsic::vector_reduce_fmax:
  143. case Intrinsic::vector_reduce_fmin: {
  144. // We require "nnan" to use a shuffle reduction; "nsz" is implied by the
  145. // semantics of the reduction.
  146. Value *Vec = II->getArgOperand(0);
  147. if (!isPowerOf2_32(
  148. cast<FixedVectorType>(Vec->getType())->getNumElements()) ||
  149. !FMF.noNaNs())
  150. continue;
  151. Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
  152. break;
  153. }
  154. }
  155. II->replaceAllUsesWith(Rdx);
  156. II->eraseFromParent();
  157. Changed = true;
  158. }
  159. return Changed;
  160. }
  161. class ExpandReductions : public FunctionPass {
  162. public:
  163. static char ID;
  164. ExpandReductions() : FunctionPass(ID) {
  165. initializeExpandReductionsPass(*PassRegistry::getPassRegistry());
  166. }
  167. bool runOnFunction(Function &F) override {
  168. const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  169. return expandReductions(F, TTI);
  170. }
  171. void getAnalysisUsage(AnalysisUsage &AU) const override {
  172. AU.addRequired<TargetTransformInfoWrapperPass>();
  173. AU.setPreservesCFG();
  174. }
  175. };
  176. }
  177. char ExpandReductions::ID;
  178. INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions",
  179. "Expand reduction intrinsics", false, false)
  180. INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
  181. INITIALIZE_PASS_END(ExpandReductions, "expand-reductions",
  182. "Expand reduction intrinsics", false, false)
  183. FunctionPass *llvm::createExpandReductionsPass() {
  184. return new ExpandReductions();
  185. }
  186. PreservedAnalyses ExpandReductionsPass::run(Function &F,
  187. FunctionAnalysisManager &AM) {
  188. const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
  189. if (!expandReductions(F, &TTI))
  190. return PreservedAnalyses::all();
  191. PreservedAnalyses PA;
  192. PA.preserveSet<CFGAnalyses>();
  193. return PA;
  194. }