ARMParallelDSP.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817
  1. //===- ARMParallelDSP.cpp - Parallel DSP Pass -----------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. /// \file
  10. /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
  11. /// purpose of this pass is do some IR pattern matching to create ACLE
  12. /// DSP intrinsics, which map on these 32-bit SIMD operations.
  13. /// This pass runs only when unaligned accesses is supported/enabled.
  14. //
  15. //===----------------------------------------------------------------------===//
  16. #include "ARM.h"
  17. #include "ARMSubtarget.h"
  18. #include "llvm/ADT/SmallPtrSet.h"
  19. #include "llvm/ADT/Statistic.h"
  20. #include "llvm/Analysis/AliasAnalysis.h"
  21. #include "llvm/Analysis/AssumptionCache.h"
  22. #include "llvm/Analysis/GlobalsModRef.h"
  23. #include "llvm/Analysis/LoopAccessAnalysis.h"
  24. #include "llvm/Analysis/TargetLibraryInfo.h"
  25. #include "llvm/CodeGen/TargetPassConfig.h"
  26. #include "llvm/IR/Instructions.h"
  27. #include "llvm/IR/IntrinsicsARM.h"
  28. #include "llvm/IR/NoFolder.h"
  29. #include "llvm/IR/PatternMatch.h"
  30. #include "llvm/Pass.h"
  31. #include "llvm/PassRegistry.h"
  32. #include "llvm/Support/Debug.h"
  33. #include "llvm/Transforms/Scalar.h"
  34. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  35. using namespace llvm;
  36. using namespace PatternMatch;
  37. #define DEBUG_TYPE "arm-parallel-dsp"
  38. STATISTIC(NumSMLAD , "Number of smlad instructions generated");
  39. static cl::opt<bool>
  40. DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
  41. cl::desc("Disable the ARM Parallel DSP pass"));
  42. static cl::opt<unsigned>
  43. NumLoadLimit("arm-parallel-dsp-load-limit", cl::Hidden, cl::init(16),
  44. cl::desc("Limit the number of loads analysed"));
  45. namespace {
  46. struct MulCandidate;
  47. class Reduction;
  48. using MulCandList = SmallVector<std::unique_ptr<MulCandidate>, 8>;
  49. using MemInstList = SmallVectorImpl<LoadInst*>;
  50. using MulPairList = SmallVector<std::pair<MulCandidate*, MulCandidate*>, 8>;
  51. // 'MulCandidate' holds the multiplication instructions that are candidates
  52. // for parallel execution.
  53. struct MulCandidate {
  54. Instruction *Root;
  55. Value* LHS;
  56. Value* RHS;
  57. bool Exchange = false;
  58. bool ReadOnly = true;
  59. bool Paired = false;
  60. SmallVector<LoadInst*, 2> VecLd; // Container for loads to widen.
  61. MulCandidate(Instruction *I, Value *lhs, Value *rhs) :
  62. Root(I), LHS(lhs), RHS(rhs) { }
  63. bool HasTwoLoadInputs() const {
  64. return isa<LoadInst>(LHS) && isa<LoadInst>(RHS);
  65. }
  66. LoadInst *getBaseLoad() const {
  67. return VecLd.front();
  68. }
  69. };
  70. /// Represent a sequence of multiply-accumulate operations with the aim to
  71. /// perform the multiplications in parallel.
  72. class Reduction {
  73. Instruction *Root = nullptr;
  74. Value *Acc = nullptr;
  75. MulCandList Muls;
  76. MulPairList MulPairs;
  77. SetVector<Instruction*> Adds;
  78. public:
  79. Reduction() = delete;
  80. Reduction (Instruction *Add) : Root(Add) { }
  81. /// Record an Add instruction that is a part of the this reduction.
  82. void InsertAdd(Instruction *I) { Adds.insert(I); }
  83. /// Create MulCandidates, each rooted at a Mul instruction, that is a part
  84. /// of this reduction.
  85. void InsertMuls() {
  86. auto GetMulOperand = [](Value *V) -> Instruction* {
  87. if (auto *SExt = dyn_cast<SExtInst>(V)) {
  88. if (auto *I = dyn_cast<Instruction>(SExt->getOperand(0)))
  89. if (I->getOpcode() == Instruction::Mul)
  90. return I;
  91. } else if (auto *I = dyn_cast<Instruction>(V)) {
  92. if (I->getOpcode() == Instruction::Mul)
  93. return I;
  94. }
  95. return nullptr;
  96. };
  97. auto InsertMul = [this](Instruction *I) {
  98. Value *LHS = cast<Instruction>(I->getOperand(0))->getOperand(0);
  99. Value *RHS = cast<Instruction>(I->getOperand(1))->getOperand(0);
  100. Muls.push_back(std::make_unique<MulCandidate>(I, LHS, RHS));
  101. };
  102. for (auto *Add : Adds) {
  103. if (Add == Acc)
  104. continue;
  105. if (auto *Mul = GetMulOperand(Add->getOperand(0)))
  106. InsertMul(Mul);
  107. if (auto *Mul = GetMulOperand(Add->getOperand(1)))
  108. InsertMul(Mul);
  109. }
  110. }
  111. /// Add the incoming accumulator value, returns true if a value had not
  112. /// already been added. Returning false signals to the user that this
  113. /// reduction already has a value to initialise the accumulator.
  114. bool InsertAcc(Value *V) {
  115. if (Acc)
  116. return false;
  117. Acc = V;
  118. return true;
  119. }
  120. /// Set two MulCandidates, rooted at muls, that can be executed as a single
  121. /// parallel operation.
  122. void AddMulPair(MulCandidate *Mul0, MulCandidate *Mul1,
  123. bool Exchange = false) {
  124. LLVM_DEBUG(dbgs() << "Pairing:\n"
  125. << *Mul0->Root << "\n"
  126. << *Mul1->Root << "\n");
  127. Mul0->Paired = true;
  128. Mul1->Paired = true;
  129. if (Exchange)
  130. Mul1->Exchange = true;
  131. MulPairs.push_back(std::make_pair(Mul0, Mul1));
  132. }
  133. /// Return true if enough mul operations are found that can be executed in
  134. /// parallel.
  135. bool CreateParallelPairs();
  136. /// Return the add instruction which is the root of the reduction.
  137. Instruction *getRoot() { return Root; }
  138. bool is64Bit() const { return Root->getType()->isIntegerTy(64); }
  139. Type *getType() const { return Root->getType(); }
  140. /// Return the incoming value to be accumulated. This maybe null.
  141. Value *getAccumulator() { return Acc; }
  142. /// Return the set of adds that comprise the reduction.
  143. SetVector<Instruction*> &getAdds() { return Adds; }
  144. /// Return the MulCandidate, rooted at mul instruction, that comprise the
  145. /// the reduction.
  146. MulCandList &getMuls() { return Muls; }
  147. /// Return the MulCandidate, rooted at mul instructions, that have been
  148. /// paired for parallel execution.
  149. MulPairList &getMulPairs() { return MulPairs; }
  150. /// To finalise, replace the uses of the root with the intrinsic call.
  151. void UpdateRoot(Instruction *SMLAD) {
  152. Root->replaceAllUsesWith(SMLAD);
  153. }
  154. void dump() {
  155. LLVM_DEBUG(dbgs() << "Reduction:\n";
  156. for (auto *Add : Adds)
  157. LLVM_DEBUG(dbgs() << *Add << "\n");
  158. for (auto &Mul : Muls)
  159. LLVM_DEBUG(dbgs() << *Mul->Root << "\n"
  160. << " " << *Mul->LHS << "\n"
  161. << " " << *Mul->RHS << "\n");
  162. LLVM_DEBUG(if (Acc) dbgs() << "Acc in: " << *Acc << "\n")
  163. );
  164. }
  165. };
  166. class WidenedLoad {
  167. LoadInst *NewLd = nullptr;
  168. SmallVector<LoadInst*, 4> Loads;
  169. public:
  170. WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide)
  171. : NewLd(Wide) {
  172. append_range(Loads, Lds);
  173. }
  174. LoadInst *getLoad() {
  175. return NewLd;
  176. }
  177. };
  178. class ARMParallelDSP : public FunctionPass {
  179. ScalarEvolution *SE;
  180. AliasAnalysis *AA;
  181. TargetLibraryInfo *TLI;
  182. DominatorTree *DT;
  183. const DataLayout *DL;
  184. Module *M;
  185. std::map<LoadInst*, LoadInst*> LoadPairs;
  186. SmallPtrSet<LoadInst*, 4> OffsetLoads;
  187. std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
  188. template<unsigned>
  189. bool IsNarrowSequence(Value *V);
  190. bool Search(Value *V, BasicBlock *BB, Reduction &R);
  191. bool RecordMemoryOps(BasicBlock *BB);
  192. void InsertParallelMACs(Reduction &Reduction);
  193. bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
  194. LoadInst* CreateWideLoad(MemInstList &Loads, IntegerType *LoadTy);
  195. bool CreateParallelPairs(Reduction &R);
  196. /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
  197. /// Dual performs two signed 16x16-bit multiplications. It adds the
  198. /// products to a 32-bit accumulate operand. Optionally, the instruction can
  199. /// exchange the halfwords of the second operand before performing the
  200. /// arithmetic.
  201. bool MatchSMLAD(Function &F);
  202. public:
  203. static char ID;
  204. ARMParallelDSP() : FunctionPass(ID) { }
  205. void getAnalysisUsage(AnalysisUsage &AU) const override {
  206. FunctionPass::getAnalysisUsage(AU);
  207. AU.addRequired<AssumptionCacheTracker>();
  208. AU.addRequired<ScalarEvolutionWrapperPass>();
  209. AU.addRequired<AAResultsWrapperPass>();
  210. AU.addRequired<TargetLibraryInfoWrapperPass>();
  211. AU.addRequired<DominatorTreeWrapperPass>();
  212. AU.addRequired<TargetPassConfig>();
  213. AU.addPreserved<ScalarEvolutionWrapperPass>();
  214. AU.addPreserved<GlobalsAAWrapperPass>();
  215. AU.setPreservesCFG();
  216. }
  217. bool runOnFunction(Function &F) override {
  218. if (DisableParallelDSP)
  219. return false;
  220. if (skipFunction(F))
  221. return false;
  222. SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  223. AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
  224. TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
  225. DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  226. auto &TPC = getAnalysis<TargetPassConfig>();
  227. M = F.getParent();
  228. DL = &M->getDataLayout();
  229. auto &TM = TPC.getTM<TargetMachine>();
  230. auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
  231. if (!ST->allowsUnalignedMem()) {
  232. LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
  233. "running pass ARMParallelDSP\n");
  234. return false;
  235. }
  236. if (!ST->hasDSP()) {
  237. LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
  238. "ARMParallelDSP\n");
  239. return false;
  240. }
  241. if (!ST->isLittle()) {
  242. LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
  243. << "ARMParallelDSP\n");
  244. return false;
  245. }
  246. LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
  247. LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
  248. bool Changes = MatchSMLAD(F);
  249. return Changes;
  250. }
  251. };
  252. }
  253. bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
  254. MemInstList &VecMem) {
  255. if (!Ld0 || !Ld1)
  256. return false;
  257. if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
  258. return false;
  259. LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
  260. dbgs() << "Ld0:"; Ld0->dump();
  261. dbgs() << "Ld1:"; Ld1->dump();
  262. );
  263. VecMem.clear();
  264. VecMem.push_back(Ld0);
  265. VecMem.push_back(Ld1);
  266. return true;
  267. }
  268. // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
  269. // instructions, which is set to 16. So here we should collect all i8 and i16
  270. // narrow operations.
  271. // TODO: we currently only collect i16, and will support i8 later, so that's
  272. // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
  273. template<unsigned MaxBitWidth>
  274. bool ARMParallelDSP::IsNarrowSequence(Value *V) {
  275. if (auto *SExt = dyn_cast<SExtInst>(V)) {
  276. if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
  277. return false;
  278. if (auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) {
  279. // Check that this load could be paired.
  280. return LoadPairs.count(Ld) || OffsetLoads.count(Ld);
  281. }
  282. }
  283. return false;
  284. }
  285. /// Iterate through the block and record base, offset pairs of loads which can
  286. /// be widened into a single load.
  287. bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
  288. SmallVector<LoadInst*, 8> Loads;
  289. SmallVector<Instruction*, 8> Writes;
  290. LoadPairs.clear();
  291. WideLoads.clear();
  292. // Collect loads and instruction that may write to memory. For now we only
  293. // record loads which are simple, sign-extended and have a single user.
  294. // TODO: Allow zero-extended loads.
  295. for (auto &I : *BB) {
  296. if (I.mayWriteToMemory())
  297. Writes.push_back(&I);
  298. auto *Ld = dyn_cast<LoadInst>(&I);
  299. if (!Ld || !Ld->isSimple() ||
  300. !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
  301. continue;
  302. Loads.push_back(Ld);
  303. }
  304. if (Loads.empty() || Loads.size() > NumLoadLimit)
  305. return false;
  306. using InstSet = std::set<Instruction*>;
  307. using DepMap = std::map<Instruction*, InstSet>;
  308. DepMap RAWDeps;
  309. // Record any writes that may alias a load.
  310. const auto Size = LocationSize::beforeOrAfterPointer();
  311. for (auto *Write : Writes) {
  312. for (auto *Read : Loads) {
  313. MemoryLocation ReadLoc =
  314. MemoryLocation(Read->getPointerOperand(), Size);
  315. if (!isModOrRefSet(AA->getModRefInfo(Write, ReadLoc)))
  316. continue;
  317. if (Write->comesBefore(Read))
  318. RAWDeps[Read].insert(Write);
  319. }
  320. }
  321. // Check whether there's not a write between the two loads which would
  322. // prevent them from being safely merged.
  323. auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
  324. bool BaseFirst = Base->comesBefore(Offset);
  325. LoadInst *Dominator = BaseFirst ? Base : Offset;
  326. LoadInst *Dominated = BaseFirst ? Offset : Base;
  327. if (RAWDeps.count(Dominated)) {
  328. InstSet &WritesBefore = RAWDeps[Dominated];
  329. for (auto *Before : WritesBefore) {
  330. // We can't move the second load backward, past a write, to merge
  331. // with the first load.
  332. if (Dominator->comesBefore(Before))
  333. return false;
  334. }
  335. }
  336. return true;
  337. };
  338. // Record base, offset load pairs.
  339. for (auto *Base : Loads) {
  340. for (auto *Offset : Loads) {
  341. if (Base == Offset || OffsetLoads.count(Offset))
  342. continue;
  343. if (isConsecutiveAccess(Base, Offset, *DL, *SE) &&
  344. SafeToPair(Base, Offset)) {
  345. LoadPairs[Base] = Offset;
  346. OffsetLoads.insert(Offset);
  347. break;
  348. }
  349. }
  350. }
  351. LLVM_DEBUG(if (!LoadPairs.empty()) {
  352. dbgs() << "Consecutive load pairs:\n";
  353. for (auto &MapIt : LoadPairs) {
  354. LLVM_DEBUG(dbgs() << *MapIt.first << ", "
  355. << *MapIt.second << "\n");
  356. }
  357. });
  358. return LoadPairs.size() > 1;
  359. }
  360. // Search recursively back through the operands to find a tree of values that
  361. // form a multiply-accumulate chain. The search records the Add and Mul
  362. // instructions that form the reduction and allows us to find a single value
  363. // to be used as the initial input to the accumlator.
  364. bool ARMParallelDSP::Search(Value *V, BasicBlock *BB, Reduction &R) {
  365. // If we find a non-instruction, try to use it as the initial accumulator
  366. // value. This may have already been found during the search in which case
  367. // this function will return false, signaling a search fail.
  368. auto *I = dyn_cast<Instruction>(V);
  369. if (!I)
  370. return R.InsertAcc(V);
  371. if (I->getParent() != BB)
  372. return false;
  373. switch (I->getOpcode()) {
  374. default:
  375. break;
  376. case Instruction::PHI:
  377. // Could be the accumulator value.
  378. return R.InsertAcc(V);
  379. case Instruction::Add: {
  380. // Adds should be adding together two muls, or another add and a mul to
  381. // be within the mac chain. One of the operands may also be the
  382. // accumulator value at which point we should stop searching.
  383. R.InsertAdd(I);
  384. Value *LHS = I->getOperand(0);
  385. Value *RHS = I->getOperand(1);
  386. bool ValidLHS = Search(LHS, BB, R);
  387. bool ValidRHS = Search(RHS, BB, R);
  388. if (ValidLHS && ValidRHS)
  389. return true;
  390. // Ensure we don't add the root as the incoming accumulator.
  391. if (R.getRoot() == I)
  392. return false;
  393. return R.InsertAcc(I);
  394. }
  395. case Instruction::Mul: {
  396. Value *MulOp0 = I->getOperand(0);
  397. Value *MulOp1 = I->getOperand(1);
  398. return IsNarrowSequence<16>(MulOp0) && IsNarrowSequence<16>(MulOp1);
  399. }
  400. case Instruction::SExt:
  401. return Search(I->getOperand(0), BB, R);
  402. }
  403. return false;
  404. }
  405. // The pass needs to identify integer add/sub reductions of 16-bit vector
  406. // multiplications.
  407. // To use SMLAD:
  408. // 1) we first need to find integer add then look for this pattern:
  409. //
  410. // acc0 = ...
  411. // ld0 = load i16
  412. // sext0 = sext i16 %ld0 to i32
  413. // ld1 = load i16
  414. // sext1 = sext i16 %ld1 to i32
  415. // mul0 = mul %sext0, %sext1
  416. // ld2 = load i16
  417. // sext2 = sext i16 %ld2 to i32
  418. // ld3 = load i16
  419. // sext3 = sext i16 %ld3 to i32
  420. // mul1 = mul i32 %sext2, %sext3
  421. // add0 = add i32 %mul0, %acc0
  422. // acc1 = add i32 %add0, %mul1
  423. //
  424. // Which can be selected to:
  425. //
  426. // ldr r0
  427. // ldr r1
  428. // smlad r2, r0, r1, r2
  429. //
  430. // If constants are used instead of loads, these will need to be hoisted
  431. // out and into a register.
  432. //
  433. // If loop invariants are used instead of loads, these need to be packed
  434. // before the loop begins.
  435. //
  436. bool ARMParallelDSP::MatchSMLAD(Function &F) {
  437. bool Changed = false;
  438. for (auto &BB : F) {
  439. SmallPtrSet<Instruction*, 4> AllAdds;
  440. if (!RecordMemoryOps(&BB))
  441. continue;
  442. for (Instruction &I : reverse(BB)) {
  443. if (I.getOpcode() != Instruction::Add)
  444. continue;
  445. if (AllAdds.count(&I))
  446. continue;
  447. const auto *Ty = I.getType();
  448. if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
  449. continue;
  450. Reduction R(&I);
  451. if (!Search(&I, &BB, R))
  452. continue;
  453. R.InsertMuls();
  454. LLVM_DEBUG(dbgs() << "After search, Reduction:\n"; R.dump());
  455. if (!CreateParallelPairs(R))
  456. continue;
  457. InsertParallelMACs(R);
  458. Changed = true;
  459. AllAdds.insert(R.getAdds().begin(), R.getAdds().end());
  460. LLVM_DEBUG(dbgs() << "BB after inserting parallel MACs:\n" << BB);
  461. }
  462. }
  463. return Changed;
  464. }
  465. bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
  466. // Not enough mul operations to make a pair.
  467. if (R.getMuls().size() < 2)
  468. return false;
  469. // Check that the muls operate directly upon sign extended loads.
  470. for (auto &MulCand : R.getMuls()) {
  471. if (!MulCand->HasTwoLoadInputs())
  472. return false;
  473. }
  474. auto CanPair = [&](Reduction &R, MulCandidate *PMul0, MulCandidate *PMul1) {
  475. // The first elements of each vector should be loads with sexts. If we
  476. // find that its two pairs of consecutive loads, then these can be
  477. // transformed into two wider loads and the users can be replaced with
  478. // DSP intrinsics.
  479. auto Ld0 = static_cast<LoadInst*>(PMul0->LHS);
  480. auto Ld1 = static_cast<LoadInst*>(PMul1->LHS);
  481. auto Ld2 = static_cast<LoadInst*>(PMul0->RHS);
  482. auto Ld3 = static_cast<LoadInst*>(PMul1->RHS);
  483. // Check that each mul is operating on two different loads.
  484. if (Ld0 == Ld2 || Ld1 == Ld3)
  485. return false;
  486. if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
  487. if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
  488. LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
  489. R.AddMulPair(PMul0, PMul1);
  490. return true;
  491. } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
  492. LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
  493. LLVM_DEBUG(dbgs() << " exchanging Ld2 and Ld3\n");
  494. R.AddMulPair(PMul0, PMul1, true);
  495. return true;
  496. }
  497. } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
  498. AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
  499. LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
  500. LLVM_DEBUG(dbgs() << " exchanging Ld0 and Ld1\n");
  501. LLVM_DEBUG(dbgs() << " and swapping muls\n");
  502. // Only the second operand can be exchanged, so swap the muls.
  503. R.AddMulPair(PMul1, PMul0, true);
  504. return true;
  505. }
  506. return false;
  507. };
  508. MulCandList &Muls = R.getMuls();
  509. const unsigned Elems = Muls.size();
  510. for (unsigned i = 0; i < Elems; ++i) {
  511. MulCandidate *PMul0 = static_cast<MulCandidate*>(Muls[i].get());
  512. if (PMul0->Paired)
  513. continue;
  514. for (unsigned j = 0; j < Elems; ++j) {
  515. if (i == j)
  516. continue;
  517. MulCandidate *PMul1 = static_cast<MulCandidate*>(Muls[j].get());
  518. if (PMul1->Paired)
  519. continue;
  520. const Instruction *Mul0 = PMul0->Root;
  521. const Instruction *Mul1 = PMul1->Root;
  522. if (Mul0 == Mul1)
  523. continue;
  524. assert(PMul0 != PMul1 && "expected different chains");
  525. if (CanPair(R, PMul0, PMul1))
  526. break;
  527. }
  528. }
  529. return !R.getMulPairs().empty();
  530. }
  531. void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
  532. auto CreateSMLAD = [&](LoadInst* WideLd0, LoadInst *WideLd1,
  533. Value *Acc, bool Exchange,
  534. Instruction *InsertAfter) {
  535. // Replace the reduction chain with an intrinsic call
  536. Value* Args[] = { WideLd0, WideLd1, Acc };
  537. Function *SMLAD = nullptr;
  538. if (Exchange)
  539. SMLAD = Acc->getType()->isIntegerTy(32) ?
  540. Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
  541. Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
  542. else
  543. SMLAD = Acc->getType()->isIntegerTy(32) ?
  544. Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
  545. Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
  546. IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
  547. BasicBlock::iterator(InsertAfter));
  548. Instruction *Call = Builder.CreateCall(SMLAD, Args);
  549. NumSMLAD++;
  550. return Call;
  551. };
  552. // Return the instruction after the dominated instruction.
  553. auto GetInsertPoint = [this](Value *A, Value *B) {
  554. assert((isa<Instruction>(A) || isa<Instruction>(B)) &&
  555. "expected at least one instruction");
  556. Value *V = nullptr;
  557. if (!isa<Instruction>(A))
  558. V = B;
  559. else if (!isa<Instruction>(B))
  560. V = A;
  561. else
  562. V = DT->dominates(cast<Instruction>(A), cast<Instruction>(B)) ? B : A;
  563. return &*++BasicBlock::iterator(cast<Instruction>(V));
  564. };
  565. Value *Acc = R.getAccumulator();
  566. // For any muls that were discovered but not paired, accumulate their values
  567. // as before.
  568. IRBuilder<NoFolder> Builder(R.getRoot()->getParent());
  569. MulCandList &MulCands = R.getMuls();
  570. for (auto &MulCand : MulCands) {
  571. if (MulCand->Paired)
  572. continue;
  573. Instruction *Mul = cast<Instruction>(MulCand->Root);
  574. LLVM_DEBUG(dbgs() << "Accumulating unpaired mul: " << *Mul << "\n");
  575. if (R.getType() != Mul->getType()) {
  576. assert(R.is64Bit() && "expected 64-bit result");
  577. Builder.SetInsertPoint(&*++BasicBlock::iterator(Mul));
  578. Mul = cast<Instruction>(Builder.CreateSExt(Mul, R.getRoot()->getType()));
  579. }
  580. if (!Acc) {
  581. Acc = Mul;
  582. continue;
  583. }
  584. // If Acc is the original incoming value to the reduction, it could be a
  585. // phi. But the phi will dominate Mul, meaning that Mul will be the
  586. // insertion point.
  587. Builder.SetInsertPoint(GetInsertPoint(Mul, Acc));
  588. Acc = Builder.CreateAdd(Mul, Acc);
  589. }
  590. if (!Acc) {
  591. Acc = R.is64Bit() ?
  592. ConstantInt::get(IntegerType::get(M->getContext(), 64), 0) :
  593. ConstantInt::get(IntegerType::get(M->getContext(), 32), 0);
  594. } else if (Acc->getType() != R.getType()) {
  595. Builder.SetInsertPoint(R.getRoot());
  596. Acc = Builder.CreateSExt(Acc, R.getType());
  597. }
  598. // Roughly sort the mul pairs in their program order.
  599. llvm::sort(R.getMulPairs(), [](auto &PairA, auto &PairB) {
  600. const Instruction *A = PairA.first->Root;
  601. const Instruction *B = PairB.first->Root;
  602. return A->comesBefore(B);
  603. });
  604. IntegerType *Ty = IntegerType::get(M->getContext(), 32);
  605. for (auto &Pair : R.getMulPairs()) {
  606. MulCandidate *LHSMul = Pair.first;
  607. MulCandidate *RHSMul = Pair.second;
  608. LoadInst *BaseLHS = LHSMul->getBaseLoad();
  609. LoadInst *BaseRHS = RHSMul->getBaseLoad();
  610. LoadInst *WideLHS = WideLoads.count(BaseLHS) ?
  611. WideLoads[BaseLHS]->getLoad() : CreateWideLoad(LHSMul->VecLd, Ty);
  612. LoadInst *WideRHS = WideLoads.count(BaseRHS) ?
  613. WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty);
  614. Instruction *InsertAfter = GetInsertPoint(WideLHS, WideRHS);
  615. InsertAfter = GetInsertPoint(InsertAfter, Acc);
  616. Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter);
  617. }
  618. R.UpdateRoot(cast<Instruction>(Acc));
  619. }
  620. LoadInst* ARMParallelDSP::CreateWideLoad(MemInstList &Loads,
  621. IntegerType *LoadTy) {
  622. assert(Loads.size() == 2 && "currently only support widening two loads");
  623. LoadInst *Base = Loads[0];
  624. LoadInst *Offset = Loads[1];
  625. Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back());
  626. Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back());
  627. assert((BaseSExt && OffsetSExt)
  628. && "Loads should have a single, extending, user");
  629. std::function<void(Value*, Value*)> MoveBefore =
  630. [&](Value *A, Value *B) -> void {
  631. if (!isa<Instruction>(A) || !isa<Instruction>(B))
  632. return;
  633. auto *Source = cast<Instruction>(A);
  634. auto *Sink = cast<Instruction>(B);
  635. if (DT->dominates(Source, Sink) ||
  636. Source->getParent() != Sink->getParent() ||
  637. isa<PHINode>(Source) || isa<PHINode>(Sink))
  638. return;
  639. Source->moveBefore(Sink);
  640. for (auto &Op : Source->operands())
  641. MoveBefore(Op, Source);
  642. };
  643. // Insert the load at the point of the original dominating load.
  644. LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset;
  645. IRBuilder<NoFolder> IRB(DomLoad->getParent(),
  646. ++BasicBlock::iterator(DomLoad));
  647. // Bitcast the pointer to a wider type and create the wide load, while making
  648. // sure to maintain the original alignment as this prevents ldrd from being
  649. // generated when it could be illegal due to memory alignment.
  650. const unsigned AddrSpace = DomLoad->getPointerAddressSpace();
  651. Value *VecPtr = IRB.CreateBitCast(Base->getPointerOperand(),
  652. LoadTy->getPointerTo(AddrSpace));
  653. LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr, Base->getAlign());
  654. // Make sure everything is in the correct order in the basic block.
  655. MoveBefore(Base->getPointerOperand(), VecPtr);
  656. MoveBefore(VecPtr, WideLoad);
  657. // From the wide load, create two values that equal the original two loads.
  658. // Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
  659. // TODO: Support big-endian as well.
  660. Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType());
  661. Value *NewBaseSExt = IRB.CreateSExt(Bottom, BaseSExt->getType());
  662. BaseSExt->replaceAllUsesWith(NewBaseSExt);
  663. IntegerType *OffsetTy = cast<IntegerType>(Offset->getType());
  664. Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth());
  665. Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
  666. Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
  667. Value *NewOffsetSExt = IRB.CreateSExt(Trunc, OffsetSExt->getType());
  668. OffsetSExt->replaceAllUsesWith(NewOffsetSExt);
  669. LLVM_DEBUG(dbgs() << "From Base and Offset:\n"
  670. << *Base << "\n" << *Offset << "\n"
  671. << "Created Wide Load:\n"
  672. << *WideLoad << "\n"
  673. << *Bottom << "\n"
  674. << *NewBaseSExt << "\n"
  675. << *Top << "\n"
  676. << *Trunc << "\n"
  677. << *NewOffsetSExt << "\n");
  678. WideLoads.emplace(std::make_pair(Base,
  679. std::make_unique<WidenedLoad>(Loads, WideLoad)));
  680. return WideLoad;
  681. }
  682. Pass *llvm::createARMParallelDSPPass() {
  683. return new ARMParallelDSP();
  684. }
  685. char ARMParallelDSP::ID = 0;
  686. INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
  687. "Transform functions to use DSP intrinsics", false, false)
  688. INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
  689. "Transform functions to use DSP intrinsics", false, false)