VectorCombine.cpp 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232
  1. //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass optimizes scalar/vector interactions using target cost models. The
  10. // transforms implemented here may not fit in traditional loop-based or SLP
  11. // vectorization passes.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "llvm/Transforms/Vectorize/VectorCombine.h"
  15. #include "llvm/ADT/Statistic.h"
  16. #include "llvm/Analysis/AssumptionCache.h"
  17. #include "llvm/Analysis/BasicAliasAnalysis.h"
  18. #include "llvm/Analysis/GlobalsModRef.h"
  19. #include "llvm/Analysis/Loads.h"
  20. #include "llvm/Analysis/TargetTransformInfo.h"
  21. #include "llvm/Analysis/ValueTracking.h"
  22. #include "llvm/Analysis/VectorUtils.h"
  23. #include "llvm/IR/Dominators.h"
  24. #include "llvm/IR/Function.h"
  25. #include "llvm/IR/IRBuilder.h"
  26. #include "llvm/IR/PatternMatch.h"
  27. #include "llvm/InitializePasses.h"
  28. #include "llvm/Pass.h"
  29. #include "llvm/Support/CommandLine.h"
  30. #include "llvm/Transforms/Utils/Local.h"
  31. #include "llvm/Transforms/Vectorize.h"
  32. #define DEBUG_TYPE "vector-combine"
  33. #include "llvm/Transforms/Utils/InstructionWorklist.h"
  34. using namespace llvm;
  35. using namespace llvm::PatternMatch;
  36. STATISTIC(NumVecLoad, "Number of vector loads formed");
  37. STATISTIC(NumVecCmp, "Number of vector compares formed");
  38. STATISTIC(NumVecBO, "Number of vector binops formed");
  39. STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
  40. STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
  41. STATISTIC(NumScalarBO, "Number of scalar binops formed");
  42. STATISTIC(NumScalarCmp, "Number of scalar compares formed");
  43. static cl::opt<bool> DisableVectorCombine(
  44. "disable-vector-combine", cl::init(false), cl::Hidden,
  45. cl::desc("Disable all vector combine transforms"));
  46. static cl::opt<bool> DisableBinopExtractShuffle(
  47. "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
  48. cl::desc("Disable binop extract to shuffle transforms"));
  49. static cl::opt<unsigned> MaxInstrsToScan(
  50. "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden,
  51. cl::desc("Max number of instructions to scan for vector combining."));
  52. static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
  53. namespace {
  54. class VectorCombine {
  55. public:
  56. VectorCombine(Function &F, const TargetTransformInfo &TTI,
  57. const DominatorTree &DT, AAResults &AA, AssumptionCache &AC,
  58. bool ScalarizationOnly)
  59. : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC),
  60. ScalarizationOnly(ScalarizationOnly) {}
  61. bool run();
  62. private:
  63. Function &F;
  64. IRBuilder<> Builder;
  65. const TargetTransformInfo &TTI;
  66. const DominatorTree &DT;
  67. AAResults &AA;
  68. AssumptionCache &AC;
  69. /// If true only perform scalarization combines and do not introduce new
  70. /// vector operations.
  71. bool ScalarizationOnly;
  72. InstructionWorklist Worklist;
  73. bool vectorizeLoadInsert(Instruction &I);
  74. ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
  75. ExtractElementInst *Ext1,
  76. unsigned PreferredExtractIndex) const;
  77. bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
  78. const Instruction &I,
  79. ExtractElementInst *&ConvertToShuffle,
  80. unsigned PreferredExtractIndex);
  81. void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
  82. Instruction &I);
  83. void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
  84. Instruction &I);
  85. bool foldExtractExtract(Instruction &I);
  86. bool foldBitcastShuf(Instruction &I);
  87. bool scalarizeBinopOrCmp(Instruction &I);
  88. bool foldExtractedCmps(Instruction &I);
  89. bool foldSingleElementStore(Instruction &I);
  90. bool scalarizeLoadExtract(Instruction &I);
  91. bool foldShuffleOfBinops(Instruction &I);
  92. void replaceValue(Value &Old, Value &New) {
  93. Old.replaceAllUsesWith(&New);
  94. New.takeName(&Old);
  95. if (auto *NewI = dyn_cast<Instruction>(&New)) {
  96. Worklist.pushUsersToWorkList(*NewI);
  97. Worklist.pushValue(NewI);
  98. }
  99. Worklist.pushValue(&Old);
  100. }
  101. void eraseInstruction(Instruction &I) {
  102. for (Value *Op : I.operands())
  103. Worklist.pushValue(Op);
  104. Worklist.remove(&I);
  105. I.eraseFromParent();
  106. }
  107. };
  108. } // namespace
  109. bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
  110. // Match insert into fixed vector of scalar value.
  111. // TODO: Handle non-zero insert index.
  112. auto *Ty = dyn_cast<FixedVectorType>(I.getType());
  113. Value *Scalar;
  114. if (!Ty || !match(&I, m_InsertElt(m_Undef(), m_Value(Scalar), m_ZeroInt())) ||
  115. !Scalar->hasOneUse())
  116. return false;
  117. // Optionally match an extract from another vector.
  118. Value *X;
  119. bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt()));
  120. if (!HasExtract)
  121. X = Scalar;
  122. // Match source value as load of scalar or vector.
  123. // Do not vectorize scalar load (widening) if atomic/volatile or under
  124. // asan/hwasan/memtag/tsan. The widened load may load data from dirty regions
  125. // or create data races non-existent in the source.
  126. auto *Load = dyn_cast<LoadInst>(X);
  127. if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
  128. Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) ||
  129. mustSuppressSpeculation(*Load))
  130. return false;
  131. const DataLayout &DL = I.getModule()->getDataLayout();
  132. Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
  133. assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
  134. unsigned AS = Load->getPointerAddressSpace();
  135. // We are potentially transforming byte-sized (8-bit) memory accesses, so make
  136. // sure we have all of our type-based constraints in place for this target.
  137. Type *ScalarTy = Scalar->getType();
  138. uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
  139. unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
  140. if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
  141. ScalarSize % 8 != 0)
  142. return false;
  143. // Check safety of replacing the scalar load with a larger vector load.
  144. // We use minimal alignment (maximum flexibility) because we only care about
  145. // the dereferenceable region. When calculating cost and creating a new op,
  146. // we may use a larger value based on alignment attributes.
  147. unsigned MinVecNumElts = MinVectorSize / ScalarSize;
  148. auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false);
  149. unsigned OffsetEltIndex = 0;
  150. Align Alignment = Load->getAlign();
  151. if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), DL, Load, &DT)) {
  152. // It is not safe to load directly from the pointer, but we can still peek
  153. // through gep offsets and check if it safe to load from a base address with
  154. // updated alignment. If it is, we can shuffle the element(s) into place
  155. // after loading.
  156. unsigned OffsetBitWidth = DL.getIndexTypeSizeInBits(SrcPtr->getType());
  157. APInt Offset(OffsetBitWidth, 0);
  158. SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(DL, Offset);
  159. // We want to shuffle the result down from a high element of a vector, so
  160. // the offset must be positive.
  161. if (Offset.isNegative())
  162. return false;
  163. // The offset must be a multiple of the scalar element to shuffle cleanly
  164. // in the element's size.
  165. uint64_t ScalarSizeInBytes = ScalarSize / 8;
  166. if (Offset.urem(ScalarSizeInBytes) != 0)
  167. return false;
  168. // If we load MinVecNumElts, will our target element still be loaded?
  169. OffsetEltIndex = Offset.udiv(ScalarSizeInBytes).getZExtValue();
  170. if (OffsetEltIndex >= MinVecNumElts)
  171. return false;
  172. if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), DL, Load, &DT))
  173. return false;
  174. // Update alignment with offset value. Note that the offset could be negated
  175. // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
  176. // negation does not change the result of the alignment calculation.
  177. Alignment = commonAlignment(Alignment, Offset.getZExtValue());
  178. }
  179. // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
  180. // Use the greater of the alignment on the load or its source pointer.
  181. Alignment = std::max(SrcPtr->getPointerAlignment(DL), Alignment);
  182. Type *LoadTy = Load->getType();
  183. InstructionCost OldCost =
  184. TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS);
  185. APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0);
  186. OldCost += TTI.getScalarizationOverhead(MinVecTy, DemandedElts,
  187. /* Insert */ true, HasExtract);
  188. // New pattern: load VecPtr
  189. InstructionCost NewCost =
  190. TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS);
  191. // Optionally, we are shuffling the loaded vector element(s) into place.
  192. // For the mask set everything but element 0 to undef to prevent poison from
  193. // propagating from the extra loaded memory. This will also optionally
  194. // shrink/grow the vector from the loaded size to the output size.
  195. // We assume this operation has no cost in codegen if there was no offset.
  196. // Note that we could use freeze to avoid poison problems, but then we might
  197. // still need a shuffle to change the vector size.
  198. unsigned OutputNumElts = Ty->getNumElements();
  199. SmallVector<int, 16> Mask(OutputNumElts, UndefMaskElem);
  200. assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
  201. Mask[0] = OffsetEltIndex;
  202. if (OffsetEltIndex)
  203. NewCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, MinVecTy, Mask);
  204. // We can aggressively convert to the vector form because the backend can
  205. // invert this transform if it does not result in a performance win.
  206. if (OldCost < NewCost || !NewCost.isValid())
  207. return false;
  208. // It is safe and potentially profitable to load a vector directly:
  209. // inselt undef, load Scalar, 0 --> load VecPtr
  210. IRBuilder<> Builder(Load);
  211. Value *CastedPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
  212. SrcPtr, MinVecTy->getPointerTo(AS));
  213. Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment);
  214. VecLd = Builder.CreateShuffleVector(VecLd, Mask);
  215. replaceValue(I, *VecLd);
  216. ++NumVecLoad;
  217. return true;
  218. }
  219. /// Determine which, if any, of the inputs should be replaced by a shuffle
  220. /// followed by extract from a different index.
  221. ExtractElementInst *VectorCombine::getShuffleExtract(
  222. ExtractElementInst *Ext0, ExtractElementInst *Ext1,
  223. unsigned PreferredExtractIndex = InvalidIndex) const {
  224. assert(isa<ConstantInt>(Ext0->getIndexOperand()) &&
  225. isa<ConstantInt>(Ext1->getIndexOperand()) &&
  226. "Expected constant extract indexes");
  227. unsigned Index0 = cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue();
  228. unsigned Index1 = cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue();
  229. // If the extract indexes are identical, no shuffle is needed.
  230. if (Index0 == Index1)
  231. return nullptr;
  232. Type *VecTy = Ext0->getVectorOperand()->getType();
  233. assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
  234. InstructionCost Cost0 =
  235. TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0);
  236. InstructionCost Cost1 =
  237. TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1);
  238. // If both costs are invalid no shuffle is needed
  239. if (!Cost0.isValid() && !Cost1.isValid())
  240. return nullptr;
  241. // We are extracting from 2 different indexes, so one operand must be shuffled
  242. // before performing a vector operation and/or extract. The more expensive
  243. // extract will be replaced by a shuffle.
  244. if (Cost0 > Cost1)
  245. return Ext0;
  246. if (Cost1 > Cost0)
  247. return Ext1;
  248. // If the costs are equal and there is a preferred extract index, shuffle the
  249. // opposite operand.
  250. if (PreferredExtractIndex == Index0)
  251. return Ext1;
  252. if (PreferredExtractIndex == Index1)
  253. return Ext0;
  254. // Otherwise, replace the extract with the higher index.
  255. return Index0 > Index1 ? Ext0 : Ext1;
  256. }
  257. /// Compare the relative costs of 2 extracts followed by scalar operation vs.
  258. /// vector operation(s) followed by extract. Return true if the existing
  259. /// instructions are cheaper than a vector alternative. Otherwise, return false
  260. /// and if one of the extracts should be transformed to a shufflevector, set
  261. /// \p ConvertToShuffle to that extract instruction.
  262. bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
  263. ExtractElementInst *Ext1,
  264. const Instruction &I,
  265. ExtractElementInst *&ConvertToShuffle,
  266. unsigned PreferredExtractIndex) {
  267. assert(isa<ConstantInt>(Ext0->getOperand(1)) &&
  268. isa<ConstantInt>(Ext1->getOperand(1)) &&
  269. "Expected constant extract indexes");
  270. unsigned Opcode = I.getOpcode();
  271. Type *ScalarTy = Ext0->getType();
  272. auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType());
  273. InstructionCost ScalarOpCost, VectorOpCost;
  274. // Get cost estimates for scalar and vector versions of the operation.
  275. bool IsBinOp = Instruction::isBinaryOp(Opcode);
  276. if (IsBinOp) {
  277. ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
  278. VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
  279. } else {
  280. assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
  281. "Expected a compare");
  282. CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
  283. ScalarOpCost = TTI.getCmpSelInstrCost(
  284. Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred);
  285. VectorOpCost = TTI.getCmpSelInstrCost(
  286. Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred);
  287. }
  288. // Get cost estimates for the extract elements. These costs will factor into
  289. // both sequences.
  290. unsigned Ext0Index = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue();
  291. unsigned Ext1Index = cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue();
  292. InstructionCost Extract0Cost =
  293. TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext0Index);
  294. InstructionCost Extract1Cost =
  295. TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext1Index);
  296. // A more expensive extract will always be replaced by a splat shuffle.
  297. // For example, if Ext0 is more expensive:
  298. // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
  299. // extelt (opcode (splat V0, Ext0), V1), Ext1
  300. // TODO: Evaluate whether that always results in lowest cost. Alternatively,
  301. // check the cost of creating a broadcast shuffle and shuffling both
  302. // operands to element 0.
  303. InstructionCost CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
  304. // Extra uses of the extracts mean that we include those costs in the
  305. // vector total because those instructions will not be eliminated.
  306. InstructionCost OldCost, NewCost;
  307. if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) {
  308. // Handle a special case. If the 2 extracts are identical, adjust the
  309. // formulas to account for that. The extra use charge allows for either the
  310. // CSE'd pattern or an unoptimized form with identical values:
  311. // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
  312. bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
  313. : !Ext0->hasOneUse() || !Ext1->hasOneUse();
  314. OldCost = CheapExtractCost + ScalarOpCost;
  315. NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
  316. } else {
  317. // Handle the general case. Each extract is actually a different value:
  318. // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
  319. OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
  320. NewCost = VectorOpCost + CheapExtractCost +
  321. !Ext0->hasOneUse() * Extract0Cost +
  322. !Ext1->hasOneUse() * Extract1Cost;
  323. }
  324. ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
  325. if (ConvertToShuffle) {
  326. if (IsBinOp && DisableBinopExtractShuffle)
  327. return true;
  328. // If we are extracting from 2 different indexes, then one operand must be
  329. // shuffled before performing the vector operation. The shuffle mask is
  330. // undefined except for 1 lane that is being translated to the remaining
  331. // extraction lane. Therefore, it is a splat shuffle. Ex:
  332. // ShufMask = { undef, undef, 0, undef }
  333. // TODO: The cost model has an option for a "broadcast" shuffle
  334. // (splat-from-element-0), but no option for a more general splat.
  335. NewCost +=
  336. TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy);
  337. }
  338. // Aggressively form a vector op if the cost is equal because the transform
  339. // may enable further optimization.
  340. // Codegen can reverse this transform (scalarize) if it was not profitable.
  341. return OldCost < NewCost;
  342. }
  343. /// Create a shuffle that translates (shifts) 1 element from the input vector
  344. /// to a new element location.
  345. static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
  346. unsigned NewIndex, IRBuilder<> &Builder) {
  347. // The shuffle mask is undefined except for 1 lane that is being translated
  348. // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
  349. // ShufMask = { 2, undef, undef, undef }
  350. auto *VecTy = cast<FixedVectorType>(Vec->getType());
  351. SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem);
  352. ShufMask[NewIndex] = OldIndex;
  353. return Builder.CreateShuffleVector(Vec, ShufMask, "shift");
  354. }
  355. /// Given an extract element instruction with constant index operand, shuffle
  356. /// the source vector (shift the scalar element) to a NewIndex for extraction.
  357. /// Return null if the input can be constant folded, so that we are not creating
  358. /// unnecessary instructions.
  359. static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
  360. unsigned NewIndex,
  361. IRBuilder<> &Builder) {
  362. // If the extract can be constant-folded, this code is unsimplified. Defer
  363. // to other passes to handle that.
  364. Value *X = ExtElt->getVectorOperand();
  365. Value *C = ExtElt->getIndexOperand();
  366. assert(isa<ConstantInt>(C) && "Expected a constant index operand");
  367. if (isa<Constant>(X))
  368. return nullptr;
  369. Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
  370. NewIndex, Builder);
  371. return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex));
  372. }
  373. /// Try to reduce extract element costs by converting scalar compares to vector
  374. /// compares followed by extract.
  375. /// cmp (ext0 V0, C), (ext1 V1, C)
  376. void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0,
  377. ExtractElementInst *Ext1, Instruction &I) {
  378. assert(isa<CmpInst>(&I) && "Expected a compare");
  379. assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
  380. cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
  381. "Expected matching constant extract indexes");
  382. // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
  383. ++NumVecCmp;
  384. CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
  385. Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
  386. Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
  387. Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand());
  388. replaceValue(I, *NewExt);
  389. }
  390. /// Try to reduce extract element costs by converting scalar binops to vector
  391. /// binops followed by extract.
  392. /// bo (ext0 V0, C), (ext1 V1, C)
  393. void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0,
  394. ExtractElementInst *Ext1, Instruction &I) {
  395. assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
  396. assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
  397. cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
  398. "Expected matching constant extract indexes");
  399. // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
  400. ++NumVecBO;
  401. Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
  402. Value *VecBO =
  403. Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
  404. // All IR flags are safe to back-propagate because any potential poison
  405. // created in unused vector elements is discarded by the extract.
  406. if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
  407. VecBOInst->copyIRFlags(&I);
  408. Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand());
  409. replaceValue(I, *NewExt);
  410. }
  411. /// Match an instruction with extracted vector operands.
  412. bool VectorCombine::foldExtractExtract(Instruction &I) {
  413. // It is not safe to transform things like div, urem, etc. because we may
  414. // create undefined behavior when executing those on unknown vector elements.
  415. if (!isSafeToSpeculativelyExecute(&I))
  416. return false;
  417. Instruction *I0, *I1;
  418. CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
  419. if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
  420. !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1))))
  421. return false;
  422. Value *V0, *V1;
  423. uint64_t C0, C1;
  424. if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
  425. !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) ||
  426. V0->getType() != V1->getType())
  427. return false;
  428. // If the scalar value 'I' is going to be re-inserted into a vector, then try
  429. // to create an extract to that same element. The extract/insert can be
  430. // reduced to a "select shuffle".
  431. // TODO: If we add a larger pattern match that starts from an insert, this
  432. // probably becomes unnecessary.
  433. auto *Ext0 = cast<ExtractElementInst>(I0);
  434. auto *Ext1 = cast<ExtractElementInst>(I1);
  435. uint64_t InsertIndex = InvalidIndex;
  436. if (I.hasOneUse())
  437. match(I.user_back(),
  438. m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
  439. ExtractElementInst *ExtractToChange;
  440. if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex))
  441. return false;
  442. if (ExtractToChange) {
  443. unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
  444. ExtractElementInst *NewExtract =
  445. translateExtract(ExtractToChange, CheapExtractIdx, Builder);
  446. if (!NewExtract)
  447. return false;
  448. if (ExtractToChange == Ext0)
  449. Ext0 = NewExtract;
  450. else
  451. Ext1 = NewExtract;
  452. }
  453. if (Pred != CmpInst::BAD_ICMP_PREDICATE)
  454. foldExtExtCmp(Ext0, Ext1, I);
  455. else
  456. foldExtExtBinop(Ext0, Ext1, I);
  457. Worklist.push(Ext0);
  458. Worklist.push(Ext1);
  459. return true;
  460. }
  461. /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
  462. /// destination type followed by shuffle. This can enable further transforms by
  463. /// moving bitcasts or shuffles together.
  464. bool VectorCombine::foldBitcastShuf(Instruction &I) {
  465. Value *V;
  466. ArrayRef<int> Mask;
  467. if (!match(&I, m_BitCast(
  468. m_OneUse(m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))))))
  469. return false;
  470. // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
  471. // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
  472. // mask for scalable type is a splat or not.
  473. // 2) Disallow non-vector casts and length-changing shuffles.
  474. // TODO: We could allow any shuffle.
  475. auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
  476. auto *SrcTy = dyn_cast<FixedVectorType>(V->getType());
  477. if (!SrcTy || !DestTy || I.getOperand(0)->getType() != SrcTy)
  478. return false;
  479. unsigned DestNumElts = DestTy->getNumElements();
  480. unsigned SrcNumElts = SrcTy->getNumElements();
  481. SmallVector<int, 16> NewMask;
  482. if (SrcNumElts <= DestNumElts) {
  483. // The bitcast is from wide to narrow/equal elements. The shuffle mask can
  484. // always be expanded to the equivalent form choosing narrower elements.
  485. assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask");
  486. unsigned ScaleFactor = DestNumElts / SrcNumElts;
  487. narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
  488. } else {
  489. // The bitcast is from narrow elements to wide elements. The shuffle mask
  490. // must choose consecutive elements to allow casting first.
  491. assert(SrcNumElts % DestNumElts == 0 && "Unexpected shuffle mask");
  492. unsigned ScaleFactor = SrcNumElts / DestNumElts;
  493. if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
  494. return false;
  495. }
  496. // The new shuffle must not cost more than the old shuffle. The bitcast is
  497. // moved ahead of the shuffle, so assume that it has the same cost as before.
  498. InstructionCost DestCost = TTI.getShuffleCost(
  499. TargetTransformInfo::SK_PermuteSingleSrc, DestTy, NewMask);
  500. InstructionCost SrcCost =
  501. TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy, Mask);
  502. if (DestCost > SrcCost || !DestCost.isValid())
  503. return false;
  504. // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC'
  505. ++NumShufOfBitcast;
  506. Value *CastV = Builder.CreateBitCast(V, DestTy);
  507. Value *Shuf = Builder.CreateShuffleVector(CastV, NewMask);
  508. replaceValue(I, *Shuf);
  509. return true;
  510. }
  511. /// Match a vector binop or compare instruction with at least one inserted
  512. /// scalar operand and convert to scalar binop/cmp followed by insertelement.
  513. bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
  514. CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
  515. Value *Ins0, *Ins1;
  516. if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
  517. !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1))))
  518. return false;
  519. // Do not convert the vector condition of a vector select into a scalar
  520. // condition. That may cause problems for codegen because of differences in
  521. // boolean formats and register-file transfers.
  522. // TODO: Can we account for that in the cost model?
  523. bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE;
  524. if (IsCmp)
  525. for (User *U : I.users())
  526. if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
  527. return false;
  528. // Match against one or both scalar values being inserted into constant
  529. // vectors:
  530. // vec_op VecC0, (inselt VecC1, V1, Index)
  531. // vec_op (inselt VecC0, V0, Index), VecC1
  532. // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
  533. // TODO: Deal with mismatched index constants and variable indexes?
  534. Constant *VecC0 = nullptr, *VecC1 = nullptr;
  535. Value *V0 = nullptr, *V1 = nullptr;
  536. uint64_t Index0 = 0, Index1 = 0;
  537. if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
  538. m_ConstantInt(Index0))) &&
  539. !match(Ins0, m_Constant(VecC0)))
  540. return false;
  541. if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
  542. m_ConstantInt(Index1))) &&
  543. !match(Ins1, m_Constant(VecC1)))
  544. return false;
  545. bool IsConst0 = !V0;
  546. bool IsConst1 = !V1;
  547. if (IsConst0 && IsConst1)
  548. return false;
  549. if (!IsConst0 && !IsConst1 && Index0 != Index1)
  550. return false;
  551. // Bail for single insertion if it is a load.
  552. // TODO: Handle this once getVectorInstrCost can cost for load/stores.
  553. auto *I0 = dyn_cast_or_null<Instruction>(V0);
  554. auto *I1 = dyn_cast_or_null<Instruction>(V1);
  555. if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
  556. (IsConst1 && I0 && I0->mayReadFromMemory()))
  557. return false;
  558. uint64_t Index = IsConst0 ? Index1 : Index0;
  559. Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
  560. Type *VecTy = I.getType();
  561. assert(VecTy->isVectorTy() &&
  562. (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
  563. (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
  564. ScalarTy->isPointerTy()) &&
  565. "Unexpected types for insert element into binop or cmp");
  566. unsigned Opcode = I.getOpcode();
  567. InstructionCost ScalarOpCost, VectorOpCost;
  568. if (IsCmp) {
  569. CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
  570. ScalarOpCost = TTI.getCmpSelInstrCost(
  571. Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred);
  572. VectorOpCost = TTI.getCmpSelInstrCost(
  573. Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred);
  574. } else {
  575. ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
  576. VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
  577. }
  578. // Get cost estimate for the insert element. This cost will factor into
  579. // both sequences.
  580. InstructionCost InsertCost =
  581. TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, Index);
  582. InstructionCost OldCost =
  583. (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost;
  584. InstructionCost NewCost = ScalarOpCost + InsertCost +
  585. (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
  586. (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
  587. // We want to scalarize unless the vector variant actually has lower cost.
  588. if (OldCost < NewCost || !NewCost.isValid())
  589. return false;
  590. // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
  591. // inselt NewVecC, (scalar_op V0, V1), Index
  592. if (IsCmp)
  593. ++NumScalarCmp;
  594. else
  595. ++NumScalarBO;
  596. // For constant cases, extract the scalar element, this should constant fold.
  597. if (IsConst0)
  598. V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
  599. if (IsConst1)
  600. V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
  601. Value *Scalar =
  602. IsCmp ? Builder.CreateCmp(Pred, V0, V1)
  603. : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
  604. Scalar->setName(I.getName() + ".scalar");
  605. // All IR flags are safe to back-propagate. There is no potential for extra
  606. // poison to be created by the scalar instruction.
  607. if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
  608. ScalarInst->copyIRFlags(&I);
  609. // Fold the vector constants in the original vectors into a new base vector.
  610. Constant *NewVecC = IsCmp ? ConstantExpr::getCompare(Pred, VecC0, VecC1)
  611. : ConstantExpr::get(Opcode, VecC0, VecC1);
  612. Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
  613. replaceValue(I, *Insert);
  614. return true;
  615. }
  616. /// Try to combine a scalar binop + 2 scalar compares of extracted elements of
  617. /// a vector into vector operations followed by extract. Note: The SLP pass
  618. /// may miss this pattern because of implementation problems.
  619. bool VectorCombine::foldExtractedCmps(Instruction &I) {
  620. // We are looking for a scalar binop of booleans.
  621. // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
  622. if (!I.isBinaryOp() || !I.getType()->isIntegerTy(1))
  623. return false;
  624. // The compare predicates should match, and each compare should have a
  625. // constant operand.
  626. // TODO: Relax the one-use constraints.
  627. Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
  628. Instruction *I0, *I1;
  629. Constant *C0, *C1;
  630. CmpInst::Predicate P0, P1;
  631. if (!match(B0, m_OneUse(m_Cmp(P0, m_Instruction(I0), m_Constant(C0)))) ||
  632. !match(B1, m_OneUse(m_Cmp(P1, m_Instruction(I1), m_Constant(C1)))) ||
  633. P0 != P1)
  634. return false;
  635. // The compare operands must be extracts of the same vector with constant
  636. // extract indexes.
  637. // TODO: Relax the one-use constraints.
  638. Value *X;
  639. uint64_t Index0, Index1;
  640. if (!match(I0, m_OneUse(m_ExtractElt(m_Value(X), m_ConstantInt(Index0)))) ||
  641. !match(I1, m_OneUse(m_ExtractElt(m_Specific(X), m_ConstantInt(Index1)))))
  642. return false;
  643. auto *Ext0 = cast<ExtractElementInst>(I0);
  644. auto *Ext1 = cast<ExtractElementInst>(I1);
  645. ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1);
  646. if (!ConvertToShuf)
  647. return false;
  648. // The original scalar pattern is:
  649. // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
  650. CmpInst::Predicate Pred = P0;
  651. unsigned CmpOpcode = CmpInst::isFPPredicate(Pred) ? Instruction::FCmp
  652. : Instruction::ICmp;
  653. auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
  654. if (!VecTy)
  655. return false;
  656. InstructionCost OldCost =
  657. TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0);
  658. OldCost += TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1);
  659. OldCost +=
  660. TTI.getCmpSelInstrCost(CmpOpcode, I0->getType(),
  661. CmpInst::makeCmpResultType(I0->getType()), Pred) *
  662. 2;
  663. OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType());
  664. // The proposed vector pattern is:
  665. // vcmp = cmp Pred X, VecC
  666. // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
  667. int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
  668. int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
  669. auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType()));
  670. InstructionCost NewCost = TTI.getCmpSelInstrCost(
  671. CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred);
  672. SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem);
  673. ShufMask[CheapIndex] = ExpensiveIndex;
  674. NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy,
  675. ShufMask);
  676. NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy);
  677. NewCost += TTI.getVectorInstrCost(Ext0->getOpcode(), CmpTy, CheapIndex);
  678. // Aggressively form vector ops if the cost is equal because the transform
  679. // may enable further optimization.
  680. // Codegen can reverse this transform (scalarize) if it was not profitable.
  681. if (OldCost < NewCost || !NewCost.isValid())
  682. return false;
  683. // Create a vector constant from the 2 scalar constants.
  684. SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
  685. UndefValue::get(VecTy->getElementType()));
  686. CmpC[Index0] = C0;
  687. CmpC[Index1] = C1;
  688. Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
  689. Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
  690. Value *VecLogic = Builder.CreateBinOp(cast<BinaryOperator>(I).getOpcode(),
  691. VCmp, Shuf);
  692. Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
  693. replaceValue(I, *NewExt);
  694. ++NumVecCmpBO;
  695. return true;
  696. }
  697. // Check if memory loc modified between two instrs in the same BB
  698. static bool isMemModifiedBetween(BasicBlock::iterator Begin,
  699. BasicBlock::iterator End,
  700. const MemoryLocation &Loc, AAResults &AA) {
  701. unsigned NumScanned = 0;
  702. return std::any_of(Begin, End, [&](const Instruction &Instr) {
  703. return isModSet(AA.getModRefInfo(&Instr, Loc)) ||
  704. ++NumScanned > MaxInstrsToScan;
  705. });
  706. }
  707. /// Helper class to indicate whether a vector index can be safely scalarized and
  708. /// if a freeze needs to be inserted.
  709. class ScalarizationResult {
  710. enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
  711. StatusTy Status;
  712. Value *ToFreeze;
  713. ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
  714. : Status(Status), ToFreeze(ToFreeze) {}
  715. public:
  716. ScalarizationResult(const ScalarizationResult &Other) = default;
  717. ~ScalarizationResult() {
  718. assert(!ToFreeze && "freeze() not called with ToFreeze being set");
  719. }
  720. static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
  721. static ScalarizationResult safe() { return {StatusTy::Safe}; }
  722. static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
  723. return {StatusTy::SafeWithFreeze, ToFreeze};
  724. }
  725. /// Returns true if the index can be scalarize without requiring a freeze.
  726. bool isSafe() const { return Status == StatusTy::Safe; }
  727. /// Returns true if the index cannot be scalarized.
  728. bool isUnsafe() const { return Status == StatusTy::Unsafe; }
  729. /// Returns true if the index can be scalarize, but requires inserting a
  730. /// freeze.
  731. bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
  732. /// Reset the state of Unsafe and clear ToFreze if set.
  733. void discard() {
  734. ToFreeze = nullptr;
  735. Status = StatusTy::Unsafe;
  736. }
  737. /// Freeze the ToFreeze and update the use in \p User to use it.
  738. void freeze(IRBuilder<> &Builder, Instruction &UserI) {
  739. assert(isSafeWithFreeze() &&
  740. "should only be used when freezing is required");
  741. assert(is_contained(ToFreeze->users(), &UserI) &&
  742. "UserI must be a user of ToFreeze");
  743. IRBuilder<>::InsertPointGuard Guard(Builder);
  744. Builder.SetInsertPoint(cast<Instruction>(&UserI));
  745. Value *Frozen =
  746. Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen");
  747. for (Use &U : make_early_inc_range((UserI.operands())))
  748. if (U.get() == ToFreeze)
  749. U.set(Frozen);
  750. ToFreeze = nullptr;
  751. }
  752. };
  753. /// Check if it is legal to scalarize a memory access to \p VecTy at index \p
  754. /// Idx. \p Idx must access a valid vector element.
  755. static ScalarizationResult canScalarizeAccess(FixedVectorType *VecTy,
  756. Value *Idx, Instruction *CtxI,
  757. AssumptionCache &AC,
  758. const DominatorTree &DT) {
  759. if (auto *C = dyn_cast<ConstantInt>(Idx)) {
  760. if (C->getValue().ult(VecTy->getNumElements()))
  761. return ScalarizationResult::safe();
  762. return ScalarizationResult::unsafe();
  763. }
  764. unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
  765. APInt Zero(IntWidth, 0);
  766. APInt MaxElts(IntWidth, VecTy->getNumElements());
  767. ConstantRange ValidIndices(Zero, MaxElts);
  768. ConstantRange IdxRange(IntWidth, true);
  769. if (isGuaranteedNotToBePoison(Idx, &AC)) {
  770. if (ValidIndices.contains(computeConstantRange(Idx, /* ForSigned */ false,
  771. true, &AC, CtxI, &DT)))
  772. return ScalarizationResult::safe();
  773. return ScalarizationResult::unsafe();
  774. }
  775. // If the index may be poison, check if we can insert a freeze before the
  776. // range of the index is restricted.
  777. Value *IdxBase;
  778. ConstantInt *CI;
  779. if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) {
  780. IdxRange = IdxRange.binaryAnd(CI->getValue());
  781. } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) {
  782. IdxRange = IdxRange.urem(CI->getValue());
  783. }
  784. if (ValidIndices.contains(IdxRange))
  785. return ScalarizationResult::safeWithFreeze(IdxBase);
  786. return ScalarizationResult::unsafe();
  787. }
  788. /// The memory operation on a vector of \p ScalarType had alignment of
  789. /// \p VectorAlignment. Compute the maximal, but conservatively correct,
  790. /// alignment that will be valid for the memory operation on a single scalar
  791. /// element of the same type with index \p Idx.
  792. static Align computeAlignmentAfterScalarization(Align VectorAlignment,
  793. Type *ScalarType, Value *Idx,
  794. const DataLayout &DL) {
  795. if (auto *C = dyn_cast<ConstantInt>(Idx))
  796. return commonAlignment(VectorAlignment,
  797. C->getZExtValue() * DL.getTypeStoreSize(ScalarType));
  798. return commonAlignment(VectorAlignment, DL.getTypeStoreSize(ScalarType));
  799. }
  800. // Combine patterns like:
  801. // %0 = load <4 x i32>, <4 x i32>* %a
  802. // %1 = insertelement <4 x i32> %0, i32 %b, i32 1
  803. // store <4 x i32> %1, <4 x i32>* %a
  804. // to:
  805. // %0 = bitcast <4 x i32>* %a to i32*
  806. // %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
  807. // store i32 %b, i32* %1
  808. bool VectorCombine::foldSingleElementStore(Instruction &I) {
  809. StoreInst *SI = dyn_cast<StoreInst>(&I);
  810. if (!SI || !SI->isSimple() ||
  811. !isa<FixedVectorType>(SI->getValueOperand()->getType()))
  812. return false;
  813. // TODO: Combine more complicated patterns (multiple insert) by referencing
  814. // TargetTransformInfo.
  815. Instruction *Source;
  816. Value *NewElement;
  817. Value *Idx;
  818. if (!match(SI->getValueOperand(),
  819. m_InsertElt(m_Instruction(Source), m_Value(NewElement),
  820. m_Value(Idx))))
  821. return false;
  822. if (auto *Load = dyn_cast<LoadInst>(Source)) {
  823. auto VecTy = cast<FixedVectorType>(SI->getValueOperand()->getType());
  824. const DataLayout &DL = I.getModule()->getDataLayout();
  825. Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
  826. // Don't optimize for atomic/volatile load or store. Ensure memory is not
  827. // modified between, vector type matches store size, and index is inbounds.
  828. if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
  829. !DL.typeSizeEqualsStoreSize(Load->getType()) ||
  830. SrcAddr != SI->getPointerOperand()->stripPointerCasts())
  831. return false;
  832. auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT);
  833. if (ScalarizableIdx.isUnsafe() ||
  834. isMemModifiedBetween(Load->getIterator(), SI->getIterator(),
  835. MemoryLocation::get(SI), AA))
  836. return false;
  837. if (ScalarizableIdx.isSafeWithFreeze())
  838. ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx));
  839. Value *GEP = Builder.CreateInBoundsGEP(
  840. SI->getValueOperand()->getType(), SI->getPointerOperand(),
  841. {ConstantInt::get(Idx->getType(), 0), Idx});
  842. StoreInst *NSI = Builder.CreateStore(NewElement, GEP);
  843. NSI->copyMetadata(*SI);
  844. Align ScalarOpAlignment = computeAlignmentAfterScalarization(
  845. std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx,
  846. DL);
  847. NSI->setAlignment(ScalarOpAlignment);
  848. replaceValue(I, *NSI);
  849. eraseInstruction(I);
  850. return true;
  851. }
  852. return false;
  853. }
  854. /// Try to scalarize vector loads feeding extractelement instructions.
  855. bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
  856. Value *Ptr;
  857. if (!match(&I, m_Load(m_Value(Ptr))))
  858. return false;
  859. auto *LI = cast<LoadInst>(&I);
  860. const DataLayout &DL = I.getModule()->getDataLayout();
  861. if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(LI->getType()))
  862. return false;
  863. auto *FixedVT = dyn_cast<FixedVectorType>(LI->getType());
  864. if (!FixedVT)
  865. return false;
  866. InstructionCost OriginalCost =
  867. TTI.getMemoryOpCost(Instruction::Load, LI->getType(), LI->getAlign(),
  868. LI->getPointerAddressSpace());
  869. InstructionCost ScalarizedCost = 0;
  870. Instruction *LastCheckedInst = LI;
  871. unsigned NumInstChecked = 0;
  872. // Check if all users of the load are extracts with no memory modifications
  873. // between the load and the extract. Compute the cost of both the original
  874. // code and the scalarized version.
  875. for (User *U : LI->users()) {
  876. auto *UI = dyn_cast<ExtractElementInst>(U);
  877. if (!UI || UI->getParent() != LI->getParent())
  878. return false;
  879. if (!isGuaranteedNotToBePoison(UI->getOperand(1), &AC, LI, &DT))
  880. return false;
  881. // Check if any instruction between the load and the extract may modify
  882. // memory.
  883. if (LastCheckedInst->comesBefore(UI)) {
  884. for (Instruction &I :
  885. make_range(std::next(LI->getIterator()), UI->getIterator())) {
  886. // Bail out if we reached the check limit or the instruction may write
  887. // to memory.
  888. if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
  889. return false;
  890. NumInstChecked++;
  891. }
  892. }
  893. if (!LastCheckedInst)
  894. LastCheckedInst = UI;
  895. else if (LastCheckedInst->comesBefore(UI))
  896. LastCheckedInst = UI;
  897. auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT);
  898. if (!ScalarIdx.isSafe()) {
  899. // TODO: Freeze index if it is safe to do so.
  900. ScalarIdx.discard();
  901. return false;
  902. }
  903. auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
  904. OriginalCost +=
  905. TTI.getVectorInstrCost(Instruction::ExtractElement, LI->getType(),
  906. Index ? Index->getZExtValue() : -1);
  907. ScalarizedCost +=
  908. TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(),
  909. Align(1), LI->getPointerAddressSpace());
  910. ScalarizedCost += TTI.getAddressComputationCost(FixedVT->getElementType());
  911. }
  912. if (ScalarizedCost >= OriginalCost)
  913. return false;
  914. // Replace extracts with narrow scalar loads.
  915. for (User *U : LI->users()) {
  916. auto *EI = cast<ExtractElementInst>(U);
  917. Builder.SetInsertPoint(EI);
  918. Value *Idx = EI->getOperand(1);
  919. Value *GEP =
  920. Builder.CreateInBoundsGEP(FixedVT, Ptr, {Builder.getInt32(0), Idx});
  921. auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
  922. FixedVT->getElementType(), GEP, EI->getName() + ".scalar"));
  923. Align ScalarOpAlignment = computeAlignmentAfterScalarization(
  924. LI->getAlign(), FixedVT->getElementType(), Idx, DL);
  925. NewLoad->setAlignment(ScalarOpAlignment);
  926. replaceValue(*EI, *NewLoad);
  927. }
  928. return true;
  929. }
  930. /// Try to convert "shuffle (binop), (binop)" with a shared binop operand into
  931. /// "binop (shuffle), (shuffle)".
  932. bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
  933. auto *VecTy = dyn_cast<FixedVectorType>(I.getType());
  934. if (!VecTy)
  935. return false;
  936. BinaryOperator *B0, *B1;
  937. ArrayRef<int> Mask;
  938. if (!match(&I, m_Shuffle(m_OneUse(m_BinOp(B0)), m_OneUse(m_BinOp(B1)),
  939. m_Mask(Mask))) ||
  940. B0->getOpcode() != B1->getOpcode() || B0->getType() != VecTy)
  941. return false;
  942. // Try to replace a binop with a shuffle if the shuffle is not costly.
  943. // The new shuffle will choose from a single, common operand, so it may be
  944. // cheaper than the existing two-operand shuffle.
  945. SmallVector<int> UnaryMask = createUnaryMask(Mask, Mask.size());
  946. Instruction::BinaryOps Opcode = B0->getOpcode();
  947. InstructionCost BinopCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
  948. InstructionCost ShufCost = TTI.getShuffleCost(
  949. TargetTransformInfo::SK_PermuteSingleSrc, VecTy, UnaryMask);
  950. if (ShufCost > BinopCost)
  951. return false;
  952. // If we have something like "add X, Y" and "add Z, X", swap ops to match.
  953. Value *X = B0->getOperand(0), *Y = B0->getOperand(1);
  954. Value *Z = B1->getOperand(0), *W = B1->getOperand(1);
  955. if (BinaryOperator::isCommutative(Opcode) && X != Z && Y != W)
  956. std::swap(X, Y);
  957. Value *Shuf0, *Shuf1;
  958. if (X == Z) {
  959. // shuf (bo X, Y), (bo X, W) --> bo (shuf X), (shuf Y, W)
  960. Shuf0 = Builder.CreateShuffleVector(X, UnaryMask);
  961. Shuf1 = Builder.CreateShuffleVector(Y, W, Mask);
  962. } else if (Y == W) {
  963. // shuf (bo X, Y), (bo Z, Y) --> bo (shuf X, Z), (shuf Y)
  964. Shuf0 = Builder.CreateShuffleVector(X, Z, Mask);
  965. Shuf1 = Builder.CreateShuffleVector(Y, UnaryMask);
  966. } else {
  967. return false;
  968. }
  969. Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1);
  970. // Intersect flags from the old binops.
  971. if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
  972. NewInst->copyIRFlags(B0);
  973. NewInst->andIRFlags(B1);
  974. }
  975. replaceValue(I, *NewBO);
  976. return true;
  977. }
  978. /// This is the entry point for all transforms. Pass manager differences are
  979. /// handled in the callers of this function.
  980. bool VectorCombine::run() {
  981. if (DisableVectorCombine)
  982. return false;
  983. // Don't attempt vectorization if the target does not support vectors.
  984. if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true)))
  985. return false;
  986. bool MadeChange = false;
  987. auto FoldInst = [this, &MadeChange](Instruction &I) {
  988. Builder.SetInsertPoint(&I);
  989. if (!ScalarizationOnly) {
  990. MadeChange |= vectorizeLoadInsert(I);
  991. MadeChange |= foldExtractExtract(I);
  992. MadeChange |= foldBitcastShuf(I);
  993. MadeChange |= foldExtractedCmps(I);
  994. MadeChange |= foldShuffleOfBinops(I);
  995. }
  996. MadeChange |= scalarizeBinopOrCmp(I);
  997. MadeChange |= scalarizeLoadExtract(I);
  998. MadeChange |= foldSingleElementStore(I);
  999. };
  1000. for (BasicBlock &BB : F) {
  1001. // Ignore unreachable basic blocks.
  1002. if (!DT.isReachableFromEntry(&BB))
  1003. continue;
  1004. // Use early increment range so that we can erase instructions in loop.
  1005. for (Instruction &I : make_early_inc_range(BB)) {
  1006. if (I.isDebugOrPseudoInst())
  1007. continue;
  1008. FoldInst(I);
  1009. }
  1010. }
  1011. while (!Worklist.isEmpty()) {
  1012. Instruction *I = Worklist.removeOne();
  1013. if (!I)
  1014. continue;
  1015. if (isInstructionTriviallyDead(I)) {
  1016. eraseInstruction(*I);
  1017. continue;
  1018. }
  1019. FoldInst(*I);
  1020. }
  1021. return MadeChange;
  1022. }
  1023. // Pass manager boilerplate below here.
  1024. namespace {
  1025. class VectorCombineLegacyPass : public FunctionPass {
  1026. public:
  1027. static char ID;
  1028. VectorCombineLegacyPass() : FunctionPass(ID) {
  1029. initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
  1030. }
  1031. void getAnalysisUsage(AnalysisUsage &AU) const override {
  1032. AU.addRequired<AssumptionCacheTracker>();
  1033. AU.addRequired<DominatorTreeWrapperPass>();
  1034. AU.addRequired<TargetTransformInfoWrapperPass>();
  1035. AU.addRequired<AAResultsWrapperPass>();
  1036. AU.setPreservesCFG();
  1037. AU.addPreserved<DominatorTreeWrapperPass>();
  1038. AU.addPreserved<GlobalsAAWrapperPass>();
  1039. AU.addPreserved<AAResultsWrapperPass>();
  1040. AU.addPreserved<BasicAAWrapperPass>();
  1041. FunctionPass::getAnalysisUsage(AU);
  1042. }
  1043. bool runOnFunction(Function &F) override {
  1044. if (skipFunction(F))
  1045. return false;
  1046. auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
  1047. auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  1048. auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  1049. auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
  1050. VectorCombine Combiner(F, TTI, DT, AA, AC, false);
  1051. return Combiner.run();
  1052. }
  1053. };
  1054. } // namespace
  1055. char VectorCombineLegacyPass::ID = 0;
  1056. INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
  1057. "Optimize scalar/vector ops", false,
  1058. false)
  1059. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  1060. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  1061. INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
  1062. "Optimize scalar/vector ops", false, false)
  1063. Pass *llvm::createVectorCombinePass() {
  1064. return new VectorCombineLegacyPass();
  1065. }
  1066. PreservedAnalyses VectorCombinePass::run(Function &F,
  1067. FunctionAnalysisManager &FAM) {
  1068. auto &AC = FAM.getResult<AssumptionAnalysis>(F);
  1069. TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
  1070. DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
  1071. AAResults &AA = FAM.getResult<AAManager>(F);
  1072. VectorCombine Combiner(F, TTI, DT, AA, AC, ScalarizationOnly);
  1073. if (!Combiner.run())
  1074. return PreservedAnalyses::all();
  1075. PreservedAnalyses PA;
  1076. PA.preserveSet<CFGAnalyses>();
  1077. return PA;
  1078. }