ARMParallelDSP.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813
  1. //===- ARMParallelDSP.cpp - Parallel DSP Pass -----------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. /// \file
  10. /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
  11. /// purpose of this pass is do some IR pattern matching to create ACLE
  12. /// DSP intrinsics, which map on these 32-bit SIMD operations.
  13. /// This pass runs only when unaligned accesses is supported/enabled.
  14. //
  15. //===----------------------------------------------------------------------===//
  16. #include "ARM.h"
  17. #include "ARMSubtarget.h"
  18. #include "llvm/ADT/SmallPtrSet.h"
  19. #include "llvm/ADT/Statistic.h"
  20. #include "llvm/Analysis/AliasAnalysis.h"
  21. #include "llvm/Analysis/AssumptionCache.h"
  22. #include "llvm/Analysis/GlobalsModRef.h"
  23. #include "llvm/Analysis/LoopAccessAnalysis.h"
  24. #include "llvm/Analysis/TargetLibraryInfo.h"
  25. #include "llvm/CodeGen/TargetPassConfig.h"
  26. #include "llvm/IR/Instructions.h"
  27. #include "llvm/IR/IntrinsicsARM.h"
  28. #include "llvm/IR/NoFolder.h"
  29. #include "llvm/IR/PatternMatch.h"
  30. #include "llvm/Pass.h"
  31. #include "llvm/PassRegistry.h"
  32. #include "llvm/Support/Debug.h"
  33. #include "llvm/Transforms/Scalar.h"
  34. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  35. using namespace llvm;
  36. using namespace PatternMatch;
  37. #define DEBUG_TYPE "arm-parallel-dsp"
  38. STATISTIC(NumSMLAD , "Number of smlad instructions generated");
  39. static cl::opt<bool>
  40. DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
  41. cl::desc("Disable the ARM Parallel DSP pass"));
  42. static cl::opt<unsigned>
  43. NumLoadLimit("arm-parallel-dsp-load-limit", cl::Hidden, cl::init(16),
  44. cl::desc("Limit the number of loads analysed"));
  45. namespace {
  46. struct MulCandidate;
  47. class Reduction;
  48. using MulCandList = SmallVector<std::unique_ptr<MulCandidate>, 8>;
  49. using MemInstList = SmallVectorImpl<LoadInst*>;
  50. using MulPairList = SmallVector<std::pair<MulCandidate*, MulCandidate*>, 8>;
  51. // 'MulCandidate' holds the multiplication instructions that are candidates
  52. // for parallel execution.
  53. struct MulCandidate {
  54. Instruction *Root;
  55. Value* LHS;
  56. Value* RHS;
  57. bool Exchange = false;
  58. bool ReadOnly = true;
  59. bool Paired = false;
  60. SmallVector<LoadInst*, 2> VecLd; // Container for loads to widen.
  61. MulCandidate(Instruction *I, Value *lhs, Value *rhs) :
  62. Root(I), LHS(lhs), RHS(rhs) { }
  63. bool HasTwoLoadInputs() const {
  64. return isa<LoadInst>(LHS) && isa<LoadInst>(RHS);
  65. }
  66. LoadInst *getBaseLoad() const {
  67. return VecLd.front();
  68. }
  69. };
  70. /// Represent a sequence of multiply-accumulate operations with the aim to
  71. /// perform the multiplications in parallel.
  72. class Reduction {
  73. Instruction *Root = nullptr;
  74. Value *Acc = nullptr;
  75. MulCandList Muls;
  76. MulPairList MulPairs;
  77. SetVector<Instruction*> Adds;
  78. public:
  79. Reduction() = delete;
  80. Reduction (Instruction *Add) : Root(Add) { }
  81. /// Record an Add instruction that is a part of the this reduction.
  82. void InsertAdd(Instruction *I) { Adds.insert(I); }
  83. /// Create MulCandidates, each rooted at a Mul instruction, that is a part
  84. /// of this reduction.
  85. void InsertMuls() {
  86. auto GetMulOperand = [](Value *V) -> Instruction* {
  87. if (auto *SExt = dyn_cast<SExtInst>(V)) {
  88. if (auto *I = dyn_cast<Instruction>(SExt->getOperand(0)))
  89. if (I->getOpcode() == Instruction::Mul)
  90. return I;
  91. } else if (auto *I = dyn_cast<Instruction>(V)) {
  92. if (I->getOpcode() == Instruction::Mul)
  93. return I;
  94. }
  95. return nullptr;
  96. };
  97. auto InsertMul = [this](Instruction *I) {
  98. Value *LHS = cast<Instruction>(I->getOperand(0))->getOperand(0);
  99. Value *RHS = cast<Instruction>(I->getOperand(1))->getOperand(0);
  100. Muls.push_back(std::make_unique<MulCandidate>(I, LHS, RHS));
  101. };
  102. for (auto *Add : Adds) {
  103. if (Add == Acc)
  104. continue;
  105. if (auto *Mul = GetMulOperand(Add->getOperand(0)))
  106. InsertMul(Mul);
  107. if (auto *Mul = GetMulOperand(Add->getOperand(1)))
  108. InsertMul(Mul);
  109. }
  110. }
  111. /// Add the incoming accumulator value, returns true if a value had not
  112. /// already been added. Returning false signals to the user that this
  113. /// reduction already has a value to initialise the accumulator.
  114. bool InsertAcc(Value *V) {
  115. if (Acc)
  116. return false;
  117. Acc = V;
  118. return true;
  119. }
  120. /// Set two MulCandidates, rooted at muls, that can be executed as a single
  121. /// parallel operation.
  122. void AddMulPair(MulCandidate *Mul0, MulCandidate *Mul1,
  123. bool Exchange = false) {
  124. LLVM_DEBUG(dbgs() << "Pairing:\n"
  125. << *Mul0->Root << "\n"
  126. << *Mul1->Root << "\n");
  127. Mul0->Paired = true;
  128. Mul1->Paired = true;
  129. if (Exchange)
  130. Mul1->Exchange = true;
  131. MulPairs.push_back(std::make_pair(Mul0, Mul1));
  132. }
  133. /// Return true if enough mul operations are found that can be executed in
  134. /// parallel.
  135. bool CreateParallelPairs();
  136. /// Return the add instruction which is the root of the reduction.
  137. Instruction *getRoot() { return Root; }
  138. bool is64Bit() const { return Root->getType()->isIntegerTy(64); }
  139. Type *getType() const { return Root->getType(); }
  140. /// Return the incoming value to be accumulated. This maybe null.
  141. Value *getAccumulator() { return Acc; }
  142. /// Return the set of adds that comprise the reduction.
  143. SetVector<Instruction*> &getAdds() { return Adds; }
  144. /// Return the MulCandidate, rooted at mul instruction, that comprise the
  145. /// the reduction.
  146. MulCandList &getMuls() { return Muls; }
  147. /// Return the MulCandidate, rooted at mul instructions, that have been
  148. /// paired for parallel execution.
  149. MulPairList &getMulPairs() { return MulPairs; }
  150. /// To finalise, replace the uses of the root with the intrinsic call.
  151. void UpdateRoot(Instruction *SMLAD) {
  152. Root->replaceAllUsesWith(SMLAD);
  153. }
  154. void dump() {
  155. LLVM_DEBUG(dbgs() << "Reduction:\n";
  156. for (auto *Add : Adds)
  157. LLVM_DEBUG(dbgs() << *Add << "\n");
  158. for (auto &Mul : Muls)
  159. LLVM_DEBUG(dbgs() << *Mul->Root << "\n"
  160. << " " << *Mul->LHS << "\n"
  161. << " " << *Mul->RHS << "\n");
  162. LLVM_DEBUG(if (Acc) dbgs() << "Acc in: " << *Acc << "\n")
  163. );
  164. }
  165. };
  166. class WidenedLoad {
  167. LoadInst *NewLd = nullptr;
  168. SmallVector<LoadInst*, 4> Loads;
  169. public:
  170. WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide)
  171. : NewLd(Wide) {
  172. append_range(Loads, Lds);
  173. }
  174. LoadInst *getLoad() {
  175. return NewLd;
  176. }
  177. };
  178. class ARMParallelDSP : public FunctionPass {
  179. ScalarEvolution *SE;
  180. AliasAnalysis *AA;
  181. TargetLibraryInfo *TLI;
  182. DominatorTree *DT;
  183. const DataLayout *DL;
  184. Module *M;
  185. std::map<LoadInst*, LoadInst*> LoadPairs;
  186. SmallPtrSet<LoadInst*, 4> OffsetLoads;
  187. std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
  188. template<unsigned>
  189. bool IsNarrowSequence(Value *V);
  190. bool Search(Value *V, BasicBlock *BB, Reduction &R);
  191. bool RecordMemoryOps(BasicBlock *BB);
  192. void InsertParallelMACs(Reduction &Reduction);
  193. bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
  194. LoadInst* CreateWideLoad(MemInstList &Loads, IntegerType *LoadTy);
  195. bool CreateParallelPairs(Reduction &R);
  196. /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
  197. /// Dual performs two signed 16x16-bit multiplications. It adds the
  198. /// products to a 32-bit accumulate operand. Optionally, the instruction can
  199. /// exchange the halfwords of the second operand before performing the
  200. /// arithmetic.
  201. bool MatchSMLAD(Function &F);
  202. public:
  203. static char ID;
  204. ARMParallelDSP() : FunctionPass(ID) { }
  205. void getAnalysisUsage(AnalysisUsage &AU) const override {
  206. FunctionPass::getAnalysisUsage(AU);
  207. AU.addRequired<AssumptionCacheTracker>();
  208. AU.addRequired<ScalarEvolutionWrapperPass>();
  209. AU.addRequired<AAResultsWrapperPass>();
  210. AU.addRequired<TargetLibraryInfoWrapperPass>();
  211. AU.addRequired<DominatorTreeWrapperPass>();
  212. AU.addRequired<TargetPassConfig>();
  213. AU.addPreserved<ScalarEvolutionWrapperPass>();
  214. AU.addPreserved<GlobalsAAWrapperPass>();
  215. AU.setPreservesCFG();
  216. }
  217. bool runOnFunction(Function &F) override {
  218. if (DisableParallelDSP)
  219. return false;
  220. if (skipFunction(F))
  221. return false;
  222. SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  223. AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
  224. TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
  225. DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  226. auto &TPC = getAnalysis<TargetPassConfig>();
  227. M = F.getParent();
  228. DL = &M->getDataLayout();
  229. auto &TM = TPC.getTM<TargetMachine>();
  230. auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
  231. if (!ST->allowsUnalignedMem()) {
  232. LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
  233. "running pass ARMParallelDSP\n");
  234. return false;
  235. }
  236. if (!ST->hasDSP()) {
  237. LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
  238. "ARMParallelDSP\n");
  239. return false;
  240. }
  241. if (!ST->isLittle()) {
  242. LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
  243. << "ARMParallelDSP\n");
  244. return false;
  245. }
  246. LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
  247. LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
  248. bool Changes = MatchSMLAD(F);
  249. return Changes;
  250. }
  251. };
  252. }
  253. bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
  254. MemInstList &VecMem) {
  255. if (!Ld0 || !Ld1)
  256. return false;
  257. if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
  258. return false;
  259. LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
  260. dbgs() << "Ld0:"; Ld0->dump();
  261. dbgs() << "Ld1:"; Ld1->dump();
  262. );
  263. VecMem.clear();
  264. VecMem.push_back(Ld0);
  265. VecMem.push_back(Ld1);
  266. return true;
  267. }
  268. // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
  269. // instructions, which is set to 16. So here we should collect all i8 and i16
  270. // narrow operations.
  271. // TODO: we currently only collect i16, and will support i8 later, so that's
  272. // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
  273. template<unsigned MaxBitWidth>
  274. bool ARMParallelDSP::IsNarrowSequence(Value *V) {
  275. if (auto *SExt = dyn_cast<SExtInst>(V)) {
  276. if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
  277. return false;
  278. if (auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) {
  279. // Check that this load could be paired.
  280. return LoadPairs.count(Ld) || OffsetLoads.count(Ld);
  281. }
  282. }
  283. return false;
  284. }
  285. /// Iterate through the block and record base, offset pairs of loads which can
  286. /// be widened into a single load.
  287. bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
  288. SmallVector<LoadInst*, 8> Loads;
  289. SmallVector<Instruction*, 8> Writes;
  290. LoadPairs.clear();
  291. WideLoads.clear();
  292. // Collect loads and instruction that may write to memory. For now we only
  293. // record loads which are simple, sign-extended and have a single user.
  294. // TODO: Allow zero-extended loads.
  295. for (auto &I : *BB) {
  296. if (I.mayWriteToMemory())
  297. Writes.push_back(&I);
  298. auto *Ld = dyn_cast<LoadInst>(&I);
  299. if (!Ld || !Ld->isSimple() ||
  300. !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
  301. continue;
  302. Loads.push_back(Ld);
  303. }
  304. if (Loads.empty() || Loads.size() > NumLoadLimit)
  305. return false;
  306. using InstSet = std::set<Instruction*>;
  307. using DepMap = std::map<Instruction*, InstSet>;
  308. DepMap RAWDeps;
  309. // Record any writes that may alias a load.
  310. const auto Size = LocationSize::beforeOrAfterPointer();
  311. for (auto Write : Writes) {
  312. for (auto Read : Loads) {
  313. MemoryLocation ReadLoc =
  314. MemoryLocation(Read->getPointerOperand(), Size);
  315. if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
  316. ModRefInfo::ModRef)))
  317. continue;
  318. if (Write->comesBefore(Read))
  319. RAWDeps[Read].insert(Write);
  320. }
  321. }
  322. // Check whether there's not a write between the two loads which would
  323. // prevent them from being safely merged.
  324. auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
  325. bool BaseFirst = Base->comesBefore(Offset);
  326. LoadInst *Dominator = BaseFirst ? Base : Offset;
  327. LoadInst *Dominated = BaseFirst ? Offset : Base;
  328. if (RAWDeps.count(Dominated)) {
  329. InstSet &WritesBefore = RAWDeps[Dominated];
  330. for (auto Before : WritesBefore) {
  331. // We can't move the second load backward, past a write, to merge
  332. // with the first load.
  333. if (Dominator->comesBefore(Before))
  334. return false;
  335. }
  336. }
  337. return true;
  338. };
  339. // Record base, offset load pairs.
  340. for (auto *Base : Loads) {
  341. for (auto *Offset : Loads) {
  342. if (Base == Offset || OffsetLoads.count(Offset))
  343. continue;
  344. if (isConsecutiveAccess(Base, Offset, *DL, *SE) &&
  345. SafeToPair(Base, Offset)) {
  346. LoadPairs[Base] = Offset;
  347. OffsetLoads.insert(Offset);
  348. break;
  349. }
  350. }
  351. }
  352. LLVM_DEBUG(if (!LoadPairs.empty()) {
  353. dbgs() << "Consecutive load pairs:\n";
  354. for (auto &MapIt : LoadPairs) {
  355. LLVM_DEBUG(dbgs() << *MapIt.first << ", "
  356. << *MapIt.second << "\n");
  357. }
  358. });
  359. return LoadPairs.size() > 1;
  360. }
  361. // Search recursively back through the operands to find a tree of values that
  362. // form a multiply-accumulate chain. The search records the Add and Mul
  363. // instructions that form the reduction and allows us to find a single value
  364. // to be used as the initial input to the accumlator.
  365. bool ARMParallelDSP::Search(Value *V, BasicBlock *BB, Reduction &R) {
  366. // If we find a non-instruction, try to use it as the initial accumulator
  367. // value. This may have already been found during the search in which case
  368. // this function will return false, signaling a search fail.
  369. auto *I = dyn_cast<Instruction>(V);
  370. if (!I)
  371. return R.InsertAcc(V);
  372. if (I->getParent() != BB)
  373. return false;
  374. switch (I->getOpcode()) {
  375. default:
  376. break;
  377. case Instruction::PHI:
  378. // Could be the accumulator value.
  379. return R.InsertAcc(V);
  380. case Instruction::Add: {
  381. // Adds should be adding together two muls, or another add and a mul to
  382. // be within the mac chain. One of the operands may also be the
  383. // accumulator value at which point we should stop searching.
  384. R.InsertAdd(I);
  385. Value *LHS = I->getOperand(0);
  386. Value *RHS = I->getOperand(1);
  387. bool ValidLHS = Search(LHS, BB, R);
  388. bool ValidRHS = Search(RHS, BB, R);
  389. if (ValidLHS && ValidRHS)
  390. return true;
  391. return R.InsertAcc(I);
  392. }
  393. case Instruction::Mul: {
  394. Value *MulOp0 = I->getOperand(0);
  395. Value *MulOp1 = I->getOperand(1);
  396. return IsNarrowSequence<16>(MulOp0) && IsNarrowSequence<16>(MulOp1);
  397. }
  398. case Instruction::SExt:
  399. return Search(I->getOperand(0), BB, R);
  400. }
  401. return false;
  402. }
  403. // The pass needs to identify integer add/sub reductions of 16-bit vector
  404. // multiplications.
  405. // To use SMLAD:
  406. // 1) we first need to find integer add then look for this pattern:
  407. //
  408. // acc0 = ...
  409. // ld0 = load i16
  410. // sext0 = sext i16 %ld0 to i32
  411. // ld1 = load i16
  412. // sext1 = sext i16 %ld1 to i32
  413. // mul0 = mul %sext0, %sext1
  414. // ld2 = load i16
  415. // sext2 = sext i16 %ld2 to i32
  416. // ld3 = load i16
  417. // sext3 = sext i16 %ld3 to i32
  418. // mul1 = mul i32 %sext2, %sext3
  419. // add0 = add i32 %mul0, %acc0
  420. // acc1 = add i32 %add0, %mul1
  421. //
  422. // Which can be selected to:
  423. //
  424. // ldr r0
  425. // ldr r1
  426. // smlad r2, r0, r1, r2
  427. //
  428. // If constants are used instead of loads, these will need to be hoisted
  429. // out and into a register.
  430. //
  431. // If loop invariants are used instead of loads, these need to be packed
  432. // before the loop begins.
  433. //
  434. bool ARMParallelDSP::MatchSMLAD(Function &F) {
  435. bool Changed = false;
  436. for (auto &BB : F) {
  437. SmallPtrSet<Instruction*, 4> AllAdds;
  438. if (!RecordMemoryOps(&BB))
  439. continue;
  440. for (Instruction &I : reverse(BB)) {
  441. if (I.getOpcode() != Instruction::Add)
  442. continue;
  443. if (AllAdds.count(&I))
  444. continue;
  445. const auto *Ty = I.getType();
  446. if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
  447. continue;
  448. Reduction R(&I);
  449. if (!Search(&I, &BB, R))
  450. continue;
  451. R.InsertMuls();
  452. LLVM_DEBUG(dbgs() << "After search, Reduction:\n"; R.dump());
  453. if (!CreateParallelPairs(R))
  454. continue;
  455. InsertParallelMACs(R);
  456. Changed = true;
  457. AllAdds.insert(R.getAdds().begin(), R.getAdds().end());
  458. }
  459. }
  460. return Changed;
  461. }
  462. bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
  463. // Not enough mul operations to make a pair.
  464. if (R.getMuls().size() < 2)
  465. return false;
  466. // Check that the muls operate directly upon sign extended loads.
  467. for (auto &MulCand : R.getMuls()) {
  468. if (!MulCand->HasTwoLoadInputs())
  469. return false;
  470. }
  471. auto CanPair = [&](Reduction &R, MulCandidate *PMul0, MulCandidate *PMul1) {
  472. // The first elements of each vector should be loads with sexts. If we
  473. // find that its two pairs of consecutive loads, then these can be
  474. // transformed into two wider loads and the users can be replaced with
  475. // DSP intrinsics.
  476. auto Ld0 = static_cast<LoadInst*>(PMul0->LHS);
  477. auto Ld1 = static_cast<LoadInst*>(PMul1->LHS);
  478. auto Ld2 = static_cast<LoadInst*>(PMul0->RHS);
  479. auto Ld3 = static_cast<LoadInst*>(PMul1->RHS);
  480. // Check that each mul is operating on two different loads.
  481. if (Ld0 == Ld2 || Ld1 == Ld3)
  482. return false;
  483. if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
  484. if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
  485. LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
  486. R.AddMulPair(PMul0, PMul1);
  487. return true;
  488. } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
  489. LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
  490. LLVM_DEBUG(dbgs() << " exchanging Ld2 and Ld3\n");
  491. R.AddMulPair(PMul0, PMul1, true);
  492. return true;
  493. }
  494. } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
  495. AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
  496. LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
  497. LLVM_DEBUG(dbgs() << " exchanging Ld0 and Ld1\n");
  498. LLVM_DEBUG(dbgs() << " and swapping muls\n");
  499. // Only the second operand can be exchanged, so swap the muls.
  500. R.AddMulPair(PMul1, PMul0, true);
  501. return true;
  502. }
  503. return false;
  504. };
  505. MulCandList &Muls = R.getMuls();
  506. const unsigned Elems = Muls.size();
  507. for (unsigned i = 0; i < Elems; ++i) {
  508. MulCandidate *PMul0 = static_cast<MulCandidate*>(Muls[i].get());
  509. if (PMul0->Paired)
  510. continue;
  511. for (unsigned j = 0; j < Elems; ++j) {
  512. if (i == j)
  513. continue;
  514. MulCandidate *PMul1 = static_cast<MulCandidate*>(Muls[j].get());
  515. if (PMul1->Paired)
  516. continue;
  517. const Instruction *Mul0 = PMul0->Root;
  518. const Instruction *Mul1 = PMul1->Root;
  519. if (Mul0 == Mul1)
  520. continue;
  521. assert(PMul0 != PMul1 && "expected different chains");
  522. if (CanPair(R, PMul0, PMul1))
  523. break;
  524. }
  525. }
  526. return !R.getMulPairs().empty();
  527. }
  528. void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
  529. auto CreateSMLAD = [&](LoadInst* WideLd0, LoadInst *WideLd1,
  530. Value *Acc, bool Exchange,
  531. Instruction *InsertAfter) {
  532. // Replace the reduction chain with an intrinsic call
  533. Value* Args[] = { WideLd0, WideLd1, Acc };
  534. Function *SMLAD = nullptr;
  535. if (Exchange)
  536. SMLAD = Acc->getType()->isIntegerTy(32) ?
  537. Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
  538. Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
  539. else
  540. SMLAD = Acc->getType()->isIntegerTy(32) ?
  541. Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
  542. Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
  543. IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
  544. BasicBlock::iterator(InsertAfter));
  545. Instruction *Call = Builder.CreateCall(SMLAD, Args);
  546. NumSMLAD++;
  547. return Call;
  548. };
  549. // Return the instruction after the dominated instruction.
  550. auto GetInsertPoint = [this](Value *A, Value *B) {
  551. assert((isa<Instruction>(A) || isa<Instruction>(B)) &&
  552. "expected at least one instruction");
  553. Value *V = nullptr;
  554. if (!isa<Instruction>(A))
  555. V = B;
  556. else if (!isa<Instruction>(B))
  557. V = A;
  558. else
  559. V = DT->dominates(cast<Instruction>(A), cast<Instruction>(B)) ? B : A;
  560. return &*++BasicBlock::iterator(cast<Instruction>(V));
  561. };
  562. Value *Acc = R.getAccumulator();
  563. // For any muls that were discovered but not paired, accumulate their values
  564. // as before.
  565. IRBuilder<NoFolder> Builder(R.getRoot()->getParent());
  566. MulCandList &MulCands = R.getMuls();
  567. for (auto &MulCand : MulCands) {
  568. if (MulCand->Paired)
  569. continue;
  570. Instruction *Mul = cast<Instruction>(MulCand->Root);
  571. LLVM_DEBUG(dbgs() << "Accumulating unpaired mul: " << *Mul << "\n");
  572. if (R.getType() != Mul->getType()) {
  573. assert(R.is64Bit() && "expected 64-bit result");
  574. Builder.SetInsertPoint(&*++BasicBlock::iterator(Mul));
  575. Mul = cast<Instruction>(Builder.CreateSExt(Mul, R.getRoot()->getType()));
  576. }
  577. if (!Acc) {
  578. Acc = Mul;
  579. continue;
  580. }
  581. // If Acc is the original incoming value to the reduction, it could be a
  582. // phi. But the phi will dominate Mul, meaning that Mul will be the
  583. // insertion point.
  584. Builder.SetInsertPoint(GetInsertPoint(Mul, Acc));
  585. Acc = Builder.CreateAdd(Mul, Acc);
  586. }
  587. if (!Acc) {
  588. Acc = R.is64Bit() ?
  589. ConstantInt::get(IntegerType::get(M->getContext(), 64), 0) :
  590. ConstantInt::get(IntegerType::get(M->getContext(), 32), 0);
  591. } else if (Acc->getType() != R.getType()) {
  592. Builder.SetInsertPoint(R.getRoot());
  593. Acc = Builder.CreateSExt(Acc, R.getType());
  594. }
  595. // Roughly sort the mul pairs in their program order.
  596. llvm::sort(R.getMulPairs(), [](auto &PairA, auto &PairB) {
  597. const Instruction *A = PairA.first->Root;
  598. const Instruction *B = PairB.first->Root;
  599. return A->comesBefore(B);
  600. });
  601. IntegerType *Ty = IntegerType::get(M->getContext(), 32);
  602. for (auto &Pair : R.getMulPairs()) {
  603. MulCandidate *LHSMul = Pair.first;
  604. MulCandidate *RHSMul = Pair.second;
  605. LoadInst *BaseLHS = LHSMul->getBaseLoad();
  606. LoadInst *BaseRHS = RHSMul->getBaseLoad();
  607. LoadInst *WideLHS = WideLoads.count(BaseLHS) ?
  608. WideLoads[BaseLHS]->getLoad() : CreateWideLoad(LHSMul->VecLd, Ty);
  609. LoadInst *WideRHS = WideLoads.count(BaseRHS) ?
  610. WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty);
  611. Instruction *InsertAfter = GetInsertPoint(WideLHS, WideRHS);
  612. InsertAfter = GetInsertPoint(InsertAfter, Acc);
  613. Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter);
  614. }
  615. R.UpdateRoot(cast<Instruction>(Acc));
  616. }
  617. LoadInst* ARMParallelDSP::CreateWideLoad(MemInstList &Loads,
  618. IntegerType *LoadTy) {
  619. assert(Loads.size() == 2 && "currently only support widening two loads");
  620. LoadInst *Base = Loads[0];
  621. LoadInst *Offset = Loads[1];
  622. Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back());
  623. Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back());
  624. assert((BaseSExt && OffsetSExt)
  625. && "Loads should have a single, extending, user");
  626. std::function<void(Value*, Value*)> MoveBefore =
  627. [&](Value *A, Value *B) -> void {
  628. if (!isa<Instruction>(A) || !isa<Instruction>(B))
  629. return;
  630. auto *Source = cast<Instruction>(A);
  631. auto *Sink = cast<Instruction>(B);
  632. if (DT->dominates(Source, Sink) ||
  633. Source->getParent() != Sink->getParent() ||
  634. isa<PHINode>(Source) || isa<PHINode>(Sink))
  635. return;
  636. Source->moveBefore(Sink);
  637. for (auto &Op : Source->operands())
  638. MoveBefore(Op, Source);
  639. };
  640. // Insert the load at the point of the original dominating load.
  641. LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset;
  642. IRBuilder<NoFolder> IRB(DomLoad->getParent(),
  643. ++BasicBlock::iterator(DomLoad));
  644. // Bitcast the pointer to a wider type and create the wide load, while making
  645. // sure to maintain the original alignment as this prevents ldrd from being
  646. // generated when it could be illegal due to memory alignment.
  647. const unsigned AddrSpace = DomLoad->getPointerAddressSpace();
  648. Value *VecPtr = IRB.CreateBitCast(Base->getPointerOperand(),
  649. LoadTy->getPointerTo(AddrSpace));
  650. LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr, Base->getAlign());
  651. // Make sure everything is in the correct order in the basic block.
  652. MoveBefore(Base->getPointerOperand(), VecPtr);
  653. MoveBefore(VecPtr, WideLoad);
  654. // From the wide load, create two values that equal the original two loads.
  655. // Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
  656. // TODO: Support big-endian as well.
  657. Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType());
  658. Value *NewBaseSExt = IRB.CreateSExt(Bottom, BaseSExt->getType());
  659. BaseSExt->replaceAllUsesWith(NewBaseSExt);
  660. IntegerType *OffsetTy = cast<IntegerType>(Offset->getType());
  661. Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth());
  662. Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
  663. Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
  664. Value *NewOffsetSExt = IRB.CreateSExt(Trunc, OffsetSExt->getType());
  665. OffsetSExt->replaceAllUsesWith(NewOffsetSExt);
  666. LLVM_DEBUG(dbgs() << "From Base and Offset:\n"
  667. << *Base << "\n" << *Offset << "\n"
  668. << "Created Wide Load:\n"
  669. << *WideLoad << "\n"
  670. << *Bottom << "\n"
  671. << *NewBaseSExt << "\n"
  672. << *Top << "\n"
  673. << *Trunc << "\n"
  674. << *NewOffsetSExt << "\n");
  675. WideLoads.emplace(std::make_pair(Base,
  676. std::make_unique<WidenedLoad>(Loads, WideLoad)));
  677. return WideLoad;
  678. }
  679. Pass *llvm::createARMParallelDSPPass() {
  680. return new ARMParallelDSP();
  681. }
  682. char ARMParallelDSP::ID = 0;
  683. INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
  684. "Transform functions to use DSP intrinsics", false, false)
  685. INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
  686. "Transform functions to use DSP intrinsics", false, false)