RISCVGatherScatterLowering.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass custom lowers llvm.gather and llvm.scatter instructions to
  10. // RISCV intrinsics.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "RISCV.h"
  14. #include "RISCVTargetMachine.h"
  15. #include "llvm/Analysis/LoopInfo.h"
  16. #include "llvm/Analysis/ValueTracking.h"
  17. #include "llvm/Analysis/VectorUtils.h"
  18. #include "llvm/CodeGen/TargetPassConfig.h"
  19. #include "llvm/IR/GetElementPtrTypeIterator.h"
  20. #include "llvm/IR/IRBuilder.h"
  21. #include "llvm/IR/IntrinsicInst.h"
  22. #include "llvm/IR/IntrinsicsRISCV.h"
  23. #include "llvm/IR/PatternMatch.h"
  24. #include "llvm/Transforms/Utils/Local.h"
  25. #include <optional>
  26. using namespace llvm;
  27. using namespace PatternMatch;
  28. #define DEBUG_TYPE "riscv-gather-scatter-lowering"
  29. namespace {
  30. class RISCVGatherScatterLowering : public FunctionPass {
  31. const RISCVSubtarget *ST = nullptr;
  32. const RISCVTargetLowering *TLI = nullptr;
  33. LoopInfo *LI = nullptr;
  34. const DataLayout *DL = nullptr;
  35. SmallVector<WeakTrackingVH> MaybeDeadPHIs;
  36. // Cache of the BasePtr and Stride determined from this GEP. When a GEP is
  37. // used by multiple gathers/scatters, this allow us to reuse the scalar
  38. // instructions we created for the first gather/scatter for the others.
  39. DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs;
  40. public:
  41. static char ID; // Pass identification, replacement for typeid
  42. RISCVGatherScatterLowering() : FunctionPass(ID) {}
  43. bool runOnFunction(Function &F) override;
  44. void getAnalysisUsage(AnalysisUsage &AU) const override {
  45. AU.setPreservesCFG();
  46. AU.addRequired<TargetPassConfig>();
  47. AU.addRequired<LoopInfoWrapperPass>();
  48. }
  49. StringRef getPassName() const override {
  50. return "RISCV gather/scatter lowering";
  51. }
  52. private:
  53. bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp);
  54. bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
  55. Value *AlignOp);
  56. std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP,
  57. IRBuilder<> &Builder);
  58. bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
  59. PHINode *&BasePtr, BinaryOperator *&Inc,
  60. IRBuilder<> &Builder);
  61. };
  62. } // end anonymous namespace
  63. char RISCVGatherScatterLowering::ID = 0;
  64. INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
  65. "RISCV gather/scatter lowering pass", false, false)
  66. FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
  67. return new RISCVGatherScatterLowering();
  68. }
  69. bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType,
  70. Value *AlignOp) {
  71. Type *ScalarType = DataType->getScalarType();
  72. if (!TLI->isLegalElementTypeForRVV(ScalarType))
  73. return false;
  74. MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
  75. if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedValue())
  76. return false;
  77. // FIXME: Let the backend type legalize by splitting/widening?
  78. EVT DataVT = TLI->getValueType(*DL, DataType);
  79. if (!TLI->isTypeLegal(DataVT))
  80. return false;
  81. return true;
  82. }
  83. // TODO: Should we consider the mask when looking for a stride?
  84. static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
  85. unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();
  86. // Check that the start value is a strided constant.
  87. auto *StartVal =
  88. dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));
  89. if (!StartVal)
  90. return std::make_pair(nullptr, nullptr);
  91. APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
  92. ConstantInt *Prev = StartVal;
  93. for (unsigned i = 1; i != NumElts; ++i) {
  94. auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));
  95. if (!C)
  96. return std::make_pair(nullptr, nullptr);
  97. APInt LocalStride = C->getValue() - Prev->getValue();
  98. if (i == 1)
  99. StrideVal = LocalStride;
  100. else if (StrideVal != LocalStride)
  101. return std::make_pair(nullptr, nullptr);
  102. Prev = C;
  103. }
  104. Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
  105. return std::make_pair(StartVal, Stride);
  106. }
  107. static std::pair<Value *, Value *> matchStridedStart(Value *Start,
  108. IRBuilder<> &Builder) {
  109. // Base case, start is a strided constant.
  110. auto *StartC = dyn_cast<Constant>(Start);
  111. if (StartC)
  112. return matchStridedConstant(StartC);
  113. // Base case, start is a stepvector
  114. if (match(Start, m_Intrinsic<Intrinsic::experimental_stepvector>())) {
  115. auto *Ty = Start->getType()->getScalarType();
  116. return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
  117. }
  118. // Not a constant, maybe it's a strided constant with a splat added to it.
  119. auto *BO = dyn_cast<BinaryOperator>(Start);
  120. if (!BO || BO->getOpcode() != Instruction::Add)
  121. return std::make_pair(nullptr, nullptr);
  122. // Look for an operand that is splatted.
  123. unsigned OtherIndex = 1;
  124. Value *Splat = getSplatValue(BO->getOperand(0));
  125. if (!Splat) {
  126. Splat = getSplatValue(BO->getOperand(1));
  127. OtherIndex = 0;
  128. }
  129. if (!Splat)
  130. return std::make_pair(nullptr, nullptr);
  131. Value *Stride;
  132. std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),
  133. Builder);
  134. if (!Start)
  135. return std::make_pair(nullptr, nullptr);
  136. // Add the splat value to the start.
  137. Builder.SetInsertPoint(BO);
  138. Builder.SetCurrentDebugLocation(DebugLoc());
  139. Start = Builder.CreateAdd(Start, Splat);
  140. return std::make_pair(Start, Stride);
  141. }
  142. // Recursively, walk about the use-def chain until we find a Phi with a strided
  143. // start value. Build and update a scalar recurrence as we unwind the recursion.
  144. // We also update the Stride as we unwind. Our goal is to move all of the
  145. // arithmetic out of the loop.
  146. bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
  147. Value *&Stride,
  148. PHINode *&BasePtr,
  149. BinaryOperator *&Inc,
  150. IRBuilder<> &Builder) {
  151. // Our base case is a Phi.
  152. if (auto *Phi = dyn_cast<PHINode>(Index)) {
  153. // A phi node we want to perform this function on should be from the
  154. // loop header.
  155. if (Phi->getParent() != L->getHeader())
  156. return false;
  157. Value *Step, *Start;
  158. if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
  159. Inc->getOpcode() != Instruction::Add)
  160. return false;
  161. assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
  162. unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
  163. assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
  164. "Expected one operand of phi to be Inc");
  165. // Only proceed if the step is loop invariant.
  166. if (!L->isLoopInvariant(Step))
  167. return false;
  168. // Step should be a splat.
  169. Step = getSplatValue(Step);
  170. if (!Step)
  171. return false;
  172. std::tie(Start, Stride) = matchStridedStart(Start, Builder);
  173. if (!Start)
  174. return false;
  175. assert(Stride != nullptr);
  176. // Build scalar phi and increment.
  177. BasePtr =
  178. PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi);
  179. Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
  180. Inc);
  181. BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
  182. BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
  183. // Note that this Phi might be eligible for removal.
  184. MaybeDeadPHIs.push_back(Phi);
  185. return true;
  186. }
  187. // Otherwise look for binary operator.
  188. auto *BO = dyn_cast<BinaryOperator>(Index);
  189. if (!BO)
  190. return false;
  191. if (BO->getOpcode() != Instruction::Add &&
  192. BO->getOpcode() != Instruction::Or &&
  193. BO->getOpcode() != Instruction::Mul &&
  194. BO->getOpcode() != Instruction::Shl)
  195. return false;
  196. // Only support shift by constant.
  197. if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1)))
  198. return false;
  199. // We need to be able to treat Or as Add.
  200. if (BO->getOpcode() == Instruction::Or &&
  201. !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
  202. return false;
  203. // We should have one operand in the loop and one splat.
  204. Value *OtherOp;
  205. if (isa<Instruction>(BO->getOperand(0)) &&
  206. L->contains(cast<Instruction>(BO->getOperand(0)))) {
  207. Index = cast<Instruction>(BO->getOperand(0));
  208. OtherOp = BO->getOperand(1);
  209. } else if (isa<Instruction>(BO->getOperand(1)) &&
  210. L->contains(cast<Instruction>(BO->getOperand(1)))) {
  211. Index = cast<Instruction>(BO->getOperand(1));
  212. OtherOp = BO->getOperand(0);
  213. } else {
  214. return false;
  215. }
  216. // Make sure other op is loop invariant.
  217. if (!L->isLoopInvariant(OtherOp))
  218. return false;
  219. // Make sure we have a splat.
  220. Value *SplatOp = getSplatValue(OtherOp);
  221. if (!SplatOp)
  222. return false;
  223. // Recurse up the use-def chain.
  224. if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
  225. return false;
  226. // Locate the Step and Start values from the recurrence.
  227. unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;
  228. unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
  229. Value *Step = Inc->getOperand(StepIndex);
  230. Value *Start = BasePtr->getOperand(StartBlock);
  231. // We need to adjust the start value in the preheader.
  232. Builder.SetInsertPoint(
  233. BasePtr->getIncomingBlock(StartBlock)->getTerminator());
  234. Builder.SetCurrentDebugLocation(DebugLoc());
  235. switch (BO->getOpcode()) {
  236. default:
  237. llvm_unreachable("Unexpected opcode!");
  238. case Instruction::Add:
  239. case Instruction::Or: {
  240. // An add only affects the start value. It's ok to do this for Or because
  241. // we already checked that there are no common set bits.
  242. // If the start value is Zero, just take the SplatOp.
  243. if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero())
  244. Start = SplatOp;
  245. else
  246. Start = Builder.CreateAdd(Start, SplatOp, "start");
  247. BasePtr->setIncomingValue(StartBlock, Start);
  248. break;
  249. }
  250. case Instruction::Mul: {
  251. // If the start is zero we don't need to multiply.
  252. if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
  253. Start = Builder.CreateMul(Start, SplatOp, "start");
  254. Step = Builder.CreateMul(Step, SplatOp, "step");
  255. // If the Stride is 1 just take the SplatOpt.
  256. if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne())
  257. Stride = SplatOp;
  258. else
  259. Stride = Builder.CreateMul(Stride, SplatOp, "stride");
  260. Inc->setOperand(StepIndex, Step);
  261. BasePtr->setIncomingValue(StartBlock, Start);
  262. break;
  263. }
  264. case Instruction::Shl: {
  265. // If the start is zero we don't need to shift.
  266. if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
  267. Start = Builder.CreateShl(Start, SplatOp, "start");
  268. Step = Builder.CreateShl(Step, SplatOp, "step");
  269. Stride = Builder.CreateShl(Stride, SplatOp, "stride");
  270. Inc->setOperand(StepIndex, Step);
  271. BasePtr->setIncomingValue(StartBlock, Start);
  272. break;
  273. }
  274. }
  275. return true;
  276. }
  277. std::pair<Value *, Value *>
  278. RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
  279. IRBuilder<> &Builder) {
  280. auto I = StridedAddrs.find(GEP);
  281. if (I != StridedAddrs.end())
  282. return I->second;
  283. SmallVector<Value *, 2> Ops(GEP->operands());
  284. // Base pointer needs to be a scalar.
  285. if (Ops[0]->getType()->isVectorTy())
  286. return std::make_pair(nullptr, nullptr);
  287. std::optional<unsigned> VecOperand;
  288. unsigned TypeScale = 0;
  289. // Look for a vector operand and scale.
  290. gep_type_iterator GTI = gep_type_begin(GEP);
  291. for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
  292. if (!Ops[i]->getType()->isVectorTy())
  293. continue;
  294. if (VecOperand)
  295. return std::make_pair(nullptr, nullptr);
  296. VecOperand = i;
  297. TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType());
  298. if (TS.isScalable())
  299. return std::make_pair(nullptr, nullptr);
  300. TypeScale = TS.getFixedValue();
  301. }
  302. // We need to find a vector index to simplify.
  303. if (!VecOperand)
  304. return std::make_pair(nullptr, nullptr);
  305. // We can't extract the stride if the arithmetic is done at a different size
  306. // than the pointer type. Adding the stride later may not wrap correctly.
  307. // Technically we could handle wider indices, but I don't expect that in
  308. // practice.
  309. Value *VecIndex = Ops[*VecOperand];
  310. Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
  311. if (VecIndex->getType() != VecIntPtrTy)
  312. return std::make_pair(nullptr, nullptr);
  313. // Handle the non-recursive case. This is what we see if the vectorizer
  314. // decides to use a scalar IV + vid on demand instead of a vector IV.
  315. auto [Start, Stride] = matchStridedStart(VecIndex, Builder);
  316. if (Start) {
  317. assert(Stride);
  318. Builder.SetInsertPoint(GEP);
  319. // Replace the vector index with the scalar start and build a scalar GEP.
  320. Ops[*VecOperand] = Start;
  321. Type *SourceTy = GEP->getSourceElementType();
  322. Value *BasePtr =
  323. Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
  324. // Convert stride to pointer size if needed.
  325. Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
  326. assert(Stride->getType() == IntPtrTy && "Unexpected type");
  327. // Scale the stride by the size of the indexed type.
  328. if (TypeScale != 1)
  329. Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
  330. auto P = std::make_pair(BasePtr, Stride);
  331. StridedAddrs[GEP] = P;
  332. return P;
  333. }
  334. // Make sure we're in a loop and that has a pre-header and a single latch.
  335. Loop *L = LI->getLoopFor(GEP->getParent());
  336. if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
  337. return std::make_pair(nullptr, nullptr);
  338. BinaryOperator *Inc;
  339. PHINode *BasePhi;
  340. if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
  341. return std::make_pair(nullptr, nullptr);
  342. assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
  343. unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
  344. assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
  345. "Expected one operand of phi to be Inc");
  346. Builder.SetInsertPoint(GEP);
  347. // Replace the vector index with the scalar phi and build a scalar GEP.
  348. Ops[*VecOperand] = BasePhi;
  349. Type *SourceTy = GEP->getSourceElementType();
  350. Value *BasePtr =
  351. Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
  352. // Final adjustments to stride should go in the start block.
  353. Builder.SetInsertPoint(
  354. BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());
  355. // Convert stride to pointer size if needed.
  356. Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
  357. assert(Stride->getType() == IntPtrTy && "Unexpected type");
  358. // Scale the stride by the size of the indexed type.
  359. if (TypeScale != 1)
  360. Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
  361. auto P = std::make_pair(BasePtr, Stride);
  362. StridedAddrs[GEP] = P;
  363. return P;
  364. }
  365. bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
  366. Type *DataType,
  367. Value *Ptr,
  368. Value *AlignOp) {
  369. // Make sure the operation will be supported by the backend.
  370. if (!isLegalTypeAndAlignment(DataType, AlignOp))
  371. return false;
  372. // Pointer should be a GEP.
  373. auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
  374. if (!GEP)
  375. return false;
  376. IRBuilder<> Builder(GEP);
  377. Value *BasePtr, *Stride;
  378. std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder);
  379. if (!BasePtr)
  380. return false;
  381. assert(Stride != nullptr);
  382. Builder.SetInsertPoint(II);
  383. CallInst *Call;
  384. if (II->getIntrinsicID() == Intrinsic::masked_gather)
  385. Call = Builder.CreateIntrinsic(
  386. Intrinsic::riscv_masked_strided_load,
  387. {DataType, BasePtr->getType(), Stride->getType()},
  388. {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
  389. else
  390. Call = Builder.CreateIntrinsic(
  391. Intrinsic::riscv_masked_strided_store,
  392. {DataType, BasePtr->getType(), Stride->getType()},
  393. {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
  394. Call->takeName(II);
  395. II->replaceAllUsesWith(Call);
  396. II->eraseFromParent();
  397. if (GEP->use_empty())
  398. RecursivelyDeleteTriviallyDeadInstructions(GEP);
  399. return true;
  400. }
  401. bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
  402. if (skipFunction(F))
  403. return false;
  404. auto &TPC = getAnalysis<TargetPassConfig>();
  405. auto &TM = TPC.getTM<RISCVTargetMachine>();
  406. ST = &TM.getSubtarget<RISCVSubtarget>(F);
  407. if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
  408. return false;
  409. TLI = ST->getTargetLowering();
  410. DL = &F.getParent()->getDataLayout();
  411. LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  412. StridedAddrs.clear();
  413. SmallVector<IntrinsicInst *, 4> Gathers;
  414. SmallVector<IntrinsicInst *, 4> Scatters;
  415. bool Changed = false;
  416. for (BasicBlock &BB : F) {
  417. for (Instruction &I : BB) {
  418. IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
  419. if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
  420. Gathers.push_back(II);
  421. } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
  422. Scatters.push_back(II);
  423. }
  424. }
  425. }
  426. // Rewrite gather/scatter to form strided load/store if possible.
  427. for (auto *II : Gathers)
  428. Changed |= tryCreateStridedLoadStore(
  429. II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
  430. for (auto *II : Scatters)
  431. Changed |=
  432. tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
  433. II->getArgOperand(1), II->getArgOperand(2));
  434. // Remove any dead phis.
  435. while (!MaybeDeadPHIs.empty()) {
  436. if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
  437. RecursivelyDeleteDeadPHINode(Phi);
  438. }
  439. return Changed;
  440. }