ExpandVectorPredication.cpp 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720
  1. //===----- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -----===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass implements IR expansion for vector predication intrinsics, allowing
  10. // targets to enable vector predication until just before codegen.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/CodeGen/ExpandVectorPredication.h"
  14. #include "llvm/ADT/Statistic.h"
  15. #include "llvm/Analysis/TargetTransformInfo.h"
  16. #include "llvm/Analysis/ValueTracking.h"
  17. #include "llvm/Analysis/VectorUtils.h"
  18. #include "llvm/CodeGen/Passes.h"
  19. #include "llvm/IR/Constants.h"
  20. #include "llvm/IR/Function.h"
  21. #include "llvm/IR/IRBuilder.h"
  22. #include "llvm/IR/InstIterator.h"
  23. #include "llvm/IR/Instructions.h"
  24. #include "llvm/IR/IntrinsicInst.h"
  25. #include "llvm/IR/Intrinsics.h"
  26. #include "llvm/InitializePasses.h"
  27. #include "llvm/Pass.h"
  28. #include "llvm/Support/CommandLine.h"
  29. #include "llvm/Support/Compiler.h"
  30. #include "llvm/Support/Debug.h"
  31. #include <optional>
  32. using namespace llvm;
  33. using VPLegalization = TargetTransformInfo::VPLegalization;
  34. using VPTransform = TargetTransformInfo::VPLegalization::VPTransform;
  35. // Keep this in sync with TargetTransformInfo::VPLegalization.
  36. #define VPINTERNAL_VPLEGAL_CASES \
  37. VPINTERNAL_CASE(Legal) \
  38. VPINTERNAL_CASE(Discard) \
  39. VPINTERNAL_CASE(Convert)
  40. #define VPINTERNAL_CASE(X) "|" #X
  41. // Override options.
  42. static cl::opt<std::string> EVLTransformOverride(
  43. "expandvp-override-evl-transform", cl::init(""), cl::Hidden,
  44. cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES
  45. ". If non-empty, ignore "
  46. "TargetTransformInfo and "
  47. "always use this transformation for the %evl parameter (Used in "
  48. "testing)."));
  49. static cl::opt<std::string> MaskTransformOverride(
  50. "expandvp-override-mask-transform", cl::init(""), cl::Hidden,
  51. cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES
  52. ". If non-empty, Ignore "
  53. "TargetTransformInfo and "
  54. "always use this transformation for the %mask parameter (Used in "
  55. "testing)."));
  56. #undef VPINTERNAL_CASE
  57. #define VPINTERNAL_CASE(X) .Case(#X, VPLegalization::X)
  58. static VPTransform parseOverrideOption(const std::string &TextOpt) {
  59. return StringSwitch<VPTransform>(TextOpt) VPINTERNAL_VPLEGAL_CASES;
  60. }
  61. #undef VPINTERNAL_VPLEGAL_CASES
  62. // Whether any override options are set.
  63. static bool anyExpandVPOverridesSet() {
  64. return !EVLTransformOverride.empty() || !MaskTransformOverride.empty();
  65. }
  66. #define DEBUG_TYPE "expandvp"
  67. STATISTIC(NumFoldedVL, "Number of folded vector length params");
  68. STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations");
  69. ///// Helpers {
  70. /// \returns Whether the vector mask \p MaskVal has all lane bits set.
  71. static bool isAllTrueMask(Value *MaskVal) {
  72. if (Value *SplattedVal = getSplatValue(MaskVal))
  73. if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
  74. return ConstValue->isAllOnesValue();
  75. return false;
  76. }
  77. /// \returns A non-excepting divisor constant for this type.
  78. static Constant *getSafeDivisor(Type *DivTy) {
  79. assert(DivTy->isIntOrIntVectorTy() && "Unsupported divisor type");
  80. return ConstantInt::get(DivTy, 1u, false);
  81. }
  82. /// Transfer operation properties from \p OldVPI to \p NewVal.
  83. static void transferDecorations(Value &NewVal, VPIntrinsic &VPI) {
  84. auto *NewInst = dyn_cast<Instruction>(&NewVal);
  85. if (!NewInst || !isa<FPMathOperator>(NewVal))
  86. return;
  87. auto *OldFMOp = dyn_cast<FPMathOperator>(&VPI);
  88. if (!OldFMOp)
  89. return;
  90. NewInst->setFastMathFlags(OldFMOp->getFastMathFlags());
  91. }
  92. /// Transfer all properties from \p OldOp to \p NewOp and replace all uses.
  93. /// OldVP gets erased.
  94. static void replaceOperation(Value &NewOp, VPIntrinsic &OldOp) {
  95. transferDecorations(NewOp, OldOp);
  96. OldOp.replaceAllUsesWith(&NewOp);
  97. OldOp.eraseFromParent();
  98. }
  99. static bool maySpeculateLanes(VPIntrinsic &VPI) {
  100. // The result of VP reductions depends on the mask and evl.
  101. if (isa<VPReductionIntrinsic>(VPI))
  102. return false;
  103. // Fallback to whether the intrinsic is speculatable.
  104. std::optional<unsigned> OpcOpt = VPI.getFunctionalOpcode();
  105. unsigned FunctionalOpc = OpcOpt.value_or((unsigned)Instruction::Call);
  106. return isSafeToSpeculativelyExecuteWithOpcode(FunctionalOpc, &VPI);
  107. }
  108. //// } Helpers
  109. namespace {
  110. // Expansion pass state at function scope.
  111. struct CachingVPExpander {
  112. Function &F;
  113. const TargetTransformInfo &TTI;
  114. /// \returns A (fixed length) vector with ascending integer indices
  115. /// (<0, 1, ..., NumElems-1>).
  116. /// \p Builder
  117. /// Used for instruction creation.
  118. /// \p LaneTy
  119. /// Integer element type of the result vector.
  120. /// \p NumElems
  121. /// Number of vector elements.
  122. Value *createStepVector(IRBuilder<> &Builder, Type *LaneTy,
  123. unsigned NumElems);
  124. /// \returns A bitmask that is true where the lane position is less-than \p
  125. /// EVLParam
  126. ///
  127. /// \p Builder
  128. /// Used for instruction creation.
  129. /// \p VLParam
  130. /// The explicit vector length parameter to test against the lane
  131. /// positions.
  132. /// \p ElemCount
  133. /// Static (potentially scalable) number of vector elements.
  134. Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam,
  135. ElementCount ElemCount);
  136. Value *foldEVLIntoMask(VPIntrinsic &VPI);
  137. /// "Remove" the %evl parameter of \p PI by setting it to the static vector
  138. /// length of the operation.
  139. void discardEVLParameter(VPIntrinsic &PI);
  140. /// Lower this VP binary operator to a unpredicated binary operator.
  141. Value *expandPredicationInBinaryOperator(IRBuilder<> &Builder,
  142. VPIntrinsic &PI);
  143. /// Lower this VP reduction to a call to an unpredicated reduction intrinsic.
  144. Value *expandPredicationInReduction(IRBuilder<> &Builder,
  145. VPReductionIntrinsic &PI);
  146. /// Lower this VP memory operation to a non-VP intrinsic.
  147. Value *expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder,
  148. VPIntrinsic &VPI);
  149. /// Lower this VP comparison to a call to an unpredicated comparison.
  150. Value *expandPredicationInComparison(IRBuilder<> &Builder,
  151. VPCmpIntrinsic &PI);
  152. /// Query TTI and expand the vector predication in \p P accordingly.
  153. Value *expandPredication(VPIntrinsic &PI);
  154. /// Determine how and whether the VPIntrinsic \p VPI shall be expanded. This
  155. /// overrides TTI with the cl::opts listed at the top of this file.
  156. VPLegalization getVPLegalizationStrategy(const VPIntrinsic &VPI) const;
  157. bool UsingTTIOverrides;
  158. public:
  159. CachingVPExpander(Function &F, const TargetTransformInfo &TTI)
  160. : F(F), TTI(TTI), UsingTTIOverrides(anyExpandVPOverridesSet()) {}
  161. bool expandVectorPredication();
  162. };
  163. //// CachingVPExpander {
  164. Value *CachingVPExpander::createStepVector(IRBuilder<> &Builder, Type *LaneTy,
  165. unsigned NumElems) {
  166. // TODO add caching
  167. SmallVector<Constant *, 16> ConstElems;
  168. for (unsigned Idx = 0; Idx < NumElems; ++Idx)
  169. ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false));
  170. return ConstantVector::get(ConstElems);
  171. }
  172. Value *CachingVPExpander::convertEVLToMask(IRBuilder<> &Builder,
  173. Value *EVLParam,
  174. ElementCount ElemCount) {
  175. // TODO add caching
  176. // Scalable vector %evl conversion.
  177. if (ElemCount.isScalable()) {
  178. auto *M = Builder.GetInsertBlock()->getModule();
  179. Type *BoolVecTy = VectorType::get(Builder.getInt1Ty(), ElemCount);
  180. Function *ActiveMaskFunc = Intrinsic::getDeclaration(
  181. M, Intrinsic::get_active_lane_mask, {BoolVecTy, EVLParam->getType()});
  182. // `get_active_lane_mask` performs an implicit less-than comparison.
  183. Value *ConstZero = Builder.getInt32(0);
  184. return Builder.CreateCall(ActiveMaskFunc, {ConstZero, EVLParam});
  185. }
  186. // Fixed vector %evl conversion.
  187. Type *LaneTy = EVLParam->getType();
  188. unsigned NumElems = ElemCount.getFixedValue();
  189. Value *VLSplat = Builder.CreateVectorSplat(NumElems, EVLParam);
  190. Value *IdxVec = createStepVector(Builder, LaneTy, NumElems);
  191. return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat);
  192. }
  193. Value *
  194. CachingVPExpander::expandPredicationInBinaryOperator(IRBuilder<> &Builder,
  195. VPIntrinsic &VPI) {
  196. assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) &&
  197. "Implicitly dropping %evl in non-speculatable operator!");
  198. auto OC = static_cast<Instruction::BinaryOps>(*VPI.getFunctionalOpcode());
  199. assert(Instruction::isBinaryOp(OC));
  200. Value *Op0 = VPI.getOperand(0);
  201. Value *Op1 = VPI.getOperand(1);
  202. Value *Mask = VPI.getMaskParam();
  203. // Blend in safe operands.
  204. if (Mask && !isAllTrueMask(Mask)) {
  205. switch (OC) {
  206. default:
  207. // Can safely ignore the predicate.
  208. break;
  209. // Division operators need a safe divisor on masked-off lanes (1).
  210. case Instruction::UDiv:
  211. case Instruction::SDiv:
  212. case Instruction::URem:
  213. case Instruction::SRem:
  214. // 2nd operand must not be zero.
  215. Value *SafeDivisor = getSafeDivisor(VPI.getType());
  216. Op1 = Builder.CreateSelect(Mask, Op1, SafeDivisor);
  217. }
  218. }
  219. Value *NewBinOp = Builder.CreateBinOp(OC, Op0, Op1, VPI.getName());
  220. replaceOperation(*NewBinOp, VPI);
  221. return NewBinOp;
  222. }
  223. static Value *getNeutralReductionElement(const VPReductionIntrinsic &VPI,
  224. Type *EltTy) {
  225. bool Negative = false;
  226. unsigned EltBits = EltTy->getScalarSizeInBits();
  227. switch (VPI.getIntrinsicID()) {
  228. default:
  229. llvm_unreachable("Expecting a VP reduction intrinsic");
  230. case Intrinsic::vp_reduce_add:
  231. case Intrinsic::vp_reduce_or:
  232. case Intrinsic::vp_reduce_xor:
  233. case Intrinsic::vp_reduce_umax:
  234. return Constant::getNullValue(EltTy);
  235. case Intrinsic::vp_reduce_mul:
  236. return ConstantInt::get(EltTy, 1, /*IsSigned*/ false);
  237. case Intrinsic::vp_reduce_and:
  238. case Intrinsic::vp_reduce_umin:
  239. return ConstantInt::getAllOnesValue(EltTy);
  240. case Intrinsic::vp_reduce_smin:
  241. return ConstantInt::get(EltTy->getContext(),
  242. APInt::getSignedMaxValue(EltBits));
  243. case Intrinsic::vp_reduce_smax:
  244. return ConstantInt::get(EltTy->getContext(),
  245. APInt::getSignedMinValue(EltBits));
  246. case Intrinsic::vp_reduce_fmax:
  247. Negative = true;
  248. [[fallthrough]];
  249. case Intrinsic::vp_reduce_fmin: {
  250. FastMathFlags Flags = VPI.getFastMathFlags();
  251. const fltSemantics &Semantics = EltTy->getFltSemantics();
  252. return !Flags.noNaNs() ? ConstantFP::getQNaN(EltTy, Negative)
  253. : !Flags.noInfs()
  254. ? ConstantFP::getInfinity(EltTy, Negative)
  255. : ConstantFP::get(EltTy,
  256. APFloat::getLargest(Semantics, Negative));
  257. }
  258. case Intrinsic::vp_reduce_fadd:
  259. return ConstantFP::getNegativeZero(EltTy);
  260. case Intrinsic::vp_reduce_fmul:
  261. return ConstantFP::get(EltTy, 1.0);
  262. }
  263. }
  264. Value *
  265. CachingVPExpander::expandPredicationInReduction(IRBuilder<> &Builder,
  266. VPReductionIntrinsic &VPI) {
  267. assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) &&
  268. "Implicitly dropping %evl in non-speculatable operator!");
  269. Value *Mask = VPI.getMaskParam();
  270. Value *RedOp = VPI.getOperand(VPI.getVectorParamPos());
  271. // Insert neutral element in masked-out positions
  272. if (Mask && !isAllTrueMask(Mask)) {
  273. auto *NeutralElt = getNeutralReductionElement(VPI, VPI.getType());
  274. auto *NeutralVector = Builder.CreateVectorSplat(
  275. cast<VectorType>(RedOp->getType())->getElementCount(), NeutralElt);
  276. RedOp = Builder.CreateSelect(Mask, RedOp, NeutralVector);
  277. }
  278. Value *Reduction;
  279. Value *Start = VPI.getOperand(VPI.getStartParamPos());
  280. switch (VPI.getIntrinsicID()) {
  281. default:
  282. llvm_unreachable("Impossible reduction kind");
  283. case Intrinsic::vp_reduce_add:
  284. Reduction = Builder.CreateAddReduce(RedOp);
  285. Reduction = Builder.CreateAdd(Reduction, Start);
  286. break;
  287. case Intrinsic::vp_reduce_mul:
  288. Reduction = Builder.CreateMulReduce(RedOp);
  289. Reduction = Builder.CreateMul(Reduction, Start);
  290. break;
  291. case Intrinsic::vp_reduce_and:
  292. Reduction = Builder.CreateAndReduce(RedOp);
  293. Reduction = Builder.CreateAnd(Reduction, Start);
  294. break;
  295. case Intrinsic::vp_reduce_or:
  296. Reduction = Builder.CreateOrReduce(RedOp);
  297. Reduction = Builder.CreateOr(Reduction, Start);
  298. break;
  299. case Intrinsic::vp_reduce_xor:
  300. Reduction = Builder.CreateXorReduce(RedOp);
  301. Reduction = Builder.CreateXor(Reduction, Start);
  302. break;
  303. case Intrinsic::vp_reduce_smax:
  304. Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ true);
  305. Reduction =
  306. Builder.CreateBinaryIntrinsic(Intrinsic::smax, Reduction, Start);
  307. break;
  308. case Intrinsic::vp_reduce_smin:
  309. Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ true);
  310. Reduction =
  311. Builder.CreateBinaryIntrinsic(Intrinsic::smin, Reduction, Start);
  312. break;
  313. case Intrinsic::vp_reduce_umax:
  314. Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ false);
  315. Reduction =
  316. Builder.CreateBinaryIntrinsic(Intrinsic::umax, Reduction, Start);
  317. break;
  318. case Intrinsic::vp_reduce_umin:
  319. Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ false);
  320. Reduction =
  321. Builder.CreateBinaryIntrinsic(Intrinsic::umin, Reduction, Start);
  322. break;
  323. case Intrinsic::vp_reduce_fmax:
  324. Reduction = Builder.CreateFPMaxReduce(RedOp);
  325. transferDecorations(*Reduction, VPI);
  326. Reduction =
  327. Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, Reduction, Start);
  328. break;
  329. case Intrinsic::vp_reduce_fmin:
  330. Reduction = Builder.CreateFPMinReduce(RedOp);
  331. transferDecorations(*Reduction, VPI);
  332. Reduction =
  333. Builder.CreateBinaryIntrinsic(Intrinsic::minnum, Reduction, Start);
  334. break;
  335. case Intrinsic::vp_reduce_fadd:
  336. Reduction = Builder.CreateFAddReduce(Start, RedOp);
  337. break;
  338. case Intrinsic::vp_reduce_fmul:
  339. Reduction = Builder.CreateFMulReduce(Start, RedOp);
  340. break;
  341. }
  342. replaceOperation(*Reduction, VPI);
  343. return Reduction;
  344. }
  345. Value *
  346. CachingVPExpander::expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder,
  347. VPIntrinsic &VPI) {
  348. assert(VPI.canIgnoreVectorLengthParam());
  349. const auto &DL = F.getParent()->getDataLayout();
  350. Value *MaskParam = VPI.getMaskParam();
  351. Value *PtrParam = VPI.getMemoryPointerParam();
  352. Value *DataParam = VPI.getMemoryDataParam();
  353. bool IsUnmasked = isAllTrueMask(MaskParam);
  354. MaybeAlign AlignOpt = VPI.getPointerAlignment();
  355. Value *NewMemoryInst = nullptr;
  356. switch (VPI.getIntrinsicID()) {
  357. default:
  358. llvm_unreachable("Not a VP memory intrinsic");
  359. case Intrinsic::vp_store:
  360. if (IsUnmasked) {
  361. StoreInst *NewStore =
  362. Builder.CreateStore(DataParam, PtrParam, /*IsVolatile*/ false);
  363. if (AlignOpt.has_value())
  364. NewStore->setAlignment(*AlignOpt);
  365. NewMemoryInst = NewStore;
  366. } else
  367. NewMemoryInst = Builder.CreateMaskedStore(
  368. DataParam, PtrParam, AlignOpt.valueOrOne(), MaskParam);
  369. break;
  370. case Intrinsic::vp_load:
  371. if (IsUnmasked) {
  372. LoadInst *NewLoad =
  373. Builder.CreateLoad(VPI.getType(), PtrParam, /*IsVolatile*/ false);
  374. if (AlignOpt.has_value())
  375. NewLoad->setAlignment(*AlignOpt);
  376. NewMemoryInst = NewLoad;
  377. } else
  378. NewMemoryInst = Builder.CreateMaskedLoad(
  379. VPI.getType(), PtrParam, AlignOpt.valueOrOne(), MaskParam);
  380. break;
  381. case Intrinsic::vp_scatter: {
  382. auto *ElementType =
  383. cast<VectorType>(DataParam->getType())->getElementType();
  384. NewMemoryInst = Builder.CreateMaskedScatter(
  385. DataParam, PtrParam,
  386. AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam);
  387. break;
  388. }
  389. case Intrinsic::vp_gather: {
  390. auto *ElementType = cast<VectorType>(VPI.getType())->getElementType();
  391. NewMemoryInst = Builder.CreateMaskedGather(
  392. VPI.getType(), PtrParam,
  393. AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam, nullptr,
  394. VPI.getName());
  395. break;
  396. }
  397. }
  398. assert(NewMemoryInst);
  399. replaceOperation(*NewMemoryInst, VPI);
  400. return NewMemoryInst;
  401. }
  402. Value *CachingVPExpander::expandPredicationInComparison(IRBuilder<> &Builder,
  403. VPCmpIntrinsic &VPI) {
  404. assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) &&
  405. "Implicitly dropping %evl in non-speculatable operator!");
  406. assert(*VPI.getFunctionalOpcode() == Instruction::ICmp ||
  407. *VPI.getFunctionalOpcode() == Instruction::FCmp);
  408. Value *Op0 = VPI.getOperand(0);
  409. Value *Op1 = VPI.getOperand(1);
  410. auto Pred = VPI.getPredicate();
  411. auto *NewCmp = Builder.CreateCmp(Pred, Op0, Op1);
  412. replaceOperation(*NewCmp, VPI);
  413. return NewCmp;
  414. }
  415. void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) {
  416. LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n");
  417. if (VPI.canIgnoreVectorLengthParam())
  418. return;
  419. Value *EVLParam = VPI.getVectorLengthParam();
  420. if (!EVLParam)
  421. return;
  422. ElementCount StaticElemCount = VPI.getStaticVectorLength();
  423. Value *MaxEVL = nullptr;
  424. Type *Int32Ty = Type::getInt32Ty(VPI.getContext());
  425. if (StaticElemCount.isScalable()) {
  426. // TODO add caching
  427. auto *M = VPI.getModule();
  428. Function *VScaleFunc =
  429. Intrinsic::getDeclaration(M, Intrinsic::vscale, Int32Ty);
  430. IRBuilder<> Builder(VPI.getParent(), VPI.getIterator());
  431. Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue());
  432. Value *VScale = Builder.CreateCall(VScaleFunc, {}, "vscale");
  433. MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size",
  434. /*NUW*/ true, /*NSW*/ false);
  435. } else {
  436. MaxEVL = ConstantInt::get(Int32Ty, StaticElemCount.getFixedValue(), false);
  437. }
  438. VPI.setVectorLengthParam(MaxEVL);
  439. }
  440. Value *CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) {
  441. LLVM_DEBUG(dbgs() << "Folding vlen for " << VPI << '\n');
  442. IRBuilder<> Builder(&VPI);
  443. // Ineffective %evl parameter and so nothing to do here.
  444. if (VPI.canIgnoreVectorLengthParam())
  445. return &VPI;
  446. // Only VP intrinsics can have an %evl parameter.
  447. Value *OldMaskParam = VPI.getMaskParam();
  448. Value *OldEVLParam = VPI.getVectorLengthParam();
  449. assert(OldMaskParam && "no mask param to fold the vl param into");
  450. assert(OldEVLParam && "no EVL param to fold away");
  451. LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n');
  452. LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n');
  453. // Convert the %evl predication into vector mask predication.
  454. ElementCount ElemCount = VPI.getStaticVectorLength();
  455. Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount);
  456. Value *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam);
  457. VPI.setMaskParam(NewMaskParam);
  458. // Drop the %evl parameter.
  459. discardEVLParameter(VPI);
  460. assert(VPI.canIgnoreVectorLengthParam() &&
  461. "transformation did not render the evl param ineffective!");
  462. // Reassess the modified instruction.
  463. return &VPI;
  464. }
  465. Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) {
  466. LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << VPI << '\n');
  467. IRBuilder<> Builder(&VPI);
  468. // Try lowering to a LLVM instruction first.
  469. auto OC = VPI.getFunctionalOpcode();
  470. if (OC && Instruction::isBinaryOp(*OC))
  471. return expandPredicationInBinaryOperator(Builder, VPI);
  472. if (auto *VPRI = dyn_cast<VPReductionIntrinsic>(&VPI))
  473. return expandPredicationInReduction(Builder, *VPRI);
  474. if (auto *VPCmp = dyn_cast<VPCmpIntrinsic>(&VPI))
  475. return expandPredicationInComparison(Builder, *VPCmp);
  476. switch (VPI.getIntrinsicID()) {
  477. default:
  478. break;
  479. case Intrinsic::vp_load:
  480. case Intrinsic::vp_store:
  481. case Intrinsic::vp_gather:
  482. case Intrinsic::vp_scatter:
  483. return expandPredicationInMemoryIntrinsic(Builder, VPI);
  484. }
  485. return &VPI;
  486. }
  487. //// } CachingVPExpander
  488. struct TransformJob {
  489. VPIntrinsic *PI;
  490. TargetTransformInfo::VPLegalization Strategy;
  491. TransformJob(VPIntrinsic *PI, TargetTransformInfo::VPLegalization InitStrat)
  492. : PI(PI), Strategy(InitStrat) {}
  493. bool isDone() const { return Strategy.shouldDoNothing(); }
  494. };
  495. void sanitizeStrategy(VPIntrinsic &VPI, VPLegalization &LegalizeStrat) {
  496. // Operations with speculatable lanes do not strictly need predication.
  497. if (maySpeculateLanes(VPI)) {
  498. // Converting a speculatable VP intrinsic means dropping %mask and %evl.
  499. // No need to expand %evl into the %mask only to ignore that code.
  500. if (LegalizeStrat.OpStrategy == VPLegalization::Convert)
  501. LegalizeStrat.EVLParamStrategy = VPLegalization::Discard;
  502. return;
  503. }
  504. // We have to preserve the predicating effect of %evl for this
  505. // non-speculatable VP intrinsic.
  506. // 1) Never discard %evl.
  507. // 2) If this VP intrinsic will be expanded to non-VP code, make sure that
  508. // %evl gets folded into %mask.
  509. if ((LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) ||
  510. (LegalizeStrat.OpStrategy == VPLegalization::Convert)) {
  511. LegalizeStrat.EVLParamStrategy = VPLegalization::Convert;
  512. }
  513. }
  514. VPLegalization
  515. CachingVPExpander::getVPLegalizationStrategy(const VPIntrinsic &VPI) const {
  516. auto VPStrat = TTI.getVPLegalizationStrategy(VPI);
  517. if (LLVM_LIKELY(!UsingTTIOverrides)) {
  518. // No overrides - we are in production.
  519. return VPStrat;
  520. }
  521. // Overrides set - we are in testing, the following does not need to be
  522. // efficient.
  523. VPStrat.EVLParamStrategy = parseOverrideOption(EVLTransformOverride);
  524. VPStrat.OpStrategy = parseOverrideOption(MaskTransformOverride);
  525. return VPStrat;
  526. }
  527. /// Expand llvm.vp.* intrinsics as requested by \p TTI.
  528. bool CachingVPExpander::expandVectorPredication() {
  529. SmallVector<TransformJob, 16> Worklist;
  530. // Collect all VPIntrinsics that need expansion and determine their expansion
  531. // strategy.
  532. for (auto &I : instructions(F)) {
  533. auto *VPI = dyn_cast<VPIntrinsic>(&I);
  534. if (!VPI)
  535. continue;
  536. auto VPStrat = getVPLegalizationStrategy(*VPI);
  537. sanitizeStrategy(*VPI, VPStrat);
  538. if (!VPStrat.shouldDoNothing())
  539. Worklist.emplace_back(VPI, VPStrat);
  540. }
  541. if (Worklist.empty())
  542. return false;
  543. // Transform all VPIntrinsics on the worklist.
  544. LLVM_DEBUG(dbgs() << "\n:::: Transforming " << Worklist.size()
  545. << " instructions ::::\n");
  546. for (TransformJob Job : Worklist) {
  547. // Transform the EVL parameter.
  548. switch (Job.Strategy.EVLParamStrategy) {
  549. case VPLegalization::Legal:
  550. break;
  551. case VPLegalization::Discard:
  552. discardEVLParameter(*Job.PI);
  553. break;
  554. case VPLegalization::Convert:
  555. if (foldEVLIntoMask(*Job.PI))
  556. ++NumFoldedVL;
  557. break;
  558. }
  559. Job.Strategy.EVLParamStrategy = VPLegalization::Legal;
  560. // Replace with a non-predicated operation.
  561. switch (Job.Strategy.OpStrategy) {
  562. case VPLegalization::Legal:
  563. break;
  564. case VPLegalization::Discard:
  565. llvm_unreachable("Invalid strategy for operators.");
  566. case VPLegalization::Convert:
  567. expandPredication(*Job.PI);
  568. ++NumLoweredVPOps;
  569. break;
  570. }
  571. Job.Strategy.OpStrategy = VPLegalization::Legal;
  572. assert(Job.isDone() && "incomplete transformation");
  573. }
  574. return true;
  575. }
  576. class ExpandVectorPredication : public FunctionPass {
  577. public:
  578. static char ID;
  579. ExpandVectorPredication() : FunctionPass(ID) {
  580. initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry());
  581. }
  582. bool runOnFunction(Function &F) override {
  583. const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  584. CachingVPExpander VPExpander(F, *TTI);
  585. return VPExpander.expandVectorPredication();
  586. }
  587. void getAnalysisUsage(AnalysisUsage &AU) const override {
  588. AU.addRequired<TargetTransformInfoWrapperPass>();
  589. AU.setPreservesCFG();
  590. }
  591. };
  592. } // namespace
  593. char ExpandVectorPredication::ID;
  594. INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expandvp",
  595. "Expand vector predication intrinsics", false, false)
  596. INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
  597. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  598. INITIALIZE_PASS_END(ExpandVectorPredication, "expandvp",
  599. "Expand vector predication intrinsics", false, false)
  600. FunctionPass *llvm::createExpandVectorPredicationPass() {
  601. return new ExpandVectorPredication();
  602. }
  603. PreservedAnalyses
  604. ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) {
  605. const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
  606. CachingVPExpander VPExpander(F, TTI);
  607. if (!VPExpander.expandVectorPredication())
  608. return PreservedAnalyses::all();
  609. PreservedAnalyses PA;
  610. PA.preserveSet<CFGAnalyses>();
  611. return PA;
  612. }