InstCombineShifts.cpp 64 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599
  1. //===- InstCombineShifts.cpp ----------------------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements the visitShl, visitLShr, and visitAShr functions.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "InstCombineInternal.h"
  13. #include "llvm/Analysis/InstructionSimplify.h"
  14. #include "llvm/IR/IntrinsicInst.h"
  15. #include "llvm/IR/PatternMatch.h"
  16. #include "llvm/Transforms/InstCombine/InstCombiner.h"
  17. using namespace llvm;
  18. using namespace PatternMatch;
  19. #define DEBUG_TYPE "instcombine"
  20. bool canTryToConstantAddTwoShiftAmounts(Value *Sh0, Value *ShAmt0, Value *Sh1,
  21. Value *ShAmt1) {
  22. // We have two shift amounts from two different shifts. The types of those
  23. // shift amounts may not match. If that's the case let's bailout now..
  24. if (ShAmt0->getType() != ShAmt1->getType())
  25. return false;
  26. // As input, we have the following pattern:
  27. // Sh0 (Sh1 X, Q), K
  28. // We want to rewrite that as:
  29. // Sh x, (Q+K) iff (Q+K) u< bitwidth(x)
  30. // While we know that originally (Q+K) would not overflow
  31. // (because 2 * (N-1) u<= iN -1), we have looked past extensions of
  32. // shift amounts. so it may now overflow in smaller bitwidth.
  33. // To ensure that does not happen, we need to ensure that the total maximal
  34. // shift amount is still representable in that smaller bit width.
  35. unsigned MaximalPossibleTotalShiftAmount =
  36. (Sh0->getType()->getScalarSizeInBits() - 1) +
  37. (Sh1->getType()->getScalarSizeInBits() - 1);
  38. APInt MaximalRepresentableShiftAmount =
  39. APInt::getAllOnes(ShAmt0->getType()->getScalarSizeInBits());
  40. return MaximalRepresentableShiftAmount.uge(MaximalPossibleTotalShiftAmount);
  41. }
  42. // Given pattern:
  43. // (x shiftopcode Q) shiftopcode K
  44. // we should rewrite it as
  45. // x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) and
  46. //
  47. // This is valid for any shift, but they must be identical, and we must be
  48. // careful in case we have (zext(Q)+zext(K)) and look past extensions,
  49. // (Q+K) must not overflow or else (Q+K) u< bitwidth(x) is bogus.
  50. //
  51. // AnalyzeForSignBitExtraction indicates that we will only analyze whether this
  52. // pattern has any 2 right-shifts that sum to 1 less than original bit width.
  53. Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts(
  54. BinaryOperator *Sh0, const SimplifyQuery &SQ,
  55. bool AnalyzeForSignBitExtraction) {
  56. // Look for a shift of some instruction, ignore zext of shift amount if any.
  57. Instruction *Sh0Op0;
  58. Value *ShAmt0;
  59. if (!match(Sh0,
  60. m_Shift(m_Instruction(Sh0Op0), m_ZExtOrSelf(m_Value(ShAmt0)))))
  61. return nullptr;
  62. // If there is a truncation between the two shifts, we must make note of it
  63. // and look through it. The truncation imposes additional constraints on the
  64. // transform.
  65. Instruction *Sh1;
  66. Value *Trunc = nullptr;
  67. match(Sh0Op0,
  68. m_CombineOr(m_CombineAnd(m_Trunc(m_Instruction(Sh1)), m_Value(Trunc)),
  69. m_Instruction(Sh1)));
  70. // Inner shift: (x shiftopcode ShAmt1)
  71. // Like with other shift, ignore zext of shift amount if any.
  72. Value *X, *ShAmt1;
  73. if (!match(Sh1, m_Shift(m_Value(X), m_ZExtOrSelf(m_Value(ShAmt1)))))
  74. return nullptr;
  75. // Verify that it would be safe to try to add those two shift amounts.
  76. if (!canTryToConstantAddTwoShiftAmounts(Sh0, ShAmt0, Sh1, ShAmt1))
  77. return nullptr;
  78. // We are only looking for signbit extraction if we have two right shifts.
  79. bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) &&
  80. match(Sh1, m_Shr(m_Value(), m_Value()));
  81. // ... and if it's not two right-shifts, we know the answer already.
  82. if (AnalyzeForSignBitExtraction && !HadTwoRightShifts)
  83. return nullptr;
  84. // The shift opcodes must be identical, unless we are just checking whether
  85. // this pattern can be interpreted as a sign-bit-extraction.
  86. Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode();
  87. bool IdenticalShOpcodes = Sh0->getOpcode() == Sh1->getOpcode();
  88. if (!IdenticalShOpcodes && !AnalyzeForSignBitExtraction)
  89. return nullptr;
  90. // If we saw truncation, we'll need to produce extra instruction,
  91. // and for that one of the operands of the shift must be one-use,
  92. // unless of course we don't actually plan to produce any instructions here.
  93. if (Trunc && !AnalyzeForSignBitExtraction &&
  94. !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
  95. return nullptr;
  96. // Can we fold (ShAmt0+ShAmt1) ?
  97. auto *NewShAmt = dyn_cast_or_null<Constant>(
  98. simplifyAddInst(ShAmt0, ShAmt1, /*isNSW=*/false, /*isNUW=*/false,
  99. SQ.getWithInstruction(Sh0)));
  100. if (!NewShAmt)
  101. return nullptr; // Did not simplify.
  102. unsigned NewShAmtBitWidth = NewShAmt->getType()->getScalarSizeInBits();
  103. unsigned XBitWidth = X->getType()->getScalarSizeInBits();
  104. // Is the new shift amount smaller than the bit width of inner/new shift?
  105. if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT,
  106. APInt(NewShAmtBitWidth, XBitWidth))))
  107. return nullptr; // FIXME: could perform constant-folding.
  108. // If there was a truncation, and we have a right-shift, we can only fold if
  109. // we are left with the original sign bit. Likewise, if we were just checking
  110. // that this is a sighbit extraction, this is the place to check it.
  111. // FIXME: zero shift amount is also legal here, but we can't *easily* check
  112. // more than one predicate so it's not really worth it.
  113. if (HadTwoRightShifts && (Trunc || AnalyzeForSignBitExtraction)) {
  114. // If it's not a sign bit extraction, then we're done.
  115. if (!match(NewShAmt,
  116. m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
  117. APInt(NewShAmtBitWidth, XBitWidth - 1))))
  118. return nullptr;
  119. // If it is, and that was the question, return the base value.
  120. if (AnalyzeForSignBitExtraction)
  121. return X;
  122. }
  123. assert(IdenticalShOpcodes && "Should not get here with different shifts.");
  124. // All good, we can do this fold.
  125. NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType());
  126. BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt);
  127. // The flags can only be propagated if there wasn't a trunc.
  128. if (!Trunc) {
  129. // If the pattern did not involve trunc, and both of the original shifts
  130. // had the same flag set, preserve the flag.
  131. if (ShiftOpcode == Instruction::BinaryOps::Shl) {
  132. NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() &&
  133. Sh1->hasNoUnsignedWrap());
  134. NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() &&
  135. Sh1->hasNoSignedWrap());
  136. } else {
  137. NewShift->setIsExact(Sh0->isExact() && Sh1->isExact());
  138. }
  139. }
  140. Instruction *Ret = NewShift;
  141. if (Trunc) {
  142. Builder.Insert(NewShift);
  143. Ret = CastInst::Create(Instruction::Trunc, NewShift, Sh0->getType());
  144. }
  145. return Ret;
  146. }
  147. // If we have some pattern that leaves only some low bits set, and then performs
  148. // left-shift of those bits, if none of the bits that are left after the final
  149. // shift are modified by the mask, we can omit the mask.
  150. //
  151. // There are many variants to this pattern:
  152. // a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt
  153. // b) (x & (~(-1 << MaskShAmt))) << ShiftShAmt
  154. // c) (x & (-1 l>> MaskShAmt)) << ShiftShAmt
  155. // d) (x & ((-1 << MaskShAmt) l>> MaskShAmt)) << ShiftShAmt
  156. // e) ((x << MaskShAmt) l>> MaskShAmt) << ShiftShAmt
  157. // f) ((x << MaskShAmt) a>> MaskShAmt) << ShiftShAmt
  158. // All these patterns can be simplified to just:
  159. // x << ShiftShAmt
  160. // iff:
  161. // a,b) (MaskShAmt+ShiftShAmt) u>= bitwidth(x)
  162. // c,d,e,f) (ShiftShAmt-MaskShAmt) s>= 0 (i.e. ShiftShAmt u>= MaskShAmt)
  163. static Instruction *
  164. dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
  165. const SimplifyQuery &Q,
  166. InstCombiner::BuilderTy &Builder) {
  167. assert(OuterShift->getOpcode() == Instruction::BinaryOps::Shl &&
  168. "The input must be 'shl'!");
  169. Value *Masked, *ShiftShAmt;
  170. match(OuterShift,
  171. m_Shift(m_Value(Masked), m_ZExtOrSelf(m_Value(ShiftShAmt))));
  172. // *If* there is a truncation between an outer shift and a possibly-mask,
  173. // then said truncation *must* be one-use, else we can't perform the fold.
  174. Value *Trunc;
  175. if (match(Masked, m_CombineAnd(m_Trunc(m_Value(Masked)), m_Value(Trunc))) &&
  176. !Trunc->hasOneUse())
  177. return nullptr;
  178. Type *NarrowestTy = OuterShift->getType();
  179. Type *WidestTy = Masked->getType();
  180. bool HadTrunc = WidestTy != NarrowestTy;
  181. // The mask must be computed in a type twice as wide to ensure
  182. // that no bits are lost if the sum-of-shifts is wider than the base type.
  183. Type *ExtendedTy = WidestTy->getExtendedType();
  184. Value *MaskShAmt;
  185. // ((1 << MaskShAmt) - 1)
  186. auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes());
  187. // (~(-1 << maskNbits))
  188. auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes());
  189. // (-1 l>> MaskShAmt)
  190. auto MaskC = m_LShr(m_AllOnes(), m_Value(MaskShAmt));
  191. // ((-1 << MaskShAmt) l>> MaskShAmt)
  192. auto MaskD =
  193. m_LShr(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_Deferred(MaskShAmt));
  194. Value *X;
  195. Constant *NewMask;
  196. if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) {
  197. // Peek through an optional zext of the shift amount.
  198. match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt)));
  199. // Verify that it would be safe to try to add those two shift amounts.
  200. if (!canTryToConstantAddTwoShiftAmounts(OuterShift, ShiftShAmt, Masked,
  201. MaskShAmt))
  202. return nullptr;
  203. // Can we simplify (MaskShAmt+ShiftShAmt) ?
  204. auto *SumOfShAmts = dyn_cast_or_null<Constant>(simplifyAddInst(
  205. MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q));
  206. if (!SumOfShAmts)
  207. return nullptr; // Did not simplify.
  208. // In this pattern SumOfShAmts correlates with the number of low bits
  209. // that shall remain in the root value (OuterShift).
  210. // An extend of an undef value becomes zero because the high bits are never
  211. // completely unknown. Replace the `undef` shift amounts with final
  212. // shift bitwidth to ensure that the value remains undef when creating the
  213. // subsequent shift op.
  214. SumOfShAmts = Constant::replaceUndefsWith(
  215. SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(),
  216. ExtendedTy->getScalarSizeInBits()));
  217. auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy);
  218. // And compute the mask as usual: ~(-1 << (SumOfShAmts))
  219. auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
  220. auto *ExtendedInvertedMask =
  221. ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts);
  222. NewMask = ConstantExpr::getNot(ExtendedInvertedMask);
  223. } else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) ||
  224. match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)),
  225. m_Deferred(MaskShAmt)))) {
  226. // Peek through an optional zext of the shift amount.
  227. match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt)));
  228. // Verify that it would be safe to try to add those two shift amounts.
  229. if (!canTryToConstantAddTwoShiftAmounts(OuterShift, ShiftShAmt, Masked,
  230. MaskShAmt))
  231. return nullptr;
  232. // Can we simplify (ShiftShAmt-MaskShAmt) ?
  233. auto *ShAmtsDiff = dyn_cast_or_null<Constant>(simplifySubInst(
  234. ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q));
  235. if (!ShAmtsDiff)
  236. return nullptr; // Did not simplify.
  237. // In this pattern ShAmtsDiff correlates with the number of high bits that
  238. // shall be unset in the root value (OuterShift).
  239. // An extend of an undef value becomes zero because the high bits are never
  240. // completely unknown. Replace the `undef` shift amounts with negated
  241. // bitwidth of innermost shift to ensure that the value remains undef when
  242. // creating the subsequent shift op.
  243. unsigned WidestTyBitWidth = WidestTy->getScalarSizeInBits();
  244. ShAmtsDiff = Constant::replaceUndefsWith(
  245. ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(),
  246. -WidestTyBitWidth));
  247. auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt(
  248. ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(),
  249. WidestTyBitWidth,
  250. /*isSigned=*/false),
  251. ShAmtsDiff),
  252. ExtendedTy);
  253. // And compute the mask as usual: (-1 l>> (NumHighBitsToClear))
  254. auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
  255. NewMask =
  256. ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear);
  257. } else
  258. return nullptr; // Don't know anything about this pattern.
  259. NewMask = ConstantExpr::getTrunc(NewMask, NarrowestTy);
  260. // Does this mask has any unset bits? If not then we can just not apply it.
  261. bool NeedMask = !match(NewMask, m_AllOnes());
  262. // If we need to apply a mask, there are several more restrictions we have.
  263. if (NeedMask) {
  264. // The old masking instruction must go away.
  265. if (!Masked->hasOneUse())
  266. return nullptr;
  267. // The original "masking" instruction must not have been`ashr`.
  268. if (match(Masked, m_AShr(m_Value(), m_Value())))
  269. return nullptr;
  270. }
  271. // If we need to apply truncation, let's do it first, since we can.
  272. // We have already ensured that the old truncation will go away.
  273. if (HadTrunc)
  274. X = Builder.CreateTrunc(X, NarrowestTy);
  275. // No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits.
  276. // We didn't change the Type of this outermost shift, so we can just do it.
  277. auto *NewShift = BinaryOperator::Create(OuterShift->getOpcode(), X,
  278. OuterShift->getOperand(1));
  279. if (!NeedMask)
  280. return NewShift;
  281. Builder.Insert(NewShift);
  282. return BinaryOperator::Create(Instruction::And, NewShift, NewMask);
  283. }
  284. /// If we have a shift-by-constant of a bitwise logic op that itself has a
  285. /// shift-by-constant operand with identical opcode, we may be able to convert
  286. /// that into 2 independent shifts followed by the logic op. This eliminates a
  287. /// a use of an intermediate value (reduces dependency chain).
  288. static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I,
  289. InstCombiner::BuilderTy &Builder) {
  290. assert(I.isShift() && "Expected a shift as input");
  291. auto *LogicInst = dyn_cast<BinaryOperator>(I.getOperand(0));
  292. if (!LogicInst || !LogicInst->isBitwiseLogicOp() || !LogicInst->hasOneUse())
  293. return nullptr;
  294. Constant *C0, *C1;
  295. if (!match(I.getOperand(1), m_Constant(C1)))
  296. return nullptr;
  297. Instruction::BinaryOps ShiftOpcode = I.getOpcode();
  298. Type *Ty = I.getType();
  299. // Find a matching one-use shift by constant. The fold is not valid if the sum
  300. // of the shift values equals or exceeds bitwidth.
  301. // TODO: Remove the one-use check if the other logic operand (Y) is constant.
  302. Value *X, *Y;
  303. auto matchFirstShift = [&](Value *V) {
  304. APInt Threshold(Ty->getScalarSizeInBits(), Ty->getScalarSizeInBits());
  305. return match(V,
  306. m_OneUse(m_BinOp(ShiftOpcode, m_Value(X), m_Constant(C0)))) &&
  307. match(ConstantExpr::getAdd(C0, C1),
  308. m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold));
  309. };
  310. // Logic ops are commutative, so check each operand for a match.
  311. if (matchFirstShift(LogicInst->getOperand(0)))
  312. Y = LogicInst->getOperand(1);
  313. else if (matchFirstShift(LogicInst->getOperand(1)))
  314. Y = LogicInst->getOperand(0);
  315. else
  316. return nullptr;
  317. // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
  318. Constant *ShiftSumC = ConstantExpr::getAdd(C0, C1);
  319. Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC);
  320. Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, C1);
  321. return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2);
  322. }
  323. Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
  324. if (Instruction *Phi = foldBinopWithPhiOperands(I))
  325. return Phi;
  326. Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
  327. assert(Op0->getType() == Op1->getType());
  328. Type *Ty = I.getType();
  329. // If the shift amount is a one-use `sext`, we can demote it to `zext`.
  330. Value *Y;
  331. if (match(Op1, m_OneUse(m_SExt(m_Value(Y))))) {
  332. Value *NewExt = Builder.CreateZExt(Y, Ty, Op1->getName());
  333. return BinaryOperator::Create(I.getOpcode(), Op0, NewExt);
  334. }
  335. // See if we can fold away this shift.
  336. if (SimplifyDemandedInstructionBits(I))
  337. return &I;
  338. // Try to fold constant and into select arguments.
  339. if (isa<Constant>(Op0))
  340. if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
  341. if (Instruction *R = FoldOpIntoSelect(I, SI))
  342. return R;
  343. if (Constant *CUI = dyn_cast<Constant>(Op1))
  344. if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I))
  345. return Res;
  346. if (auto *NewShift = cast_or_null<Instruction>(
  347. reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ)))
  348. return NewShift;
  349. // Pre-shift a constant shifted by a variable amount with constant offset:
  350. // C shift (A add nuw C1) --> (C shift C1) shift A
  351. Value *A;
  352. Constant *C, *C1;
  353. if (match(Op0, m_Constant(C)) &&
  354. match(Op1, m_NUWAdd(m_Value(A), m_Constant(C1)))) {
  355. Value *NewC = Builder.CreateBinOp(I.getOpcode(), C, C1);
  356. return BinaryOperator::Create(I.getOpcode(), NewC, A);
  357. }
  358. unsigned BitWidth = Ty->getScalarSizeInBits();
  359. const APInt *AC, *AddC;
  360. // Try to pre-shift a constant shifted by a variable amount added with a
  361. // negative number:
  362. // C << (X - AddC) --> (C >> AddC) << X
  363. // and
  364. // C >> (X - AddC) --> (C << AddC) >> X
  365. if (match(Op0, m_APInt(AC)) && match(Op1, m_Add(m_Value(A), m_APInt(AddC))) &&
  366. AddC->isNegative() && (-*AddC).ult(BitWidth)) {
  367. assert(!AC->isZero() && "Expected simplify of shifted zero");
  368. unsigned PosOffset = (-*AddC).getZExtValue();
  369. auto isSuitableForPreShift = [PosOffset, &I, AC]() {
  370. switch (I.getOpcode()) {
  371. default:
  372. return false;
  373. case Instruction::Shl:
  374. return (I.hasNoSignedWrap() || I.hasNoUnsignedWrap()) &&
  375. AC->eq(AC->lshr(PosOffset).shl(PosOffset));
  376. case Instruction::LShr:
  377. return I.isExact() && AC->eq(AC->shl(PosOffset).lshr(PosOffset));
  378. case Instruction::AShr:
  379. return I.isExact() && AC->eq(AC->shl(PosOffset).ashr(PosOffset));
  380. }
  381. };
  382. if (isSuitableForPreShift()) {
  383. Constant *NewC = ConstantInt::get(Ty, I.getOpcode() == Instruction::Shl
  384. ? AC->lshr(PosOffset)
  385. : AC->shl(PosOffset));
  386. BinaryOperator *NewShiftOp =
  387. BinaryOperator::Create(I.getOpcode(), NewC, A);
  388. if (I.getOpcode() == Instruction::Shl) {
  389. NewShiftOp->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
  390. } else {
  391. NewShiftOp->setIsExact();
  392. }
  393. return NewShiftOp;
  394. }
  395. }
  396. // X shift (A srem C) -> X shift (A and (C - 1)) iff C is a power of 2.
  397. // Because shifts by negative values (which could occur if A were negative)
  398. // are undefined.
  399. if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Constant(C))) &&
  400. match(C, m_Power2())) {
  401. // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't
  402. // demand the sign bit (and many others) here??
  403. Constant *Mask = ConstantExpr::getSub(C, ConstantInt::get(Ty, 1));
  404. Value *Rem = Builder.CreateAnd(A, Mask, Op1->getName());
  405. return replaceOperand(I, 1, Rem);
  406. }
  407. if (Instruction *Logic = foldShiftOfShiftedLogic(I, Builder))
  408. return Logic;
  409. return nullptr;
  410. }
  411. /// Return true if we can simplify two logical (either left or right) shifts
  412. /// that have constant shift amounts: OuterShift (InnerShift X, C1), C2.
  413. static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
  414. Instruction *InnerShift,
  415. InstCombinerImpl &IC, Instruction *CxtI) {
  416. assert(InnerShift->isLogicalShift() && "Unexpected instruction type");
  417. // We need constant scalar or constant splat shifts.
  418. const APInt *InnerShiftConst;
  419. if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst)))
  420. return false;
  421. // Two logical shifts in the same direction:
  422. // shl (shl X, C1), C2 --> shl X, C1 + C2
  423. // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
  424. bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
  425. if (IsInnerShl == IsOuterShl)
  426. return true;
  427. // Equal shift amounts in opposite directions become bitwise 'and':
  428. // lshr (shl X, C), C --> and X, C'
  429. // shl (lshr X, C), C --> and X, C'
  430. if (*InnerShiftConst == OuterShAmt)
  431. return true;
  432. // If the 2nd shift is bigger than the 1st, we can fold:
  433. // lshr (shl X, C1), C2 --> and (shl X, C1 - C2), C3
  434. // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3
  435. // but it isn't profitable unless we know the and'd out bits are already zero.
  436. // Also, check that the inner shift is valid (less than the type width) or
  437. // we'll crash trying to produce the bit mask for the 'and'.
  438. unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
  439. if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) {
  440. unsigned InnerShAmt = InnerShiftConst->getZExtValue();
  441. unsigned MaskShift =
  442. IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
  443. APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift;
  444. if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, 0, CxtI))
  445. return true;
  446. }
  447. return false;
  448. }
  449. /// See if we can compute the specified value, but shifted logically to the left
  450. /// or right by some number of bits. This should return true if the expression
  451. /// can be computed for the same cost as the current expression tree. This is
  452. /// used to eliminate extraneous shifting from things like:
  453. /// %C = shl i128 %A, 64
  454. /// %D = shl i128 %B, 96
  455. /// %E = or i128 %C, %D
  456. /// %F = lshr i128 %E, 64
  457. /// where the client will ask if E can be computed shifted right by 64-bits. If
  458. /// this succeeds, getShiftedValue() will be called to produce the value.
  459. static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
  460. InstCombinerImpl &IC, Instruction *CxtI) {
  461. // We can always evaluate constants shifted.
  462. if (isa<Constant>(V))
  463. return true;
  464. Instruction *I = dyn_cast<Instruction>(V);
  465. if (!I) return false;
  466. // We can't mutate something that has multiple uses: doing so would
  467. // require duplicating the instruction in general, which isn't profitable.
  468. if (!I->hasOneUse()) return false;
  469. switch (I->getOpcode()) {
  470. default: return false;
  471. case Instruction::And:
  472. case Instruction::Or:
  473. case Instruction::Xor:
  474. // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
  475. return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
  476. canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
  477. case Instruction::Shl:
  478. case Instruction::LShr:
  479. return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI);
  480. case Instruction::Select: {
  481. SelectInst *SI = cast<SelectInst>(I);
  482. Value *TrueVal = SI->getTrueValue();
  483. Value *FalseVal = SI->getFalseValue();
  484. return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
  485. canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
  486. }
  487. case Instruction::PHI: {
  488. // We can change a phi if we can change all operands. Note that we never
  489. // get into trouble with cyclic PHIs here because we only consider
  490. // instructions with a single use.
  491. PHINode *PN = cast<PHINode>(I);
  492. for (Value *IncValue : PN->incoming_values())
  493. if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
  494. return false;
  495. return true;
  496. }
  497. case Instruction::Mul: {
  498. const APInt *MulConst;
  499. // We can fold (shr (mul X, -(1 << C)), C) -> (and (neg X), C`)
  500. return !IsLeftShift && match(I->getOperand(1), m_APInt(MulConst)) &&
  501. MulConst->isNegatedPowerOf2() &&
  502. MulConst->countTrailingZeros() == NumBits;
  503. }
  504. }
  505. }
  506. /// Fold OuterShift (InnerShift X, C1), C2.
  507. /// See canEvaluateShiftedShift() for the constraints on these instructions.
  508. static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt,
  509. bool IsOuterShl,
  510. InstCombiner::BuilderTy &Builder) {
  511. bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
  512. Type *ShType = InnerShift->getType();
  513. unsigned TypeWidth = ShType->getScalarSizeInBits();
  514. // We only accept shifts-by-a-constant in canEvaluateShifted().
  515. const APInt *C1;
  516. match(InnerShift->getOperand(1), m_APInt(C1));
  517. unsigned InnerShAmt = C1->getZExtValue();
  518. // Change the shift amount and clear the appropriate IR flags.
  519. auto NewInnerShift = [&](unsigned ShAmt) {
  520. InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt));
  521. if (IsInnerShl) {
  522. InnerShift->setHasNoUnsignedWrap(false);
  523. InnerShift->setHasNoSignedWrap(false);
  524. } else {
  525. InnerShift->setIsExact(false);
  526. }
  527. return InnerShift;
  528. };
  529. // Two logical shifts in the same direction:
  530. // shl (shl X, C1), C2 --> shl X, C1 + C2
  531. // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
  532. if (IsInnerShl == IsOuterShl) {
  533. // If this is an oversized composite shift, then unsigned shifts get 0.
  534. if (InnerShAmt + OuterShAmt >= TypeWidth)
  535. return Constant::getNullValue(ShType);
  536. return NewInnerShift(InnerShAmt + OuterShAmt);
  537. }
  538. // Equal shift amounts in opposite directions become bitwise 'and':
  539. // lshr (shl X, C), C --> and X, C'
  540. // shl (lshr X, C), C --> and X, C'
  541. if (InnerShAmt == OuterShAmt) {
  542. APInt Mask = IsInnerShl
  543. ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt)
  544. : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt);
  545. Value *And = Builder.CreateAnd(InnerShift->getOperand(0),
  546. ConstantInt::get(ShType, Mask));
  547. if (auto *AndI = dyn_cast<Instruction>(And)) {
  548. AndI->moveBefore(InnerShift);
  549. AndI->takeName(InnerShift);
  550. }
  551. return And;
  552. }
  553. assert(InnerShAmt > OuterShAmt &&
  554. "Unexpected opposite direction logical shift pair");
  555. // In general, we would need an 'and' for this transform, but
  556. // canEvaluateShiftedShift() guarantees that the masked-off bits are not used.
  557. // lshr (shl X, C1), C2 --> shl X, C1 - C2
  558. // shl (lshr X, C1), C2 --> lshr X, C1 - C2
  559. return NewInnerShift(InnerShAmt - OuterShAmt);
  560. }
  561. /// When canEvaluateShifted() returns true for an expression, this function
  562. /// inserts the new computation that produces the shifted value.
  563. static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
  564. InstCombinerImpl &IC, const DataLayout &DL) {
  565. // We can always evaluate constants shifted.
  566. if (Constant *C = dyn_cast<Constant>(V)) {
  567. if (isLeftShift)
  568. return IC.Builder.CreateShl(C, NumBits);
  569. else
  570. return IC.Builder.CreateLShr(C, NumBits);
  571. }
  572. Instruction *I = cast<Instruction>(V);
  573. IC.addToWorklist(I);
  574. switch (I->getOpcode()) {
  575. default: llvm_unreachable("Inconsistency with CanEvaluateShifted");
  576. case Instruction::And:
  577. case Instruction::Or:
  578. case Instruction::Xor:
  579. // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
  580. I->setOperand(
  581. 0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL));
  582. I->setOperand(
  583. 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
  584. return I;
  585. case Instruction::Shl:
  586. case Instruction::LShr:
  587. return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift,
  588. IC.Builder);
  589. case Instruction::Select:
  590. I->setOperand(
  591. 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
  592. I->setOperand(
  593. 2, getShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL));
  594. return I;
  595. case Instruction::PHI: {
  596. // We can change a phi if we can change all operands. Note that we never
  597. // get into trouble with cyclic PHIs here because we only consider
  598. // instructions with a single use.
  599. PHINode *PN = cast<PHINode>(I);
  600. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
  601. PN->setIncomingValue(i, getShiftedValue(PN->getIncomingValue(i), NumBits,
  602. isLeftShift, IC, DL));
  603. return PN;
  604. }
  605. case Instruction::Mul: {
  606. assert(!isLeftShift && "Unexpected shift direction!");
  607. auto *Neg = BinaryOperator::CreateNeg(I->getOperand(0));
  608. IC.InsertNewInstWith(Neg, *I);
  609. unsigned TypeWidth = I->getType()->getScalarSizeInBits();
  610. APInt Mask = APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits);
  611. auto *And = BinaryOperator::CreateAnd(Neg,
  612. ConstantInt::get(I->getType(), Mask));
  613. And->takeName(I);
  614. return IC.InsertNewInstWith(And, *I);
  615. }
  616. }
  617. }
  618. // If this is a bitwise operator or add with a constant RHS we might be able
  619. // to pull it through a shift.
  620. static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift,
  621. BinaryOperator *BO) {
  622. switch (BO->getOpcode()) {
  623. default:
  624. return false; // Do not perform transform!
  625. case Instruction::Add:
  626. return Shift.getOpcode() == Instruction::Shl;
  627. case Instruction::Or:
  628. case Instruction::And:
  629. return true;
  630. case Instruction::Xor:
  631. // Do not change a 'not' of logical shift because that would create a normal
  632. // 'xor'. The 'not' is likely better for analysis, SCEV, and codegen.
  633. return !(Shift.isLogicalShift() && match(BO, m_Not(m_Value())));
  634. }
  635. }
  636. Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1,
  637. BinaryOperator &I) {
  638. // (C2 << X) << C1 --> (C2 << C1) << X
  639. // (C2 >> X) >> C1 --> (C2 >> C1) >> X
  640. Constant *C2;
  641. Value *X;
  642. if (match(Op0, m_BinOp(I.getOpcode(), m_Constant(C2), m_Value(X))))
  643. return BinaryOperator::Create(
  644. I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), C2, C1), X);
  645. bool IsLeftShift = I.getOpcode() == Instruction::Shl;
  646. Type *Ty = I.getType();
  647. unsigned TypeBits = Ty->getScalarSizeInBits();
  648. // (X / +DivC) >> (Width - 1) --> ext (X <= -DivC)
  649. // (X / -DivC) >> (Width - 1) --> ext (X >= +DivC)
  650. const APInt *DivC;
  651. if (!IsLeftShift && match(C1, m_SpecificIntAllowUndef(TypeBits - 1)) &&
  652. match(Op0, m_SDiv(m_Value(X), m_APInt(DivC))) && !DivC->isZero() &&
  653. !DivC->isMinSignedValue()) {
  654. Constant *NegDivC = ConstantInt::get(Ty, -(*DivC));
  655. ICmpInst::Predicate Pred =
  656. DivC->isNegative() ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_SLE;
  657. Value *Cmp = Builder.CreateICmp(Pred, X, NegDivC);
  658. auto ExtOpcode = (I.getOpcode() == Instruction::AShr) ? Instruction::SExt
  659. : Instruction::ZExt;
  660. return CastInst::Create(ExtOpcode, Cmp, Ty);
  661. }
  662. const APInt *Op1C;
  663. if (!match(C1, m_APInt(Op1C)))
  664. return nullptr;
  665. assert(!Op1C->uge(TypeBits) &&
  666. "Shift over the type width should have been removed already");
  667. // See if we can propagate this shift into the input, this covers the trivial
  668. // cast of lshr(shl(x,c1),c2) as well as other more complex cases.
  669. if (I.getOpcode() != Instruction::AShr &&
  670. canEvaluateShifted(Op0, Op1C->getZExtValue(), IsLeftShift, *this, &I)) {
  671. LLVM_DEBUG(
  672. dbgs() << "ICE: GetShiftedValue propagating shift through expression"
  673. " to eliminate shift:\n IN: "
  674. << *Op0 << "\n SH: " << I << "\n");
  675. return replaceInstUsesWith(
  676. I, getShiftedValue(Op0, Op1C->getZExtValue(), IsLeftShift, *this, DL));
  677. }
  678. if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I))
  679. return FoldedShift;
  680. if (!Op0->hasOneUse())
  681. return nullptr;
  682. if (auto *Op0BO = dyn_cast<BinaryOperator>(Op0)) {
  683. // If the operand is a bitwise operator with a constant RHS, and the
  684. // shift is the only use, we can pull it out of the shift.
  685. const APInt *Op0C;
  686. if (match(Op0BO->getOperand(1), m_APInt(Op0C))) {
  687. if (canShiftBinOpWithConstantRHS(I, Op0BO)) {
  688. Value *NewRHS =
  689. Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(1), C1);
  690. Value *NewShift =
  691. Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), C1);
  692. NewShift->takeName(Op0BO);
  693. return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, NewRHS);
  694. }
  695. }
  696. }
  697. // If we have a select that conditionally executes some binary operator,
  698. // see if we can pull it the select and operator through the shift.
  699. //
  700. // For example, turning:
  701. // shl (select C, (add X, C1), X), C2
  702. // Into:
  703. // Y = shl X, C2
  704. // select C, (add Y, C1 << C2), Y
  705. Value *Cond;
  706. BinaryOperator *TBO;
  707. Value *FalseVal;
  708. if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)),
  709. m_Value(FalseVal)))) {
  710. const APInt *C;
  711. if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal &&
  712. match(TBO->getOperand(1), m_APInt(C)) &&
  713. canShiftBinOpWithConstantRHS(I, TBO)) {
  714. Value *NewRHS =
  715. Builder.CreateBinOp(I.getOpcode(), TBO->getOperand(1), C1);
  716. Value *NewShift = Builder.CreateBinOp(I.getOpcode(), FalseVal, C1);
  717. Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift, NewRHS);
  718. return SelectInst::Create(Cond, NewOp, NewShift);
  719. }
  720. }
  721. BinaryOperator *FBO;
  722. Value *TrueVal;
  723. if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal),
  724. m_OneUse(m_BinOp(FBO))))) {
  725. const APInt *C;
  726. if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal &&
  727. match(FBO->getOperand(1), m_APInt(C)) &&
  728. canShiftBinOpWithConstantRHS(I, FBO)) {
  729. Value *NewRHS =
  730. Builder.CreateBinOp(I.getOpcode(), FBO->getOperand(1), C1);
  731. Value *NewShift = Builder.CreateBinOp(I.getOpcode(), TrueVal, C1);
  732. Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift, NewRHS);
  733. return SelectInst::Create(Cond, NewShift, NewOp);
  734. }
  735. }
  736. return nullptr;
  737. }
  738. // Tries to perform
  739. // (lshr (add (zext X), (zext Y)), K)
  740. // -> (icmp ult (add X, Y), X)
  741. // where
  742. // - The add's operands are zexts from a K-bits integer to a bigger type.
  743. // - The add is only used by the shr, or by iK (or narrower) truncates.
  744. // - The lshr type has more than 2 bits (other types are boolean math).
  745. // - K > 1
  746. // note that
  747. // - The resulting add cannot have nuw/nsw, else on overflow we get a
  748. // poison value and the transform isn't legal anymore.
  749. Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
  750. assert(I.getOpcode() == Instruction::LShr);
  751. Value *Add = I.getOperand(0);
  752. Value *ShiftAmt = I.getOperand(1);
  753. Type *Ty = I.getType();
  754. if (Ty->getScalarSizeInBits() < 3)
  755. return nullptr;
  756. const APInt *ShAmtAPInt = nullptr;
  757. Value *X = nullptr, *Y = nullptr;
  758. if (!match(ShiftAmt, m_APInt(ShAmtAPInt)) ||
  759. !match(Add,
  760. m_Add(m_OneUse(m_ZExt(m_Value(X))), m_OneUse(m_ZExt(m_Value(Y))))))
  761. return nullptr;
  762. const unsigned ShAmt = ShAmtAPInt->getZExtValue();
  763. if (ShAmt == 1)
  764. return nullptr;
  765. // X/Y are zexts from `ShAmt`-sized ints.
  766. if (X->getType()->getScalarSizeInBits() != ShAmt ||
  767. Y->getType()->getScalarSizeInBits() != ShAmt)
  768. return nullptr;
  769. // Make sure that `Add` is only used by `I` and `ShAmt`-truncates.
  770. if (!Add->hasOneUse()) {
  771. for (User *U : Add->users()) {
  772. if (U == &I)
  773. continue;
  774. TruncInst *Trunc = dyn_cast<TruncInst>(U);
  775. if (!Trunc || Trunc->getType()->getScalarSizeInBits() > ShAmt)
  776. return nullptr;
  777. }
  778. }
  779. // Insert at Add so that the newly created `NarrowAdd` will dominate it's
  780. // users (i.e. `Add`'s users).
  781. Instruction *AddInst = cast<Instruction>(Add);
  782. Builder.SetInsertPoint(AddInst);
  783. Value *NarrowAdd = Builder.CreateAdd(X, Y, "add.narrowed");
  784. Value *Overflow =
  785. Builder.CreateICmpULT(NarrowAdd, X, "add.narrowed.overflow");
  786. // Replace the uses of the original add with a zext of the
  787. // NarrowAdd's result. Note that all users at this stage are known to
  788. // be ShAmt-sized truncs, or the lshr itself.
  789. if (!Add->hasOneUse())
  790. replaceInstUsesWith(*AddInst, Builder.CreateZExt(NarrowAdd, Ty));
  791. // Replace the LShr with a zext of the overflow check.
  792. return new ZExtInst(Overflow, Ty);
  793. }
  794. Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
  795. const SimplifyQuery Q = SQ.getWithInstruction(&I);
  796. if (Value *V = simplifyShlInst(I.getOperand(0), I.getOperand(1),
  797. I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), Q))
  798. return replaceInstUsesWith(I, V);
  799. if (Instruction *X = foldVectorBinop(I))
  800. return X;
  801. if (Instruction *V = commonShiftTransforms(I))
  802. return V;
  803. if (Instruction *V = dropRedundantMaskingOfLeftShiftInput(&I, Q, Builder))
  804. return V;
  805. Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
  806. Type *Ty = I.getType();
  807. unsigned BitWidth = Ty->getScalarSizeInBits();
  808. const APInt *C;
  809. if (match(Op1, m_APInt(C))) {
  810. unsigned ShAmtC = C->getZExtValue();
  811. // shl (zext X), C --> zext (shl X, C)
  812. // This is only valid if X would have zeros shifted out.
  813. Value *X;
  814. if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) {
  815. unsigned SrcWidth = X->getType()->getScalarSizeInBits();
  816. if (ShAmtC < SrcWidth &&
  817. MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmtC), 0, &I))
  818. return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty);
  819. }
  820. // (X >> C) << C --> X & (-1 << C)
  821. if (match(Op0, m_Shr(m_Value(X), m_Specific(Op1)))) {
  822. APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC));
  823. return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
  824. }
  825. const APInt *C1;
  826. if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(C1)))) &&
  827. C1->ult(BitWidth)) {
  828. unsigned ShrAmt = C1->getZExtValue();
  829. if (ShrAmt < ShAmtC) {
  830. // If C1 < C: (X >>?,exact C1) << C --> X << (C - C1)
  831. Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt);
  832. auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
  833. NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
  834. NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
  835. return NewShl;
  836. }
  837. if (ShrAmt > ShAmtC) {
  838. // If C1 > C: (X >>?exact C1) << C --> X >>?exact (C1 - C)
  839. Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmtC);
  840. auto *NewShr = BinaryOperator::Create(
  841. cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff);
  842. NewShr->setIsExact(true);
  843. return NewShr;
  844. }
  845. }
  846. if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_APInt(C1)))) &&
  847. C1->ult(BitWidth)) {
  848. unsigned ShrAmt = C1->getZExtValue();
  849. if (ShrAmt < ShAmtC) {
  850. // If C1 < C: (X >>? C1) << C --> (X << (C - C1)) & (-1 << C)
  851. Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt);
  852. auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
  853. NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
  854. NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
  855. Builder.Insert(NewShl);
  856. APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC));
  857. return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask));
  858. }
  859. if (ShrAmt > ShAmtC) {
  860. // If C1 > C: (X >>? C1) << C --> (X >>? (C1 - C)) & (-1 << C)
  861. Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmtC);
  862. auto *OldShr = cast<BinaryOperator>(Op0);
  863. auto *NewShr =
  864. BinaryOperator::Create(OldShr->getOpcode(), X, ShiftDiff);
  865. NewShr->setIsExact(OldShr->isExact());
  866. Builder.Insert(NewShr);
  867. APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC));
  868. return BinaryOperator::CreateAnd(NewShr, ConstantInt::get(Ty, Mask));
  869. }
  870. }
  871. // Similar to above, but look through an intermediate trunc instruction.
  872. BinaryOperator *Shr;
  873. if (match(Op0, m_OneUse(m_Trunc(m_OneUse(m_BinOp(Shr))))) &&
  874. match(Shr, m_Shr(m_Value(X), m_APInt(C1)))) {
  875. // The larger shift direction survives through the transform.
  876. unsigned ShrAmtC = C1->getZExtValue();
  877. unsigned ShDiff = ShrAmtC > ShAmtC ? ShrAmtC - ShAmtC : ShAmtC - ShrAmtC;
  878. Constant *ShiftDiffC = ConstantInt::get(X->getType(), ShDiff);
  879. auto ShiftOpc = ShrAmtC > ShAmtC ? Shr->getOpcode() : Instruction::Shl;
  880. // If C1 > C:
  881. // (trunc (X >> C1)) << C --> (trunc (X >> (C1 - C))) && (-1 << C)
  882. // If C > C1:
  883. // (trunc (X >> C1)) << C --> (trunc (X << (C - C1))) && (-1 << C)
  884. Value *NewShift = Builder.CreateBinOp(ShiftOpc, X, ShiftDiffC, "sh.diff");
  885. Value *Trunc = Builder.CreateTrunc(NewShift, Ty, "tr.sh.diff");
  886. APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC));
  887. return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, Mask));
  888. }
  889. if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) {
  890. unsigned AmtSum = ShAmtC + C1->getZExtValue();
  891. // Oversized shifts are simplified to zero in InstSimplify.
  892. if (AmtSum < BitWidth)
  893. // (X << C1) << C2 --> X << (C1 + C2)
  894. return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum));
  895. }
  896. // If we have an opposite shift by the same amount, we may be able to
  897. // reorder binops and shifts to eliminate math/logic.
  898. auto isSuitableBinOpcode = [](Instruction::BinaryOps BinOpcode) {
  899. switch (BinOpcode) {
  900. default:
  901. return false;
  902. case Instruction::Add:
  903. case Instruction::And:
  904. case Instruction::Or:
  905. case Instruction::Xor:
  906. case Instruction::Sub:
  907. // NOTE: Sub is not commutable and the tranforms below may not be valid
  908. // when the shift-right is operand 1 (RHS) of the sub.
  909. return true;
  910. }
  911. };
  912. BinaryOperator *Op0BO;
  913. if (match(Op0, m_OneUse(m_BinOp(Op0BO))) &&
  914. isSuitableBinOpcode(Op0BO->getOpcode())) {
  915. // Commute so shift-right is on LHS of the binop.
  916. // (Y bop (X >> C)) << C -> ((X >> C) bop Y) << C
  917. // (Y bop ((X >> C) & CC)) << C -> (((X >> C) & CC) bop Y) << C
  918. Value *Shr = Op0BO->getOperand(0);
  919. Value *Y = Op0BO->getOperand(1);
  920. Value *X;
  921. const APInt *CC;
  922. if (Op0BO->isCommutative() && Y->hasOneUse() &&
  923. (match(Y, m_Shr(m_Value(), m_Specific(Op1))) ||
  924. match(Y, m_And(m_OneUse(m_Shr(m_Value(), m_Specific(Op1))),
  925. m_APInt(CC)))))
  926. std::swap(Shr, Y);
  927. // ((X >> C) bop Y) << C -> (X bop (Y << C)) & (~0 << C)
  928. if (match(Shr, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) {
  929. // Y << C
  930. Value *YS = Builder.CreateShl(Y, Op1, Op0BO->getName());
  931. // (X bop (Y << C))
  932. Value *B =
  933. Builder.CreateBinOp(Op0BO->getOpcode(), X, YS, Shr->getName());
  934. unsigned Op1Val = C->getLimitedValue(BitWidth);
  935. APInt Bits = APInt::getHighBitsSet(BitWidth, BitWidth - Op1Val);
  936. Constant *Mask = ConstantInt::get(Ty, Bits);
  937. return BinaryOperator::CreateAnd(B, Mask);
  938. }
  939. // (((X >> C) & CC) bop Y) << C -> (X & (CC << C)) bop (Y << C)
  940. if (match(Shr,
  941. m_OneUse(m_And(m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))),
  942. m_APInt(CC))))) {
  943. // Y << C
  944. Value *YS = Builder.CreateShl(Y, Op1, Op0BO->getName());
  945. // X & (CC << C)
  946. Value *M = Builder.CreateAnd(X, ConstantInt::get(Ty, CC->shl(*C)),
  947. X->getName() + ".mask");
  948. return BinaryOperator::Create(Op0BO->getOpcode(), M, YS);
  949. }
  950. }
  951. // (C1 - X) << C --> (C1 << C) - (X << C)
  952. if (match(Op0, m_OneUse(m_Sub(m_APInt(C1), m_Value(X))))) {
  953. Constant *NewLHS = ConstantInt::get(Ty, C1->shl(*C));
  954. Value *NewShift = Builder.CreateShl(X, Op1);
  955. return BinaryOperator::CreateSub(NewLHS, NewShift);
  956. }
  957. // If the shifted-out value is known-zero, then this is a NUW shift.
  958. if (!I.hasNoUnsignedWrap() &&
  959. MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmtC), 0,
  960. &I)) {
  961. I.setHasNoUnsignedWrap();
  962. return &I;
  963. }
  964. // If the shifted-out value is all signbits, then this is a NSW shift.
  965. if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmtC) {
  966. I.setHasNoSignedWrap();
  967. return &I;
  968. }
  969. }
  970. // Transform (x >> y) << y to x & (-1 << y)
  971. // Valid for any type of right-shift.
  972. Value *X;
  973. if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) {
  974. Constant *AllOnes = ConstantInt::getAllOnesValue(Ty);
  975. Value *Mask = Builder.CreateShl(AllOnes, Op1);
  976. return BinaryOperator::CreateAnd(Mask, X);
  977. }
  978. Constant *C1;
  979. if (match(Op1, m_Constant(C1))) {
  980. Constant *C2;
  981. Value *X;
  982. // (X * C2) << C1 --> X * (C2 << C1)
  983. if (match(Op0, m_Mul(m_Value(X), m_Constant(C2))))
  984. return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1));
  985. // shl (zext i1 X), C1 --> select (X, 1 << C1, 0)
  986. if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
  987. auto *NewC = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C1);
  988. return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty));
  989. }
  990. }
  991. if (match(Op0, m_One())) {
  992. // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1
  993. if (match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X))))
  994. return BinaryOperator::CreateLShr(
  995. ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X);
  996. // The only way to shift out the 1 is with an over-shift, so that would
  997. // be poison with or without "nuw". Undef is excluded because (undef << X)
  998. // is not undef (it is zero).
  999. Constant *ConstantOne = cast<Constant>(Op0);
  1000. if (!I.hasNoUnsignedWrap() && !ConstantOne->containsUndefElement()) {
  1001. I.setHasNoUnsignedWrap();
  1002. return &I;
  1003. }
  1004. }
  1005. return nullptr;
  1006. }
  1007. Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
  1008. if (Value *V = simplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
  1009. SQ.getWithInstruction(&I)))
  1010. return replaceInstUsesWith(I, V);
  1011. if (Instruction *X = foldVectorBinop(I))
  1012. return X;
  1013. if (Instruction *R = commonShiftTransforms(I))
  1014. return R;
  1015. Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
  1016. Type *Ty = I.getType();
  1017. Value *X;
  1018. const APInt *C;
  1019. unsigned BitWidth = Ty->getScalarSizeInBits();
  1020. // (iN (~X) u>> (N - 1)) --> zext (X > -1)
  1021. if (match(Op0, m_OneUse(m_Not(m_Value(X)))) &&
  1022. match(Op1, m_SpecificIntAllowUndef(BitWidth - 1)))
  1023. return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty);
  1024. if (match(Op1, m_APInt(C))) {
  1025. unsigned ShAmtC = C->getZExtValue();
  1026. auto *II = dyn_cast<IntrinsicInst>(Op0);
  1027. if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmtC &&
  1028. (II->getIntrinsicID() == Intrinsic::ctlz ||
  1029. II->getIntrinsicID() == Intrinsic::cttz ||
  1030. II->getIntrinsicID() == Intrinsic::ctpop)) {
  1031. // ctlz.i32(x)>>5 --> zext(x == 0)
  1032. // cttz.i32(x)>>5 --> zext(x == 0)
  1033. // ctpop.i32(x)>>5 --> zext(x == -1)
  1034. bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop;
  1035. Constant *RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0);
  1036. Value *Cmp = Builder.CreateICmpEQ(II->getArgOperand(0), RHS);
  1037. return new ZExtInst(Cmp, Ty);
  1038. }
  1039. Value *X;
  1040. const APInt *C1;
  1041. if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) {
  1042. if (C1->ult(ShAmtC)) {
  1043. unsigned ShlAmtC = C1->getZExtValue();
  1044. Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShlAmtC);
  1045. if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
  1046. // (X <<nuw C1) >>u C --> X >>u (C - C1)
  1047. auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff);
  1048. NewLShr->setIsExact(I.isExact());
  1049. return NewLShr;
  1050. }
  1051. if (Op0->hasOneUse()) {
  1052. // (X << C1) >>u C --> (X >>u (C - C1)) & (-1 >> C)
  1053. Value *NewLShr = Builder.CreateLShr(X, ShiftDiff, "", I.isExact());
  1054. APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC));
  1055. return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask));
  1056. }
  1057. } else if (C1->ugt(ShAmtC)) {
  1058. unsigned ShlAmtC = C1->getZExtValue();
  1059. Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmtC - ShAmtC);
  1060. if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
  1061. // (X <<nuw C1) >>u C --> X <<nuw (C1 - C)
  1062. auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
  1063. NewShl->setHasNoUnsignedWrap(true);
  1064. return NewShl;
  1065. }
  1066. if (Op0->hasOneUse()) {
  1067. // (X << C1) >>u C --> X << (C1 - C) & (-1 >> C)
  1068. Value *NewShl = Builder.CreateShl(X, ShiftDiff);
  1069. APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC));
  1070. return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask));
  1071. }
  1072. } else {
  1073. assert(*C1 == ShAmtC);
  1074. // (X << C) >>u C --> X & (-1 >>u C)
  1075. APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC));
  1076. return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
  1077. }
  1078. }
  1079. // ((X << C) + Y) >>u C --> (X + (Y >>u C)) & (-1 >>u C)
  1080. // TODO: Consolidate with the more general transform that starts from shl
  1081. // (the shifts are in the opposite order).
  1082. Value *Y;
  1083. if (match(Op0,
  1084. m_OneUse(m_c_Add(m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))),
  1085. m_Value(Y))))) {
  1086. Value *NewLshr = Builder.CreateLShr(Y, Op1);
  1087. Value *NewAdd = Builder.CreateAdd(NewLshr, X);
  1088. unsigned Op1Val = C->getLimitedValue(BitWidth);
  1089. APInt Bits = APInt::getLowBitsSet(BitWidth, BitWidth - Op1Val);
  1090. Constant *Mask = ConstantInt::get(Ty, Bits);
  1091. return BinaryOperator::CreateAnd(NewAdd, Mask);
  1092. }
  1093. if (match(Op0, m_OneUse(m_ZExt(m_Value(X)))) &&
  1094. (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) {
  1095. assert(ShAmtC < X->getType()->getScalarSizeInBits() &&
  1096. "Big shift not simplified to zero?");
  1097. // lshr (zext iM X to iN), C --> zext (lshr X, C) to iN
  1098. Value *NewLShr = Builder.CreateLShr(X, ShAmtC);
  1099. return new ZExtInst(NewLShr, Ty);
  1100. }
  1101. if (match(Op0, m_SExt(m_Value(X)))) {
  1102. unsigned SrcTyBitWidth = X->getType()->getScalarSizeInBits();
  1103. // lshr (sext i1 X to iN), C --> select (X, -1 >> C, 0)
  1104. if (SrcTyBitWidth == 1) {
  1105. auto *NewC = ConstantInt::get(
  1106. Ty, APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC));
  1107. return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty));
  1108. }
  1109. if ((!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType())) &&
  1110. Op0->hasOneUse()) {
  1111. // Are we moving the sign bit to the low bit and widening with high
  1112. // zeros? lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN
  1113. if (ShAmtC == BitWidth - 1) {
  1114. Value *NewLShr = Builder.CreateLShr(X, SrcTyBitWidth - 1);
  1115. return new ZExtInst(NewLShr, Ty);
  1116. }
  1117. // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN
  1118. if (ShAmtC == BitWidth - SrcTyBitWidth) {
  1119. // The new shift amount can't be more than the narrow source type.
  1120. unsigned NewShAmt = std::min(ShAmtC, SrcTyBitWidth - 1);
  1121. Value *AShr = Builder.CreateAShr(X, NewShAmt);
  1122. return new ZExtInst(AShr, Ty);
  1123. }
  1124. }
  1125. }
  1126. if (ShAmtC == BitWidth - 1) {
  1127. // lshr i32 or(X,-X), 31 --> zext (X != 0)
  1128. if (match(Op0, m_OneUse(m_c_Or(m_Neg(m_Value(X)), m_Deferred(X)))))
  1129. return new ZExtInst(Builder.CreateIsNotNull(X), Ty);
  1130. // lshr i32 (X -nsw Y), 31 --> zext (X < Y)
  1131. if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
  1132. return new ZExtInst(Builder.CreateICmpSLT(X, Y), Ty);
  1133. // Check if a number is negative and odd:
  1134. // lshr i32 (srem X, 2), 31 --> and (X >> 31), X
  1135. if (match(Op0, m_OneUse(m_SRem(m_Value(X), m_SpecificInt(2))))) {
  1136. Value *Signbit = Builder.CreateLShr(X, ShAmtC);
  1137. return BinaryOperator::CreateAnd(Signbit, X);
  1138. }
  1139. }
  1140. // (X >>u C1) >>u C --> X >>u (C1 + C)
  1141. if (match(Op0, m_LShr(m_Value(X), m_APInt(C1)))) {
  1142. // Oversized shifts are simplified to zero in InstSimplify.
  1143. unsigned AmtSum = ShAmtC + C1->getZExtValue();
  1144. if (AmtSum < BitWidth)
  1145. return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum));
  1146. }
  1147. Instruction *TruncSrc;
  1148. if (match(Op0, m_OneUse(m_Trunc(m_Instruction(TruncSrc)))) &&
  1149. match(TruncSrc, m_LShr(m_Value(X), m_APInt(C1)))) {
  1150. unsigned SrcWidth = X->getType()->getScalarSizeInBits();
  1151. unsigned AmtSum = ShAmtC + C1->getZExtValue();
  1152. // If the combined shift fits in the source width:
  1153. // (trunc (X >>u C1)) >>u C --> and (trunc (X >>u (C1 + C)), MaskC
  1154. //
  1155. // If the first shift covers the number of bits truncated, then the
  1156. // mask instruction is eliminated (and so the use check is relaxed).
  1157. if (AmtSum < SrcWidth &&
  1158. (TruncSrc->hasOneUse() || C1->uge(SrcWidth - BitWidth))) {
  1159. Value *SumShift = Builder.CreateLShr(X, AmtSum, "sum.shift");
  1160. Value *Trunc = Builder.CreateTrunc(SumShift, Ty, I.getName());
  1161. // If the first shift does not cover the number of bits truncated, then
  1162. // we require a mask to get rid of high bits in the result.
  1163. APInt MaskC = APInt::getAllOnes(BitWidth).lshr(ShAmtC);
  1164. return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, MaskC));
  1165. }
  1166. }
  1167. const APInt *MulC;
  1168. if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC)))) {
  1169. // Look for a "splat" mul pattern - it replicates bits across each half of
  1170. // a value, so a right shift is just a mask of the low bits:
  1171. // lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1
  1172. // TODO: Generalize to allow more than just half-width shifts?
  1173. if (BitWidth > 2 && ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() &&
  1174. MulC->logBase2() == ShAmtC)
  1175. return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2));
  1176. // The one-use check is not strictly necessary, but codegen may not be
  1177. // able to invert the transform and perf may suffer with an extra mul
  1178. // instruction.
  1179. if (Op0->hasOneUse()) {
  1180. APInt NewMulC = MulC->lshr(ShAmtC);
  1181. // if c is divisible by (1 << ShAmtC):
  1182. // lshr (mul nuw x, MulC), ShAmtC -> mul nuw x, (MulC >> ShAmtC)
  1183. if (MulC->eq(NewMulC.shl(ShAmtC))) {
  1184. auto *NewMul =
  1185. BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC));
  1186. BinaryOperator *OrigMul = cast<BinaryOperator>(Op0);
  1187. NewMul->setHasNoSignedWrap(OrigMul->hasNoSignedWrap());
  1188. return NewMul;
  1189. }
  1190. }
  1191. }
  1192. // Try to narrow bswap.
  1193. // In the case where the shift amount equals the bitwidth difference, the
  1194. // shift is eliminated.
  1195. if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::bswap>(
  1196. m_OneUse(m_ZExt(m_Value(X))))))) {
  1197. unsigned SrcWidth = X->getType()->getScalarSizeInBits();
  1198. unsigned WidthDiff = BitWidth - SrcWidth;
  1199. if (SrcWidth % 16 == 0) {
  1200. Value *NarrowSwap = Builder.CreateUnaryIntrinsic(Intrinsic::bswap, X);
  1201. if (ShAmtC >= WidthDiff) {
  1202. // (bswap (zext X)) >> C --> zext (bswap X >> C')
  1203. Value *NewShift = Builder.CreateLShr(NarrowSwap, ShAmtC - WidthDiff);
  1204. return new ZExtInst(NewShift, Ty);
  1205. } else {
  1206. // (bswap (zext X)) >> C --> (zext (bswap X)) << C'
  1207. Value *NewZExt = Builder.CreateZExt(NarrowSwap, Ty);
  1208. Constant *ShiftDiff = ConstantInt::get(Ty, WidthDiff - ShAmtC);
  1209. return BinaryOperator::CreateShl(NewZExt, ShiftDiff);
  1210. }
  1211. }
  1212. }
  1213. // Reduce add-carry of bools to logic:
  1214. // ((zext BoolX) + (zext BoolY)) >> 1 --> zext (BoolX && BoolY)
  1215. Value *BoolX, *BoolY;
  1216. if (ShAmtC == 1 && match(Op0, m_Add(m_Value(X), m_Value(Y))) &&
  1217. match(X, m_ZExt(m_Value(BoolX))) && match(Y, m_ZExt(m_Value(BoolY))) &&
  1218. BoolX->getType()->isIntOrIntVectorTy(1) &&
  1219. BoolY->getType()->isIntOrIntVectorTy(1) &&
  1220. (X->hasOneUse() || Y->hasOneUse() || Op0->hasOneUse())) {
  1221. Value *And = Builder.CreateAnd(BoolX, BoolY);
  1222. return new ZExtInst(And, Ty);
  1223. }
  1224. // If the shifted-out value is known-zero, then this is an exact shift.
  1225. if (!I.isExact() &&
  1226. MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) {
  1227. I.setIsExact();
  1228. return &I;
  1229. }
  1230. }
  1231. // Transform (x << y) >> y to x & (-1 >> y)
  1232. if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) {
  1233. Constant *AllOnes = ConstantInt::getAllOnesValue(Ty);
  1234. Value *Mask = Builder.CreateLShr(AllOnes, Op1);
  1235. return BinaryOperator::CreateAnd(Mask, X);
  1236. }
  1237. if (Instruction *Overflow = foldLShrOverflowBit(I))
  1238. return Overflow;
  1239. return nullptr;
  1240. }
  1241. Instruction *
  1242. InstCombinerImpl::foldVariableSignZeroExtensionOfVariableHighBitExtract(
  1243. BinaryOperator &OldAShr) {
  1244. assert(OldAShr.getOpcode() == Instruction::AShr &&
  1245. "Must be called with arithmetic right-shift instruction only.");
  1246. // Check that constant C is a splat of the element-wise bitwidth of V.
  1247. auto BitWidthSplat = [](Constant *C, Value *V) {
  1248. return match(
  1249. C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
  1250. APInt(C->getType()->getScalarSizeInBits(),
  1251. V->getType()->getScalarSizeInBits())));
  1252. };
  1253. // It should look like variable-length sign-extension on the outside:
  1254. // (Val << (bitwidth(Val)-Nbits)) a>> (bitwidth(Val)-Nbits)
  1255. Value *NBits;
  1256. Instruction *MaybeTrunc;
  1257. Constant *C1, *C2;
  1258. if (!match(&OldAShr,
  1259. m_AShr(m_Shl(m_Instruction(MaybeTrunc),
  1260. m_ZExtOrSelf(m_Sub(m_Constant(C1),
  1261. m_ZExtOrSelf(m_Value(NBits))))),
  1262. m_ZExtOrSelf(m_Sub(m_Constant(C2),
  1263. m_ZExtOrSelf(m_Deferred(NBits)))))) ||
  1264. !BitWidthSplat(C1, &OldAShr) || !BitWidthSplat(C2, &OldAShr))
  1265. return nullptr;
  1266. // There may or may not be a truncation after outer two shifts.
  1267. Instruction *HighBitExtract;
  1268. match(MaybeTrunc, m_TruncOrSelf(m_Instruction(HighBitExtract)));
  1269. bool HadTrunc = MaybeTrunc != HighBitExtract;
  1270. // And finally, the innermost part of the pattern must be a right-shift.
  1271. Value *X, *NumLowBitsToSkip;
  1272. if (!match(HighBitExtract, m_Shr(m_Value(X), m_Value(NumLowBitsToSkip))))
  1273. return nullptr;
  1274. // Said right-shift must extract high NBits bits - C0 must be it's bitwidth.
  1275. Constant *C0;
  1276. if (!match(NumLowBitsToSkip,
  1277. m_ZExtOrSelf(
  1278. m_Sub(m_Constant(C0), m_ZExtOrSelf(m_Specific(NBits))))) ||
  1279. !BitWidthSplat(C0, HighBitExtract))
  1280. return nullptr;
  1281. // Since the NBits is identical for all shifts, if the outermost and
  1282. // innermost shifts are identical, then outermost shifts are redundant.
  1283. // If we had truncation, do keep it though.
  1284. if (HighBitExtract->getOpcode() == OldAShr.getOpcode())
  1285. return replaceInstUsesWith(OldAShr, MaybeTrunc);
  1286. // Else, if there was a truncation, then we need to ensure that one
  1287. // instruction will go away.
  1288. if (HadTrunc && !match(&OldAShr, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
  1289. return nullptr;
  1290. // Finally, bypass two innermost shifts, and perform the outermost shift on
  1291. // the operands of the innermost shift.
  1292. Instruction *NewAShr =
  1293. BinaryOperator::Create(OldAShr.getOpcode(), X, NumLowBitsToSkip);
  1294. NewAShr->copyIRFlags(HighBitExtract); // We can preserve 'exact'-ness.
  1295. if (!HadTrunc)
  1296. return NewAShr;
  1297. Builder.Insert(NewAShr);
  1298. return TruncInst::CreateTruncOrBitCast(NewAShr, OldAShr.getType());
  1299. }
  1300. Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
  1301. if (Value *V = simplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
  1302. SQ.getWithInstruction(&I)))
  1303. return replaceInstUsesWith(I, V);
  1304. if (Instruction *X = foldVectorBinop(I))
  1305. return X;
  1306. if (Instruction *R = commonShiftTransforms(I))
  1307. return R;
  1308. Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
  1309. Type *Ty = I.getType();
  1310. unsigned BitWidth = Ty->getScalarSizeInBits();
  1311. const APInt *ShAmtAPInt;
  1312. if (match(Op1, m_APInt(ShAmtAPInt)) && ShAmtAPInt->ult(BitWidth)) {
  1313. unsigned ShAmt = ShAmtAPInt->getZExtValue();
  1314. // If the shift amount equals the difference in width of the destination
  1315. // and source scalar types:
  1316. // ashr (shl (zext X), C), C --> sext X
  1317. Value *X;
  1318. if (match(Op0, m_Shl(m_ZExt(m_Value(X)), m_Specific(Op1))) &&
  1319. ShAmt == BitWidth - X->getType()->getScalarSizeInBits())
  1320. return new SExtInst(X, Ty);
  1321. // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However,
  1322. // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
  1323. const APInt *ShOp1;
  1324. if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1))) &&
  1325. ShOp1->ult(BitWidth)) {
  1326. unsigned ShlAmt = ShOp1->getZExtValue();
  1327. if (ShlAmt < ShAmt) {
  1328. // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1)
  1329. Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
  1330. auto *NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff);
  1331. NewAShr->setIsExact(I.isExact());
  1332. return NewAShr;
  1333. }
  1334. if (ShlAmt > ShAmt) {
  1335. // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2)
  1336. Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
  1337. auto *NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff);
  1338. NewShl->setHasNoSignedWrap(true);
  1339. return NewShl;
  1340. }
  1341. }
  1342. if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1))) &&
  1343. ShOp1->ult(BitWidth)) {
  1344. unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
  1345. // Oversized arithmetic shifts replicate the sign bit.
  1346. AmtSum = std::min(AmtSum, BitWidth - 1);
  1347. // (X >>s C1) >>s C2 --> X >>s (C1 + C2)
  1348. return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum));
  1349. }
  1350. if (match(Op0, m_OneUse(m_SExt(m_Value(X)))) &&
  1351. (Ty->isVectorTy() || shouldChangeType(Ty, X->getType()))) {
  1352. // ashr (sext X), C --> sext (ashr X, C')
  1353. Type *SrcTy = X->getType();
  1354. ShAmt = std::min(ShAmt, SrcTy->getScalarSizeInBits() - 1);
  1355. Value *NewSh = Builder.CreateAShr(X, ConstantInt::get(SrcTy, ShAmt));
  1356. return new SExtInst(NewSh, Ty);
  1357. }
  1358. if (ShAmt == BitWidth - 1) {
  1359. // ashr i32 or(X,-X), 31 --> sext (X != 0)
  1360. if (match(Op0, m_OneUse(m_c_Or(m_Neg(m_Value(X)), m_Deferred(X)))))
  1361. return new SExtInst(Builder.CreateIsNotNull(X), Ty);
  1362. // ashr i32 (X -nsw Y), 31 --> sext (X < Y)
  1363. Value *Y;
  1364. if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
  1365. return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty);
  1366. }
  1367. // If the shifted-out value is known-zero, then this is an exact shift.
  1368. if (!I.isExact() &&
  1369. MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
  1370. I.setIsExact();
  1371. return &I;
  1372. }
  1373. }
  1374. // Prefer `-(x & 1)` over `(x << (bitwidth(x)-1)) a>> (bitwidth(x)-1)`
  1375. // as the pattern to splat the lowest bit.
  1376. // FIXME: iff X is already masked, we don't need the one-use check.
  1377. Value *X;
  1378. if (match(Op1, m_SpecificIntAllowUndef(BitWidth - 1)) &&
  1379. match(Op0, m_OneUse(m_Shl(m_Value(X),
  1380. m_SpecificIntAllowUndef(BitWidth - 1))))) {
  1381. Constant *Mask = ConstantInt::get(Ty, 1);
  1382. // Retain the knowledge about the ignored lanes.
  1383. Mask = Constant::mergeUndefsWith(
  1384. Constant::mergeUndefsWith(Mask, cast<Constant>(Op1)),
  1385. cast<Constant>(cast<Instruction>(Op0)->getOperand(1)));
  1386. X = Builder.CreateAnd(X, Mask);
  1387. return BinaryOperator::CreateNeg(X);
  1388. }
  1389. if (Instruction *R = foldVariableSignZeroExtensionOfVariableHighBitExtract(I))
  1390. return R;
  1391. // See if we can turn a signed shr into an unsigned shr.
  1392. if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) {
  1393. Instruction *Lshr = BinaryOperator::CreateLShr(Op0, Op1);
  1394. Lshr->setIsExact(I.isExact());
  1395. return Lshr;
  1396. }
  1397. // ashr (xor %x, -1), %y --> xor (ashr %x, %y), -1
  1398. if (match(Op0, m_OneUse(m_Not(m_Value(X))))) {
  1399. // Note that we must drop 'exact'-ness of the shift!
  1400. // Note that we can't keep undef's in -1 vector constant!
  1401. auto *NewAShr = Builder.CreateAShr(X, Op1, Op0->getName() + ".not");
  1402. return BinaryOperator::CreateNot(NewAShr);
  1403. }
  1404. return nullptr;
  1405. }