MVEGatherScatterLowering.cpp 51 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279
  1. //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. /// This pass custom lowers llvm.gather and llvm.scatter instructions to
  10. /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
  11. /// produce a better final result as we go.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "ARM.h"
  15. #include "ARMBaseInstrInfo.h"
  16. #include "ARMSubtarget.h"
  17. #include "llvm/Analysis/LoopInfo.h"
  18. #include "llvm/Analysis/TargetTransformInfo.h"
  19. #include "llvm/Analysis/ValueTracking.h"
  20. #include "llvm/CodeGen/TargetLowering.h"
  21. #include "llvm/CodeGen/TargetPassConfig.h"
  22. #include "llvm/CodeGen/TargetSubtargetInfo.h"
  23. #include "llvm/InitializePasses.h"
  24. #include "llvm/IR/BasicBlock.h"
  25. #include "llvm/IR/Constant.h"
  26. #include "llvm/IR/Constants.h"
  27. #include "llvm/IR/DerivedTypes.h"
  28. #include "llvm/IR/Function.h"
  29. #include "llvm/IR/InstrTypes.h"
  30. #include "llvm/IR/Instruction.h"
  31. #include "llvm/IR/Instructions.h"
  32. #include "llvm/IR/IntrinsicInst.h"
  33. #include "llvm/IR/Intrinsics.h"
  34. #include "llvm/IR/IntrinsicsARM.h"
  35. #include "llvm/IR/IRBuilder.h"
  36. #include "llvm/IR/PatternMatch.h"
  37. #include "llvm/IR/Type.h"
  38. #include "llvm/IR/Value.h"
  39. #include "llvm/Pass.h"
  40. #include "llvm/Support/Casting.h"
  41. #include "llvm/Transforms/Utils/Local.h"
  42. #include <algorithm>
  43. #include <cassert>
  44. using namespace llvm;
  45. #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
  46. cl::opt<bool> EnableMaskedGatherScatters(
  47. "enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
  48. cl::desc("Enable the generation of masked gathers and scatters"));
  49. namespace {
  50. class MVEGatherScatterLowering : public FunctionPass {
  51. public:
  52. static char ID; // Pass identification, replacement for typeid
  53. explicit MVEGatherScatterLowering() : FunctionPass(ID) {
  54. initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
  55. }
  56. bool runOnFunction(Function &F) override;
  57. StringRef getPassName() const override {
  58. return "MVE gather/scatter lowering";
  59. }
  60. void getAnalysisUsage(AnalysisUsage &AU) const override {
  61. AU.setPreservesCFG();
  62. AU.addRequired<TargetPassConfig>();
  63. AU.addRequired<LoopInfoWrapperPass>();
  64. FunctionPass::getAnalysisUsage(AU);
  65. }
  66. private:
  67. LoopInfo *LI = nullptr;
  68. const DataLayout *DL;
  69. // Check this is a valid gather with correct alignment
  70. bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
  71. Align Alignment);
  72. // Check whether Ptr is hidden behind a bitcast and look through it
  73. void lookThroughBitcast(Value *&Ptr);
  74. // Decompose a ptr into Base and Offsets, potentially using a GEP to return a
  75. // scalar base and vector offsets, or else fallback to using a base of 0 and
  76. // offset of Ptr where possible.
  77. Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale,
  78. FixedVectorType *Ty, Type *MemoryTy,
  79. IRBuilder<> &Builder);
  80. // Check for a getelementptr and deduce base and offsets from it, on success
  81. // returning the base directly and the offsets indirectly using the Offsets
  82. // argument
  83. Value *decomposeGEP(Value *&Offsets, FixedVectorType *Ty,
  84. GetElementPtrInst *GEP, IRBuilder<> &Builder);
  85. // Compute the scale of this gather/scatter instruction
  86. int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
  87. // If the value is a constant, or derived from constants via additions
  88. // and multilications, return its numeric value
  89. Optional<int64_t> getIfConst(const Value *V);
  90. // If Inst is an add instruction, check whether one summand is a
  91. // constant. If so, scale this constant and return it together with
  92. // the other summand.
  93. std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
  94. Instruction *lowerGather(IntrinsicInst *I);
  95. // Create a gather from a base + vector of offsets
  96. Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
  97. Instruction *&Root,
  98. IRBuilder<> &Builder);
  99. // Create a gather from a vector of pointers
  100. Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
  101. IRBuilder<> &Builder,
  102. int64_t Increment = 0);
  103. // Create an incrementing gather from a vector of pointers
  104. Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
  105. IRBuilder<> &Builder,
  106. int64_t Increment = 0);
  107. Instruction *lowerScatter(IntrinsicInst *I);
  108. // Create a scatter to a base + vector of offsets
  109. Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
  110. IRBuilder<> &Builder);
  111. // Create a scatter to a vector of pointers
  112. Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
  113. IRBuilder<> &Builder,
  114. int64_t Increment = 0);
  115. // Create an incrementing scatter from a vector of pointers
  116. Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
  117. IRBuilder<> &Builder,
  118. int64_t Increment = 0);
  119. // QI gathers and scatters can increment their offsets on their own if
  120. // the increment is a constant value (digit)
  121. Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr,
  122. IRBuilder<> &Builder);
  123. // QI gathers/scatters can increment their offsets on their own if the
  124. // increment is a constant value (digit) - this creates a writeback QI
  125. // gather/scatter
  126. Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
  127. Value *Ptr, unsigned TypeScale,
  128. IRBuilder<> &Builder);
  129. // Optimise the base and offsets of the given address
  130. bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
  131. // Try to fold consecutive geps together into one
  132. Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, IRBuilder<> &Builder);
  133. // Check whether these offsets could be moved out of the loop they're in
  134. bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
  135. // Pushes the given add out of the loop
  136. void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
  137. // Pushes the given mul or shl out of the loop
  138. void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound,
  139. Value *OffsSecondOperand, unsigned LoopIncrement,
  140. IRBuilder<> &Builder);
  141. };
  142. } // end anonymous namespace
  143. char MVEGatherScatterLowering::ID = 0;
  144. INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
  145. "MVE gather/scattering lowering pass", false, false)
  146. Pass *llvm::createMVEGatherScatterLoweringPass() {
  147. return new MVEGatherScatterLowering();
  148. }
  149. bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
  150. unsigned ElemSize,
  151. Align Alignment) {
  152. if (((NumElements == 4 &&
  153. (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
  154. (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
  155. (NumElements == 16 && ElemSize == 8)) &&
  156. Alignment >= ElemSize / 8)
  157. return true;
  158. LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
  159. << "valid alignment or vector type \n");
  160. return false;
  161. }
  162. static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
  163. // Offsets that are not of type <N x i32> are sign extended by the
  164. // getelementptr instruction, and MVE gathers/scatters treat the offset as
  165. // unsigned. Thus, if the element size is smaller than 32, we can only allow
  166. // positive offsets - i.e., the offsets are not allowed to be variables we
  167. // can't look into.
  168. // Additionally, <N x i32> offsets have to either originate from a zext of a
  169. // vector with element types smaller or equal the type of the gather we're
  170. // looking at, or consist of constants that we can check are small enough
  171. // to fit into the gather type.
  172. // Thus we check that 0 < value < 2^TargetElemSize.
  173. unsigned TargetElemSize = 128 / TargetElemCount;
  174. unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
  175. ->getElementType()
  176. ->getScalarSizeInBits();
  177. if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
  178. Constant *ConstOff = dyn_cast<Constant>(Offsets);
  179. if (!ConstOff)
  180. return false;
  181. int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
  182. auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
  183. ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
  184. if (!OConst)
  185. return false;
  186. int SExtValue = OConst->getSExtValue();
  187. if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
  188. return false;
  189. return true;
  190. };
  191. if (isa<FixedVectorType>(ConstOff->getType())) {
  192. for (unsigned i = 0; i < TargetElemCount; i++) {
  193. if (!CheckValueSize(ConstOff->getAggregateElement(i)))
  194. return false;
  195. }
  196. } else {
  197. if (!CheckValueSize(ConstOff))
  198. return false;
  199. }
  200. }
  201. return true;
  202. }
  203. Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets,
  204. int &Scale, FixedVectorType *Ty,
  205. Type *MemoryTy,
  206. IRBuilder<> &Builder) {
  207. if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
  208. if (Value *V = decomposeGEP(Offsets, Ty, GEP, Builder)) {
  209. Scale =
  210. computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(),
  211. MemoryTy->getScalarSizeInBits());
  212. return Scale == -1 ? nullptr : V;
  213. }
  214. }
  215. // If we couldn't use the GEP (or it doesn't exist), attempt to use a
  216. // BasePtr of 0 with Ptr as the Offsets, so long as there are only 4
  217. // elements.
  218. FixedVectorType *PtrTy = cast<FixedVectorType>(Ptr->getType());
  219. if (PtrTy->getNumElements() != 4 || MemoryTy->getScalarSizeInBits() == 32)
  220. return nullptr;
  221. Value *Zero = ConstantInt::get(Builder.getInt32Ty(), 0);
  222. Value *BasePtr = Builder.CreateIntToPtr(Zero, Builder.getInt8PtrTy());
  223. Offsets = Builder.CreatePtrToInt(
  224. Ptr, FixedVectorType::get(Builder.getInt32Ty(), 4));
  225. Scale = 0;
  226. return BasePtr;
  227. }
  228. Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets,
  229. FixedVectorType *Ty,
  230. GetElementPtrInst *GEP,
  231. IRBuilder<> &Builder) {
  232. if (!GEP) {
  233. LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer "
  234. << "found\n");
  235. return nullptr;
  236. }
  237. LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
  238. << " Looking at intrinsic for base + vector of offsets\n");
  239. Value *GEPPtr = GEP->getPointerOperand();
  240. Offsets = GEP->getOperand(1);
  241. if (GEPPtr->getType()->isVectorTy() ||
  242. !isa<FixedVectorType>(Offsets->getType()))
  243. return nullptr;
  244. if (GEP->getNumOperands() != 2) {
  245. LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
  246. << " operands. Expanding.\n");
  247. return nullptr;
  248. }
  249. Offsets = GEP->getOperand(1);
  250. unsigned OffsetsElemCount =
  251. cast<FixedVectorType>(Offsets->getType())->getNumElements();
  252. // Paranoid check whether the number of parallel lanes is the same
  253. assert(Ty->getNumElements() == OffsetsElemCount);
  254. ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
  255. if (ZextOffs)
  256. Offsets = ZextOffs->getOperand(0);
  257. FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
  258. // If the offsets are already being zext-ed to <N x i32>, that relieves us of
  259. // having to make sure that they won't overflow.
  260. if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
  261. ->getElementType()
  262. ->getScalarSizeInBits() != 32)
  263. if (!checkOffsetSize(Offsets, OffsetsElemCount))
  264. return nullptr;
  265. // The offset sizes have been checked; if any truncating or zext-ing is
  266. // required to fix them, do that now
  267. if (Ty != Offsets->getType()) {
  268. if ((Ty->getElementType()->getScalarSizeInBits() <
  269. OffsetType->getElementType()->getScalarSizeInBits())) {
  270. Offsets = Builder.CreateTrunc(Offsets, Ty);
  271. } else {
  272. Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
  273. }
  274. }
  275. // If none of the checks failed, return the gep's base pointer
  276. LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
  277. return GEPPtr;
  278. }
  279. void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
  280. // Look through bitcast instruction if #elements is the same
  281. if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
  282. auto *BCTy = cast<FixedVectorType>(BitCast->getType());
  283. auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
  284. if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
  285. LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through "
  286. << "bitcast\n");
  287. Ptr = BitCast->getOperand(0);
  288. }
  289. }
  290. }
  291. int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
  292. unsigned MemoryElemSize) {
  293. // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
  294. // or a 8bit, 16bit or 32bit load/store scaled by 1
  295. if (GEPElemSize == 32 && MemoryElemSize == 32)
  296. return 2;
  297. else if (GEPElemSize == 16 && MemoryElemSize == 16)
  298. return 1;
  299. else if (GEPElemSize == 8)
  300. return 0;
  301. LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
  302. << "create intrinsic\n");
  303. return -1;
  304. }
  305. Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
  306. const Constant *C = dyn_cast<Constant>(V);
  307. if (C && C->getSplatValue())
  308. return Optional<int64_t>{C->getUniqueInteger().getSExtValue()};
  309. if (!isa<Instruction>(V))
  310. return Optional<int64_t>{};
  311. const Instruction *I = cast<Instruction>(V);
  312. if (I->getOpcode() == Instruction::Add || I->getOpcode() == Instruction::Or ||
  313. I->getOpcode() == Instruction::Mul ||
  314. I->getOpcode() == Instruction::Shl) {
  315. Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
  316. Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
  317. if (!Op0 || !Op1)
  318. return Optional<int64_t>{};
  319. if (I->getOpcode() == Instruction::Add)
  320. return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
  321. if (I->getOpcode() == Instruction::Mul)
  322. return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
  323. if (I->getOpcode() == Instruction::Shl)
  324. return Optional<int64_t>{Op0.getValue() << Op1.getValue()};
  325. if (I->getOpcode() == Instruction::Or)
  326. return Optional<int64_t>{Op0.getValue() | Op1.getValue()};
  327. }
  328. return Optional<int64_t>{};
  329. }
  330. // Return true if I is an Or instruction that is equivalent to an add, due to
  331. // the operands having no common bits set.
  332. static bool isAddLikeOr(Instruction *I, const DataLayout &DL) {
  333. return I->getOpcode() == Instruction::Or &&
  334. haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL);
  335. }
  336. std::pair<Value *, int64_t>
  337. MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
  338. std::pair<Value *, int64_t> ReturnFalse =
  339. std::pair<Value *, int64_t>(nullptr, 0);
  340. // At this point, the instruction we're looking at must be an add or an
  341. // add-like-or.
  342. Instruction *Add = dyn_cast<Instruction>(Inst);
  343. if (Add == nullptr ||
  344. (Add->getOpcode() != Instruction::Add && !isAddLikeOr(Add, *DL)))
  345. return ReturnFalse;
  346. Value *Summand;
  347. Optional<int64_t> Const;
  348. // Find out which operand the value that is increased is
  349. if ((Const = getIfConst(Add->getOperand(0))))
  350. Summand = Add->getOperand(1);
  351. else if ((Const = getIfConst(Add->getOperand(1))))
  352. Summand = Add->getOperand(0);
  353. else
  354. return ReturnFalse;
  355. // Check that the constant is small enough for an incrementing gather
  356. int64_t Immediate = Const.getValue() << TypeScale;
  357. if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
  358. return ReturnFalse;
  359. return std::pair<Value *, int64_t>(Summand, Immediate);
  360. }
  361. Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
  362. using namespace PatternMatch;
  363. LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"
  364. << *I << "\n");
  365. // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
  366. // Attempt to turn the masked gather in I into a MVE intrinsic
  367. // Potentially optimising the addressing modes as we do so.
  368. auto *Ty = cast<FixedVectorType>(I->getType());
  369. Value *Ptr = I->getArgOperand(0);
  370. Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
  371. Value *Mask = I->getArgOperand(2);
  372. Value *PassThru = I->getArgOperand(3);
  373. if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
  374. Alignment))
  375. return nullptr;
  376. lookThroughBitcast(Ptr);
  377. assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
  378. IRBuilder<> Builder(I->getContext());
  379. Builder.SetInsertPoint(I);
  380. Builder.SetCurrentDebugLocation(I->getDebugLoc());
  381. Instruction *Root = I;
  382. Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder);
  383. if (!Load)
  384. Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
  385. if (!Load)
  386. Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
  387. if (!Load)
  388. return nullptr;
  389. if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
  390. LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
  391. << "creating select\n");
  392. Load = SelectInst::Create(Mask, Load, PassThru);
  393. Builder.Insert(Load);
  394. }
  395. Root->replaceAllUsesWith(Load);
  396. Root->eraseFromParent();
  397. if (Root != I)
  398. // If this was an extending gather, we need to get rid of the sext/zext
  399. // sext/zext as well as of the gather itself
  400. I->eraseFromParent();
  401. LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"
  402. << *Load << "\n");
  403. return Load;
  404. }
  405. Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
  406. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
  407. using namespace PatternMatch;
  408. auto *Ty = cast<FixedVectorType>(I->getType());
  409. LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
  410. if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
  411. // Can't build an intrinsic for this
  412. return nullptr;
  413. Value *Mask = I->getArgOperand(2);
  414. if (match(Mask, m_One()))
  415. return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
  416. {Ty, Ptr->getType()},
  417. {Ptr, Builder.getInt32(Increment)});
  418. else
  419. return Builder.CreateIntrinsic(
  420. Intrinsic::arm_mve_vldr_gather_base_predicated,
  421. {Ty, Ptr->getType(), Mask->getType()},
  422. {Ptr, Builder.getInt32(Increment), Mask});
  423. }
  424. Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
  425. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
  426. using namespace PatternMatch;
  427. auto *Ty = cast<FixedVectorType>(I->getType());
  428. LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with "
  429. << "writeback\n");
  430. if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
  431. // Can't build an intrinsic for this
  432. return nullptr;
  433. Value *Mask = I->getArgOperand(2);
  434. if (match(Mask, m_One()))
  435. return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
  436. {Ty, Ptr->getType()},
  437. {Ptr, Builder.getInt32(Increment)});
  438. else
  439. return Builder.CreateIntrinsic(
  440. Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
  441. {Ty, Ptr->getType(), Mask->getType()},
  442. {Ptr, Builder.getInt32(Increment), Mask});
  443. }
  444. Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
  445. IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
  446. using namespace PatternMatch;
  447. Type *MemoryTy = I->getType();
  448. Type *ResultTy = MemoryTy;
  449. unsigned Unsigned = 1;
  450. // The size of the gather was already checked in isLegalTypeAndAlignment;
  451. // if it was not a full vector width an appropriate extend should follow.
  452. auto *Extend = Root;
  453. bool TruncResult = false;
  454. if (MemoryTy->getPrimitiveSizeInBits() < 128) {
  455. if (I->hasOneUse()) {
  456. // If the gather has a single extend of the correct type, use an extending
  457. // gather and replace the ext. In which case the correct root to replace
  458. // is not the CallInst itself, but the instruction which extends it.
  459. Instruction* User = cast<Instruction>(*I->users().begin());
  460. if (isa<SExtInst>(User) &&
  461. User->getType()->getPrimitiveSizeInBits() == 128) {
  462. LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
  463. << *User << "\n");
  464. Extend = User;
  465. ResultTy = User->getType();
  466. Unsigned = 0;
  467. } else if (isa<ZExtInst>(User) &&
  468. User->getType()->getPrimitiveSizeInBits() == 128) {
  469. LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
  470. << *ResultTy << "\n");
  471. Extend = User;
  472. ResultTy = User->getType();
  473. }
  474. }
  475. // If an extend hasn't been found and the type is an integer, create an
  476. // extending gather and truncate back to the original type.
  477. if (ResultTy->getPrimitiveSizeInBits() < 128 &&
  478. ResultTy->isIntOrIntVectorTy()) {
  479. ResultTy = ResultTy->getWithNewBitWidth(
  480. 128 / cast<FixedVectorType>(ResultTy)->getNumElements());
  481. TruncResult = true;
  482. LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: "
  483. << *ResultTy << "\n");
  484. }
  485. // The final size of the gather must be a full vector width
  486. if (ResultTy->getPrimitiveSizeInBits() != 128) {
  487. LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided "
  488. "from the correct type. Expanding\n");
  489. return nullptr;
  490. }
  491. }
  492. Value *Offsets;
  493. int Scale;
  494. Value *BasePtr = decomposePtr(
  495. Ptr, Offsets, Scale, cast<FixedVectorType>(ResultTy), MemoryTy, Builder);
  496. if (!BasePtr)
  497. return nullptr;
  498. Root = Extend;
  499. Value *Mask = I->getArgOperand(2);
  500. Instruction *Load = nullptr;
  501. if (!match(Mask, m_One()))
  502. Load = Builder.CreateIntrinsic(
  503. Intrinsic::arm_mve_vldr_gather_offset_predicated,
  504. {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
  505. {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
  506. Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
  507. else
  508. Load = Builder.CreateIntrinsic(
  509. Intrinsic::arm_mve_vldr_gather_offset,
  510. {ResultTy, BasePtr->getType(), Offsets->getType()},
  511. {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
  512. Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
  513. if (TruncResult) {
  514. Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy);
  515. Builder.Insert(Load);
  516. }
  517. return Load;
  518. }
  519. Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
  520. using namespace PatternMatch;
  521. LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"
  522. << *I << "\n");
  523. // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
  524. // Attempt to turn the masked scatter in I into a MVE intrinsic
  525. // Potentially optimising the addressing modes as we do so.
  526. Value *Input = I->getArgOperand(0);
  527. Value *Ptr = I->getArgOperand(1);
  528. Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
  529. auto *Ty = cast<FixedVectorType>(Input->getType());
  530. if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
  531. Alignment))
  532. return nullptr;
  533. lookThroughBitcast(Ptr);
  534. assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
  535. IRBuilder<> Builder(I->getContext());
  536. Builder.SetInsertPoint(I);
  537. Builder.SetCurrentDebugLocation(I->getDebugLoc());
  538. Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder);
  539. if (!Store)
  540. Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
  541. if (!Store)
  542. Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
  543. if (!Store)
  544. return nullptr;
  545. LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"
  546. << *Store << "\n");
  547. I->eraseFromParent();
  548. return Store;
  549. }
  550. Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
  551. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
  552. using namespace PatternMatch;
  553. Value *Input = I->getArgOperand(0);
  554. auto *Ty = cast<FixedVectorType>(Input->getType());
  555. // Only QR variants allow truncating
  556. if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
  557. // Can't build an intrinsic for this
  558. return nullptr;
  559. }
  560. Value *Mask = I->getArgOperand(3);
  561. // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
  562. LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
  563. if (match(Mask, m_One()))
  564. return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
  565. {Ptr->getType(), Input->getType()},
  566. {Ptr, Builder.getInt32(Increment), Input});
  567. else
  568. return Builder.CreateIntrinsic(
  569. Intrinsic::arm_mve_vstr_scatter_base_predicated,
  570. {Ptr->getType(), Input->getType(), Mask->getType()},
  571. {Ptr, Builder.getInt32(Increment), Input, Mask});
  572. }
  573. Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
  574. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
  575. using namespace PatternMatch;
  576. Value *Input = I->getArgOperand(0);
  577. auto *Ty = cast<FixedVectorType>(Input->getType());
  578. LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers "
  579. << "with writeback\n");
  580. if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
  581. // Can't build an intrinsic for this
  582. return nullptr;
  583. Value *Mask = I->getArgOperand(3);
  584. if (match(Mask, m_One()))
  585. return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
  586. {Ptr->getType(), Input->getType()},
  587. {Ptr, Builder.getInt32(Increment), Input});
  588. else
  589. return Builder.CreateIntrinsic(
  590. Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
  591. {Ptr->getType(), Input->getType(), Mask->getType()},
  592. {Ptr, Builder.getInt32(Increment), Input, Mask});
  593. }
  594. Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
  595. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
  596. using namespace PatternMatch;
  597. Value *Input = I->getArgOperand(0);
  598. Value *Mask = I->getArgOperand(3);
  599. Type *InputTy = Input->getType();
  600. Type *MemoryTy = InputTy;
  601. LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
  602. << " to base + vector of offsets\n");
  603. // If the input has been truncated, try to integrate that trunc into the
  604. // scatter instruction (we don't care about alignment here)
  605. if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
  606. Value *PreTrunc = Trunc->getOperand(0);
  607. Type *PreTruncTy = PreTrunc->getType();
  608. if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
  609. Input = PreTrunc;
  610. InputTy = PreTruncTy;
  611. }
  612. }
  613. bool ExtendInput = false;
  614. if (InputTy->getPrimitiveSizeInBits() < 128 &&
  615. InputTy->isIntOrIntVectorTy()) {
  616. // If we can't find a trunc to incorporate into the instruction, create an
  617. // implicit one with a zext, so that we can still create a scatter. We know
  618. // that the input type is 4x/8x/16x and of type i8/i16/i32, so any type
  619. // smaller than 128 bits will divide evenly into a 128bit vector.
  620. InputTy = InputTy->getWithNewBitWidth(
  621. 128 / cast<FixedVectorType>(InputTy)->getNumElements());
  622. ExtendInput = true;
  623. LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n"
  624. << *Input << "\n");
  625. }
  626. if (InputTy->getPrimitiveSizeInBits() != 128) {
  627. LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "
  628. "non-standard input types. Expanding.\n");
  629. return nullptr;
  630. }
  631. Value *Offsets;
  632. int Scale;
  633. Value *BasePtr = decomposePtr(
  634. Ptr, Offsets, Scale, cast<FixedVectorType>(InputTy), MemoryTy, Builder);
  635. if (!BasePtr)
  636. return nullptr;
  637. if (ExtendInput)
  638. Input = Builder.CreateZExt(Input, InputTy);
  639. if (!match(Mask, m_One()))
  640. return Builder.CreateIntrinsic(
  641. Intrinsic::arm_mve_vstr_scatter_offset_predicated,
  642. {BasePtr->getType(), Offsets->getType(), Input->getType(),
  643. Mask->getType()},
  644. {BasePtr, Offsets, Input,
  645. Builder.getInt32(MemoryTy->getScalarSizeInBits()),
  646. Builder.getInt32(Scale), Mask});
  647. else
  648. return Builder.CreateIntrinsic(
  649. Intrinsic::arm_mve_vstr_scatter_offset,
  650. {BasePtr->getType(), Offsets->getType(), Input->getType()},
  651. {BasePtr, Offsets, Input,
  652. Builder.getInt32(MemoryTy->getScalarSizeInBits()),
  653. Builder.getInt32(Scale)});
  654. }
  655. Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
  656. IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
  657. FixedVectorType *Ty;
  658. if (I->getIntrinsicID() == Intrinsic::masked_gather)
  659. Ty = cast<FixedVectorType>(I->getType());
  660. else
  661. Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
  662. // Incrementing gathers only exist for v4i32
  663. if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
  664. return nullptr;
  665. // Incrementing gathers are not beneficial outside of a loop
  666. Loop *L = LI->getLoopFor(I->getParent());
  667. if (L == nullptr)
  668. return nullptr;
  669. // Decompose the GEP into Base and Offsets
  670. GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
  671. Value *Offsets;
  672. Value *BasePtr = decomposeGEP(Offsets, Ty, GEP, Builder);
  673. if (!BasePtr)
  674. return nullptr;
  675. LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
  676. "wb gather/scatter\n");
  677. // The gep was in charge of making sure the offsets are scaled correctly
  678. // - calculate that factor so it can be applied by hand
  679. int TypeScale =
  680. computeScale(DL->getTypeSizeInBits(GEP->getOperand(0)->getType()),
  681. DL->getTypeSizeInBits(GEP->getType()) /
  682. cast<FixedVectorType>(GEP->getType())->getNumElements());
  683. if (TypeScale == -1)
  684. return nullptr;
  685. if (GEP->hasOneUse()) {
  686. // Only in this case do we want to build a wb gather, because the wb will
  687. // change the phi which does affect other users of the gep (which will still
  688. // be using the phi in the old way)
  689. if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets,
  690. TypeScale, Builder))
  691. return Load;
  692. }
  693. LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
  694. "non-wb gather/scatter\n");
  695. std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
  696. if (Add.first == nullptr)
  697. return nullptr;
  698. Value *OffsetsIncoming = Add.first;
  699. int64_t Immediate = Add.second;
  700. // Make sure the offsets are scaled correctly
  701. Instruction *ScaledOffsets = BinaryOperator::Create(
  702. Instruction::Shl, OffsetsIncoming,
  703. Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
  704. "ScaledIndex", I);
  705. // Add the base to the offsets
  706. OffsetsIncoming = BinaryOperator::Create(
  707. Instruction::Add, ScaledOffsets,
  708. Builder.CreateVectorSplat(
  709. Ty->getNumElements(),
  710. Builder.CreatePtrToInt(
  711. BasePtr,
  712. cast<VectorType>(ScaledOffsets->getType())->getElementType())),
  713. "StartIndex", I);
  714. if (I->getIntrinsicID() == Intrinsic::masked_gather)
  715. return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate);
  716. else
  717. return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate);
  718. }
  719. Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
  720. IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
  721. IRBuilder<> &Builder) {
  722. // Check whether this gather's offset is incremented by a constant - if so,
  723. // and the load is of the right type, we can merge this into a QI gather
  724. Loop *L = LI->getLoopFor(I->getParent());
  725. // Offsets that are worth merging into this instruction will be incremented
  726. // by a constant, thus we're looking for an add of a phi and a constant
  727. PHINode *Phi = dyn_cast<PHINode>(Offsets);
  728. if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
  729. Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
  730. // No phi means no IV to write back to; if there is a phi, we expect it
  731. // to have exactly two incoming values; the only phis we are interested in
  732. // will be loop IV's and have exactly two uses, one in their increment and
  733. // one in the gather's gep
  734. return nullptr;
  735. unsigned IncrementIndex =
  736. Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
  737. // Look through the phi to the phi increment
  738. Offsets = Phi->getIncomingValue(IncrementIndex);
  739. std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
  740. if (Add.first == nullptr)
  741. return nullptr;
  742. Value *OffsetsIncoming = Add.first;
  743. int64_t Immediate = Add.second;
  744. if (OffsetsIncoming != Phi)
  745. // Then the increment we are looking at is not an increment of the
  746. // induction variable, and we don't want to do a writeback
  747. return nullptr;
  748. Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
  749. unsigned NumElems =
  750. cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
  751. // Make sure the offsets are scaled correctly
  752. Instruction *ScaledOffsets = BinaryOperator::Create(
  753. Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
  754. Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
  755. "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
  756. // Add the base to the offsets
  757. OffsetsIncoming = BinaryOperator::Create(
  758. Instruction::Add, ScaledOffsets,
  759. Builder.CreateVectorSplat(
  760. NumElems,
  761. Builder.CreatePtrToInt(
  762. BasePtr,
  763. cast<VectorType>(ScaledOffsets->getType())->getElementType())),
  764. "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
  765. // The gather is pre-incrementing
  766. OffsetsIncoming = BinaryOperator::Create(
  767. Instruction::Sub, OffsetsIncoming,
  768. Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
  769. "PreIncrementStartIndex",
  770. &Phi->getIncomingBlock(1 - IncrementIndex)->back());
  771. Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
  772. Builder.SetInsertPoint(I);
  773. Instruction *EndResult;
  774. Instruction *NewInduction;
  775. if (I->getIntrinsicID() == Intrinsic::masked_gather) {
  776. // Build the incrementing gather
  777. Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
  778. // One value to be handed to whoever uses the gather, one is the loop
  779. // increment
  780. EndResult = ExtractValueInst::Create(Load, 0, "Gather");
  781. NewInduction = ExtractValueInst::Create(Load, 1, "GatherIncrement");
  782. Builder.Insert(EndResult);
  783. Builder.Insert(NewInduction);
  784. } else {
  785. // Build the incrementing scatter
  786. EndResult = NewInduction =
  787. tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
  788. }
  789. Instruction *AddInst = cast<Instruction>(Offsets);
  790. AddInst->replaceAllUsesWith(NewInduction);
  791. AddInst->eraseFromParent();
  792. Phi->setIncomingValue(IncrementIndex, NewInduction);
  793. return EndResult;
  794. }
  795. void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
  796. Value *OffsSecondOperand,
  797. unsigned StartIndex) {
  798. LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
  799. Instruction *InsertionPoint =
  800. &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
  801. // Initialize the phi with a vector that contains a sum of the constants
  802. Instruction *NewIndex = BinaryOperator::Create(
  803. Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
  804. "PushedOutAdd", InsertionPoint);
  805. unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
  806. // Order such that start index comes first (this reduces mov's)
  807. Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
  808. Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
  809. Phi->getIncomingBlock(IncrementIndex));
  810. Phi->removeIncomingValue(IncrementIndex);
  811. Phi->removeIncomingValue(StartIndex);
  812. }
  813. void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi,
  814. Value *IncrementPerRound,
  815. Value *OffsSecondOperand,
  816. unsigned LoopIncrement,
  817. IRBuilder<> &Builder) {
  818. LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
  819. // Create a new scalar add outside of the loop and transform it to a splat
  820. // by which loop variable can be incremented
  821. Instruction *InsertionPoint = &cast<Instruction>(
  822. Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
  823. // Create a new index
  824. Value *StartIndex =
  825. BinaryOperator::Create((Instruction::BinaryOps)Opcode,
  826. Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
  827. OffsSecondOperand, "PushedOutMul", InsertionPoint);
  828. Instruction *Product =
  829. BinaryOperator::Create((Instruction::BinaryOps)Opcode, IncrementPerRound,
  830. OffsSecondOperand, "Product", InsertionPoint);
  831. // Increment NewIndex by Product instead of the multiplication
  832. Instruction *NewIncrement = BinaryOperator::Create(
  833. Instruction::Add, Phi, Product, "IncrementPushedOutMul",
  834. cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
  835. .getPrevNode());
  836. Phi->addIncoming(StartIndex,
  837. Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
  838. Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
  839. Phi->removeIncomingValue((unsigned)0);
  840. Phi->removeIncomingValue((unsigned)0);
  841. }
  842. // Check whether all usages of this instruction are as offsets of
  843. // gathers/scatters or simple arithmetics only used by gathers/scatters
  844. static bool hasAllGatScatUsers(Instruction *I, const DataLayout &DL) {
  845. if (I->hasNUses(0)) {
  846. return false;
  847. }
  848. bool Gatscat = true;
  849. for (User *U : I->users()) {
  850. if (!isa<Instruction>(U))
  851. return false;
  852. if (isa<GetElementPtrInst>(U) ||
  853. isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
  854. return Gatscat;
  855. } else {
  856. unsigned OpCode = cast<Instruction>(U)->getOpcode();
  857. if ((OpCode == Instruction::Add || OpCode == Instruction::Mul ||
  858. OpCode == Instruction::Shl ||
  859. isAddLikeOr(cast<Instruction>(U), DL)) &&
  860. hasAllGatScatUsers(cast<Instruction>(U), DL)) {
  861. continue;
  862. }
  863. return false;
  864. }
  865. }
  866. return Gatscat;
  867. }
  868. bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
  869. LoopInfo *LI) {
  870. LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n"
  871. << *Offsets << "\n");
  872. // Optimise the addresses of gathers/scatters by moving invariant
  873. // calculations out of the loop
  874. if (!isa<Instruction>(Offsets))
  875. return false;
  876. Instruction *Offs = cast<Instruction>(Offsets);
  877. if (Offs->getOpcode() != Instruction::Add && !isAddLikeOr(Offs, *DL) &&
  878. Offs->getOpcode() != Instruction::Mul &&
  879. Offs->getOpcode() != Instruction::Shl)
  880. return false;
  881. Loop *L = LI->getLoopFor(BB);
  882. if (L == nullptr)
  883. return false;
  884. if (!Offs->hasOneUse()) {
  885. if (!hasAllGatScatUsers(Offs, *DL))
  886. return false;
  887. }
  888. // Find out which, if any, operand of the instruction
  889. // is a phi node
  890. PHINode *Phi;
  891. int OffsSecondOp;
  892. if (isa<PHINode>(Offs->getOperand(0))) {
  893. Phi = cast<PHINode>(Offs->getOperand(0));
  894. OffsSecondOp = 1;
  895. } else if (isa<PHINode>(Offs->getOperand(1))) {
  896. Phi = cast<PHINode>(Offs->getOperand(1));
  897. OffsSecondOp = 0;
  898. } else {
  899. bool Changed = false;
  900. if (isa<Instruction>(Offs->getOperand(0)) &&
  901. L->contains(cast<Instruction>(Offs->getOperand(0))))
  902. Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
  903. if (isa<Instruction>(Offs->getOperand(1)) &&
  904. L->contains(cast<Instruction>(Offs->getOperand(1))))
  905. Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
  906. if (!Changed)
  907. return false;
  908. if (isa<PHINode>(Offs->getOperand(0))) {
  909. Phi = cast<PHINode>(Offs->getOperand(0));
  910. OffsSecondOp = 1;
  911. } else if (isa<PHINode>(Offs->getOperand(1))) {
  912. Phi = cast<PHINode>(Offs->getOperand(1));
  913. OffsSecondOp = 0;
  914. } else {
  915. return false;
  916. }
  917. }
  918. // A phi node we want to perform this function on should be from the
  919. // loop header.
  920. if (Phi->getParent() != L->getHeader())
  921. return false;
  922. // We're looking for a simple add recurrence.
  923. BinaryOperator *IncInstruction;
  924. Value *Start, *IncrementPerRound;
  925. if (!matchSimpleRecurrence(Phi, IncInstruction, Start, IncrementPerRound) ||
  926. IncInstruction->getOpcode() != Instruction::Add)
  927. return false;
  928. int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1;
  929. // Get the value that is added to/multiplied with the phi
  930. Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
  931. if (IncrementPerRound->getType() != OffsSecondOperand->getType() ||
  932. !L->isLoopInvariant(OffsSecondOperand))
  933. // Something has gone wrong, abort
  934. return false;
  935. // Only proceed if the increment per round is a constant or an instruction
  936. // which does not originate from within the loop
  937. if (!isa<Constant>(IncrementPerRound) &&
  938. !(isa<Instruction>(IncrementPerRound) &&
  939. !L->contains(cast<Instruction>(IncrementPerRound))))
  940. return false;
  941. // If the phi is not used by anything else, we can just adapt it when
  942. // replacing the instruction; if it is, we'll have to duplicate it
  943. PHINode *NewPhi;
  944. if (Phi->getNumUses() == 2) {
  945. // No other users -> reuse existing phi (One user is the instruction
  946. // we're looking at, the other is the phi increment)
  947. if (IncInstruction->getNumUses() != 1) {
  948. // If the incrementing instruction does have more users than
  949. // our phi, we need to copy it
  950. IncInstruction = BinaryOperator::Create(
  951. Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
  952. IncrementPerRound, "LoopIncrement", IncInstruction);
  953. Phi->setIncomingValue(IncrementingBlock, IncInstruction);
  954. }
  955. NewPhi = Phi;
  956. } else {
  957. // There are other users -> create a new phi
  958. NewPhi = PHINode::Create(Phi->getType(), 2, "NewPhi", Phi);
  959. // Copy the incoming values of the old phi
  960. NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
  961. Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
  962. IncInstruction = BinaryOperator::Create(
  963. Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
  964. IncrementPerRound, "LoopIncrement", IncInstruction);
  965. NewPhi->addIncoming(IncInstruction,
  966. Phi->getIncomingBlock(IncrementingBlock));
  967. IncrementingBlock = 1;
  968. }
  969. IRBuilder<> Builder(BB->getContext());
  970. Builder.SetInsertPoint(Phi);
  971. Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
  972. switch (Offs->getOpcode()) {
  973. case Instruction::Add:
  974. case Instruction::Or:
  975. pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
  976. break;
  977. case Instruction::Mul:
  978. case Instruction::Shl:
  979. pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound,
  980. OffsSecondOperand, IncrementingBlock, Builder);
  981. break;
  982. default:
  983. return false;
  984. }
  985. LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable "
  986. << "add/mul\n");
  987. // The instruction has now been "absorbed" into the phi value
  988. Offs->replaceAllUsesWith(NewPhi);
  989. if (Offs->hasNUses(0))
  990. Offs->eraseFromParent();
  991. // Clean up the old increment in case it's unused because we built a new
  992. // one
  993. if (IncInstruction->hasNUses(0))
  994. IncInstruction->eraseFromParent();
  995. return true;
  996. }
  997. static Value *CheckAndCreateOffsetAdd(Value *X, Value *Y, Value *GEP,
  998. IRBuilder<> &Builder) {
  999. // Splat the non-vector value to a vector of the given type - if the value is
  1000. // a constant (and its value isn't too big), we can even use this opportunity
  1001. // to scale it to the size of the vector elements
  1002. auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
  1003. ConstantInt *Const;
  1004. if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
  1005. VT->getElementType() != NonVectorVal->getType()) {
  1006. unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
  1007. uint64_t N = Const->getZExtValue();
  1008. if (N < (unsigned)(1 << (TargetElemSize - 1))) {
  1009. NonVectorVal = Builder.CreateVectorSplat(
  1010. VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
  1011. return;
  1012. }
  1013. }
  1014. NonVectorVal =
  1015. Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
  1016. };
  1017. FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
  1018. FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
  1019. // If one of X, Y is not a vector, we have to splat it in order
  1020. // to add the two of them.
  1021. if (XElType && !YElType) {
  1022. FixSummands(XElType, Y);
  1023. YElType = cast<FixedVectorType>(Y->getType());
  1024. } else if (YElType && !XElType) {
  1025. FixSummands(YElType, X);
  1026. XElType = cast<FixedVectorType>(X->getType());
  1027. }
  1028. assert(XElType && YElType && "Unknown vector types");
  1029. // Check that the summands are of compatible types
  1030. if (XElType != YElType) {
  1031. LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
  1032. return nullptr;
  1033. }
  1034. if (XElType->getElementType()->getScalarSizeInBits() != 32) {
  1035. // Check that by adding the vectors we do not accidentally
  1036. // create an overflow
  1037. Constant *ConstX = dyn_cast<Constant>(X);
  1038. Constant *ConstY = dyn_cast<Constant>(Y);
  1039. if (!ConstX || !ConstY)
  1040. return nullptr;
  1041. unsigned TargetElemSize = 128 / XElType->getNumElements();
  1042. for (unsigned i = 0; i < XElType->getNumElements(); i++) {
  1043. ConstantInt *ConstXEl =
  1044. dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
  1045. ConstantInt *ConstYEl =
  1046. dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
  1047. if (!ConstXEl || !ConstYEl ||
  1048. ConstXEl->getZExtValue() + ConstYEl->getZExtValue() >=
  1049. (unsigned)(1 << (TargetElemSize - 1)))
  1050. return nullptr;
  1051. }
  1052. }
  1053. Value *Add = Builder.CreateAdd(X, Y);
  1054. FixedVectorType *GEPType = cast<FixedVectorType>(GEP->getType());
  1055. if (checkOffsetSize(Add, GEPType->getNumElements()))
  1056. return Add;
  1057. else
  1058. return nullptr;
  1059. }
  1060. Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
  1061. Value *&Offsets,
  1062. IRBuilder<> &Builder) {
  1063. Value *GEPPtr = GEP->getPointerOperand();
  1064. Offsets = GEP->getOperand(1);
  1065. // We only merge geps with constant offsets, because only for those
  1066. // we can make sure that we do not cause an overflow
  1067. if (!isa<Constant>(Offsets))
  1068. return nullptr;
  1069. GetElementPtrInst *BaseGEP;
  1070. if ((BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr))) {
  1071. // Merge the two geps into one
  1072. Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Builder);
  1073. if (!BaseBasePtr)
  1074. return nullptr;
  1075. Offsets =
  1076. CheckAndCreateOffsetAdd(Offsets, GEP->getOperand(1), GEP, Builder);
  1077. if (Offsets == nullptr)
  1078. return nullptr;
  1079. return BaseBasePtr;
  1080. }
  1081. return GEPPtr;
  1082. }
  1083. bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
  1084. LoopInfo *LI) {
  1085. GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
  1086. if (!GEP)
  1087. return false;
  1088. bool Changed = false;
  1089. if (GEP->hasOneUse() && isa<GetElementPtrInst>(GEP->getPointerOperand())) {
  1090. IRBuilder<> Builder(GEP->getContext());
  1091. Builder.SetInsertPoint(GEP);
  1092. Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
  1093. Value *Offsets;
  1094. Value *Base = foldGEP(GEP, Offsets, Builder);
  1095. // We only want to merge the geps if there is a real chance that they can be
  1096. // used by an MVE gather; thus the offset has to have the correct size
  1097. // (always i32 if it is not of vector type) and the base has to be a
  1098. // pointer.
  1099. if (Offsets && Base && Base != GEP) {
  1100. GetElementPtrInst *NewAddress = GetElementPtrInst::Create(
  1101. GEP->getSourceElementType(), Base, Offsets, "gep.merged", GEP);
  1102. GEP->replaceAllUsesWith(NewAddress);
  1103. GEP = NewAddress;
  1104. Changed = true;
  1105. }
  1106. }
  1107. Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
  1108. return Changed;
  1109. }
  1110. bool MVEGatherScatterLowering::runOnFunction(Function &F) {
  1111. if (!EnableMaskedGatherScatters)
  1112. return false;
  1113. auto &TPC = getAnalysis<TargetPassConfig>();
  1114. auto &TM = TPC.getTM<TargetMachine>();
  1115. auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
  1116. if (!ST->hasMVEIntegerOps())
  1117. return false;
  1118. LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  1119. DL = &F.getParent()->getDataLayout();
  1120. SmallVector<IntrinsicInst *, 4> Gathers;
  1121. SmallVector<IntrinsicInst *, 4> Scatters;
  1122. bool Changed = false;
  1123. for (BasicBlock &BB : F) {
  1124. Changed |= SimplifyInstructionsInBlock(&BB);
  1125. for (Instruction &I : BB) {
  1126. IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
  1127. if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
  1128. isa<FixedVectorType>(II->getType())) {
  1129. Gathers.push_back(II);
  1130. Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
  1131. } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
  1132. isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
  1133. Scatters.push_back(II);
  1134. Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
  1135. }
  1136. }
  1137. }
  1138. for (unsigned i = 0; i < Gathers.size(); i++) {
  1139. IntrinsicInst *I = Gathers[i];
  1140. Instruction *L = lowerGather(I);
  1141. if (L == nullptr)
  1142. continue;
  1143. // Get rid of any now dead instructions
  1144. SimplifyInstructionsInBlock(L->getParent());
  1145. Changed = true;
  1146. }
  1147. for (unsigned i = 0; i < Scatters.size(); i++) {
  1148. IntrinsicInst *I = Scatters[i];
  1149. Instruction *S = lowerScatter(I);
  1150. if (S == nullptr)
  1151. continue;
  1152. // Get rid of any now dead instructions
  1153. SimplifyInstructionsInBlock(S->getParent());
  1154. Changed = true;
  1155. }
  1156. return Changed;
  1157. }