NaryReassociate.cpp 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660
  1. //===- NaryReassociate.cpp - Reassociate n-ary expressions ----------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass reassociates n-ary add expressions and eliminates the redundancy
  10. // exposed by the reassociation.
  11. //
  12. // A motivating example:
  13. //
  14. // void foo(int a, int b) {
  15. // bar(a + b);
  16. // bar((a + 2) + b);
  17. // }
  18. //
  19. // An ideal compiler should reassociate (a + 2) + b to (a + b) + 2 and simplify
  20. // the above code to
  21. //
  22. // int t = a + b;
  23. // bar(t);
  24. // bar(t + 2);
  25. //
  26. // However, the Reassociate pass is unable to do that because it processes each
  27. // instruction individually and believes (a + 2) + b is the best form according
  28. // to its rank system.
  29. //
  30. // To address this limitation, NaryReassociate reassociates an expression in a
  31. // form that reuses existing instructions. As a result, NaryReassociate can
  32. // reassociate (a + 2) + b in the example to (a + b) + 2 because it detects that
  33. // (a + b) is computed before.
  34. //
  35. // NaryReassociate works as follows. For every instruction in the form of (a +
  36. // b) + c, it checks whether a + c or b + c is already computed by a dominating
  37. // instruction. If so, it then reassociates (a + b) + c into (a + c) + b or (b +
  38. // c) + a and removes the redundancy accordingly. To efficiently look up whether
  39. // an expression is computed before, we store each instruction seen and its SCEV
  40. // into an SCEV-to-instruction map.
  41. //
  42. // Although the algorithm pattern-matches only ternary additions, it
  43. // automatically handles many >3-ary expressions by walking through the function
  44. // in the depth-first order. For example, given
  45. //
  46. // (a + c) + d
  47. // ((a + b) + c) + d
  48. //
  49. // NaryReassociate first rewrites (a + b) + c to (a + c) + b, and then rewrites
  50. // ((a + c) + b) + d into ((a + c) + d) + b.
  51. //
  52. // Finally, the above dominator-based algorithm may need to be run multiple
  53. // iterations before emitting optimal code. One source of this need is that we
  54. // only split an operand when it is used only once. The above algorithm can
  55. // eliminate an instruction and decrease the usage count of its operands. As a
  56. // result, an instruction that previously had multiple uses may become a
  57. // single-use instruction and thus eligible for split consideration. For
  58. // example,
  59. //
  60. // ac = a + c
  61. // ab = a + b
  62. // abc = ab + c
  63. // ab2 = ab + b
  64. // ab2c = ab2 + c
  65. //
  66. // In the first iteration, we cannot reassociate abc to ac+b because ab is used
  67. // twice. However, we can reassociate ab2c to abc+b in the first iteration. As a
  68. // result, ab2 becomes dead and ab will be used only once in the second
  69. // iteration.
  70. //
  71. // Limitations and TODO items:
  72. //
  73. // 1) We only considers n-ary adds and muls for now. This should be extended
  74. // and generalized.
  75. //
  76. //===----------------------------------------------------------------------===//
  77. #include "llvm/Transforms/Scalar/NaryReassociate.h"
  78. #include "llvm/ADT/DepthFirstIterator.h"
  79. #include "llvm/ADT/SmallVector.h"
  80. #include "llvm/Analysis/AssumptionCache.h"
  81. #include "llvm/Analysis/ScalarEvolution.h"
  82. #include "llvm/Analysis/ScalarEvolutionExpressions.h"
  83. #include "llvm/Analysis/TargetLibraryInfo.h"
  84. #include "llvm/Analysis/TargetTransformInfo.h"
  85. #include "llvm/Analysis/ValueTracking.h"
  86. #include "llvm/IR/BasicBlock.h"
  87. #include "llvm/IR/Constants.h"
  88. #include "llvm/IR/DataLayout.h"
  89. #include "llvm/IR/DerivedTypes.h"
  90. #include "llvm/IR/Dominators.h"
  91. #include "llvm/IR/Function.h"
  92. #include "llvm/IR/GetElementPtrTypeIterator.h"
  93. #include "llvm/IR/IRBuilder.h"
  94. #include "llvm/IR/InstrTypes.h"
  95. #include "llvm/IR/Instruction.h"
  96. #include "llvm/IR/Instructions.h"
  97. #include "llvm/IR/Module.h"
  98. #include "llvm/IR/Operator.h"
  99. #include "llvm/IR/PatternMatch.h"
  100. #include "llvm/IR/Type.h"
  101. #include "llvm/IR/Value.h"
  102. #include "llvm/IR/ValueHandle.h"
  103. #include "llvm/InitializePasses.h"
  104. #include "llvm/Pass.h"
  105. #include "llvm/Support/Casting.h"
  106. #include "llvm/Support/ErrorHandling.h"
  107. #include "llvm/Transforms/Scalar.h"
  108. #include "llvm/Transforms/Utils/Local.h"
  109. #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
  110. #include <cassert>
  111. #include <cstdint>
  112. using namespace llvm;
  113. using namespace PatternMatch;
  114. #define DEBUG_TYPE "nary-reassociate"
  115. namespace {
  116. class NaryReassociateLegacyPass : public FunctionPass {
  117. public:
  118. static char ID;
  119. NaryReassociateLegacyPass() : FunctionPass(ID) {
  120. initializeNaryReassociateLegacyPassPass(*PassRegistry::getPassRegistry());
  121. }
  122. bool doInitialization(Module &M) override {
  123. return false;
  124. }
  125. bool runOnFunction(Function &F) override;
  126. void getAnalysisUsage(AnalysisUsage &AU) const override {
  127. AU.addPreserved<DominatorTreeWrapperPass>();
  128. AU.addPreserved<ScalarEvolutionWrapperPass>();
  129. AU.addPreserved<TargetLibraryInfoWrapperPass>();
  130. AU.addRequired<AssumptionCacheTracker>();
  131. AU.addRequired<DominatorTreeWrapperPass>();
  132. AU.addRequired<ScalarEvolutionWrapperPass>();
  133. AU.addRequired<TargetLibraryInfoWrapperPass>();
  134. AU.addRequired<TargetTransformInfoWrapperPass>();
  135. AU.setPreservesCFG();
  136. }
  137. private:
  138. NaryReassociatePass Impl;
  139. };
  140. } // end anonymous namespace
  141. char NaryReassociateLegacyPass::ID = 0;
  142. INITIALIZE_PASS_BEGIN(NaryReassociateLegacyPass, "nary-reassociate",
  143. "Nary reassociation", false, false)
  144. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  145. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  146. INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
  147. INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
  148. INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
  149. INITIALIZE_PASS_END(NaryReassociateLegacyPass, "nary-reassociate",
  150. "Nary reassociation", false, false)
  151. FunctionPass *llvm::createNaryReassociatePass() {
  152. return new NaryReassociateLegacyPass();
  153. }
  154. bool NaryReassociateLegacyPass::runOnFunction(Function &F) {
  155. if (skipFunction(F))
  156. return false;
  157. auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
  158. auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  159. auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  160. auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
  161. auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  162. return Impl.runImpl(F, AC, DT, SE, TLI, TTI);
  163. }
  164. PreservedAnalyses NaryReassociatePass::run(Function &F,
  165. FunctionAnalysisManager &AM) {
  166. auto *AC = &AM.getResult<AssumptionAnalysis>(F);
  167. auto *DT = &AM.getResult<DominatorTreeAnalysis>(F);
  168. auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
  169. auto *TLI = &AM.getResult<TargetLibraryAnalysis>(F);
  170. auto *TTI = &AM.getResult<TargetIRAnalysis>(F);
  171. if (!runImpl(F, AC, DT, SE, TLI, TTI))
  172. return PreservedAnalyses::all();
  173. PreservedAnalyses PA;
  174. PA.preserveSet<CFGAnalyses>();
  175. PA.preserve<ScalarEvolutionAnalysis>();
  176. return PA;
  177. }
  178. bool NaryReassociatePass::runImpl(Function &F, AssumptionCache *AC_,
  179. DominatorTree *DT_, ScalarEvolution *SE_,
  180. TargetLibraryInfo *TLI_,
  181. TargetTransformInfo *TTI_) {
  182. AC = AC_;
  183. DT = DT_;
  184. SE = SE_;
  185. TLI = TLI_;
  186. TTI = TTI_;
  187. DL = &F.getParent()->getDataLayout();
  188. bool Changed = false, ChangedInThisIteration;
  189. do {
  190. ChangedInThisIteration = doOneIteration(F);
  191. Changed |= ChangedInThisIteration;
  192. } while (ChangedInThisIteration);
  193. return Changed;
  194. }
  195. bool NaryReassociatePass::doOneIteration(Function &F) {
  196. bool Changed = false;
  197. SeenExprs.clear();
  198. // Process the basic blocks in a depth first traversal of the dominator
  199. // tree. This order ensures that all bases of a candidate are in Candidates
  200. // when we process it.
  201. SmallVector<WeakTrackingVH, 16> DeadInsts;
  202. for (const auto Node : depth_first(DT)) {
  203. BasicBlock *BB = Node->getBlock();
  204. for (Instruction &OrigI : *BB) {
  205. const SCEV *OrigSCEV = nullptr;
  206. if (Instruction *NewI = tryReassociate(&OrigI, OrigSCEV)) {
  207. Changed = true;
  208. OrigI.replaceAllUsesWith(NewI);
  209. // Add 'OrigI' to the list of dead instructions.
  210. DeadInsts.push_back(WeakTrackingVH(&OrigI));
  211. // Add the rewritten instruction to SeenExprs; the original
  212. // instruction is deleted.
  213. const SCEV *NewSCEV = SE->getSCEV(NewI);
  214. SeenExprs[NewSCEV].push_back(WeakTrackingVH(NewI));
  215. // Ideally, NewSCEV should equal OldSCEV because tryReassociate(I)
  216. // is equivalent to I. However, ScalarEvolution::getSCEV may
  217. // weaken nsw causing NewSCEV not to equal OldSCEV. For example,
  218. // suppose we reassociate
  219. // I = &a[sext(i +nsw j)] // assuming sizeof(a[0]) = 4
  220. // to
  221. // NewI = &a[sext(i)] + sext(j).
  222. //
  223. // ScalarEvolution computes
  224. // getSCEV(I) = a + 4 * sext(i + j)
  225. // getSCEV(newI) = a + 4 * sext(i) + 4 * sext(j)
  226. // which are different SCEVs.
  227. //
  228. // To alleviate this issue of ScalarEvolution not always capturing
  229. // equivalence, we add I to SeenExprs[OldSCEV] as well so that we can
  230. // map both SCEV before and after tryReassociate(I) to I.
  231. //
  232. // This improvement is exercised in @reassociate_gep_nsw in
  233. // nary-gep.ll.
  234. if (NewSCEV != OrigSCEV)
  235. SeenExprs[OrigSCEV].push_back(WeakTrackingVH(NewI));
  236. } else if (OrigSCEV)
  237. SeenExprs[OrigSCEV].push_back(WeakTrackingVH(&OrigI));
  238. }
  239. }
  240. // Delete all dead instructions from 'DeadInsts'.
  241. // Please note ScalarEvolution is updated along the way.
  242. RecursivelyDeleteTriviallyDeadInstructionsPermissive(
  243. DeadInsts, TLI, nullptr, [this](Value *V) { SE->forgetValue(V); });
  244. return Changed;
  245. }
  246. template <typename PredT>
  247. Instruction *
  248. NaryReassociatePass::matchAndReassociateMinOrMax(Instruction *I,
  249. const SCEV *&OrigSCEV) {
  250. Value *LHS = nullptr;
  251. Value *RHS = nullptr;
  252. auto MinMaxMatcher =
  253. MaxMin_match<ICmpInst, bind_ty<Value>, bind_ty<Value>, PredT>(
  254. m_Value(LHS), m_Value(RHS));
  255. if (match(I, MinMaxMatcher)) {
  256. OrigSCEV = SE->getSCEV(I);
  257. if (auto *NewMinMax = dyn_cast_or_null<Instruction>(
  258. tryReassociateMinOrMax(I, MinMaxMatcher, LHS, RHS)))
  259. return NewMinMax;
  260. if (auto *NewMinMax = dyn_cast_or_null<Instruction>(
  261. tryReassociateMinOrMax(I, MinMaxMatcher, RHS, LHS)))
  262. return NewMinMax;
  263. }
  264. return nullptr;
  265. }
  266. Instruction *NaryReassociatePass::tryReassociate(Instruction * I,
  267. const SCEV *&OrigSCEV) {
  268. if (!SE->isSCEVable(I->getType()))
  269. return nullptr;
  270. switch (I->getOpcode()) {
  271. case Instruction::Add:
  272. case Instruction::Mul:
  273. OrigSCEV = SE->getSCEV(I);
  274. return tryReassociateBinaryOp(cast<BinaryOperator>(I));
  275. case Instruction::GetElementPtr:
  276. OrigSCEV = SE->getSCEV(I);
  277. return tryReassociateGEP(cast<GetElementPtrInst>(I));
  278. default:
  279. break;
  280. }
  281. // Try to match signed/unsigned Min/Max.
  282. Instruction *ResI = nullptr;
  283. // TODO: Currently min/max reassociation is restricted to integer types only
  284. // due to use of SCEVExpander which my introduce incompatible forms of min/max
  285. // for pointer types.
  286. if (I->getType()->isIntegerTy())
  287. if ((ResI = matchAndReassociateMinOrMax<umin_pred_ty>(I, OrigSCEV)) ||
  288. (ResI = matchAndReassociateMinOrMax<smin_pred_ty>(I, OrigSCEV)) ||
  289. (ResI = matchAndReassociateMinOrMax<umax_pred_ty>(I, OrigSCEV)) ||
  290. (ResI = matchAndReassociateMinOrMax<smax_pred_ty>(I, OrigSCEV)))
  291. return ResI;
  292. return nullptr;
  293. }
  294. static bool isGEPFoldable(GetElementPtrInst *GEP,
  295. const TargetTransformInfo *TTI) {
  296. SmallVector<const Value *, 4> Indices(GEP->indices());
  297. return TTI->getGEPCost(GEP->getSourceElementType(), GEP->getPointerOperand(),
  298. Indices) == TargetTransformInfo::TCC_Free;
  299. }
  300. Instruction *NaryReassociatePass::tryReassociateGEP(GetElementPtrInst *GEP) {
  301. // Not worth reassociating GEP if it is foldable.
  302. if (isGEPFoldable(GEP, TTI))
  303. return nullptr;
  304. gep_type_iterator GTI = gep_type_begin(*GEP);
  305. for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) {
  306. if (GTI.isSequential()) {
  307. if (auto *NewGEP = tryReassociateGEPAtIndex(GEP, I - 1,
  308. GTI.getIndexedType())) {
  309. return NewGEP;
  310. }
  311. }
  312. }
  313. return nullptr;
  314. }
  315. bool NaryReassociatePass::requiresSignExtension(Value *Index,
  316. GetElementPtrInst *GEP) {
  317. unsigned PointerSizeInBits =
  318. DL->getPointerSizeInBits(GEP->getType()->getPointerAddressSpace());
  319. return cast<IntegerType>(Index->getType())->getBitWidth() < PointerSizeInBits;
  320. }
  321. GetElementPtrInst *
  322. NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP,
  323. unsigned I, Type *IndexedType) {
  324. Value *IndexToSplit = GEP->getOperand(I + 1);
  325. if (SExtInst *SExt = dyn_cast<SExtInst>(IndexToSplit)) {
  326. IndexToSplit = SExt->getOperand(0);
  327. } else if (ZExtInst *ZExt = dyn_cast<ZExtInst>(IndexToSplit)) {
  328. // zext can be treated as sext if the source is non-negative.
  329. if (isKnownNonNegative(ZExt->getOperand(0), *DL, 0, AC, GEP, DT))
  330. IndexToSplit = ZExt->getOperand(0);
  331. }
  332. if (AddOperator *AO = dyn_cast<AddOperator>(IndexToSplit)) {
  333. // If the I-th index needs sext and the underlying add is not equipped with
  334. // nsw, we cannot split the add because
  335. // sext(LHS + RHS) != sext(LHS) + sext(RHS).
  336. if (requiresSignExtension(IndexToSplit, GEP) &&
  337. computeOverflowForSignedAdd(AO, *DL, AC, GEP, DT) !=
  338. OverflowResult::NeverOverflows)
  339. return nullptr;
  340. Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1);
  341. // IndexToSplit = LHS + RHS.
  342. if (auto *NewGEP = tryReassociateGEPAtIndex(GEP, I, LHS, RHS, IndexedType))
  343. return NewGEP;
  344. // Symmetrically, try IndexToSplit = RHS + LHS.
  345. if (LHS != RHS) {
  346. if (auto *NewGEP =
  347. tryReassociateGEPAtIndex(GEP, I, RHS, LHS, IndexedType))
  348. return NewGEP;
  349. }
  350. }
  351. return nullptr;
  352. }
  353. GetElementPtrInst *
  354. NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP,
  355. unsigned I, Value *LHS,
  356. Value *RHS, Type *IndexedType) {
  357. // Look for GEP's closest dominator that has the same SCEV as GEP except that
  358. // the I-th index is replaced with LHS.
  359. SmallVector<const SCEV *, 4> IndexExprs;
  360. for (Use &Index : GEP->indices())
  361. IndexExprs.push_back(SE->getSCEV(Index));
  362. // Replace the I-th index with LHS.
  363. IndexExprs[I] = SE->getSCEV(LHS);
  364. if (isKnownNonNegative(LHS, *DL, 0, AC, GEP, DT) &&
  365. DL->getTypeSizeInBits(LHS->getType()).getFixedValue() <
  366. DL->getTypeSizeInBits(GEP->getOperand(I)->getType())
  367. .getFixedValue()) {
  368. // Zero-extend LHS if it is non-negative. InstCombine canonicalizes sext to
  369. // zext if the source operand is proved non-negative. We should do that
  370. // consistently so that CandidateExpr more likely appears before. See
  371. // @reassociate_gep_assume for an example of this canonicalization.
  372. IndexExprs[I] =
  373. SE->getZeroExtendExpr(IndexExprs[I], GEP->getOperand(I)->getType());
  374. }
  375. const SCEV *CandidateExpr = SE->getGEPExpr(cast<GEPOperator>(GEP),
  376. IndexExprs);
  377. Value *Candidate = findClosestMatchingDominator(CandidateExpr, GEP);
  378. if (Candidate == nullptr)
  379. return nullptr;
  380. IRBuilder<> Builder(GEP);
  381. // Candidate does not necessarily have the same pointer type as GEP. Use
  382. // bitcast or pointer cast to make sure they have the same type, so that the
  383. // later RAUW doesn't complain.
  384. Candidate = Builder.CreateBitOrPointerCast(Candidate, GEP->getType());
  385. assert(Candidate->getType() == GEP->getType());
  386. // NewGEP = (char *)Candidate + RHS * sizeof(IndexedType)
  387. uint64_t IndexedSize = DL->getTypeAllocSize(IndexedType);
  388. Type *ElementType = GEP->getResultElementType();
  389. uint64_t ElementSize = DL->getTypeAllocSize(ElementType);
  390. // Another less rare case: because I is not necessarily the last index of the
  391. // GEP, the size of the type at the I-th index (IndexedSize) is not
  392. // necessarily divisible by ElementSize. For example,
  393. //
  394. // #pragma pack(1)
  395. // struct S {
  396. // int a[3];
  397. // int64 b[8];
  398. // };
  399. // #pragma pack()
  400. //
  401. // sizeof(S) = 100 is indivisible by sizeof(int64) = 8.
  402. //
  403. // TODO: bail out on this case for now. We could emit uglygep.
  404. if (IndexedSize % ElementSize != 0)
  405. return nullptr;
  406. // NewGEP = &Candidate[RHS * (sizeof(IndexedType) / sizeof(Candidate[0])));
  407. Type *IntPtrTy = DL->getIntPtrType(GEP->getType());
  408. if (RHS->getType() != IntPtrTy)
  409. RHS = Builder.CreateSExtOrTrunc(RHS, IntPtrTy);
  410. if (IndexedSize != ElementSize) {
  411. RHS = Builder.CreateMul(
  412. RHS, ConstantInt::get(IntPtrTy, IndexedSize / ElementSize));
  413. }
  414. GetElementPtrInst *NewGEP = cast<GetElementPtrInst>(
  415. Builder.CreateGEP(GEP->getResultElementType(), Candidate, RHS));
  416. NewGEP->setIsInBounds(GEP->isInBounds());
  417. NewGEP->takeName(GEP);
  418. return NewGEP;
  419. }
  420. Instruction *NaryReassociatePass::tryReassociateBinaryOp(BinaryOperator *I) {
  421. Value *LHS = I->getOperand(0), *RHS = I->getOperand(1);
  422. // There is no need to reassociate 0.
  423. if (SE->getSCEV(I)->isZero())
  424. return nullptr;
  425. if (auto *NewI = tryReassociateBinaryOp(LHS, RHS, I))
  426. return NewI;
  427. if (auto *NewI = tryReassociateBinaryOp(RHS, LHS, I))
  428. return NewI;
  429. return nullptr;
  430. }
  431. Instruction *NaryReassociatePass::tryReassociateBinaryOp(Value *LHS, Value *RHS,
  432. BinaryOperator *I) {
  433. Value *A = nullptr, *B = nullptr;
  434. // To be conservative, we reassociate I only when it is the only user of (A op
  435. // B).
  436. if (LHS->hasOneUse() && matchTernaryOp(I, LHS, A, B)) {
  437. // I = (A op B) op RHS
  438. // = (A op RHS) op B or (B op RHS) op A
  439. const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(B);
  440. const SCEV *RHSExpr = SE->getSCEV(RHS);
  441. if (BExpr != RHSExpr) {
  442. if (auto *NewI =
  443. tryReassociatedBinaryOp(getBinarySCEV(I, AExpr, RHSExpr), B, I))
  444. return NewI;
  445. }
  446. if (AExpr != RHSExpr) {
  447. if (auto *NewI =
  448. tryReassociatedBinaryOp(getBinarySCEV(I, BExpr, RHSExpr), A, I))
  449. return NewI;
  450. }
  451. }
  452. return nullptr;
  453. }
  454. Instruction *NaryReassociatePass::tryReassociatedBinaryOp(const SCEV *LHSExpr,
  455. Value *RHS,
  456. BinaryOperator *I) {
  457. // Look for the closest dominator LHS of I that computes LHSExpr, and replace
  458. // I with LHS op RHS.
  459. auto *LHS = findClosestMatchingDominator(LHSExpr, I);
  460. if (LHS == nullptr)
  461. return nullptr;
  462. Instruction *NewI = nullptr;
  463. switch (I->getOpcode()) {
  464. case Instruction::Add:
  465. NewI = BinaryOperator::CreateAdd(LHS, RHS, "", I);
  466. break;
  467. case Instruction::Mul:
  468. NewI = BinaryOperator::CreateMul(LHS, RHS, "", I);
  469. break;
  470. default:
  471. llvm_unreachable("Unexpected instruction.");
  472. }
  473. NewI->takeName(I);
  474. return NewI;
  475. }
  476. bool NaryReassociatePass::matchTernaryOp(BinaryOperator *I, Value *V,
  477. Value *&Op1, Value *&Op2) {
  478. switch (I->getOpcode()) {
  479. case Instruction::Add:
  480. return match(V, m_Add(m_Value(Op1), m_Value(Op2)));
  481. case Instruction::Mul:
  482. return match(V, m_Mul(m_Value(Op1), m_Value(Op2)));
  483. default:
  484. llvm_unreachable("Unexpected instruction.");
  485. }
  486. return false;
  487. }
  488. const SCEV *NaryReassociatePass::getBinarySCEV(BinaryOperator *I,
  489. const SCEV *LHS,
  490. const SCEV *RHS) {
  491. switch (I->getOpcode()) {
  492. case Instruction::Add:
  493. return SE->getAddExpr(LHS, RHS);
  494. case Instruction::Mul:
  495. return SE->getMulExpr(LHS, RHS);
  496. default:
  497. llvm_unreachable("Unexpected instruction.");
  498. }
  499. return nullptr;
  500. }
  501. Instruction *
  502. NaryReassociatePass::findClosestMatchingDominator(const SCEV *CandidateExpr,
  503. Instruction *Dominatee) {
  504. auto Pos = SeenExprs.find(CandidateExpr);
  505. if (Pos == SeenExprs.end())
  506. return nullptr;
  507. auto &Candidates = Pos->second;
  508. // Because we process the basic blocks in pre-order of the dominator tree, a
  509. // candidate that doesn't dominate the current instruction won't dominate any
  510. // future instruction either. Therefore, we pop it out of the stack. This
  511. // optimization makes the algorithm O(n).
  512. while (!Candidates.empty()) {
  513. // Candidates stores WeakTrackingVHs, so a candidate can be nullptr if it's
  514. // removed
  515. // during rewriting.
  516. if (Value *Candidate = Candidates.back()) {
  517. Instruction *CandidateInstruction = cast<Instruction>(Candidate);
  518. if (DT->dominates(CandidateInstruction, Dominatee))
  519. return CandidateInstruction;
  520. }
  521. Candidates.pop_back();
  522. }
  523. return nullptr;
  524. }
  525. template <typename MaxMinT> static SCEVTypes convertToSCEVype(MaxMinT &MM) {
  526. if (std::is_same_v<smax_pred_ty, typename MaxMinT::PredType>)
  527. return scSMaxExpr;
  528. else if (std::is_same_v<umax_pred_ty, typename MaxMinT::PredType>)
  529. return scUMaxExpr;
  530. else if (std::is_same_v<smin_pred_ty, typename MaxMinT::PredType>)
  531. return scSMinExpr;
  532. else if (std::is_same_v<umin_pred_ty, typename MaxMinT::PredType>)
  533. return scUMinExpr;
  534. llvm_unreachable("Can't convert MinMax pattern to SCEV type");
  535. return scUnknown;
  536. }
  537. // Parameters:
  538. // I - instruction matched by MaxMinMatch matcher
  539. // MaxMinMatch - min/max idiom matcher
  540. // LHS - first operand of I
  541. // RHS - second operand of I
  542. template <typename MaxMinT>
  543. Value *NaryReassociatePass::tryReassociateMinOrMax(Instruction *I,
  544. MaxMinT MaxMinMatch,
  545. Value *LHS, Value *RHS) {
  546. Value *A = nullptr, *B = nullptr;
  547. MaxMinT m_MaxMin(m_Value(A), m_Value(B));
  548. if (LHS->hasNUsesOrMore(3) ||
  549. // The optimization is profitable only if LHS can be removed in the end.
  550. // In other words LHS should be used (directly or indirectly) by I only.
  551. llvm::any_of(LHS->users(),
  552. [&](auto *U) {
  553. return U != I &&
  554. !(U->hasOneUser() && *U->users().begin() == I);
  555. }) ||
  556. !match(LHS, m_MaxMin))
  557. return nullptr;
  558. auto tryCombination = [&](Value *A, const SCEV *AExpr, Value *B,
  559. const SCEV *BExpr, Value *C,
  560. const SCEV *CExpr) -> Value * {
  561. SmallVector<const SCEV *, 2> Ops1{BExpr, AExpr};
  562. const SCEVTypes SCEVType = convertToSCEVype(m_MaxMin);
  563. const SCEV *R1Expr = SE->getMinMaxExpr(SCEVType, Ops1);
  564. Instruction *R1MinMax = findClosestMatchingDominator(R1Expr, I);
  565. if (!R1MinMax)
  566. return nullptr;
  567. LLVM_DEBUG(dbgs() << "NARY: Found common sub-expr: " << *R1MinMax << "\n");
  568. SmallVector<const SCEV *, 2> Ops2{SE->getUnknown(C),
  569. SE->getUnknown(R1MinMax)};
  570. const SCEV *R2Expr = SE->getMinMaxExpr(SCEVType, Ops2);
  571. SCEVExpander Expander(*SE, *DL, "nary-reassociate");
  572. Value *NewMinMax = Expander.expandCodeFor(R2Expr, I->getType(), I);
  573. NewMinMax->setName(Twine(I->getName()).concat(".nary"));
  574. LLVM_DEBUG(dbgs() << "NARY: Deleting: " << *I << "\n"
  575. << "NARY: Inserting: " << *NewMinMax << "\n");
  576. return NewMinMax;
  577. };
  578. const SCEV *AExpr = SE->getSCEV(A);
  579. const SCEV *BExpr = SE->getSCEV(B);
  580. const SCEV *RHSExpr = SE->getSCEV(RHS);
  581. if (BExpr != RHSExpr) {
  582. // Try (A op RHS) op B
  583. if (auto *NewMinMax = tryCombination(A, AExpr, RHS, RHSExpr, B, BExpr))
  584. return NewMinMax;
  585. }
  586. if (AExpr != RHSExpr) {
  587. // Try (RHS op B) op A
  588. if (auto *NewMinMax = tryCombination(RHS, RHSExpr, B, BExpr, A, AExpr))
  589. return NewMinMax;
  590. }
  591. return nullptr;
  592. }