MVEGatherScatterLowering.cpp 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206
  1. //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. /// This pass custom lowers llvm.gather and llvm.scatter instructions to
  10. /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
  11. /// produce a better final result as we go.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "ARM.h"
  15. #include "ARMBaseInstrInfo.h"
  16. #include "ARMSubtarget.h"
  17. #include "llvm/Analysis/LoopInfo.h"
  18. #include "llvm/Analysis/TargetTransformInfo.h"
  19. #include "llvm/CodeGen/TargetLowering.h"
  20. #include "llvm/CodeGen/TargetPassConfig.h"
  21. #include "llvm/CodeGen/TargetSubtargetInfo.h"
  22. #include "llvm/InitializePasses.h"
  23. #include "llvm/IR/BasicBlock.h"
  24. #include "llvm/IR/Constant.h"
  25. #include "llvm/IR/Constants.h"
  26. #include "llvm/IR/DerivedTypes.h"
  27. #include "llvm/IR/Function.h"
  28. #include "llvm/IR/InstrTypes.h"
  29. #include "llvm/IR/Instruction.h"
  30. #include "llvm/IR/Instructions.h"
  31. #include "llvm/IR/IntrinsicInst.h"
  32. #include "llvm/IR/Intrinsics.h"
  33. #include "llvm/IR/IntrinsicsARM.h"
  34. #include "llvm/IR/IRBuilder.h"
  35. #include "llvm/IR/PatternMatch.h"
  36. #include "llvm/IR/Type.h"
  37. #include "llvm/IR/Value.h"
  38. #include "llvm/Pass.h"
  39. #include "llvm/Support/Casting.h"
  40. #include "llvm/Transforms/Utils/Local.h"
  41. #include <algorithm>
  42. #include <cassert>
  43. using namespace llvm;
  44. #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
  45. cl::opt<bool> EnableMaskedGatherScatters(
  46. "enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
  47. cl::desc("Enable the generation of masked gathers and scatters"));
  48. namespace {
  49. class MVEGatherScatterLowering : public FunctionPass {
  50. public:
  51. static char ID; // Pass identification, replacement for typeid
  52. explicit MVEGatherScatterLowering() : FunctionPass(ID) {
  53. initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
  54. }
  55. bool runOnFunction(Function &F) override;
  56. StringRef getPassName() const override {
  57. return "MVE gather/scatter lowering";
  58. }
  59. void getAnalysisUsage(AnalysisUsage &AU) const override {
  60. AU.setPreservesCFG();
  61. AU.addRequired<TargetPassConfig>();
  62. AU.addRequired<LoopInfoWrapperPass>();
  63. FunctionPass::getAnalysisUsage(AU);
  64. }
  65. private:
  66. LoopInfo *LI = nullptr;
  67. // Check this is a valid gather with correct alignment
  68. bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
  69. Align Alignment);
  70. // Check whether Ptr is hidden behind a bitcast and look through it
  71. void lookThroughBitcast(Value *&Ptr);
  72. // Check for a getelementptr and deduce base and offsets from it, on success
  73. // returning the base directly and the offsets indirectly using the Offsets
  74. // argument
  75. Value *checkGEP(Value *&Offsets, FixedVectorType *Ty, GetElementPtrInst *GEP,
  76. IRBuilder<> &Builder);
  77. // Compute the scale of this gather/scatter instruction
  78. int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
  79. // If the value is a constant, or derived from constants via additions
  80. // and multilications, return its numeric value
  81. Optional<int64_t> getIfConst(const Value *V);
  82. // If Inst is an add instruction, check whether one summand is a
  83. // constant. If so, scale this constant and return it together with
  84. // the other summand.
  85. std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
  86. Value *lowerGather(IntrinsicInst *I);
  87. // Create a gather from a base + vector of offsets
  88. Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
  89. Instruction *&Root, IRBuilder<> &Builder);
  90. // Create a gather from a vector of pointers
  91. Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
  92. IRBuilder<> &Builder, int64_t Increment = 0);
  93. // Create an incrementing gather from a vector of pointers
  94. Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
  95. IRBuilder<> &Builder,
  96. int64_t Increment = 0);
  97. Value *lowerScatter(IntrinsicInst *I);
  98. // Create a scatter to a base + vector of offsets
  99. Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
  100. IRBuilder<> &Builder);
  101. // Create a scatter to a vector of pointers
  102. Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
  103. IRBuilder<> &Builder,
  104. int64_t Increment = 0);
  105. // Create an incrementing scatter from a vector of pointers
  106. Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
  107. IRBuilder<> &Builder,
  108. int64_t Increment = 0);
  109. // QI gathers and scatters can increment their offsets on their own if
  110. // the increment is a constant value (digit)
  111. Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr,
  112. Value *Ptr, GetElementPtrInst *GEP,
  113. IRBuilder<> &Builder);
  114. // QI gathers/scatters can increment their offsets on their own if the
  115. // increment is a constant value (digit) - this creates a writeback QI
  116. // gather/scatter
  117. Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
  118. Value *Ptr, unsigned TypeScale,
  119. IRBuilder<> &Builder);
  120. // Optimise the base and offsets of the given address
  121. bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
  122. // Try to fold consecutive geps together into one
  123. Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, IRBuilder<> &Builder);
  124. // Check whether these offsets could be moved out of the loop they're in
  125. bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
  126. // Pushes the given add out of the loop
  127. void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
  128. // Pushes the given mul out of the loop
  129. void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
  130. Value *OffsSecondOperand, unsigned LoopIncrement,
  131. IRBuilder<> &Builder);
  132. };
  133. } // end anonymous namespace
  134. char MVEGatherScatterLowering::ID = 0;
  135. INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
  136. "MVE gather/scattering lowering pass", false, false)
  137. Pass *llvm::createMVEGatherScatterLoweringPass() {
  138. return new MVEGatherScatterLowering();
  139. }
  140. bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
  141. unsigned ElemSize,
  142. Align Alignment) {
  143. if (((NumElements == 4 &&
  144. (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
  145. (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
  146. (NumElements == 16 && ElemSize == 8)) &&
  147. Alignment >= ElemSize / 8)
  148. return true;
  149. LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
  150. << "valid alignment or vector type \n");
  151. return false;
  152. }
  153. static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
  154. // Offsets that are not of type <N x i32> are sign extended by the
  155. // getelementptr instruction, and MVE gathers/scatters treat the offset as
  156. // unsigned. Thus, if the element size is smaller than 32, we can only allow
  157. // positive offsets - i.e., the offsets are not allowed to be variables we
  158. // can't look into.
  159. // Additionally, <N x i32> offsets have to either originate from a zext of a
  160. // vector with element types smaller or equal the type of the gather we're
  161. // looking at, or consist of constants that we can check are small enough
  162. // to fit into the gather type.
  163. // Thus we check that 0 < value < 2^TargetElemSize.
  164. unsigned TargetElemSize = 128 / TargetElemCount;
  165. unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
  166. ->getElementType()
  167. ->getScalarSizeInBits();
  168. if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
  169. Constant *ConstOff = dyn_cast<Constant>(Offsets);
  170. if (!ConstOff)
  171. return false;
  172. int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
  173. auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
  174. ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
  175. if (!OConst)
  176. return false;
  177. int SExtValue = OConst->getSExtValue();
  178. if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
  179. return false;
  180. return true;
  181. };
  182. if (isa<FixedVectorType>(ConstOff->getType())) {
  183. for (unsigned i = 0; i < TargetElemCount; i++) {
  184. if (!CheckValueSize(ConstOff->getAggregateElement(i)))
  185. return false;
  186. }
  187. } else {
  188. if (!CheckValueSize(ConstOff))
  189. return false;
  190. }
  191. }
  192. return true;
  193. }
  194. Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, FixedVectorType *Ty,
  195. GetElementPtrInst *GEP,
  196. IRBuilder<> &Builder) {
  197. if (!GEP) {
  198. LLVM_DEBUG(
  199. dbgs() << "masked gathers/scatters: no getelementpointer found\n");
  200. return nullptr;
  201. }
  202. LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
  203. << " Looking at intrinsic for base + vector of offsets\n");
  204. Value *GEPPtr = GEP->getPointerOperand();
  205. Offsets = GEP->getOperand(1);
  206. if (GEPPtr->getType()->isVectorTy() ||
  207. !isa<FixedVectorType>(Offsets->getType()))
  208. return nullptr;
  209. if (GEP->getNumOperands() != 2) {
  210. LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
  211. << " operands. Expanding.\n");
  212. return nullptr;
  213. }
  214. Offsets = GEP->getOperand(1);
  215. unsigned OffsetsElemCount =
  216. cast<FixedVectorType>(Offsets->getType())->getNumElements();
  217. // Paranoid check whether the number of parallel lanes is the same
  218. assert(Ty->getNumElements() == OffsetsElemCount);
  219. ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
  220. if (ZextOffs)
  221. Offsets = ZextOffs->getOperand(0);
  222. FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
  223. // If the offsets are already being zext-ed to <N x i32>, that relieves us of
  224. // having to make sure that they won't overflow.
  225. if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
  226. ->getElementType()
  227. ->getScalarSizeInBits() != 32)
  228. if (!checkOffsetSize(Offsets, OffsetsElemCount))
  229. return nullptr;
  230. // The offset sizes have been checked; if any truncating or zext-ing is
  231. // required to fix them, do that now
  232. if (Ty != Offsets->getType()) {
  233. if ((Ty->getElementType()->getScalarSizeInBits() <
  234. OffsetType->getElementType()->getScalarSizeInBits())) {
  235. Offsets = Builder.CreateTrunc(Offsets, Ty);
  236. } else {
  237. Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
  238. }
  239. }
  240. // If none of the checks failed, return the gep's base pointer
  241. LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
  242. return GEPPtr;
  243. }
  244. void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
  245. // Look through bitcast instruction if #elements is the same
  246. if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
  247. auto *BCTy = cast<FixedVectorType>(BitCast->getType());
  248. auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
  249. if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
  250. LLVM_DEBUG(
  251. dbgs() << "masked gathers/scatters: looking through bitcast\n");
  252. Ptr = BitCast->getOperand(0);
  253. }
  254. }
  255. }
  256. int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
  257. unsigned MemoryElemSize) {
  258. // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
  259. // or a 8bit, 16bit or 32bit load/store scaled by 1
  260. if (GEPElemSize == 32 && MemoryElemSize == 32)
  261. return 2;
  262. else if (GEPElemSize == 16 && MemoryElemSize == 16)
  263. return 1;
  264. else if (GEPElemSize == 8)
  265. return 0;
  266. LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
  267. << "create intrinsic\n");
  268. return -1;
  269. }
  270. Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
  271. const Constant *C = dyn_cast<Constant>(V);
  272. if (C != nullptr)
  273. return Optional<int64_t>{C->getUniqueInteger().getSExtValue()};
  274. if (!isa<Instruction>(V))
  275. return Optional<int64_t>{};
  276. const Instruction *I = cast<Instruction>(V);
  277. if (I->getOpcode() == Instruction::Add ||
  278. I->getOpcode() == Instruction::Mul) {
  279. Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
  280. Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
  281. if (!Op0 || !Op1)
  282. return Optional<int64_t>{};
  283. if (I->getOpcode() == Instruction::Add)
  284. return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
  285. if (I->getOpcode() == Instruction::Mul)
  286. return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
  287. }
  288. return Optional<int64_t>{};
  289. }
  290. std::pair<Value *, int64_t>
  291. MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
  292. std::pair<Value *, int64_t> ReturnFalse =
  293. std::pair<Value *, int64_t>(nullptr, 0);
  294. // At this point, the instruction we're looking at must be an add or we
  295. // bail out
  296. Instruction *Add = dyn_cast<Instruction>(Inst);
  297. if (Add == nullptr || Add->getOpcode() != Instruction::Add)
  298. return ReturnFalse;
  299. Value *Summand;
  300. Optional<int64_t> Const;
  301. // Find out which operand the value that is increased is
  302. if ((Const = getIfConst(Add->getOperand(0))))
  303. Summand = Add->getOperand(1);
  304. else if ((Const = getIfConst(Add->getOperand(1))))
  305. Summand = Add->getOperand(0);
  306. else
  307. return ReturnFalse;
  308. // Check that the constant is small enough for an incrementing gather
  309. int64_t Immediate = Const.getValue() << TypeScale;
  310. if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
  311. return ReturnFalse;
  312. return std::pair<Value *, int64_t>(Summand, Immediate);
  313. }
  314. Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
  315. using namespace PatternMatch;
  316. LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n");
  317. // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
  318. // Attempt to turn the masked gather in I into a MVE intrinsic
  319. // Potentially optimising the addressing modes as we do so.
  320. auto *Ty = cast<FixedVectorType>(I->getType());
  321. Value *Ptr = I->getArgOperand(0);
  322. Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
  323. Value *Mask = I->getArgOperand(2);
  324. Value *PassThru = I->getArgOperand(3);
  325. if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
  326. Alignment))
  327. return nullptr;
  328. lookThroughBitcast(Ptr);
  329. assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
  330. IRBuilder<> Builder(I->getContext());
  331. Builder.SetInsertPoint(I);
  332. Builder.SetCurrentDebugLocation(I->getDebugLoc());
  333. Instruction *Root = I;
  334. Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
  335. if (!Load)
  336. Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
  337. if (!Load)
  338. return nullptr;
  339. if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
  340. LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
  341. << "creating select\n");
  342. Load = Builder.CreateSelect(Mask, Load, PassThru);
  343. }
  344. Root->replaceAllUsesWith(Load);
  345. Root->eraseFromParent();
  346. if (Root != I)
  347. // If this was an extending gather, we need to get rid of the sext/zext
  348. // sext/zext as well as of the gather itself
  349. I->eraseFromParent();
  350. LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n");
  351. return Load;
  352. }
  353. Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I,
  354. Value *Ptr,
  355. IRBuilder<> &Builder,
  356. int64_t Increment) {
  357. using namespace PatternMatch;
  358. auto *Ty = cast<FixedVectorType>(I->getType());
  359. LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
  360. if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
  361. // Can't build an intrinsic for this
  362. return nullptr;
  363. Value *Mask = I->getArgOperand(2);
  364. if (match(Mask, m_One()))
  365. return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
  366. {Ty, Ptr->getType()},
  367. {Ptr, Builder.getInt32(Increment)});
  368. else
  369. return Builder.CreateIntrinsic(
  370. Intrinsic::arm_mve_vldr_gather_base_predicated,
  371. {Ty, Ptr->getType(), Mask->getType()},
  372. {Ptr, Builder.getInt32(Increment), Mask});
  373. }
  374. Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
  375. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
  376. using namespace PatternMatch;
  377. auto *Ty = cast<FixedVectorType>(I->getType());
  378. LLVM_DEBUG(
  379. dbgs()
  380. << "masked gathers: loading from vector of pointers with writeback\n");
  381. if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
  382. // Can't build an intrinsic for this
  383. return nullptr;
  384. Value *Mask = I->getArgOperand(2);
  385. if (match(Mask, m_One()))
  386. return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
  387. {Ty, Ptr->getType()},
  388. {Ptr, Builder.getInt32(Increment)});
  389. else
  390. return Builder.CreateIntrinsic(
  391. Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
  392. {Ty, Ptr->getType(), Mask->getType()},
  393. {Ptr, Builder.getInt32(Increment), Mask});
  394. }
  395. Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
  396. IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
  397. using namespace PatternMatch;
  398. Type *OriginalTy = I->getType();
  399. Type *ResultTy = OriginalTy;
  400. unsigned Unsigned = 1;
  401. // The size of the gather was already checked in isLegalTypeAndAlignment;
  402. // if it was not a full vector width an appropriate extend should follow.
  403. auto *Extend = Root;
  404. if (OriginalTy->getPrimitiveSizeInBits() < 128) {
  405. // Only transform gathers with exactly one use
  406. if (!I->hasOneUse())
  407. return nullptr;
  408. // The correct root to replace is not the CallInst itself, but the
  409. // instruction which extends it
  410. Extend = cast<Instruction>(*I->users().begin());
  411. if (isa<SExtInst>(Extend)) {
  412. Unsigned = 0;
  413. } else if (!isa<ZExtInst>(Extend)) {
  414. LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. "
  415. << "Expanding\n");
  416. return nullptr;
  417. }
  418. LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n");
  419. ResultTy = Extend->getType();
  420. // The final size of the gather must be a full vector width
  421. if (ResultTy->getPrimitiveSizeInBits() != 128) {
  422. LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. "
  423. << "Expanding\n");
  424. return nullptr;
  425. }
  426. }
  427. GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
  428. Value *Offsets;
  429. Value *BasePtr =
  430. checkGEP(Offsets, cast<FixedVectorType>(ResultTy), GEP, Builder);
  431. if (!BasePtr)
  432. return nullptr;
  433. // Check whether the offset is a constant increment that could be merged into
  434. // a QI gather
  435. Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
  436. if (Load)
  437. return Load;
  438. int Scale = computeScale(
  439. BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
  440. OriginalTy->getScalarSizeInBits());
  441. if (Scale == -1)
  442. return nullptr;
  443. Root = Extend;
  444. Value *Mask = I->getArgOperand(2);
  445. if (!match(Mask, m_One()))
  446. return Builder.CreateIntrinsic(
  447. Intrinsic::arm_mve_vldr_gather_offset_predicated,
  448. {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
  449. {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
  450. Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
  451. else
  452. return Builder.CreateIntrinsic(
  453. Intrinsic::arm_mve_vldr_gather_offset,
  454. {ResultTy, BasePtr->getType(), Offsets->getType()},
  455. {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
  456. Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
  457. }
  458. Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
  459. using namespace PatternMatch;
  460. LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n");
  461. // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
  462. // Attempt to turn the masked scatter in I into a MVE intrinsic
  463. // Potentially optimising the addressing modes as we do so.
  464. Value *Input = I->getArgOperand(0);
  465. Value *Ptr = I->getArgOperand(1);
  466. Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
  467. auto *Ty = cast<FixedVectorType>(Input->getType());
  468. if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
  469. Alignment))
  470. return nullptr;
  471. lookThroughBitcast(Ptr);
  472. assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
  473. IRBuilder<> Builder(I->getContext());
  474. Builder.SetInsertPoint(I);
  475. Builder.SetCurrentDebugLocation(I->getDebugLoc());
  476. Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
  477. if (!Store)
  478. Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
  479. if (!Store)
  480. return nullptr;
  481. LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n");
  482. I->eraseFromParent();
  483. return Store;
  484. }
  485. Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
  486. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
  487. using namespace PatternMatch;
  488. Value *Input = I->getArgOperand(0);
  489. auto *Ty = cast<FixedVectorType>(Input->getType());
  490. // Only QR variants allow truncating
  491. if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
  492. // Can't build an intrinsic for this
  493. return nullptr;
  494. }
  495. Value *Mask = I->getArgOperand(3);
  496. // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
  497. LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
  498. if (match(Mask, m_One()))
  499. return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
  500. {Ptr->getType(), Input->getType()},
  501. {Ptr, Builder.getInt32(Increment), Input});
  502. else
  503. return Builder.CreateIntrinsic(
  504. Intrinsic::arm_mve_vstr_scatter_base_predicated,
  505. {Ptr->getType(), Input->getType(), Mask->getType()},
  506. {Ptr, Builder.getInt32(Increment), Input, Mask});
  507. }
  508. Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
  509. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
  510. using namespace PatternMatch;
  511. Value *Input = I->getArgOperand(0);
  512. auto *Ty = cast<FixedVectorType>(Input->getType());
  513. LLVM_DEBUG(
  514. dbgs()
  515. << "masked scatters: storing to a vector of pointers with writeback\n");
  516. if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
  517. // Can't build an intrinsic for this
  518. return nullptr;
  519. Value *Mask = I->getArgOperand(3);
  520. if (match(Mask, m_One()))
  521. return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
  522. {Ptr->getType(), Input->getType()},
  523. {Ptr, Builder.getInt32(Increment), Input});
  524. else
  525. return Builder.CreateIntrinsic(
  526. Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
  527. {Ptr->getType(), Input->getType(), Mask->getType()},
  528. {Ptr, Builder.getInt32(Increment), Input, Mask});
  529. }
  530. Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
  531. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
  532. using namespace PatternMatch;
  533. Value *Input = I->getArgOperand(0);
  534. Value *Mask = I->getArgOperand(3);
  535. Type *InputTy = Input->getType();
  536. Type *MemoryTy = InputTy;
  537. LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
  538. << " to base + vector of offsets\n");
  539. // If the input has been truncated, try to integrate that trunc into the
  540. // scatter instruction (we don't care about alignment here)
  541. if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
  542. Value *PreTrunc = Trunc->getOperand(0);
  543. Type *PreTruncTy = PreTrunc->getType();
  544. if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
  545. Input = PreTrunc;
  546. InputTy = PreTruncTy;
  547. }
  548. }
  549. if (InputTy->getPrimitiveSizeInBits() != 128) {
  550. LLVM_DEBUG(
  551. dbgs() << "masked scatters: cannot create scatters for non-standard"
  552. << " input types. Expanding.\n");
  553. return nullptr;
  554. }
  555. GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
  556. Value *Offsets;
  557. Value *BasePtr =
  558. checkGEP(Offsets, cast<FixedVectorType>(InputTy), GEP, Builder);
  559. if (!BasePtr)
  560. return nullptr;
  561. // Check whether the offset is a constant increment that could be merged into
  562. // a QI gather
  563. Value *Store =
  564. tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
  565. if (Store)
  566. return Store;
  567. int Scale = computeScale(
  568. BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
  569. MemoryTy->getScalarSizeInBits());
  570. if (Scale == -1)
  571. return nullptr;
  572. if (!match(Mask, m_One()))
  573. return Builder.CreateIntrinsic(
  574. Intrinsic::arm_mve_vstr_scatter_offset_predicated,
  575. {BasePtr->getType(), Offsets->getType(), Input->getType(),
  576. Mask->getType()},
  577. {BasePtr, Offsets, Input,
  578. Builder.getInt32(MemoryTy->getScalarSizeInBits()),
  579. Builder.getInt32(Scale), Mask});
  580. else
  581. return Builder.CreateIntrinsic(
  582. Intrinsic::arm_mve_vstr_scatter_offset,
  583. {BasePtr->getType(), Offsets->getType(), Input->getType()},
  584. {BasePtr, Offsets, Input,
  585. Builder.getInt32(MemoryTy->getScalarSizeInBits()),
  586. Builder.getInt32(Scale)});
  587. }
  588. Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
  589. IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP,
  590. IRBuilder<> &Builder) {
  591. FixedVectorType *Ty;
  592. if (I->getIntrinsicID() == Intrinsic::masked_gather)
  593. Ty = cast<FixedVectorType>(I->getType());
  594. else
  595. Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
  596. // Incrementing gathers only exist for v4i32
  597. if (Ty->getNumElements() != 4 ||
  598. Ty->getScalarSizeInBits() != 32)
  599. return nullptr;
  600. Loop *L = LI->getLoopFor(I->getParent());
  601. if (L == nullptr)
  602. // Incrementing gathers are not beneficial outside of a loop
  603. return nullptr;
  604. LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
  605. "wb gather/scatter\n");
  606. // The gep was in charge of making sure the offsets are scaled correctly
  607. // - calculate that factor so it can be applied by hand
  608. DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout();
  609. int TypeScale =
  610. computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()),
  611. DT.getTypeSizeInBits(GEP->getType()) /
  612. cast<FixedVectorType>(GEP->getType())->getNumElements());
  613. if (TypeScale == -1)
  614. return nullptr;
  615. if (GEP->hasOneUse()) {
  616. // Only in this case do we want to build a wb gather, because the wb will
  617. // change the phi which does affect other users of the gep (which will still
  618. // be using the phi in the old way)
  619. Value *Load =
  620. tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder);
  621. if (Load != nullptr)
  622. return Load;
  623. }
  624. LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
  625. "non-wb gather/scatter\n");
  626. std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
  627. if (Add.first == nullptr)
  628. return nullptr;
  629. Value *OffsetsIncoming = Add.first;
  630. int64_t Immediate = Add.second;
  631. // Make sure the offsets are scaled correctly
  632. Instruction *ScaledOffsets = BinaryOperator::Create(
  633. Instruction::Shl, OffsetsIncoming,
  634. Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
  635. "ScaledIndex", I);
  636. // Add the base to the offsets
  637. OffsetsIncoming = BinaryOperator::Create(
  638. Instruction::Add, ScaledOffsets,
  639. Builder.CreateVectorSplat(
  640. Ty->getNumElements(),
  641. Builder.CreatePtrToInt(
  642. BasePtr,
  643. cast<VectorType>(ScaledOffsets->getType())->getElementType())),
  644. "StartIndex", I);
  645. if (I->getIntrinsicID() == Intrinsic::masked_gather)
  646. return cast<IntrinsicInst>(
  647. tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate));
  648. else
  649. return cast<IntrinsicInst>(
  650. tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate));
  651. }
  652. Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
  653. IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
  654. IRBuilder<> &Builder) {
  655. // Check whether this gather's offset is incremented by a constant - if so,
  656. // and the load is of the right type, we can merge this into a QI gather
  657. Loop *L = LI->getLoopFor(I->getParent());
  658. // Offsets that are worth merging into this instruction will be incremented
  659. // by a constant, thus we're looking for an add of a phi and a constant
  660. PHINode *Phi = dyn_cast<PHINode>(Offsets);
  661. if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
  662. Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
  663. // No phi means no IV to write back to; if there is a phi, we expect it
  664. // to have exactly two incoming values; the only phis we are interested in
  665. // will be loop IV's and have exactly two uses, one in their increment and
  666. // one in the gather's gep
  667. return nullptr;
  668. unsigned IncrementIndex =
  669. Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
  670. // Look through the phi to the phi increment
  671. Offsets = Phi->getIncomingValue(IncrementIndex);
  672. std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
  673. if (Add.first == nullptr)
  674. return nullptr;
  675. Value *OffsetsIncoming = Add.first;
  676. int64_t Immediate = Add.second;
  677. if (OffsetsIncoming != Phi)
  678. // Then the increment we are looking at is not an increment of the
  679. // induction variable, and we don't want to do a writeback
  680. return nullptr;
  681. Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
  682. unsigned NumElems =
  683. cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
  684. // Make sure the offsets are scaled correctly
  685. Instruction *ScaledOffsets = BinaryOperator::Create(
  686. Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
  687. Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
  688. "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
  689. // Add the base to the offsets
  690. OffsetsIncoming = BinaryOperator::Create(
  691. Instruction::Add, ScaledOffsets,
  692. Builder.CreateVectorSplat(
  693. NumElems,
  694. Builder.CreatePtrToInt(
  695. BasePtr,
  696. cast<VectorType>(ScaledOffsets->getType())->getElementType())),
  697. "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
  698. // The gather is pre-incrementing
  699. OffsetsIncoming = BinaryOperator::Create(
  700. Instruction::Sub, OffsetsIncoming,
  701. Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
  702. "PreIncrementStartIndex",
  703. &Phi->getIncomingBlock(1 - IncrementIndex)->back());
  704. Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
  705. Builder.SetInsertPoint(I);
  706. Value *EndResult;
  707. Value *NewInduction;
  708. if (I->getIntrinsicID() == Intrinsic::masked_gather) {
  709. // Build the incrementing gather
  710. Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
  711. // One value to be handed to whoever uses the gather, one is the loop
  712. // increment
  713. EndResult = Builder.CreateExtractValue(Load, 0, "Gather");
  714. NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement");
  715. } else {
  716. // Build the incrementing scatter
  717. NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
  718. EndResult = NewInduction;
  719. }
  720. Instruction *AddInst = cast<Instruction>(Offsets);
  721. AddInst->replaceAllUsesWith(NewInduction);
  722. AddInst->eraseFromParent();
  723. Phi->setIncomingValue(IncrementIndex, NewInduction);
  724. return EndResult;
  725. }
  726. void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
  727. Value *OffsSecondOperand,
  728. unsigned StartIndex) {
  729. LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
  730. Instruction *InsertionPoint =
  731. &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
  732. // Initialize the phi with a vector that contains a sum of the constants
  733. Instruction *NewIndex = BinaryOperator::Create(
  734. Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
  735. "PushedOutAdd", InsertionPoint);
  736. unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
  737. // Order such that start index comes first (this reduces mov's)
  738. Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
  739. Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
  740. Phi->getIncomingBlock(IncrementIndex));
  741. Phi->removeIncomingValue(IncrementIndex);
  742. Phi->removeIncomingValue(StartIndex);
  743. }
  744. void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
  745. Value *IncrementPerRound,
  746. Value *OffsSecondOperand,
  747. unsigned LoopIncrement,
  748. IRBuilder<> &Builder) {
  749. LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
  750. // Create a new scalar add outside of the loop and transform it to a splat
  751. // by which loop variable can be incremented
  752. Instruction *InsertionPoint = &cast<Instruction>(
  753. Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
  754. // Create a new index
  755. Value *StartIndex = BinaryOperator::Create(
  756. Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
  757. OffsSecondOperand, "PushedOutMul", InsertionPoint);
  758. Instruction *Product =
  759. BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
  760. OffsSecondOperand, "Product", InsertionPoint);
  761. // Increment NewIndex by Product instead of the multiplication
  762. Instruction *NewIncrement = BinaryOperator::Create(
  763. Instruction::Add, Phi, Product, "IncrementPushedOutMul",
  764. cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
  765. .getPrevNode());
  766. Phi->addIncoming(StartIndex,
  767. Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
  768. Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
  769. Phi->removeIncomingValue((unsigned)0);
  770. Phi->removeIncomingValue((unsigned)0);
  771. }
  772. // Check whether all usages of this instruction are as offsets of
  773. // gathers/scatters or simple arithmetics only used by gathers/scatters
  774. static bool hasAllGatScatUsers(Instruction *I) {
  775. if (I->hasNUses(0)) {
  776. return false;
  777. }
  778. bool Gatscat = true;
  779. for (User *U : I->users()) {
  780. if (!isa<Instruction>(U))
  781. return false;
  782. if (isa<GetElementPtrInst>(U) ||
  783. isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
  784. return Gatscat;
  785. } else {
  786. unsigned OpCode = cast<Instruction>(U)->getOpcode();
  787. if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
  788. hasAllGatScatUsers(cast<Instruction>(U))) {
  789. continue;
  790. }
  791. return false;
  792. }
  793. }
  794. return Gatscat;
  795. }
  796. bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
  797. LoopInfo *LI) {
  798. LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n");
  799. // Optimise the addresses of gathers/scatters by moving invariant
  800. // calculations out of the loop
  801. if (!isa<Instruction>(Offsets))
  802. return false;
  803. Instruction *Offs = cast<Instruction>(Offsets);
  804. if (Offs->getOpcode() != Instruction::Add &&
  805. Offs->getOpcode() != Instruction::Mul)
  806. return false;
  807. Loop *L = LI->getLoopFor(BB);
  808. if (L == nullptr)
  809. return false;
  810. if (!Offs->hasOneUse()) {
  811. if (!hasAllGatScatUsers(Offs))
  812. return false;
  813. }
  814. // Find out which, if any, operand of the instruction
  815. // is a phi node
  816. PHINode *Phi;
  817. int OffsSecondOp;
  818. if (isa<PHINode>(Offs->getOperand(0))) {
  819. Phi = cast<PHINode>(Offs->getOperand(0));
  820. OffsSecondOp = 1;
  821. } else if (isa<PHINode>(Offs->getOperand(1))) {
  822. Phi = cast<PHINode>(Offs->getOperand(1));
  823. OffsSecondOp = 0;
  824. } else {
  825. bool Changed = true;
  826. if (isa<Instruction>(Offs->getOperand(0)) &&
  827. L->contains(cast<Instruction>(Offs->getOperand(0))))
  828. Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
  829. if (isa<Instruction>(Offs->getOperand(1)) &&
  830. L->contains(cast<Instruction>(Offs->getOperand(1))))
  831. Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
  832. if (!Changed) {
  833. return false;
  834. } else {
  835. if (isa<PHINode>(Offs->getOperand(0))) {
  836. Phi = cast<PHINode>(Offs->getOperand(0));
  837. OffsSecondOp = 1;
  838. } else if (isa<PHINode>(Offs->getOperand(1))) {
  839. Phi = cast<PHINode>(Offs->getOperand(1));
  840. OffsSecondOp = 0;
  841. } else {
  842. return false;
  843. }
  844. }
  845. }
  846. // A phi node we want to perform this function on should be from the
  847. // loop header, and shouldn't have more than 2 incoming values
  848. if (Phi->getParent() != L->getHeader() ||
  849. Phi->getNumIncomingValues() != 2)
  850. return false;
  851. // The phi must be an induction variable
  852. int IncrementingBlock = -1;
  853. for (int i = 0; i < 2; i++)
  854. if (auto *Op = dyn_cast<Instruction>(Phi->getIncomingValue(i)))
  855. if (Op->getOpcode() == Instruction::Add &&
  856. (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi))
  857. IncrementingBlock = i;
  858. if (IncrementingBlock == -1)
  859. return false;
  860. Instruction *IncInstruction =
  861. cast<Instruction>(Phi->getIncomingValue(IncrementingBlock));
  862. // If the phi is not used by anything else, we can just adapt it when
  863. // replacing the instruction; if it is, we'll have to duplicate it
  864. PHINode *NewPhi;
  865. Value *IncrementPerRound = IncInstruction->getOperand(
  866. (IncInstruction->getOperand(0) == Phi) ? 1 : 0);
  867. // Get the value that is added to/multiplied with the phi
  868. Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
  869. if (IncrementPerRound->getType() != OffsSecondOperand->getType() ||
  870. !L->isLoopInvariant(OffsSecondOperand))
  871. // Something has gone wrong, abort
  872. return false;
  873. // Only proceed if the increment per round is a constant or an instruction
  874. // which does not originate from within the loop
  875. if (!isa<Constant>(IncrementPerRound) &&
  876. !(isa<Instruction>(IncrementPerRound) &&
  877. !L->contains(cast<Instruction>(IncrementPerRound))))
  878. return false;
  879. if (Phi->getNumUses() == 2) {
  880. // No other users -> reuse existing phi (One user is the instruction
  881. // we're looking at, the other is the phi increment)
  882. if (IncInstruction->getNumUses() != 1) {
  883. // If the incrementing instruction does have more users than
  884. // our phi, we need to copy it
  885. IncInstruction = BinaryOperator::Create(
  886. Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
  887. IncrementPerRound, "LoopIncrement", IncInstruction);
  888. Phi->setIncomingValue(IncrementingBlock, IncInstruction);
  889. }
  890. NewPhi = Phi;
  891. } else {
  892. // There are other users -> create a new phi
  893. NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi);
  894. std::vector<Value *> Increases;
  895. // Copy the incoming values of the old phi
  896. NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
  897. Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
  898. IncInstruction = BinaryOperator::Create(
  899. Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
  900. IncrementPerRound, "LoopIncrement", IncInstruction);
  901. NewPhi->addIncoming(IncInstruction,
  902. Phi->getIncomingBlock(IncrementingBlock));
  903. IncrementingBlock = 1;
  904. }
  905. IRBuilder<> Builder(BB->getContext());
  906. Builder.SetInsertPoint(Phi);
  907. Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
  908. switch (Offs->getOpcode()) {
  909. case Instruction::Add:
  910. pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
  911. break;
  912. case Instruction::Mul:
  913. pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
  914. Builder);
  915. break;
  916. default:
  917. return false;
  918. }
  919. LLVM_DEBUG(
  920. dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n");
  921. // The instruction has now been "absorbed" into the phi value
  922. Offs->replaceAllUsesWith(NewPhi);
  923. if (Offs->hasNUses(0))
  924. Offs->eraseFromParent();
  925. // Clean up the old increment in case it's unused because we built a new
  926. // one
  927. if (IncInstruction->hasNUses(0))
  928. IncInstruction->eraseFromParent();
  929. return true;
  930. }
  931. static Value *CheckAndCreateOffsetAdd(Value *X, Value *Y, Value *GEP,
  932. IRBuilder<> &Builder) {
  933. // Splat the non-vector value to a vector of the given type - if the value is
  934. // a constant (and its value isn't too big), we can even use this opportunity
  935. // to scale it to the size of the vector elements
  936. auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
  937. ConstantInt *Const;
  938. if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
  939. VT->getElementType() != NonVectorVal->getType()) {
  940. unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
  941. uint64_t N = Const->getZExtValue();
  942. if (N < (unsigned)(1 << (TargetElemSize - 1))) {
  943. NonVectorVal = Builder.CreateVectorSplat(
  944. VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
  945. return;
  946. }
  947. }
  948. NonVectorVal =
  949. Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
  950. };
  951. FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
  952. FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
  953. // If one of X, Y is not a vector, we have to splat it in order
  954. // to add the two of them.
  955. if (XElType && !YElType) {
  956. FixSummands(XElType, Y);
  957. YElType = cast<FixedVectorType>(Y->getType());
  958. } else if (YElType && !XElType) {
  959. FixSummands(YElType, X);
  960. XElType = cast<FixedVectorType>(X->getType());
  961. }
  962. assert(XElType && YElType && "Unknown vector types");
  963. // Check that the summands are of compatible types
  964. if (XElType != YElType) {
  965. LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
  966. return nullptr;
  967. }
  968. if (XElType->getElementType()->getScalarSizeInBits() != 32) {
  969. // Check that by adding the vectors we do not accidentally
  970. // create an overflow
  971. Constant *ConstX = dyn_cast<Constant>(X);
  972. Constant *ConstY = dyn_cast<Constant>(Y);
  973. if (!ConstX || !ConstY)
  974. return nullptr;
  975. unsigned TargetElemSize = 128 / XElType->getNumElements();
  976. for (unsigned i = 0; i < XElType->getNumElements(); i++) {
  977. ConstantInt *ConstXEl =
  978. dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
  979. ConstantInt *ConstYEl =
  980. dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
  981. if (!ConstXEl || !ConstYEl ||
  982. ConstXEl->getZExtValue() + ConstYEl->getZExtValue() >=
  983. (unsigned)(1 << (TargetElemSize - 1)))
  984. return nullptr;
  985. }
  986. }
  987. Value *Add = Builder.CreateAdd(X, Y);
  988. FixedVectorType *GEPType = cast<FixedVectorType>(GEP->getType());
  989. if (checkOffsetSize(Add, GEPType->getNumElements()))
  990. return Add;
  991. else
  992. return nullptr;
  993. }
  994. Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
  995. Value *&Offsets,
  996. IRBuilder<> &Builder) {
  997. Value *GEPPtr = GEP->getPointerOperand();
  998. Offsets = GEP->getOperand(1);
  999. // We only merge geps with constant offsets, because only for those
  1000. // we can make sure that we do not cause an overflow
  1001. if (!isa<Constant>(Offsets))
  1002. return nullptr;
  1003. GetElementPtrInst *BaseGEP;
  1004. if ((BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr))) {
  1005. // Merge the two geps into one
  1006. Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Builder);
  1007. if (!BaseBasePtr)
  1008. return nullptr;
  1009. Offsets =
  1010. CheckAndCreateOffsetAdd(Offsets, GEP->getOperand(1), GEP, Builder);
  1011. if (Offsets == nullptr)
  1012. return nullptr;
  1013. return BaseBasePtr;
  1014. }
  1015. return GEPPtr;
  1016. }
  1017. bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
  1018. LoopInfo *LI) {
  1019. GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
  1020. if (!GEP)
  1021. return false;
  1022. bool Changed = false;
  1023. if (GEP->hasOneUse() &&
  1024. dyn_cast<GetElementPtrInst>(GEP->getPointerOperand())) {
  1025. IRBuilder<> Builder(GEP->getContext());
  1026. Builder.SetInsertPoint(GEP);
  1027. Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
  1028. Value *Offsets;
  1029. Value *Base = foldGEP(GEP, Offsets, Builder);
  1030. // We only want to merge the geps if there is a real chance that they can be
  1031. // used by an MVE gather; thus the offset has to have the correct size
  1032. // (always i32 if it is not of vector type) and the base has to be a
  1033. // pointer.
  1034. if (Offsets && Base && Base != GEP) {
  1035. PointerType *BaseType = cast<PointerType>(Base->getType());
  1036. GetElementPtrInst *NewAddress = GetElementPtrInst::Create(
  1037. BaseType->getPointerElementType(), Base, Offsets, "gep.merged", GEP);
  1038. GEP->replaceAllUsesWith(NewAddress);
  1039. GEP = NewAddress;
  1040. Changed = true;
  1041. }
  1042. }
  1043. Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
  1044. return Changed;
  1045. }
  1046. bool MVEGatherScatterLowering::runOnFunction(Function &F) {
  1047. if (!EnableMaskedGatherScatters)
  1048. return false;
  1049. auto &TPC = getAnalysis<TargetPassConfig>();
  1050. auto &TM = TPC.getTM<TargetMachine>();
  1051. auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
  1052. if (!ST->hasMVEIntegerOps())
  1053. return false;
  1054. LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  1055. SmallVector<IntrinsicInst *, 4> Gathers;
  1056. SmallVector<IntrinsicInst *, 4> Scatters;
  1057. bool Changed = false;
  1058. for (BasicBlock &BB : F) {
  1059. Changed |= SimplifyInstructionsInBlock(&BB);
  1060. for (Instruction &I : BB) {
  1061. IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
  1062. if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
  1063. isa<FixedVectorType>(II->getType())) {
  1064. Gathers.push_back(II);
  1065. Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
  1066. } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
  1067. isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
  1068. Scatters.push_back(II);
  1069. Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
  1070. }
  1071. }
  1072. }
  1073. for (unsigned i = 0; i < Gathers.size(); i++) {
  1074. IntrinsicInst *I = Gathers[i];
  1075. Value *L = lowerGather(I);
  1076. if (L == nullptr)
  1077. continue;
  1078. // Get rid of any now dead instructions
  1079. SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent());
  1080. Changed = true;
  1081. }
  1082. for (unsigned i = 0; i < Scatters.size(); i++) {
  1083. IntrinsicInst *I = Scatters[i];
  1084. Value *S = lowerScatter(I);
  1085. if (S == nullptr)
  1086. continue;
  1087. // Get rid of any now dead instructions
  1088. SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent());
  1089. Changed = true;
  1090. }
  1091. return Changed;
  1092. }