ExpandVectorPredication.cpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. //===----- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -----===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass implements IR expansion for vector predication intrinsics, allowing
  10. // targets to enable vector predication until just before codegen.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/CodeGen/ExpandVectorPredication.h"
  14. #include "llvm/ADT/Statistic.h"
  15. #include "llvm/Analysis/TargetTransformInfo.h"
  16. #include "llvm/Analysis/ValueTracking.h"
  17. #include "llvm/CodeGen/Passes.h"
  18. #include "llvm/IR/Constants.h"
  19. #include "llvm/IR/Function.h"
  20. #include "llvm/IR/IRBuilder.h"
  21. #include "llvm/IR/InstIterator.h"
  22. #include "llvm/IR/Instructions.h"
  23. #include "llvm/IR/IntrinsicInst.h"
  24. #include "llvm/IR/Intrinsics.h"
  25. #include "llvm/IR/Module.h"
  26. #include "llvm/InitializePasses.h"
  27. #include "llvm/Pass.h"
  28. #include "llvm/Support/CommandLine.h"
  29. #include "llvm/Support/Compiler.h"
  30. #include "llvm/Support/Debug.h"
  31. #include "llvm/Support/MathExtras.h"
  32. using namespace llvm;
  33. using VPLegalization = TargetTransformInfo::VPLegalization;
  34. using VPTransform = TargetTransformInfo::VPLegalization::VPTransform;
  35. // Keep this in sync with TargetTransformInfo::VPLegalization.
  36. #define VPINTERNAL_VPLEGAL_CASES \
  37. VPINTERNAL_CASE(Legal) \
  38. VPINTERNAL_CASE(Discard) \
  39. VPINTERNAL_CASE(Convert)
  40. #define VPINTERNAL_CASE(X) "|" #X
  41. // Override options.
  42. static cl::opt<std::string> EVLTransformOverride(
  43. "expandvp-override-evl-transform", cl::init(""), cl::Hidden,
  44. cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES
  45. ". If non-empty, ignore "
  46. "TargetTransformInfo and "
  47. "always use this transformation for the %evl parameter (Used in "
  48. "testing)."));
  49. static cl::opt<std::string> MaskTransformOverride(
  50. "expandvp-override-mask-transform", cl::init(""), cl::Hidden,
  51. cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES
  52. ". If non-empty, Ignore "
  53. "TargetTransformInfo and "
  54. "always use this transformation for the %mask parameter (Used in "
  55. "testing)."));
  56. #undef VPINTERNAL_CASE
  57. #define VPINTERNAL_CASE(X) .Case(#X, VPLegalization::X)
  58. static VPTransform parseOverrideOption(const std::string &TextOpt) {
  59. return StringSwitch<VPTransform>(TextOpt) VPINTERNAL_VPLEGAL_CASES;
  60. }
  61. #undef VPINTERNAL_VPLEGAL_CASES
  62. // Whether any override options are set.
  63. static bool anyExpandVPOverridesSet() {
  64. return !EVLTransformOverride.empty() || !MaskTransformOverride.empty();
  65. }
  66. #define DEBUG_TYPE "expandvp"
  67. STATISTIC(NumFoldedVL, "Number of folded vector length params");
  68. STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations");
  69. ///// Helpers {
  70. /// \returns Whether the vector mask \p MaskVal has all lane bits set.
  71. static bool isAllTrueMask(Value *MaskVal) {
  72. auto *ConstVec = dyn_cast<ConstantVector>(MaskVal);
  73. return ConstVec && ConstVec->isAllOnesValue();
  74. }
  75. /// \returns A non-excepting divisor constant for this type.
  76. static Constant *getSafeDivisor(Type *DivTy) {
  77. assert(DivTy->isIntOrIntVectorTy() && "Unsupported divisor type");
  78. return ConstantInt::get(DivTy, 1u, false);
  79. }
  80. /// Transfer operation properties from \p OldVPI to \p NewVal.
  81. static void transferDecorations(Value &NewVal, VPIntrinsic &VPI) {
  82. auto *NewInst = dyn_cast<Instruction>(&NewVal);
  83. if (!NewInst || !isa<FPMathOperator>(NewVal))
  84. return;
  85. auto *OldFMOp = dyn_cast<FPMathOperator>(&VPI);
  86. if (!OldFMOp)
  87. return;
  88. NewInst->setFastMathFlags(OldFMOp->getFastMathFlags());
  89. }
  90. /// Transfer all properties from \p OldOp to \p NewOp and replace all uses.
  91. /// OldVP gets erased.
  92. static void replaceOperation(Value &NewOp, VPIntrinsic &OldOp) {
  93. transferDecorations(NewOp, OldOp);
  94. OldOp.replaceAllUsesWith(&NewOp);
  95. OldOp.eraseFromParent();
  96. }
  97. //// } Helpers
  98. namespace {
  99. // Expansion pass state at function scope.
  100. struct CachingVPExpander {
  101. Function &F;
  102. const TargetTransformInfo &TTI;
  103. /// \returns A (fixed length) vector with ascending integer indices
  104. /// (<0, 1, ..., NumElems-1>).
  105. /// \p Builder
  106. /// Used for instruction creation.
  107. /// \p LaneTy
  108. /// Integer element type of the result vector.
  109. /// \p NumElems
  110. /// Number of vector elements.
  111. Value *createStepVector(IRBuilder<> &Builder, Type *LaneTy,
  112. unsigned NumElems);
  113. /// \returns A bitmask that is true where the lane position is less-than \p
  114. /// EVLParam
  115. ///
  116. /// \p Builder
  117. /// Used for instruction creation.
  118. /// \p VLParam
  119. /// The explicit vector length parameter to test against the lane
  120. /// positions.
  121. /// \p ElemCount
  122. /// Static (potentially scalable) number of vector elements.
  123. Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam,
  124. ElementCount ElemCount);
  125. Value *foldEVLIntoMask(VPIntrinsic &VPI);
  126. /// "Remove" the %evl parameter of \p PI by setting it to the static vector
  127. /// length of the operation.
  128. void discardEVLParameter(VPIntrinsic &PI);
  129. /// \brief Lower this VP binary operator to a unpredicated binary operator.
  130. Value *expandPredicationInBinaryOperator(IRBuilder<> &Builder,
  131. VPIntrinsic &PI);
  132. /// \brief Lower this VP reduction to a call to an unpredicated reduction
  133. /// intrinsic.
  134. Value *expandPredicationInReduction(IRBuilder<> &Builder,
  135. VPReductionIntrinsic &PI);
  136. /// \brief Query TTI and expand the vector predication in \p P accordingly.
  137. Value *expandPredication(VPIntrinsic &PI);
  138. /// \brief Determine how and whether the VPIntrinsic \p VPI shall be
  139. /// expanded. This overrides TTI with the cl::opts listed at the top of this
  140. /// file.
  141. VPLegalization getVPLegalizationStrategy(const VPIntrinsic &VPI) const;
  142. bool UsingTTIOverrides;
  143. public:
  144. CachingVPExpander(Function &F, const TargetTransformInfo &TTI)
  145. : F(F), TTI(TTI), UsingTTIOverrides(anyExpandVPOverridesSet()) {}
  146. bool expandVectorPredication();
  147. };
  148. //// CachingVPExpander {
  149. Value *CachingVPExpander::createStepVector(IRBuilder<> &Builder, Type *LaneTy,
  150. unsigned NumElems) {
  151. // TODO add caching
  152. SmallVector<Constant *, 16> ConstElems;
  153. for (unsigned Idx = 0; Idx < NumElems; ++Idx)
  154. ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false));
  155. return ConstantVector::get(ConstElems);
  156. }
  157. Value *CachingVPExpander::convertEVLToMask(IRBuilder<> &Builder,
  158. Value *EVLParam,
  159. ElementCount ElemCount) {
  160. // TODO add caching
  161. // Scalable vector %evl conversion.
  162. if (ElemCount.isScalable()) {
  163. auto *M = Builder.GetInsertBlock()->getModule();
  164. Type *BoolVecTy = VectorType::get(Builder.getInt1Ty(), ElemCount);
  165. Function *ActiveMaskFunc = Intrinsic::getDeclaration(
  166. M, Intrinsic::get_active_lane_mask, {BoolVecTy, EVLParam->getType()});
  167. // `get_active_lane_mask` performs an implicit less-than comparison.
  168. Value *ConstZero = Builder.getInt32(0);
  169. return Builder.CreateCall(ActiveMaskFunc, {ConstZero, EVLParam});
  170. }
  171. // Fixed vector %evl conversion.
  172. Type *LaneTy = EVLParam->getType();
  173. unsigned NumElems = ElemCount.getFixedValue();
  174. Value *VLSplat = Builder.CreateVectorSplat(NumElems, EVLParam);
  175. Value *IdxVec = createStepVector(Builder, LaneTy, NumElems);
  176. return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat);
  177. }
  178. Value *
  179. CachingVPExpander::expandPredicationInBinaryOperator(IRBuilder<> &Builder,
  180. VPIntrinsic &VPI) {
  181. assert((isSafeToSpeculativelyExecute(&VPI) ||
  182. VPI.canIgnoreVectorLengthParam()) &&
  183. "Implicitly dropping %evl in non-speculatable operator!");
  184. auto OC = static_cast<Instruction::BinaryOps>(*VPI.getFunctionalOpcode());
  185. assert(Instruction::isBinaryOp(OC));
  186. Value *Op0 = VPI.getOperand(0);
  187. Value *Op1 = VPI.getOperand(1);
  188. Value *Mask = VPI.getMaskParam();
  189. // Blend in safe operands.
  190. if (Mask && !isAllTrueMask(Mask)) {
  191. switch (OC) {
  192. default:
  193. // Can safely ignore the predicate.
  194. break;
  195. // Division operators need a safe divisor on masked-off lanes (1).
  196. case Instruction::UDiv:
  197. case Instruction::SDiv:
  198. case Instruction::URem:
  199. case Instruction::SRem:
  200. // 2nd operand must not be zero.
  201. Value *SafeDivisor = getSafeDivisor(VPI.getType());
  202. Op1 = Builder.CreateSelect(Mask, Op1, SafeDivisor);
  203. }
  204. }
  205. Value *NewBinOp = Builder.CreateBinOp(OC, Op0, Op1, VPI.getName());
  206. replaceOperation(*NewBinOp, VPI);
  207. return NewBinOp;
  208. }
  209. static Value *getNeutralReductionElement(const VPReductionIntrinsic &VPI,
  210. Type *EltTy) {
  211. bool Negative = false;
  212. unsigned EltBits = EltTy->getScalarSizeInBits();
  213. switch (VPI.getIntrinsicID()) {
  214. default:
  215. llvm_unreachable("Expecting a VP reduction intrinsic");
  216. case Intrinsic::vp_reduce_add:
  217. case Intrinsic::vp_reduce_or:
  218. case Intrinsic::vp_reduce_xor:
  219. case Intrinsic::vp_reduce_umax:
  220. return Constant::getNullValue(EltTy);
  221. case Intrinsic::vp_reduce_mul:
  222. return ConstantInt::get(EltTy, 1, /*IsSigned*/ false);
  223. case Intrinsic::vp_reduce_and:
  224. case Intrinsic::vp_reduce_umin:
  225. return ConstantInt::getAllOnesValue(EltTy);
  226. case Intrinsic::vp_reduce_smin:
  227. return ConstantInt::get(EltTy->getContext(),
  228. APInt::getSignedMaxValue(EltBits));
  229. case Intrinsic::vp_reduce_smax:
  230. return ConstantInt::get(EltTy->getContext(),
  231. APInt::getSignedMinValue(EltBits));
  232. case Intrinsic::vp_reduce_fmax:
  233. Negative = true;
  234. LLVM_FALLTHROUGH;
  235. case Intrinsic::vp_reduce_fmin: {
  236. FastMathFlags Flags = VPI.getFastMathFlags();
  237. const fltSemantics &Semantics = EltTy->getFltSemantics();
  238. return !Flags.noNaNs() ? ConstantFP::getQNaN(EltTy, Negative)
  239. : !Flags.noInfs()
  240. ? ConstantFP::getInfinity(EltTy, Negative)
  241. : ConstantFP::get(EltTy,
  242. APFloat::getLargest(Semantics, Negative));
  243. }
  244. case Intrinsic::vp_reduce_fadd:
  245. return ConstantFP::getNegativeZero(EltTy);
  246. case Intrinsic::vp_reduce_fmul:
  247. return ConstantFP::get(EltTy, 1.0);
  248. }
  249. }
  250. Value *
  251. CachingVPExpander::expandPredicationInReduction(IRBuilder<> &Builder,
  252. VPReductionIntrinsic &VPI) {
  253. assert((isSafeToSpeculativelyExecute(&VPI) ||
  254. VPI.canIgnoreVectorLengthParam()) &&
  255. "Implicitly dropping %evl in non-speculatable operator!");
  256. Value *Mask = VPI.getMaskParam();
  257. Value *RedOp = VPI.getOperand(VPI.getVectorParamPos());
  258. // Insert neutral element in masked-out positions
  259. if (Mask && !isAllTrueMask(Mask)) {
  260. auto *NeutralElt = getNeutralReductionElement(VPI, VPI.getType());
  261. auto *NeutralVector = Builder.CreateVectorSplat(
  262. cast<VectorType>(RedOp->getType())->getElementCount(), NeutralElt);
  263. RedOp = Builder.CreateSelect(Mask, RedOp, NeutralVector);
  264. }
  265. Value *Reduction;
  266. Value *Start = VPI.getOperand(VPI.getStartParamPos());
  267. switch (VPI.getIntrinsicID()) {
  268. default:
  269. llvm_unreachable("Impossible reduction kind");
  270. case Intrinsic::vp_reduce_add:
  271. Reduction = Builder.CreateAddReduce(RedOp);
  272. Reduction = Builder.CreateAdd(Reduction, Start);
  273. break;
  274. case Intrinsic::vp_reduce_mul:
  275. Reduction = Builder.CreateMulReduce(RedOp);
  276. Reduction = Builder.CreateMul(Reduction, Start);
  277. break;
  278. case Intrinsic::vp_reduce_and:
  279. Reduction = Builder.CreateAndReduce(RedOp);
  280. Reduction = Builder.CreateAnd(Reduction, Start);
  281. break;
  282. case Intrinsic::vp_reduce_or:
  283. Reduction = Builder.CreateOrReduce(RedOp);
  284. Reduction = Builder.CreateOr(Reduction, Start);
  285. break;
  286. case Intrinsic::vp_reduce_xor:
  287. Reduction = Builder.CreateXorReduce(RedOp);
  288. Reduction = Builder.CreateXor(Reduction, Start);
  289. break;
  290. case Intrinsic::vp_reduce_smax:
  291. Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ true);
  292. Reduction =
  293. Builder.CreateBinaryIntrinsic(Intrinsic::smax, Reduction, Start);
  294. break;
  295. case Intrinsic::vp_reduce_smin:
  296. Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ true);
  297. Reduction =
  298. Builder.CreateBinaryIntrinsic(Intrinsic::smin, Reduction, Start);
  299. break;
  300. case Intrinsic::vp_reduce_umax:
  301. Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ false);
  302. Reduction =
  303. Builder.CreateBinaryIntrinsic(Intrinsic::umax, Reduction, Start);
  304. break;
  305. case Intrinsic::vp_reduce_umin:
  306. Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ false);
  307. Reduction =
  308. Builder.CreateBinaryIntrinsic(Intrinsic::umin, Reduction, Start);
  309. break;
  310. case Intrinsic::vp_reduce_fmax:
  311. Reduction = Builder.CreateFPMaxReduce(RedOp);
  312. transferDecorations(*Reduction, VPI);
  313. Reduction =
  314. Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, Reduction, Start);
  315. break;
  316. case Intrinsic::vp_reduce_fmin:
  317. Reduction = Builder.CreateFPMinReduce(RedOp);
  318. transferDecorations(*Reduction, VPI);
  319. Reduction =
  320. Builder.CreateBinaryIntrinsic(Intrinsic::minnum, Reduction, Start);
  321. break;
  322. case Intrinsic::vp_reduce_fadd:
  323. Reduction = Builder.CreateFAddReduce(Start, RedOp);
  324. break;
  325. case Intrinsic::vp_reduce_fmul:
  326. Reduction = Builder.CreateFMulReduce(Start, RedOp);
  327. break;
  328. }
  329. replaceOperation(*Reduction, VPI);
  330. return Reduction;
  331. }
  332. void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) {
  333. LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n");
  334. if (VPI.canIgnoreVectorLengthParam())
  335. return;
  336. Value *EVLParam = VPI.getVectorLengthParam();
  337. if (!EVLParam)
  338. return;
  339. ElementCount StaticElemCount = VPI.getStaticVectorLength();
  340. Value *MaxEVL = nullptr;
  341. Type *Int32Ty = Type::getInt32Ty(VPI.getContext());
  342. if (StaticElemCount.isScalable()) {
  343. // TODO add caching
  344. auto *M = VPI.getModule();
  345. Function *VScaleFunc =
  346. Intrinsic::getDeclaration(M, Intrinsic::vscale, Int32Ty);
  347. IRBuilder<> Builder(VPI.getParent(), VPI.getIterator());
  348. Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue());
  349. Value *VScale = Builder.CreateCall(VScaleFunc, {}, "vscale");
  350. MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size",
  351. /*NUW*/ true, /*NSW*/ false);
  352. } else {
  353. MaxEVL = ConstantInt::get(Int32Ty, StaticElemCount.getFixedValue(), false);
  354. }
  355. VPI.setVectorLengthParam(MaxEVL);
  356. }
  357. Value *CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) {
  358. LLVM_DEBUG(dbgs() << "Folding vlen for " << VPI << '\n');
  359. IRBuilder<> Builder(&VPI);
  360. // Ineffective %evl parameter and so nothing to do here.
  361. if (VPI.canIgnoreVectorLengthParam())
  362. return &VPI;
  363. // Only VP intrinsics can have an %evl parameter.
  364. Value *OldMaskParam = VPI.getMaskParam();
  365. Value *OldEVLParam = VPI.getVectorLengthParam();
  366. assert(OldMaskParam && "no mask param to fold the vl param into");
  367. assert(OldEVLParam && "no EVL param to fold away");
  368. LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n');
  369. LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n');
  370. // Convert the %evl predication into vector mask predication.
  371. ElementCount ElemCount = VPI.getStaticVectorLength();
  372. Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount);
  373. Value *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam);
  374. VPI.setMaskParam(NewMaskParam);
  375. // Drop the %evl parameter.
  376. discardEVLParameter(VPI);
  377. assert(VPI.canIgnoreVectorLengthParam() &&
  378. "transformation did not render the evl param ineffective!");
  379. // Reassess the modified instruction.
  380. return &VPI;
  381. }
  382. Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) {
  383. LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << VPI << '\n');
  384. IRBuilder<> Builder(&VPI);
  385. // Try lowering to a LLVM instruction first.
  386. auto OC = VPI.getFunctionalOpcode();
  387. if (OC && Instruction::isBinaryOp(*OC))
  388. return expandPredicationInBinaryOperator(Builder, VPI);
  389. if (auto *VPRI = dyn_cast<VPReductionIntrinsic>(&VPI))
  390. return expandPredicationInReduction(Builder, *VPRI);
  391. return &VPI;
  392. }
  393. //// } CachingVPExpander
  394. struct TransformJob {
  395. VPIntrinsic *PI;
  396. TargetTransformInfo::VPLegalization Strategy;
  397. TransformJob(VPIntrinsic *PI, TargetTransformInfo::VPLegalization InitStrat)
  398. : PI(PI), Strategy(InitStrat) {}
  399. bool isDone() const { return Strategy.shouldDoNothing(); }
  400. };
  401. void sanitizeStrategy(Instruction &I, VPLegalization &LegalizeStrat) {
  402. // Speculatable instructions do not strictly need predication.
  403. if (isSafeToSpeculativelyExecute(&I)) {
  404. // Converting a speculatable VP intrinsic means dropping %mask and %evl.
  405. // No need to expand %evl into the %mask only to ignore that code.
  406. if (LegalizeStrat.OpStrategy == VPLegalization::Convert)
  407. LegalizeStrat.EVLParamStrategy = VPLegalization::Discard;
  408. return;
  409. }
  410. // We have to preserve the predicating effect of %evl for this
  411. // non-speculatable VP intrinsic.
  412. // 1) Never discard %evl.
  413. // 2) If this VP intrinsic will be expanded to non-VP code, make sure that
  414. // %evl gets folded into %mask.
  415. if ((LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) ||
  416. (LegalizeStrat.OpStrategy == VPLegalization::Convert)) {
  417. LegalizeStrat.EVLParamStrategy = VPLegalization::Convert;
  418. }
  419. }
  420. VPLegalization
  421. CachingVPExpander::getVPLegalizationStrategy(const VPIntrinsic &VPI) const {
  422. auto VPStrat = TTI.getVPLegalizationStrategy(VPI);
  423. if (LLVM_LIKELY(!UsingTTIOverrides)) {
  424. // No overrides - we are in production.
  425. return VPStrat;
  426. }
  427. // Overrides set - we are in testing, the following does not need to be
  428. // efficient.
  429. VPStrat.EVLParamStrategy = parseOverrideOption(EVLTransformOverride);
  430. VPStrat.OpStrategy = parseOverrideOption(MaskTransformOverride);
  431. return VPStrat;
  432. }
  433. /// \brief Expand llvm.vp.* intrinsics as requested by \p TTI.
  434. bool CachingVPExpander::expandVectorPredication() {
  435. SmallVector<TransformJob, 16> Worklist;
  436. // Collect all VPIntrinsics that need expansion and determine their expansion
  437. // strategy.
  438. for (auto &I : instructions(F)) {
  439. auto *VPI = dyn_cast<VPIntrinsic>(&I);
  440. if (!VPI)
  441. continue;
  442. auto VPStrat = getVPLegalizationStrategy(*VPI);
  443. sanitizeStrategy(I, VPStrat);
  444. if (!VPStrat.shouldDoNothing())
  445. Worklist.emplace_back(VPI, VPStrat);
  446. }
  447. if (Worklist.empty())
  448. return false;
  449. // Transform all VPIntrinsics on the worklist.
  450. LLVM_DEBUG(dbgs() << "\n:::: Transforming " << Worklist.size()
  451. << " instructions ::::\n");
  452. for (TransformJob Job : Worklist) {
  453. // Transform the EVL parameter.
  454. switch (Job.Strategy.EVLParamStrategy) {
  455. case VPLegalization::Legal:
  456. break;
  457. case VPLegalization::Discard:
  458. discardEVLParameter(*Job.PI);
  459. break;
  460. case VPLegalization::Convert:
  461. if (foldEVLIntoMask(*Job.PI))
  462. ++NumFoldedVL;
  463. break;
  464. }
  465. Job.Strategy.EVLParamStrategy = VPLegalization::Legal;
  466. // Replace with a non-predicated operation.
  467. switch (Job.Strategy.OpStrategy) {
  468. case VPLegalization::Legal:
  469. break;
  470. case VPLegalization::Discard:
  471. llvm_unreachable("Invalid strategy for operators.");
  472. case VPLegalization::Convert:
  473. expandPredication(*Job.PI);
  474. ++NumLoweredVPOps;
  475. break;
  476. }
  477. Job.Strategy.OpStrategy = VPLegalization::Legal;
  478. assert(Job.isDone() && "incomplete transformation");
  479. }
  480. return true;
  481. }
  482. class ExpandVectorPredication : public FunctionPass {
  483. public:
  484. static char ID;
  485. ExpandVectorPredication() : FunctionPass(ID) {
  486. initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry());
  487. }
  488. bool runOnFunction(Function &F) override {
  489. const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  490. CachingVPExpander VPExpander(F, *TTI);
  491. return VPExpander.expandVectorPredication();
  492. }
  493. void getAnalysisUsage(AnalysisUsage &AU) const override {
  494. AU.addRequired<TargetTransformInfoWrapperPass>();
  495. AU.setPreservesCFG();
  496. }
  497. };
  498. } // namespace
  499. char ExpandVectorPredication::ID;
  500. INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expandvp",
  501. "Expand vector predication intrinsics", false, false)
  502. INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
  503. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  504. INITIALIZE_PASS_END(ExpandVectorPredication, "expandvp",
  505. "Expand vector predication intrinsics", false, false)
  506. FunctionPass *llvm::createExpandVectorPredicationPass() {
  507. return new ExpandVectorPredication();
  508. }
  509. PreservedAnalyses
  510. ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) {
  511. const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
  512. CachingVPExpander VPExpander(F, TTI);
  513. if (!VPExpander.expandVectorPredication())
  514. return PreservedAnalyses::all();
  515. PreservedAnalyses PA;
  516. PA.preserveSet<CFGAnalyses>();
  517. return PA;
  518. }