TruncInstCombine.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. //===- TruncInstCombine.cpp -----------------------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // TruncInstCombine - looks for expression dags post-dominated by TruncInst and
  10. // for each eligible dag, it will create a reduced bit-width expression, replace
  11. // the old expression with this new one and remove the old expression.
  12. // Eligible expression dag is such that:
  13. // 1. Contains only supported instructions.
  14. // 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value.
  15. // 3. Can be evaluated into type with reduced legal bit-width.
  16. // 4. All instructions in the dag must not have users outside the dag.
  17. // The only exception is for {ZExt, SExt}Inst with operand type equal to
  18. // the new reduced type evaluated in (3).
  19. //
  20. // The motivation for this optimization is that evaluating and expression using
  21. // smaller bit-width is preferable, especially for vectorization where we can
  22. // fit more values in one vectorized instruction. In addition, this optimization
  23. // may decrease the number of cast instructions, but will not increase it.
  24. //
  25. //===----------------------------------------------------------------------===//
  26. #include "AggressiveInstCombineInternal.h"
  27. #include "llvm/ADT/STLExtras.h"
  28. #include "llvm/ADT/Statistic.h"
  29. #include "llvm/Analysis/ConstantFolding.h"
  30. #include "llvm/Analysis/TargetLibraryInfo.h"
  31. #include "llvm/IR/DataLayout.h"
  32. #include "llvm/IR/Dominators.h"
  33. #include "llvm/IR/IRBuilder.h"
  34. #include "llvm/IR/Instruction.h"
  35. #include "llvm/Support/KnownBits.h"
  36. using namespace llvm;
  37. #define DEBUG_TYPE "aggressive-instcombine"
  38. STATISTIC(
  39. NumDAGsReduced,
  40. "Number of truncations eliminated by reducing bit width of expression DAG");
  41. STATISTIC(NumInstrsReduced,
  42. "Number of instructions whose bit width was reduced");
  43. /// Given an instruction and a container, it fills all the relevant operands of
  44. /// that instruction, with respect to the Trunc expression dag optimizaton.
  45. static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
  46. unsigned Opc = I->getOpcode();
  47. switch (Opc) {
  48. case Instruction::Trunc:
  49. case Instruction::ZExt:
  50. case Instruction::SExt:
  51. // These CastInst are considered leaves of the evaluated expression, thus,
  52. // their operands are not relevent.
  53. break;
  54. case Instruction::Add:
  55. case Instruction::Sub:
  56. case Instruction::Mul:
  57. case Instruction::And:
  58. case Instruction::Or:
  59. case Instruction::Xor:
  60. case Instruction::Shl:
  61. case Instruction::LShr:
  62. case Instruction::AShr:
  63. case Instruction::UDiv:
  64. case Instruction::URem:
  65. case Instruction::InsertElement:
  66. Ops.push_back(I->getOperand(0));
  67. Ops.push_back(I->getOperand(1));
  68. break;
  69. case Instruction::ExtractElement:
  70. Ops.push_back(I->getOperand(0));
  71. break;
  72. case Instruction::Select:
  73. Ops.push_back(I->getOperand(1));
  74. Ops.push_back(I->getOperand(2));
  75. break;
  76. default:
  77. llvm_unreachable("Unreachable!");
  78. }
  79. }
  80. bool TruncInstCombine::buildTruncExpressionDag() {
  81. SmallVector<Value *, 8> Worklist;
  82. SmallVector<Instruction *, 8> Stack;
  83. // Clear old expression dag.
  84. InstInfoMap.clear();
  85. Worklist.push_back(CurrentTruncInst->getOperand(0));
  86. while (!Worklist.empty()) {
  87. Value *Curr = Worklist.back();
  88. if (isa<Constant>(Curr)) {
  89. Worklist.pop_back();
  90. continue;
  91. }
  92. auto *I = dyn_cast<Instruction>(Curr);
  93. if (!I)
  94. return false;
  95. if (!Stack.empty() && Stack.back() == I) {
  96. // Already handled all instruction operands, can remove it from both the
  97. // Worklist and the Stack, and add it to the instruction info map.
  98. Worklist.pop_back();
  99. Stack.pop_back();
  100. // Insert I to the Info map.
  101. InstInfoMap.insert(std::make_pair(I, Info()));
  102. continue;
  103. }
  104. if (InstInfoMap.count(I)) {
  105. Worklist.pop_back();
  106. continue;
  107. }
  108. // Add the instruction to the stack before start handling its operands.
  109. Stack.push_back(I);
  110. unsigned Opc = I->getOpcode();
  111. switch (Opc) {
  112. case Instruction::Trunc:
  113. case Instruction::ZExt:
  114. case Instruction::SExt:
  115. // trunc(trunc(x)) -> trunc(x)
  116. // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest
  117. // trunc(ext(x)) -> trunc(x) if the source type is larger than the new
  118. // dest
  119. break;
  120. case Instruction::Add:
  121. case Instruction::Sub:
  122. case Instruction::Mul:
  123. case Instruction::And:
  124. case Instruction::Or:
  125. case Instruction::Xor:
  126. case Instruction::Shl:
  127. case Instruction::LShr:
  128. case Instruction::AShr:
  129. case Instruction::UDiv:
  130. case Instruction::URem:
  131. case Instruction::InsertElement:
  132. case Instruction::ExtractElement:
  133. case Instruction::Select: {
  134. SmallVector<Value *, 2> Operands;
  135. getRelevantOperands(I, Operands);
  136. append_range(Worklist, Operands);
  137. break;
  138. }
  139. default:
  140. // TODO: Can handle more cases here:
  141. // 1. shufflevector
  142. // 2. sdiv, srem
  143. // 3. phi node(and loop handling)
  144. // ...
  145. return false;
  146. }
  147. }
  148. return true;
  149. }
  150. unsigned TruncInstCombine::getMinBitWidth() {
  151. SmallVector<Value *, 8> Worklist;
  152. SmallVector<Instruction *, 8> Stack;
  153. Value *Src = CurrentTruncInst->getOperand(0);
  154. Type *DstTy = CurrentTruncInst->getType();
  155. unsigned TruncBitWidth = DstTy->getScalarSizeInBits();
  156. unsigned OrigBitWidth =
  157. CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();
  158. if (isa<Constant>(Src))
  159. return TruncBitWidth;
  160. Worklist.push_back(Src);
  161. InstInfoMap[cast<Instruction>(Src)].ValidBitWidth = TruncBitWidth;
  162. while (!Worklist.empty()) {
  163. Value *Curr = Worklist.back();
  164. if (isa<Constant>(Curr)) {
  165. Worklist.pop_back();
  166. continue;
  167. }
  168. // Otherwise, it must be an instruction.
  169. auto *I = cast<Instruction>(Curr);
  170. auto &Info = InstInfoMap[I];
  171. SmallVector<Value *, 2> Operands;
  172. getRelevantOperands(I, Operands);
  173. if (!Stack.empty() && Stack.back() == I) {
  174. // Already handled all instruction operands, can remove it from both, the
  175. // Worklist and the Stack, and update MinBitWidth.
  176. Worklist.pop_back();
  177. Stack.pop_back();
  178. for (auto *Operand : Operands)
  179. if (auto *IOp = dyn_cast<Instruction>(Operand))
  180. Info.MinBitWidth =
  181. std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth);
  182. continue;
  183. }
  184. // Add the instruction to the stack before start handling its operands.
  185. Stack.push_back(I);
  186. unsigned ValidBitWidth = Info.ValidBitWidth;
  187. // Update minimum bit-width before handling its operands. This is required
  188. // when the instruction is part of a loop.
  189. Info.MinBitWidth = std::max(Info.MinBitWidth, Info.ValidBitWidth);
  190. for (auto *Operand : Operands)
  191. if (auto *IOp = dyn_cast<Instruction>(Operand)) {
  192. // If we already calculated the minimum bit-width for this valid
  193. // bit-width, or for a smaller valid bit-width, then just keep the
  194. // answer we already calculated.
  195. unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth;
  196. if (IOpBitwidth >= ValidBitWidth)
  197. continue;
  198. InstInfoMap[IOp].ValidBitWidth = ValidBitWidth;
  199. Worklist.push_back(IOp);
  200. }
  201. }
  202. unsigned MinBitWidth = InstInfoMap.lookup(cast<Instruction>(Src)).MinBitWidth;
  203. assert(MinBitWidth >= TruncBitWidth);
  204. if (MinBitWidth > TruncBitWidth) {
  205. // In this case reducing expression with vector type might generate a new
  206. // vector type, which is not preferable as it might result in generating
  207. // sub-optimal code.
  208. if (DstTy->isVectorTy())
  209. return OrigBitWidth;
  210. // Use the smallest integer type in the range [MinBitWidth, OrigBitWidth).
  211. Type *Ty = DL.getSmallestLegalIntType(DstTy->getContext(), MinBitWidth);
  212. // Update minimum bit-width with the new destination type bit-width if
  213. // succeeded to find such, otherwise, with original bit-width.
  214. MinBitWidth = Ty ? Ty->getScalarSizeInBits() : OrigBitWidth;
  215. } else { // MinBitWidth == TruncBitWidth
  216. // In this case the expression can be evaluated with the trunc instruction
  217. // destination type, and trunc instruction can be omitted. However, we
  218. // should not perform the evaluation if the original type is a legal scalar
  219. // type and the target type is illegal.
  220. bool FromLegal = MinBitWidth == 1 || DL.isLegalInteger(OrigBitWidth);
  221. bool ToLegal = MinBitWidth == 1 || DL.isLegalInteger(MinBitWidth);
  222. if (!DstTy->isVectorTy() && FromLegal && !ToLegal)
  223. return OrigBitWidth;
  224. }
  225. return MinBitWidth;
  226. }
  227. Type *TruncInstCombine::getBestTruncatedType() {
  228. if (!buildTruncExpressionDag())
  229. return nullptr;
  230. // We don't want to duplicate instructions, which isn't profitable. Thus, we
  231. // can't shrink something that has multiple users, unless all users are
  232. // post-dominated by the trunc instruction, i.e., were visited during the
  233. // expression evaluation.
  234. unsigned DesiredBitWidth = 0;
  235. for (auto Itr : InstInfoMap) {
  236. Instruction *I = Itr.first;
  237. if (I->hasOneUse())
  238. continue;
  239. bool IsExtInst = (isa<ZExtInst>(I) || isa<SExtInst>(I));
  240. for (auto *U : I->users())
  241. if (auto *UI = dyn_cast<Instruction>(U))
  242. if (UI != CurrentTruncInst && !InstInfoMap.count(UI)) {
  243. if (!IsExtInst)
  244. return nullptr;
  245. // If this is an extension from the dest type, we can eliminate it,
  246. // even if it has multiple users. Thus, update the DesiredBitWidth and
  247. // validate all extension instructions agrees on same DesiredBitWidth.
  248. unsigned ExtInstBitWidth =
  249. I->getOperand(0)->getType()->getScalarSizeInBits();
  250. if (DesiredBitWidth && DesiredBitWidth != ExtInstBitWidth)
  251. return nullptr;
  252. DesiredBitWidth = ExtInstBitWidth;
  253. }
  254. }
  255. unsigned OrigBitWidth =
  256. CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();
  257. // Initialize MinBitWidth for shift instructions with the minimum number
  258. // that is greater than shift amount (i.e. shift amount + 1).
  259. // For `lshr` adjust MinBitWidth so that all potentially truncated
  260. // bits of the value-to-be-shifted are zeros.
  261. // For `ashr` adjust MinBitWidth so that all potentially truncated
  262. // bits of the value-to-be-shifted are sign bits (all zeros or ones)
  263. // and even one (first) untruncated bit is sign bit.
  264. // Exit early if MinBitWidth is not less than original bitwidth.
  265. for (auto &Itr : InstInfoMap) {
  266. Instruction *I = Itr.first;
  267. if (I->isShift()) {
  268. KnownBits KnownRHS = computeKnownBits(I->getOperand(1));
  269. unsigned MinBitWidth = KnownRHS.getMaxValue()
  270. .uadd_sat(APInt(OrigBitWidth, 1))
  271. .getLimitedValue(OrigBitWidth);
  272. if (MinBitWidth == OrigBitWidth)
  273. return nullptr;
  274. if (I->getOpcode() == Instruction::LShr) {
  275. KnownBits KnownLHS = computeKnownBits(I->getOperand(0));
  276. MinBitWidth =
  277. std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits());
  278. }
  279. if (I->getOpcode() == Instruction::AShr) {
  280. unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0));
  281. MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1);
  282. }
  283. if (MinBitWidth >= OrigBitWidth)
  284. return nullptr;
  285. Itr.second.MinBitWidth = MinBitWidth;
  286. }
  287. if (I->getOpcode() == Instruction::UDiv ||
  288. I->getOpcode() == Instruction::URem) {
  289. unsigned MinBitWidth = 0;
  290. for (const auto &Op : I->operands()) {
  291. KnownBits Known = computeKnownBits(Op);
  292. MinBitWidth =
  293. std::max(Known.getMaxValue().getActiveBits(), MinBitWidth);
  294. if (MinBitWidth >= OrigBitWidth)
  295. return nullptr;
  296. }
  297. Itr.second.MinBitWidth = MinBitWidth;
  298. }
  299. }
  300. // Calculate minimum allowed bit-width allowed for shrinking the currently
  301. // visited truncate's operand.
  302. unsigned MinBitWidth = getMinBitWidth();
  303. // Check that we can shrink to smaller bit-width than original one and that
  304. // it is similar to the DesiredBitWidth is such exists.
  305. if (MinBitWidth >= OrigBitWidth ||
  306. (DesiredBitWidth && DesiredBitWidth != MinBitWidth))
  307. return nullptr;
  308. return IntegerType::get(CurrentTruncInst->getContext(), MinBitWidth);
  309. }
  310. /// Given a reduced scalar type \p Ty and a \p V value, return a reduced type
  311. /// for \p V, according to its type, if it vector type, return the vector
  312. /// version of \p Ty, otherwise return \p Ty.
  313. static Type *getReducedType(Value *V, Type *Ty) {
  314. assert(Ty && !Ty->isVectorTy() && "Expect Scalar Type");
  315. if (auto *VTy = dyn_cast<VectorType>(V->getType()))
  316. return VectorType::get(Ty, VTy->getElementCount());
  317. return Ty;
  318. }
  319. Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) {
  320. Type *Ty = getReducedType(V, SclTy);
  321. if (auto *C = dyn_cast<Constant>(V)) {
  322. C = ConstantExpr::getIntegerCast(C, Ty, false);
  323. // If we got a constantexpr back, try to simplify it with DL info.
  324. return ConstantFoldConstant(C, DL, &TLI);
  325. }
  326. auto *I = cast<Instruction>(V);
  327. Info Entry = InstInfoMap.lookup(I);
  328. assert(Entry.NewValue);
  329. return Entry.NewValue;
  330. }
  331. void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
  332. NumInstrsReduced += InstInfoMap.size();
  333. for (auto &Itr : InstInfoMap) { // Forward
  334. Instruction *I = Itr.first;
  335. TruncInstCombine::Info &NodeInfo = Itr.second;
  336. assert(!NodeInfo.NewValue && "Instruction has been evaluated");
  337. IRBuilder<> Builder(I);
  338. Value *Res = nullptr;
  339. unsigned Opc = I->getOpcode();
  340. switch (Opc) {
  341. case Instruction::Trunc:
  342. case Instruction::ZExt:
  343. case Instruction::SExt: {
  344. Type *Ty = getReducedType(I, SclTy);
  345. // If the source type of the cast is the type we're trying for then we can
  346. // just return the source. There's no need to insert it because it is not
  347. // new.
  348. if (I->getOperand(0)->getType() == Ty) {
  349. assert(!isa<TruncInst>(I) && "Cannot reach here with TruncInst");
  350. NodeInfo.NewValue = I->getOperand(0);
  351. continue;
  352. }
  353. // Otherwise, must be the same type of cast, so just reinsert a new one.
  354. // This also handles the case of zext(trunc(x)) -> zext(x).
  355. Res = Builder.CreateIntCast(I->getOperand(0), Ty,
  356. Opc == Instruction::SExt);
  357. // Update Worklist entries with new value if needed.
  358. // There are three possible changes to the Worklist:
  359. // 1. Update Old-TruncInst -> New-TruncInst.
  360. // 2. Remove Old-TruncInst (if New node is not TruncInst).
  361. // 3. Add New-TruncInst (if Old node was not TruncInst).
  362. auto *Entry = find(Worklist, I);
  363. if (Entry != Worklist.end()) {
  364. if (auto *NewCI = dyn_cast<TruncInst>(Res))
  365. *Entry = NewCI;
  366. else
  367. Worklist.erase(Entry);
  368. } else if (auto *NewCI = dyn_cast<TruncInst>(Res))
  369. Worklist.push_back(NewCI);
  370. break;
  371. }
  372. case Instruction::Add:
  373. case Instruction::Sub:
  374. case Instruction::Mul:
  375. case Instruction::And:
  376. case Instruction::Or:
  377. case Instruction::Xor:
  378. case Instruction::Shl:
  379. case Instruction::LShr:
  380. case Instruction::AShr:
  381. case Instruction::UDiv:
  382. case Instruction::URem: {
  383. Value *LHS = getReducedOperand(I->getOperand(0), SclTy);
  384. Value *RHS = getReducedOperand(I->getOperand(1), SclTy);
  385. Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS);
  386. // Preserve `exact` flag since truncation doesn't change exactness
  387. if (auto *PEO = dyn_cast<PossiblyExactOperator>(I))
  388. if (auto *ResI = dyn_cast<Instruction>(Res))
  389. ResI->setIsExact(PEO->isExact());
  390. break;
  391. }
  392. case Instruction::ExtractElement: {
  393. Value *Vec = getReducedOperand(I->getOperand(0), SclTy);
  394. Value *Idx = I->getOperand(1);
  395. Res = Builder.CreateExtractElement(Vec, Idx);
  396. break;
  397. }
  398. case Instruction::InsertElement: {
  399. Value *Vec = getReducedOperand(I->getOperand(0), SclTy);
  400. Value *NewElt = getReducedOperand(I->getOperand(1), SclTy);
  401. Value *Idx = I->getOperand(2);
  402. Res = Builder.CreateInsertElement(Vec, NewElt, Idx);
  403. break;
  404. }
  405. case Instruction::Select: {
  406. Value *Op0 = I->getOperand(0);
  407. Value *LHS = getReducedOperand(I->getOperand(1), SclTy);
  408. Value *RHS = getReducedOperand(I->getOperand(2), SclTy);
  409. Res = Builder.CreateSelect(Op0, LHS, RHS);
  410. break;
  411. }
  412. default:
  413. llvm_unreachable("Unhandled instruction");
  414. }
  415. NodeInfo.NewValue = Res;
  416. if (auto *ResI = dyn_cast<Instruction>(Res))
  417. ResI->takeName(I);
  418. }
  419. Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy);
  420. Type *DstTy = CurrentTruncInst->getType();
  421. if (Res->getType() != DstTy) {
  422. IRBuilder<> Builder(CurrentTruncInst);
  423. Res = Builder.CreateIntCast(Res, DstTy, false);
  424. if (auto *ResI = dyn_cast<Instruction>(Res))
  425. ResI->takeName(CurrentTruncInst);
  426. }
  427. CurrentTruncInst->replaceAllUsesWith(Res);
  428. // Erase old expression dag, which was replaced by the reduced expression dag.
  429. // We iterate backward, which means we visit the instruction before we visit
  430. // any of its operands, this way, when we get to the operand, we already
  431. // removed the instructions (from the expression dag) that uses it.
  432. CurrentTruncInst->eraseFromParent();
  433. for (auto &I : llvm::reverse(InstInfoMap)) {
  434. // We still need to check that the instruction has no users before we erase
  435. // it, because {SExt, ZExt}Inst Instruction might have other users that was
  436. // not reduced, in such case, we need to keep that instruction.
  437. if (I.first->use_empty())
  438. I.first->eraseFromParent();
  439. }
  440. }
  441. bool TruncInstCombine::run(Function &F) {
  442. bool MadeIRChange = false;
  443. // Collect all TruncInst in the function into the Worklist for evaluating.
  444. for (auto &BB : F) {
  445. // Ignore unreachable basic block.
  446. if (!DT.isReachableFromEntry(&BB))
  447. continue;
  448. for (auto &I : BB)
  449. if (auto *CI = dyn_cast<TruncInst>(&I))
  450. Worklist.push_back(CI);
  451. }
  452. // Process all TruncInst in the Worklist, for each instruction:
  453. // 1. Check if it dominates an eligible expression dag to be reduced.
  454. // 2. Create a reduced expression dag and replace the old one with it.
  455. while (!Worklist.empty()) {
  456. CurrentTruncInst = Worklist.pop_back_val();
  457. if (Type *NewDstSclTy = getBestTruncatedType()) {
  458. LLVM_DEBUG(
  459. dbgs() << "ICE: TruncInstCombine reducing type of expression dag "
  460. "dominated by: "
  461. << CurrentTruncInst << '\n');
  462. ReduceExpressionDag(NewDstSclTy);
  463. ++NumDAGsReduced;
  464. MadeIRChange = true;
  465. }
  466. }
  467. return MadeIRChange;
  468. }