MVEGatherScatterLowering.cpp 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298
  1. //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. /// This pass custom lowers llvm.gather and llvm.scatter instructions to
  10. /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
  11. /// produce a better final result as we go.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "ARM.h"
  15. #include "ARMBaseInstrInfo.h"
  16. #include "ARMSubtarget.h"
  17. #include "llvm/Analysis/LoopInfo.h"
  18. #include "llvm/Analysis/TargetTransformInfo.h"
  19. #include "llvm/Analysis/ValueTracking.h"
  20. #include "llvm/CodeGen/TargetLowering.h"
  21. #include "llvm/CodeGen/TargetPassConfig.h"
  22. #include "llvm/CodeGen/TargetSubtargetInfo.h"
  23. #include "llvm/InitializePasses.h"
  24. #include "llvm/IR/BasicBlock.h"
  25. #include "llvm/IR/Constant.h"
  26. #include "llvm/IR/Constants.h"
  27. #include "llvm/IR/DerivedTypes.h"
  28. #include "llvm/IR/Function.h"
  29. #include "llvm/IR/InstrTypes.h"
  30. #include "llvm/IR/Instruction.h"
  31. #include "llvm/IR/Instructions.h"
  32. #include "llvm/IR/IntrinsicInst.h"
  33. #include "llvm/IR/Intrinsics.h"
  34. #include "llvm/IR/IntrinsicsARM.h"
  35. #include "llvm/IR/IRBuilder.h"
  36. #include "llvm/IR/PatternMatch.h"
  37. #include "llvm/IR/Type.h"
  38. #include "llvm/IR/Value.h"
  39. #include "llvm/Pass.h"
  40. #include "llvm/Support/Casting.h"
  41. #include "llvm/Transforms/Utils/Local.h"
  42. #include <algorithm>
  43. #include <cassert>
  44. using namespace llvm;
  45. #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
  46. cl::opt<bool> EnableMaskedGatherScatters(
  47. "enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
  48. cl::desc("Enable the generation of masked gathers and scatters"));
  49. namespace {
  50. class MVEGatherScatterLowering : public FunctionPass {
  51. public:
  52. static char ID; // Pass identification, replacement for typeid
  53. explicit MVEGatherScatterLowering() : FunctionPass(ID) {
  54. initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
  55. }
  56. bool runOnFunction(Function &F) override;
  57. StringRef getPassName() const override {
  58. return "MVE gather/scatter lowering";
  59. }
  60. void getAnalysisUsage(AnalysisUsage &AU) const override {
  61. AU.setPreservesCFG();
  62. AU.addRequired<TargetPassConfig>();
  63. AU.addRequired<LoopInfoWrapperPass>();
  64. FunctionPass::getAnalysisUsage(AU);
  65. }
  66. private:
  67. LoopInfo *LI = nullptr;
  68. const DataLayout *DL;
  69. // Check this is a valid gather with correct alignment
  70. bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
  71. Align Alignment);
  72. // Check whether Ptr is hidden behind a bitcast and look through it
  73. void lookThroughBitcast(Value *&Ptr);
  74. // Decompose a ptr into Base and Offsets, potentially using a GEP to return a
  75. // scalar base and vector offsets, or else fallback to using a base of 0 and
  76. // offset of Ptr where possible.
  77. Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale,
  78. FixedVectorType *Ty, Type *MemoryTy,
  79. IRBuilder<> &Builder);
  80. // Check for a getelementptr and deduce base and offsets from it, on success
  81. // returning the base directly and the offsets indirectly using the Offsets
  82. // argument
  83. Value *decomposeGEP(Value *&Offsets, FixedVectorType *Ty,
  84. GetElementPtrInst *GEP, IRBuilder<> &Builder);
  85. // Compute the scale of this gather/scatter instruction
  86. int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
  87. // If the value is a constant, or derived from constants via additions
  88. // and multilications, return its numeric value
  89. std::optional<int64_t> getIfConst(const Value *V);
  90. // If Inst is an add instruction, check whether one summand is a
  91. // constant. If so, scale this constant and return it together with
  92. // the other summand.
  93. std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
  94. Instruction *lowerGather(IntrinsicInst *I);
  95. // Create a gather from a base + vector of offsets
  96. Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
  97. Instruction *&Root,
  98. IRBuilder<> &Builder);
  99. // Create a gather from a vector of pointers
  100. Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
  101. IRBuilder<> &Builder,
  102. int64_t Increment = 0);
  103. // Create an incrementing gather from a vector of pointers
  104. Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
  105. IRBuilder<> &Builder,
  106. int64_t Increment = 0);
  107. Instruction *lowerScatter(IntrinsicInst *I);
  108. // Create a scatter to a base + vector of offsets
  109. Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
  110. IRBuilder<> &Builder);
  111. // Create a scatter to a vector of pointers
  112. Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
  113. IRBuilder<> &Builder,
  114. int64_t Increment = 0);
  115. // Create an incrementing scatter from a vector of pointers
  116. Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
  117. IRBuilder<> &Builder,
  118. int64_t Increment = 0);
  119. // QI gathers and scatters can increment their offsets on their own if
  120. // the increment is a constant value (digit)
  121. Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr,
  122. IRBuilder<> &Builder);
  123. // QI gathers/scatters can increment their offsets on their own if the
  124. // increment is a constant value (digit) - this creates a writeback QI
  125. // gather/scatter
  126. Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
  127. Value *Ptr, unsigned TypeScale,
  128. IRBuilder<> &Builder);
  129. // Optimise the base and offsets of the given address
  130. bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
  131. // Try to fold consecutive geps together into one
  132. Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, unsigned &Scale,
  133. IRBuilder<> &Builder);
  134. // Check whether these offsets could be moved out of the loop they're in
  135. bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
  136. // Pushes the given add out of the loop
  137. void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
  138. // Pushes the given mul or shl out of the loop
  139. void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound,
  140. Value *OffsSecondOperand, unsigned LoopIncrement,
  141. IRBuilder<> &Builder);
  142. };
  143. } // end anonymous namespace
  144. char MVEGatherScatterLowering::ID = 0;
  145. INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
  146. "MVE gather/scattering lowering pass", false, false)
  147. Pass *llvm::createMVEGatherScatterLoweringPass() {
  148. return new MVEGatherScatterLowering();
  149. }
  150. bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
  151. unsigned ElemSize,
  152. Align Alignment) {
  153. if (((NumElements == 4 &&
  154. (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
  155. (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
  156. (NumElements == 16 && ElemSize == 8)) &&
  157. Alignment >= ElemSize / 8)
  158. return true;
  159. LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
  160. << "valid alignment or vector type \n");
  161. return false;
  162. }
  163. static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
  164. // Offsets that are not of type <N x i32> are sign extended by the
  165. // getelementptr instruction, and MVE gathers/scatters treat the offset as
  166. // unsigned. Thus, if the element size is smaller than 32, we can only allow
  167. // positive offsets - i.e., the offsets are not allowed to be variables we
  168. // can't look into.
  169. // Additionally, <N x i32> offsets have to either originate from a zext of a
  170. // vector with element types smaller or equal the type of the gather we're
  171. // looking at, or consist of constants that we can check are small enough
  172. // to fit into the gather type.
  173. // Thus we check that 0 < value < 2^TargetElemSize.
  174. unsigned TargetElemSize = 128 / TargetElemCount;
  175. unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
  176. ->getElementType()
  177. ->getScalarSizeInBits();
  178. if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
  179. Constant *ConstOff = dyn_cast<Constant>(Offsets);
  180. if (!ConstOff)
  181. return false;
  182. int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
  183. auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
  184. ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
  185. if (!OConst)
  186. return false;
  187. int SExtValue = OConst->getSExtValue();
  188. if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
  189. return false;
  190. return true;
  191. };
  192. if (isa<FixedVectorType>(ConstOff->getType())) {
  193. for (unsigned i = 0; i < TargetElemCount; i++) {
  194. if (!CheckValueSize(ConstOff->getAggregateElement(i)))
  195. return false;
  196. }
  197. } else {
  198. if (!CheckValueSize(ConstOff))
  199. return false;
  200. }
  201. }
  202. return true;
  203. }
  204. Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets,
  205. int &Scale, FixedVectorType *Ty,
  206. Type *MemoryTy,
  207. IRBuilder<> &Builder) {
  208. if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
  209. if (Value *V = decomposeGEP(Offsets, Ty, GEP, Builder)) {
  210. Scale =
  211. computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(),
  212. MemoryTy->getScalarSizeInBits());
  213. return Scale == -1 ? nullptr : V;
  214. }
  215. }
  216. // If we couldn't use the GEP (or it doesn't exist), attempt to use a
  217. // BasePtr of 0 with Ptr as the Offsets, so long as there are only 4
  218. // elements.
  219. FixedVectorType *PtrTy = cast<FixedVectorType>(Ptr->getType());
  220. if (PtrTy->getNumElements() != 4 || MemoryTy->getScalarSizeInBits() == 32)
  221. return nullptr;
  222. Value *Zero = ConstantInt::get(Builder.getInt32Ty(), 0);
  223. Value *BasePtr = Builder.CreateIntToPtr(Zero, Builder.getInt8PtrTy());
  224. Offsets = Builder.CreatePtrToInt(
  225. Ptr, FixedVectorType::get(Builder.getInt32Ty(), 4));
  226. Scale = 0;
  227. return BasePtr;
  228. }
  229. Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets,
  230. FixedVectorType *Ty,
  231. GetElementPtrInst *GEP,
  232. IRBuilder<> &Builder) {
  233. if (!GEP) {
  234. LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer "
  235. << "found\n");
  236. return nullptr;
  237. }
  238. LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
  239. << " Looking at intrinsic for base + vector of offsets\n");
  240. Value *GEPPtr = GEP->getPointerOperand();
  241. Offsets = GEP->getOperand(1);
  242. if (GEPPtr->getType()->isVectorTy() ||
  243. !isa<FixedVectorType>(Offsets->getType()))
  244. return nullptr;
  245. if (GEP->getNumOperands() != 2) {
  246. LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
  247. << " operands. Expanding.\n");
  248. return nullptr;
  249. }
  250. Offsets = GEP->getOperand(1);
  251. unsigned OffsetsElemCount =
  252. cast<FixedVectorType>(Offsets->getType())->getNumElements();
  253. // Paranoid check whether the number of parallel lanes is the same
  254. assert(Ty->getNumElements() == OffsetsElemCount);
  255. ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
  256. if (ZextOffs)
  257. Offsets = ZextOffs->getOperand(0);
  258. FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
  259. // If the offsets are already being zext-ed to <N x i32>, that relieves us of
  260. // having to make sure that they won't overflow.
  261. if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
  262. ->getElementType()
  263. ->getScalarSizeInBits() != 32)
  264. if (!checkOffsetSize(Offsets, OffsetsElemCount))
  265. return nullptr;
  266. // The offset sizes have been checked; if any truncating or zext-ing is
  267. // required to fix them, do that now
  268. if (Ty != Offsets->getType()) {
  269. if ((Ty->getElementType()->getScalarSizeInBits() <
  270. OffsetType->getElementType()->getScalarSizeInBits())) {
  271. Offsets = Builder.CreateTrunc(Offsets, Ty);
  272. } else {
  273. Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
  274. }
  275. }
  276. // If none of the checks failed, return the gep's base pointer
  277. LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
  278. return GEPPtr;
  279. }
  280. void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
  281. // Look through bitcast instruction if #elements is the same
  282. if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
  283. auto *BCTy = cast<FixedVectorType>(BitCast->getType());
  284. auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
  285. if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
  286. LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through "
  287. << "bitcast\n");
  288. Ptr = BitCast->getOperand(0);
  289. }
  290. }
  291. }
  292. int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
  293. unsigned MemoryElemSize) {
  294. // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
  295. // or a 8bit, 16bit or 32bit load/store scaled by 1
  296. if (GEPElemSize == 32 && MemoryElemSize == 32)
  297. return 2;
  298. else if (GEPElemSize == 16 && MemoryElemSize == 16)
  299. return 1;
  300. else if (GEPElemSize == 8)
  301. return 0;
  302. LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
  303. << "create intrinsic\n");
  304. return -1;
  305. }
  306. std::optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
  307. const Constant *C = dyn_cast<Constant>(V);
  308. if (C && C->getSplatValue())
  309. return std::optional<int64_t>{C->getUniqueInteger().getSExtValue()};
  310. if (!isa<Instruction>(V))
  311. return std::optional<int64_t>{};
  312. const Instruction *I = cast<Instruction>(V);
  313. if (I->getOpcode() == Instruction::Add || I->getOpcode() == Instruction::Or ||
  314. I->getOpcode() == Instruction::Mul ||
  315. I->getOpcode() == Instruction::Shl) {
  316. std::optional<int64_t> Op0 = getIfConst(I->getOperand(0));
  317. std::optional<int64_t> Op1 = getIfConst(I->getOperand(1));
  318. if (!Op0 || !Op1)
  319. return std::optional<int64_t>{};
  320. if (I->getOpcode() == Instruction::Add)
  321. return std::optional<int64_t>{*Op0 + *Op1};
  322. if (I->getOpcode() == Instruction::Mul)
  323. return std::optional<int64_t>{*Op0 * *Op1};
  324. if (I->getOpcode() == Instruction::Shl)
  325. return std::optional<int64_t>{*Op0 << *Op1};
  326. if (I->getOpcode() == Instruction::Or)
  327. return std::optional<int64_t>{*Op0 | *Op1};
  328. }
  329. return std::optional<int64_t>{};
  330. }
  331. // Return true if I is an Or instruction that is equivalent to an add, due to
  332. // the operands having no common bits set.
  333. static bool isAddLikeOr(Instruction *I, const DataLayout &DL) {
  334. return I->getOpcode() == Instruction::Or &&
  335. haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL);
  336. }
  337. std::pair<Value *, int64_t>
  338. MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
  339. std::pair<Value *, int64_t> ReturnFalse =
  340. std::pair<Value *, int64_t>(nullptr, 0);
  341. // At this point, the instruction we're looking at must be an add or an
  342. // add-like-or.
  343. Instruction *Add = dyn_cast<Instruction>(Inst);
  344. if (Add == nullptr ||
  345. (Add->getOpcode() != Instruction::Add && !isAddLikeOr(Add, *DL)))
  346. return ReturnFalse;
  347. Value *Summand;
  348. std::optional<int64_t> Const;
  349. // Find out which operand the value that is increased is
  350. if ((Const = getIfConst(Add->getOperand(0))))
  351. Summand = Add->getOperand(1);
  352. else if ((Const = getIfConst(Add->getOperand(1))))
  353. Summand = Add->getOperand(0);
  354. else
  355. return ReturnFalse;
  356. // Check that the constant is small enough for an incrementing gather
  357. int64_t Immediate = *Const << TypeScale;
  358. if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
  359. return ReturnFalse;
  360. return std::pair<Value *, int64_t>(Summand, Immediate);
  361. }
  362. Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
  363. using namespace PatternMatch;
  364. LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"
  365. << *I << "\n");
  366. // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
  367. // Attempt to turn the masked gather in I into a MVE intrinsic
  368. // Potentially optimising the addressing modes as we do so.
  369. auto *Ty = cast<FixedVectorType>(I->getType());
  370. Value *Ptr = I->getArgOperand(0);
  371. Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
  372. Value *Mask = I->getArgOperand(2);
  373. Value *PassThru = I->getArgOperand(3);
  374. if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
  375. Alignment))
  376. return nullptr;
  377. lookThroughBitcast(Ptr);
  378. assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
  379. IRBuilder<> Builder(I->getContext());
  380. Builder.SetInsertPoint(I);
  381. Builder.SetCurrentDebugLocation(I->getDebugLoc());
  382. Instruction *Root = I;
  383. Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder);
  384. if (!Load)
  385. Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
  386. if (!Load)
  387. Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
  388. if (!Load)
  389. return nullptr;
  390. if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
  391. LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
  392. << "creating select\n");
  393. Load = SelectInst::Create(Mask, Load, PassThru);
  394. Builder.Insert(Load);
  395. }
  396. Root->replaceAllUsesWith(Load);
  397. Root->eraseFromParent();
  398. if (Root != I)
  399. // If this was an extending gather, we need to get rid of the sext/zext
  400. // sext/zext as well as of the gather itself
  401. I->eraseFromParent();
  402. LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"
  403. << *Load << "\n");
  404. return Load;
  405. }
  406. Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
  407. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
  408. using namespace PatternMatch;
  409. auto *Ty = cast<FixedVectorType>(I->getType());
  410. LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
  411. if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
  412. // Can't build an intrinsic for this
  413. return nullptr;
  414. Value *Mask = I->getArgOperand(2);
  415. if (match(Mask, m_One()))
  416. return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
  417. {Ty, Ptr->getType()},
  418. {Ptr, Builder.getInt32(Increment)});
  419. else
  420. return Builder.CreateIntrinsic(
  421. Intrinsic::arm_mve_vldr_gather_base_predicated,
  422. {Ty, Ptr->getType(), Mask->getType()},
  423. {Ptr, Builder.getInt32(Increment), Mask});
  424. }
  425. Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
  426. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
  427. using namespace PatternMatch;
  428. auto *Ty = cast<FixedVectorType>(I->getType());
  429. LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with "
  430. << "writeback\n");
  431. if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
  432. // Can't build an intrinsic for this
  433. return nullptr;
  434. Value *Mask = I->getArgOperand(2);
  435. if (match(Mask, m_One()))
  436. return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
  437. {Ty, Ptr->getType()},
  438. {Ptr, Builder.getInt32(Increment)});
  439. else
  440. return Builder.CreateIntrinsic(
  441. Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
  442. {Ty, Ptr->getType(), Mask->getType()},
  443. {Ptr, Builder.getInt32(Increment), Mask});
  444. }
  445. Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
  446. IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
  447. using namespace PatternMatch;
  448. Type *MemoryTy = I->getType();
  449. Type *ResultTy = MemoryTy;
  450. unsigned Unsigned = 1;
  451. // The size of the gather was already checked in isLegalTypeAndAlignment;
  452. // if it was not a full vector width an appropriate extend should follow.
  453. auto *Extend = Root;
  454. bool TruncResult = false;
  455. if (MemoryTy->getPrimitiveSizeInBits() < 128) {
  456. if (I->hasOneUse()) {
  457. // If the gather has a single extend of the correct type, use an extending
  458. // gather and replace the ext. In which case the correct root to replace
  459. // is not the CallInst itself, but the instruction which extends it.
  460. Instruction* User = cast<Instruction>(*I->users().begin());
  461. if (isa<SExtInst>(User) &&
  462. User->getType()->getPrimitiveSizeInBits() == 128) {
  463. LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
  464. << *User << "\n");
  465. Extend = User;
  466. ResultTy = User->getType();
  467. Unsigned = 0;
  468. } else if (isa<ZExtInst>(User) &&
  469. User->getType()->getPrimitiveSizeInBits() == 128) {
  470. LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
  471. << *ResultTy << "\n");
  472. Extend = User;
  473. ResultTy = User->getType();
  474. }
  475. }
  476. // If an extend hasn't been found and the type is an integer, create an
  477. // extending gather and truncate back to the original type.
  478. if (ResultTy->getPrimitiveSizeInBits() < 128 &&
  479. ResultTy->isIntOrIntVectorTy()) {
  480. ResultTy = ResultTy->getWithNewBitWidth(
  481. 128 / cast<FixedVectorType>(ResultTy)->getNumElements());
  482. TruncResult = true;
  483. LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: "
  484. << *ResultTy << "\n");
  485. }
  486. // The final size of the gather must be a full vector width
  487. if (ResultTy->getPrimitiveSizeInBits() != 128) {
  488. LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided "
  489. "from the correct type. Expanding\n");
  490. return nullptr;
  491. }
  492. }
  493. Value *Offsets;
  494. int Scale;
  495. Value *BasePtr = decomposePtr(
  496. Ptr, Offsets, Scale, cast<FixedVectorType>(ResultTy), MemoryTy, Builder);
  497. if (!BasePtr)
  498. return nullptr;
  499. Root = Extend;
  500. Value *Mask = I->getArgOperand(2);
  501. Instruction *Load = nullptr;
  502. if (!match(Mask, m_One()))
  503. Load = Builder.CreateIntrinsic(
  504. Intrinsic::arm_mve_vldr_gather_offset_predicated,
  505. {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
  506. {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
  507. Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
  508. else
  509. Load = Builder.CreateIntrinsic(
  510. Intrinsic::arm_mve_vldr_gather_offset,
  511. {ResultTy, BasePtr->getType(), Offsets->getType()},
  512. {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
  513. Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
  514. if (TruncResult) {
  515. Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy);
  516. Builder.Insert(Load);
  517. }
  518. return Load;
  519. }
  520. Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
  521. using namespace PatternMatch;
  522. LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"
  523. << *I << "\n");
  524. // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
  525. // Attempt to turn the masked scatter in I into a MVE intrinsic
  526. // Potentially optimising the addressing modes as we do so.
  527. Value *Input = I->getArgOperand(0);
  528. Value *Ptr = I->getArgOperand(1);
  529. Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
  530. auto *Ty = cast<FixedVectorType>(Input->getType());
  531. if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
  532. Alignment))
  533. return nullptr;
  534. lookThroughBitcast(Ptr);
  535. assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
  536. IRBuilder<> Builder(I->getContext());
  537. Builder.SetInsertPoint(I);
  538. Builder.SetCurrentDebugLocation(I->getDebugLoc());
  539. Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder);
  540. if (!Store)
  541. Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
  542. if (!Store)
  543. Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
  544. if (!Store)
  545. return nullptr;
  546. LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"
  547. << *Store << "\n");
  548. I->eraseFromParent();
  549. return Store;
  550. }
  551. Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
  552. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
  553. using namespace PatternMatch;
  554. Value *Input = I->getArgOperand(0);
  555. auto *Ty = cast<FixedVectorType>(Input->getType());
  556. // Only QR variants allow truncating
  557. if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
  558. // Can't build an intrinsic for this
  559. return nullptr;
  560. }
  561. Value *Mask = I->getArgOperand(3);
  562. // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
  563. LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
  564. if (match(Mask, m_One()))
  565. return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
  566. {Ptr->getType(), Input->getType()},
  567. {Ptr, Builder.getInt32(Increment), Input});
  568. else
  569. return Builder.CreateIntrinsic(
  570. Intrinsic::arm_mve_vstr_scatter_base_predicated,
  571. {Ptr->getType(), Input->getType(), Mask->getType()},
  572. {Ptr, Builder.getInt32(Increment), Input, Mask});
  573. }
  574. Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
  575. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
  576. using namespace PatternMatch;
  577. Value *Input = I->getArgOperand(0);
  578. auto *Ty = cast<FixedVectorType>(Input->getType());
  579. LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers "
  580. << "with writeback\n");
  581. if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
  582. // Can't build an intrinsic for this
  583. return nullptr;
  584. Value *Mask = I->getArgOperand(3);
  585. if (match(Mask, m_One()))
  586. return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
  587. {Ptr->getType(), Input->getType()},
  588. {Ptr, Builder.getInt32(Increment), Input});
  589. else
  590. return Builder.CreateIntrinsic(
  591. Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
  592. {Ptr->getType(), Input->getType(), Mask->getType()},
  593. {Ptr, Builder.getInt32(Increment), Input, Mask});
  594. }
  595. Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
  596. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
  597. using namespace PatternMatch;
  598. Value *Input = I->getArgOperand(0);
  599. Value *Mask = I->getArgOperand(3);
  600. Type *InputTy = Input->getType();
  601. Type *MemoryTy = InputTy;
  602. LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
  603. << " to base + vector of offsets\n");
  604. // If the input has been truncated, try to integrate that trunc into the
  605. // scatter instruction (we don't care about alignment here)
  606. if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
  607. Value *PreTrunc = Trunc->getOperand(0);
  608. Type *PreTruncTy = PreTrunc->getType();
  609. if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
  610. Input = PreTrunc;
  611. InputTy = PreTruncTy;
  612. }
  613. }
  614. bool ExtendInput = false;
  615. if (InputTy->getPrimitiveSizeInBits() < 128 &&
  616. InputTy->isIntOrIntVectorTy()) {
  617. // If we can't find a trunc to incorporate into the instruction, create an
  618. // implicit one with a zext, so that we can still create a scatter. We know
  619. // that the input type is 4x/8x/16x and of type i8/i16/i32, so any type
  620. // smaller than 128 bits will divide evenly into a 128bit vector.
  621. InputTy = InputTy->getWithNewBitWidth(
  622. 128 / cast<FixedVectorType>(InputTy)->getNumElements());
  623. ExtendInput = true;
  624. LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n"
  625. << *Input << "\n");
  626. }
  627. if (InputTy->getPrimitiveSizeInBits() != 128) {
  628. LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "
  629. "non-standard input types. Expanding.\n");
  630. return nullptr;
  631. }
  632. Value *Offsets;
  633. int Scale;
  634. Value *BasePtr = decomposePtr(
  635. Ptr, Offsets, Scale, cast<FixedVectorType>(InputTy), MemoryTy, Builder);
  636. if (!BasePtr)
  637. return nullptr;
  638. if (ExtendInput)
  639. Input = Builder.CreateZExt(Input, InputTy);
  640. if (!match(Mask, m_One()))
  641. return Builder.CreateIntrinsic(
  642. Intrinsic::arm_mve_vstr_scatter_offset_predicated,
  643. {BasePtr->getType(), Offsets->getType(), Input->getType(),
  644. Mask->getType()},
  645. {BasePtr, Offsets, Input,
  646. Builder.getInt32(MemoryTy->getScalarSizeInBits()),
  647. Builder.getInt32(Scale), Mask});
  648. else
  649. return Builder.CreateIntrinsic(
  650. Intrinsic::arm_mve_vstr_scatter_offset,
  651. {BasePtr->getType(), Offsets->getType(), Input->getType()},
  652. {BasePtr, Offsets, Input,
  653. Builder.getInt32(MemoryTy->getScalarSizeInBits()),
  654. Builder.getInt32(Scale)});
  655. }
  656. Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
  657. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
  658. FixedVectorType *Ty;
  659. if (I->getIntrinsicID() == Intrinsic::masked_gather)
  660. Ty = cast<FixedVectorType>(I->getType());
  661. else
  662. Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
  663. // Incrementing gathers only exist for v4i32
  664. if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
  665. return nullptr;
  666. // Incrementing gathers are not beneficial outside of a loop
  667. Loop *L = LI->getLoopFor(I->getParent());
  668. if (L == nullptr)
  669. return nullptr;
  670. // Decompose the GEP into Base and Offsets
  671. GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
  672. Value *Offsets;
  673. Value *BasePtr = decomposeGEP(Offsets, Ty, GEP, Builder);
  674. if (!BasePtr)
  675. return nullptr;
  676. LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
  677. "wb gather/scatter\n");
  678. // The gep was in charge of making sure the offsets are scaled correctly
  679. // - calculate that factor so it can be applied by hand
  680. int TypeScale =
  681. computeScale(DL->getTypeSizeInBits(GEP->getOperand(0)->getType()),
  682. DL->getTypeSizeInBits(GEP->getType()) /
  683. cast<FixedVectorType>(GEP->getType())->getNumElements());
  684. if (TypeScale == -1)
  685. return nullptr;
  686. if (GEP->hasOneUse()) {
  687. // Only in this case do we want to build a wb gather, because the wb will
  688. // change the phi which does affect other users of the gep (which will still
  689. // be using the phi in the old way)
  690. if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets,
  691. TypeScale, Builder))
  692. return Load;
  693. }
  694. LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
  695. "non-wb gather/scatter\n");
  696. std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
  697. if (Add.first == nullptr)
  698. return nullptr;
  699. Value *OffsetsIncoming = Add.first;
  700. int64_t Immediate = Add.second;
  701. // Make sure the offsets are scaled correctly
  702. Instruction *ScaledOffsets = BinaryOperator::Create(
  703. Instruction::Shl, OffsetsIncoming,
  704. Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
  705. "ScaledIndex", I);
  706. // Add the base to the offsets
  707. OffsetsIncoming = BinaryOperator::Create(
  708. Instruction::Add, ScaledOffsets,
  709. Builder.CreateVectorSplat(
  710. Ty->getNumElements(),
  711. Builder.CreatePtrToInt(
  712. BasePtr,
  713. cast<VectorType>(ScaledOffsets->getType())->getElementType())),
  714. "StartIndex", I);
  715. if (I->getIntrinsicID() == Intrinsic::masked_gather)
  716. return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate);
  717. else
  718. return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate);
  719. }
  720. Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
  721. IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
  722. IRBuilder<> &Builder) {
  723. // Check whether this gather's offset is incremented by a constant - if so,
  724. // and the load is of the right type, we can merge this into a QI gather
  725. Loop *L = LI->getLoopFor(I->getParent());
  726. // Offsets that are worth merging into this instruction will be incremented
  727. // by a constant, thus we're looking for an add of a phi and a constant
  728. PHINode *Phi = dyn_cast<PHINode>(Offsets);
  729. if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
  730. Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
  731. // No phi means no IV to write back to; if there is a phi, we expect it
  732. // to have exactly two incoming values; the only phis we are interested in
  733. // will be loop IV's and have exactly two uses, one in their increment and
  734. // one in the gather's gep
  735. return nullptr;
  736. unsigned IncrementIndex =
  737. Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
  738. // Look through the phi to the phi increment
  739. Offsets = Phi->getIncomingValue(IncrementIndex);
  740. std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
  741. if (Add.first == nullptr)
  742. return nullptr;
  743. Value *OffsetsIncoming = Add.first;
  744. int64_t Immediate = Add.second;
  745. if (OffsetsIncoming != Phi)
  746. // Then the increment we are looking at is not an increment of the
  747. // induction variable, and we don't want to do a writeback
  748. return nullptr;
  749. Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
  750. unsigned NumElems =
  751. cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
  752. // Make sure the offsets are scaled correctly
  753. Instruction *ScaledOffsets = BinaryOperator::Create(
  754. Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
  755. Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
  756. "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
  757. // Add the base to the offsets
  758. OffsetsIncoming = BinaryOperator::Create(
  759. Instruction::Add, ScaledOffsets,
  760. Builder.CreateVectorSplat(
  761. NumElems,
  762. Builder.CreatePtrToInt(
  763. BasePtr,
  764. cast<VectorType>(ScaledOffsets->getType())->getElementType())),
  765. "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
  766. // The gather is pre-incrementing
  767. OffsetsIncoming = BinaryOperator::Create(
  768. Instruction::Sub, OffsetsIncoming,
  769. Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
  770. "PreIncrementStartIndex",
  771. &Phi->getIncomingBlock(1 - IncrementIndex)->back());
  772. Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
  773. Builder.SetInsertPoint(I);
  774. Instruction *EndResult;
  775. Instruction *NewInduction;
  776. if (I->getIntrinsicID() == Intrinsic::masked_gather) {
  777. // Build the incrementing gather
  778. Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
  779. // One value to be handed to whoever uses the gather, one is the loop
  780. // increment
  781. EndResult = ExtractValueInst::Create(Load, 0, "Gather");
  782. NewInduction = ExtractValueInst::Create(Load, 1, "GatherIncrement");
  783. Builder.Insert(EndResult);
  784. Builder.Insert(NewInduction);
  785. } else {
  786. // Build the incrementing scatter
  787. EndResult = NewInduction =
  788. tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
  789. }
  790. Instruction *AddInst = cast<Instruction>(Offsets);
  791. AddInst->replaceAllUsesWith(NewInduction);
  792. AddInst->eraseFromParent();
  793. Phi->setIncomingValue(IncrementIndex, NewInduction);
  794. return EndResult;
  795. }
  796. void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
  797. Value *OffsSecondOperand,
  798. unsigned StartIndex) {
  799. LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
  800. Instruction *InsertionPoint =
  801. &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
  802. // Initialize the phi with a vector that contains a sum of the constants
  803. Instruction *NewIndex = BinaryOperator::Create(
  804. Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
  805. "PushedOutAdd", InsertionPoint);
  806. unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
  807. // Order such that start index comes first (this reduces mov's)
  808. Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
  809. Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
  810. Phi->getIncomingBlock(IncrementIndex));
  811. Phi->removeIncomingValue(IncrementIndex);
  812. Phi->removeIncomingValue(StartIndex);
  813. }
  814. void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi,
  815. Value *IncrementPerRound,
  816. Value *OffsSecondOperand,
  817. unsigned LoopIncrement,
  818. IRBuilder<> &Builder) {
  819. LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
  820. // Create a new scalar add outside of the loop and transform it to a splat
  821. // by which loop variable can be incremented
  822. Instruction *InsertionPoint = &cast<Instruction>(
  823. Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
  824. // Create a new index
  825. Value *StartIndex =
  826. BinaryOperator::Create((Instruction::BinaryOps)Opcode,
  827. Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
  828. OffsSecondOperand, "PushedOutMul", InsertionPoint);
  829. Instruction *Product =
  830. BinaryOperator::Create((Instruction::BinaryOps)Opcode, IncrementPerRound,
  831. OffsSecondOperand, "Product", InsertionPoint);
  832. // Increment NewIndex by Product instead of the multiplication
  833. Instruction *NewIncrement = BinaryOperator::Create(
  834. Instruction::Add, Phi, Product, "IncrementPushedOutMul",
  835. cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
  836. .getPrevNode());
  837. Phi->addIncoming(StartIndex,
  838. Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
  839. Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
  840. Phi->removeIncomingValue((unsigned)0);
  841. Phi->removeIncomingValue((unsigned)0);
  842. }
  843. // Check whether all usages of this instruction are as offsets of
  844. // gathers/scatters or simple arithmetics only used by gathers/scatters
  845. static bool hasAllGatScatUsers(Instruction *I, const DataLayout &DL) {
  846. if (I->hasNUses(0)) {
  847. return false;
  848. }
  849. bool Gatscat = true;
  850. for (User *U : I->users()) {
  851. if (!isa<Instruction>(U))
  852. return false;
  853. if (isa<GetElementPtrInst>(U) ||
  854. isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
  855. return Gatscat;
  856. } else {
  857. unsigned OpCode = cast<Instruction>(U)->getOpcode();
  858. if ((OpCode == Instruction::Add || OpCode == Instruction::Mul ||
  859. OpCode == Instruction::Shl ||
  860. isAddLikeOr(cast<Instruction>(U), DL)) &&
  861. hasAllGatScatUsers(cast<Instruction>(U), DL)) {
  862. continue;
  863. }
  864. return false;
  865. }
  866. }
  867. return Gatscat;
  868. }
  869. bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
  870. LoopInfo *LI) {
  871. LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize: "
  872. << *Offsets << "\n");
  873. // Optimise the addresses of gathers/scatters by moving invariant
  874. // calculations out of the loop
  875. if (!isa<Instruction>(Offsets))
  876. return false;
  877. Instruction *Offs = cast<Instruction>(Offsets);
  878. if (Offs->getOpcode() != Instruction::Add && !isAddLikeOr(Offs, *DL) &&
  879. Offs->getOpcode() != Instruction::Mul &&
  880. Offs->getOpcode() != Instruction::Shl)
  881. return false;
  882. Loop *L = LI->getLoopFor(BB);
  883. if (L == nullptr)
  884. return false;
  885. if (!Offs->hasOneUse()) {
  886. if (!hasAllGatScatUsers(Offs, *DL))
  887. return false;
  888. }
  889. // Find out which, if any, operand of the instruction
  890. // is a phi node
  891. PHINode *Phi;
  892. int OffsSecondOp;
  893. if (isa<PHINode>(Offs->getOperand(0))) {
  894. Phi = cast<PHINode>(Offs->getOperand(0));
  895. OffsSecondOp = 1;
  896. } else if (isa<PHINode>(Offs->getOperand(1))) {
  897. Phi = cast<PHINode>(Offs->getOperand(1));
  898. OffsSecondOp = 0;
  899. } else {
  900. bool Changed = false;
  901. if (isa<Instruction>(Offs->getOperand(0)) &&
  902. L->contains(cast<Instruction>(Offs->getOperand(0))))
  903. Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
  904. if (isa<Instruction>(Offs->getOperand(1)) &&
  905. L->contains(cast<Instruction>(Offs->getOperand(1))))
  906. Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
  907. if (!Changed)
  908. return false;
  909. if (isa<PHINode>(Offs->getOperand(0))) {
  910. Phi = cast<PHINode>(Offs->getOperand(0));
  911. OffsSecondOp = 1;
  912. } else if (isa<PHINode>(Offs->getOperand(1))) {
  913. Phi = cast<PHINode>(Offs->getOperand(1));
  914. OffsSecondOp = 0;
  915. } else {
  916. return false;
  917. }
  918. }
  919. // A phi node we want to perform this function on should be from the
  920. // loop header.
  921. if (Phi->getParent() != L->getHeader())
  922. return false;
  923. // We're looking for a simple add recurrence.
  924. BinaryOperator *IncInstruction;
  925. Value *Start, *IncrementPerRound;
  926. if (!matchSimpleRecurrence(Phi, IncInstruction, Start, IncrementPerRound) ||
  927. IncInstruction->getOpcode() != Instruction::Add)
  928. return false;
  929. int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1;
  930. // Get the value that is added to/multiplied with the phi
  931. Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
  932. if (IncrementPerRound->getType() != OffsSecondOperand->getType() ||
  933. !L->isLoopInvariant(OffsSecondOperand))
  934. // Something has gone wrong, abort
  935. return false;
  936. // Only proceed if the increment per round is a constant or an instruction
  937. // which does not originate from within the loop
  938. if (!isa<Constant>(IncrementPerRound) &&
  939. !(isa<Instruction>(IncrementPerRound) &&
  940. !L->contains(cast<Instruction>(IncrementPerRound))))
  941. return false;
  942. // If the phi is not used by anything else, we can just adapt it when
  943. // replacing the instruction; if it is, we'll have to duplicate it
  944. PHINode *NewPhi;
  945. if (Phi->getNumUses() == 2) {
  946. // No other users -> reuse existing phi (One user is the instruction
  947. // we're looking at, the other is the phi increment)
  948. if (IncInstruction->getNumUses() != 1) {
  949. // If the incrementing instruction does have more users than
  950. // our phi, we need to copy it
  951. IncInstruction = BinaryOperator::Create(
  952. Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
  953. IncrementPerRound, "LoopIncrement", IncInstruction);
  954. Phi->setIncomingValue(IncrementingBlock, IncInstruction);
  955. }
  956. NewPhi = Phi;
  957. } else {
  958. // There are other users -> create a new phi
  959. NewPhi = PHINode::Create(Phi->getType(), 2, "NewPhi", Phi);
  960. // Copy the incoming values of the old phi
  961. NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
  962. Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
  963. IncInstruction = BinaryOperator::Create(
  964. Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
  965. IncrementPerRound, "LoopIncrement", IncInstruction);
  966. NewPhi->addIncoming(IncInstruction,
  967. Phi->getIncomingBlock(IncrementingBlock));
  968. IncrementingBlock = 1;
  969. }
  970. IRBuilder<> Builder(BB->getContext());
  971. Builder.SetInsertPoint(Phi);
  972. Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
  973. switch (Offs->getOpcode()) {
  974. case Instruction::Add:
  975. case Instruction::Or:
  976. pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
  977. break;
  978. case Instruction::Mul:
  979. case Instruction::Shl:
  980. pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound,
  981. OffsSecondOperand, IncrementingBlock, Builder);
  982. break;
  983. default:
  984. return false;
  985. }
  986. LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable "
  987. << "add/mul\n");
  988. // The instruction has now been "absorbed" into the phi value
  989. Offs->replaceAllUsesWith(NewPhi);
  990. if (Offs->hasNUses(0))
  991. Offs->eraseFromParent();
  992. // Clean up the old increment in case it's unused because we built a new
  993. // one
  994. if (IncInstruction->hasNUses(0))
  995. IncInstruction->eraseFromParent();
  996. return true;
  997. }
  998. static Value *CheckAndCreateOffsetAdd(Value *X, unsigned ScaleX, Value *Y,
  999. unsigned ScaleY, IRBuilder<> &Builder) {
  1000. // Splat the non-vector value to a vector of the given type - if the value is
  1001. // a constant (and its value isn't too big), we can even use this opportunity
  1002. // to scale it to the size of the vector elements
  1003. auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
  1004. ConstantInt *Const;
  1005. if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
  1006. VT->getElementType() != NonVectorVal->getType()) {
  1007. unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
  1008. uint64_t N = Const->getZExtValue();
  1009. if (N < (unsigned)(1 << (TargetElemSize - 1))) {
  1010. NonVectorVal = Builder.CreateVectorSplat(
  1011. VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
  1012. return;
  1013. }
  1014. }
  1015. NonVectorVal =
  1016. Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
  1017. };
  1018. FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
  1019. FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
  1020. // If one of X, Y is not a vector, we have to splat it in order
  1021. // to add the two of them.
  1022. if (XElType && !YElType) {
  1023. FixSummands(XElType, Y);
  1024. YElType = cast<FixedVectorType>(Y->getType());
  1025. } else if (YElType && !XElType) {
  1026. FixSummands(YElType, X);
  1027. XElType = cast<FixedVectorType>(X->getType());
  1028. }
  1029. assert(XElType && YElType && "Unknown vector types");
  1030. // Check that the summands are of compatible types
  1031. if (XElType != YElType) {
  1032. LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
  1033. return nullptr;
  1034. }
  1035. if (XElType->getElementType()->getScalarSizeInBits() != 32) {
  1036. // Check that by adding the vectors we do not accidentally
  1037. // create an overflow
  1038. Constant *ConstX = dyn_cast<Constant>(X);
  1039. Constant *ConstY = dyn_cast<Constant>(Y);
  1040. if (!ConstX || !ConstY)
  1041. return nullptr;
  1042. unsigned TargetElemSize = 128 / XElType->getNumElements();
  1043. for (unsigned i = 0; i < XElType->getNumElements(); i++) {
  1044. ConstantInt *ConstXEl =
  1045. dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
  1046. ConstantInt *ConstYEl =
  1047. dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
  1048. if (!ConstXEl || !ConstYEl ||
  1049. ConstXEl->getZExtValue() * ScaleX +
  1050. ConstYEl->getZExtValue() * ScaleY >=
  1051. (unsigned)(1 << (TargetElemSize - 1)))
  1052. return nullptr;
  1053. }
  1054. }
  1055. Value *XScale = Builder.CreateVectorSplat(
  1056. XElType->getNumElements(),
  1057. Builder.getIntN(XElType->getScalarSizeInBits(), ScaleX));
  1058. Value *YScale = Builder.CreateVectorSplat(
  1059. YElType->getNumElements(),
  1060. Builder.getIntN(YElType->getScalarSizeInBits(), ScaleY));
  1061. Value *Add = Builder.CreateAdd(Builder.CreateMul(X, XScale),
  1062. Builder.CreateMul(Y, YScale));
  1063. if (checkOffsetSize(Add, XElType->getNumElements()))
  1064. return Add;
  1065. else
  1066. return nullptr;
  1067. }
  1068. Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
  1069. Value *&Offsets, unsigned &Scale,
  1070. IRBuilder<> &Builder) {
  1071. Value *GEPPtr = GEP->getPointerOperand();
  1072. Offsets = GEP->getOperand(1);
  1073. Scale = DL->getTypeAllocSize(GEP->getSourceElementType());
  1074. // We only merge geps with constant offsets, because only for those
  1075. // we can make sure that we do not cause an overflow
  1076. if (GEP->getNumIndices() != 1 || !isa<Constant>(Offsets))
  1077. return nullptr;
  1078. if (GetElementPtrInst *BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr)) {
  1079. // Merge the two geps into one
  1080. Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Scale, Builder);
  1081. if (!BaseBasePtr)
  1082. return nullptr;
  1083. Offsets = CheckAndCreateOffsetAdd(
  1084. Offsets, Scale, GEP->getOperand(1),
  1085. DL->getTypeAllocSize(GEP->getSourceElementType()), Builder);
  1086. if (Offsets == nullptr)
  1087. return nullptr;
  1088. Scale = 1; // Scale is always an i8 at this point.
  1089. return BaseBasePtr;
  1090. }
  1091. return GEPPtr;
  1092. }
  1093. bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
  1094. LoopInfo *LI) {
  1095. GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
  1096. if (!GEP)
  1097. return false;
  1098. bool Changed = false;
  1099. if (GEP->hasOneUse() && isa<GetElementPtrInst>(GEP->getPointerOperand())) {
  1100. IRBuilder<> Builder(GEP->getContext());
  1101. Builder.SetInsertPoint(GEP);
  1102. Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
  1103. Value *Offsets;
  1104. unsigned Scale;
  1105. Value *Base = foldGEP(GEP, Offsets, Scale, Builder);
  1106. // We only want to merge the geps if there is a real chance that they can be
  1107. // used by an MVE gather; thus the offset has to have the correct size
  1108. // (always i32 if it is not of vector type) and the base has to be a
  1109. // pointer.
  1110. if (Offsets && Base && Base != GEP) {
  1111. assert(Scale == 1 && "Expected to fold GEP to a scale of 1");
  1112. Type *BaseTy = Builder.getInt8PtrTy();
  1113. if (auto *VecTy = dyn_cast<FixedVectorType>(Base->getType()))
  1114. BaseTy = FixedVectorType::get(BaseTy, VecTy);
  1115. GetElementPtrInst *NewAddress = GetElementPtrInst::Create(
  1116. Builder.getInt8Ty(), Builder.CreateBitCast(Base, BaseTy), Offsets,
  1117. "gep.merged", GEP);
  1118. LLVM_DEBUG(dbgs() << "Folded GEP: " << *GEP
  1119. << "\n new : " << *NewAddress << "\n");
  1120. GEP->replaceAllUsesWith(
  1121. Builder.CreateBitCast(NewAddress, GEP->getType()));
  1122. GEP = NewAddress;
  1123. Changed = true;
  1124. }
  1125. }
  1126. Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
  1127. return Changed;
  1128. }
  1129. bool MVEGatherScatterLowering::runOnFunction(Function &F) {
  1130. if (!EnableMaskedGatherScatters)
  1131. return false;
  1132. auto &TPC = getAnalysis<TargetPassConfig>();
  1133. auto &TM = TPC.getTM<TargetMachine>();
  1134. auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
  1135. if (!ST->hasMVEIntegerOps())
  1136. return false;
  1137. LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  1138. DL = &F.getParent()->getDataLayout();
  1139. SmallVector<IntrinsicInst *, 4> Gathers;
  1140. SmallVector<IntrinsicInst *, 4> Scatters;
  1141. bool Changed = false;
  1142. for (BasicBlock &BB : F) {
  1143. Changed |= SimplifyInstructionsInBlock(&BB);
  1144. for (Instruction &I : BB) {
  1145. IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
  1146. if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
  1147. isa<FixedVectorType>(II->getType())) {
  1148. Gathers.push_back(II);
  1149. Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
  1150. } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
  1151. isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
  1152. Scatters.push_back(II);
  1153. Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
  1154. }
  1155. }
  1156. }
  1157. for (unsigned i = 0; i < Gathers.size(); i++) {
  1158. IntrinsicInst *I = Gathers[i];
  1159. Instruction *L = lowerGather(I);
  1160. if (L == nullptr)
  1161. continue;
  1162. // Get rid of any now dead instructions
  1163. SimplifyInstructionsInBlock(L->getParent());
  1164. Changed = true;
  1165. }
  1166. for (unsigned i = 0; i < Scatters.size(); i++) {
  1167. IntrinsicInst *I = Scatters[i];
  1168. Instruction *S = lowerScatter(I);
  1169. if (S == nullptr)
  1170. continue;
  1171. // Get rid of any now dead instructions
  1172. SimplifyInstructionsInBlock(S->getParent());
  1173. Changed = true;
  1174. }
  1175. return Changed;
  1176. }