ARMParallelDSP.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821
  1. //===- ARMParallelDSP.cpp - Parallel DSP Pass -----------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. /// \file
  10. /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
  11. /// purpose of this pass is do some IR pattern matching to create ACLE
  12. /// DSP intrinsics, which map on these 32-bit SIMD operations.
  13. /// This pass runs only when unaligned accesses is supported/enabled.
  14. //
  15. //===----------------------------------------------------------------------===//
  16. #include "ARM.h"
  17. #include "ARMSubtarget.h"
  18. #include "llvm/ADT/SmallPtrSet.h"
  19. #include "llvm/ADT/Statistic.h"
  20. #include "llvm/Analysis/AliasAnalysis.h"
  21. #include "llvm/Analysis/AssumptionCache.h"
  22. #include "llvm/Analysis/GlobalsModRef.h"
  23. #include "llvm/Analysis/LoopAccessAnalysis.h"
  24. #include "llvm/Analysis/TargetLibraryInfo.h"
  25. #include "llvm/CodeGen/TargetPassConfig.h"
  26. #include "llvm/IR/Instructions.h"
  27. #include "llvm/IR/IntrinsicsARM.h"
  28. #include "llvm/IR/NoFolder.h"
  29. #include "llvm/IR/PatternMatch.h"
  30. #include "llvm/Pass.h"
  31. #include "llvm/PassRegistry.h"
  32. #include "llvm/Support/Debug.h"
  33. #include "llvm/Transforms/Scalar.h"
  34. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  35. using namespace llvm;
  36. using namespace PatternMatch;
  37. #define DEBUG_TYPE "arm-parallel-dsp"
  38. STATISTIC(NumSMLAD , "Number of smlad instructions generated");
  39. static cl::opt<bool>
  40. DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
  41. cl::desc("Disable the ARM Parallel DSP pass"));
  42. static cl::opt<unsigned>
  43. NumLoadLimit("arm-parallel-dsp-load-limit", cl::Hidden, cl::init(16),
  44. cl::desc("Limit the number of loads analysed"));
  45. namespace {
  46. struct MulCandidate;
  47. class Reduction;
  48. using MulCandList = SmallVector<std::unique_ptr<MulCandidate>, 8>;
  49. using MemInstList = SmallVectorImpl<LoadInst*>;
  50. using MulPairList = SmallVector<std::pair<MulCandidate*, MulCandidate*>, 8>;
  51. // 'MulCandidate' holds the multiplication instructions that are candidates
  52. // for parallel execution.
  53. struct MulCandidate {
  54. Instruction *Root;
  55. Value* LHS;
  56. Value* RHS;
  57. bool Exchange = false;
  58. bool ReadOnly = true;
  59. bool Paired = false;
  60. SmallVector<LoadInst*, 2> VecLd; // Container for loads to widen.
  61. MulCandidate(Instruction *I, Value *lhs, Value *rhs) :
  62. Root(I), LHS(lhs), RHS(rhs) { }
  63. bool HasTwoLoadInputs() const {
  64. return isa<LoadInst>(LHS) && isa<LoadInst>(RHS);
  65. }
  66. LoadInst *getBaseLoad() const {
  67. return VecLd.front();
  68. }
  69. };
  70. /// Represent a sequence of multiply-accumulate operations with the aim to
  71. /// perform the multiplications in parallel.
  72. class Reduction {
  73. Instruction *Root = nullptr;
  74. Value *Acc = nullptr;
  75. MulCandList Muls;
  76. MulPairList MulPairs;
  77. SetVector<Instruction*> Adds;
  78. public:
  79. Reduction() = delete;
  80. Reduction (Instruction *Add) : Root(Add) { }
  81. /// Record an Add instruction that is a part of the this reduction.
  82. void InsertAdd(Instruction *I) { Adds.insert(I); }
  83. /// Create MulCandidates, each rooted at a Mul instruction, that is a part
  84. /// of this reduction.
  85. void InsertMuls() {
  86. auto GetMulOperand = [](Value *V) -> Instruction* {
  87. if (auto *SExt = dyn_cast<SExtInst>(V)) {
  88. if (auto *I = dyn_cast<Instruction>(SExt->getOperand(0)))
  89. if (I->getOpcode() == Instruction::Mul)
  90. return I;
  91. } else if (auto *I = dyn_cast<Instruction>(V)) {
  92. if (I->getOpcode() == Instruction::Mul)
  93. return I;
  94. }
  95. return nullptr;
  96. };
  97. auto InsertMul = [this](Instruction *I) {
  98. Value *LHS = cast<Instruction>(I->getOperand(0))->getOperand(0);
  99. Value *RHS = cast<Instruction>(I->getOperand(1))->getOperand(0);
  100. Muls.push_back(std::make_unique<MulCandidate>(I, LHS, RHS));
  101. };
  102. for (auto *Add : Adds) {
  103. if (Add == Acc)
  104. continue;
  105. if (auto *Mul = GetMulOperand(Add->getOperand(0)))
  106. InsertMul(Mul);
  107. if (auto *Mul = GetMulOperand(Add->getOperand(1)))
  108. InsertMul(Mul);
  109. }
  110. }
  111. /// Add the incoming accumulator value, returns true if a value had not
  112. /// already been added. Returning false signals to the user that this
  113. /// reduction already has a value to initialise the accumulator.
  114. bool InsertAcc(Value *V) {
  115. if (Acc)
  116. return false;
  117. Acc = V;
  118. return true;
  119. }
  120. /// Set two MulCandidates, rooted at muls, that can be executed as a single
  121. /// parallel operation.
  122. void AddMulPair(MulCandidate *Mul0, MulCandidate *Mul1,
  123. bool Exchange = false) {
  124. LLVM_DEBUG(dbgs() << "Pairing:\n"
  125. << *Mul0->Root << "\n"
  126. << *Mul1->Root << "\n");
  127. Mul0->Paired = true;
  128. Mul1->Paired = true;
  129. if (Exchange)
  130. Mul1->Exchange = true;
  131. MulPairs.push_back(std::make_pair(Mul0, Mul1));
  132. }
  133. /// Return true if enough mul operations are found that can be executed in
  134. /// parallel.
  135. bool CreateParallelPairs();
  136. /// Return the add instruction which is the root of the reduction.
  137. Instruction *getRoot() { return Root; }
  138. bool is64Bit() const { return Root->getType()->isIntegerTy(64); }
  139. Type *getType() const { return Root->getType(); }
  140. /// Return the incoming value to be accumulated. This maybe null.
  141. Value *getAccumulator() { return Acc; }
  142. /// Return the set of adds that comprise the reduction.
  143. SetVector<Instruction*> &getAdds() { return Adds; }
  144. /// Return the MulCandidate, rooted at mul instruction, that comprise the
  145. /// the reduction.
  146. MulCandList &getMuls() { return Muls; }
  147. /// Return the MulCandidate, rooted at mul instructions, that have been
  148. /// paired for parallel execution.
  149. MulPairList &getMulPairs() { return MulPairs; }
  150. /// To finalise, replace the uses of the root with the intrinsic call.
  151. void UpdateRoot(Instruction *SMLAD) {
  152. Root->replaceAllUsesWith(SMLAD);
  153. }
  154. void dump() {
  155. LLVM_DEBUG(dbgs() << "Reduction:\n";
  156. for (auto *Add : Adds)
  157. LLVM_DEBUG(dbgs() << *Add << "\n");
  158. for (auto &Mul : Muls)
  159. LLVM_DEBUG(dbgs() << *Mul->Root << "\n"
  160. << " " << *Mul->LHS << "\n"
  161. << " " << *Mul->RHS << "\n");
  162. LLVM_DEBUG(if (Acc) dbgs() << "Acc in: " << *Acc << "\n")
  163. );
  164. }
  165. };
  166. class WidenedLoad {
  167. LoadInst *NewLd = nullptr;
  168. SmallVector<LoadInst*, 4> Loads;
  169. public:
  170. WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide)
  171. : NewLd(Wide) {
  172. append_range(Loads, Lds);
  173. }
  174. LoadInst *getLoad() {
  175. return NewLd;
  176. }
  177. };
  178. class ARMParallelDSP : public FunctionPass {
  179. ScalarEvolution *SE;
  180. AliasAnalysis *AA;
  181. TargetLibraryInfo *TLI;
  182. DominatorTree *DT;
  183. const DataLayout *DL;
  184. Module *M;
  185. std::map<LoadInst*, LoadInst*> LoadPairs;
  186. SmallPtrSet<LoadInst*, 4> OffsetLoads;
  187. std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
  188. template<unsigned>
  189. bool IsNarrowSequence(Value *V);
  190. bool Search(Value *V, BasicBlock *BB, Reduction &R);
  191. bool RecordMemoryOps(BasicBlock *BB);
  192. void InsertParallelMACs(Reduction &Reduction);
  193. bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
  194. LoadInst* CreateWideLoad(MemInstList &Loads, IntegerType *LoadTy);
  195. bool CreateParallelPairs(Reduction &R);
  196. /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
  197. /// Dual performs two signed 16x16-bit multiplications. It adds the
  198. /// products to a 32-bit accumulate operand. Optionally, the instruction can
  199. /// exchange the halfwords of the second operand before performing the
  200. /// arithmetic.
  201. bool MatchSMLAD(Function &F);
  202. public:
  203. static char ID;
  204. ARMParallelDSP() : FunctionPass(ID) { }
  205. void getAnalysisUsage(AnalysisUsage &AU) const override {
  206. FunctionPass::getAnalysisUsage(AU);
  207. AU.addRequired<AssumptionCacheTracker>();
  208. AU.addRequired<ScalarEvolutionWrapperPass>();
  209. AU.addRequired<AAResultsWrapperPass>();
  210. AU.addRequired<TargetLibraryInfoWrapperPass>();
  211. AU.addRequired<DominatorTreeWrapperPass>();
  212. AU.addRequired<TargetPassConfig>();
  213. AU.addPreserved<ScalarEvolutionWrapperPass>();
  214. AU.addPreserved<GlobalsAAWrapperPass>();
  215. AU.setPreservesCFG();
  216. }
  217. bool runOnFunction(Function &F) override {
  218. if (DisableParallelDSP)
  219. return false;
  220. if (skipFunction(F))
  221. return false;
  222. SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  223. AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
  224. TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
  225. DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  226. auto &TPC = getAnalysis<TargetPassConfig>();
  227. M = F.getParent();
  228. DL = &M->getDataLayout();
  229. auto &TM = TPC.getTM<TargetMachine>();
  230. auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
  231. if (!ST->allowsUnalignedMem()) {
  232. LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
  233. "running pass ARMParallelDSP\n");
  234. return false;
  235. }
  236. if (!ST->hasDSP()) {
  237. LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
  238. "ARMParallelDSP\n");
  239. return false;
  240. }
  241. if (!ST->isLittle()) {
  242. LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
  243. << "ARMParallelDSP\n");
  244. return false;
  245. }
  246. LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
  247. LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
  248. bool Changes = MatchSMLAD(F);
  249. return Changes;
  250. }
  251. };
  252. }
  253. template<typename MemInst>
  254. static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
  255. const DataLayout &DL, ScalarEvolution &SE) {
  256. if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE))
  257. return true;
  258. return false;
  259. }
  260. bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
  261. MemInstList &VecMem) {
  262. if (!Ld0 || !Ld1)
  263. return false;
  264. if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
  265. return false;
  266. LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
  267. dbgs() << "Ld0:"; Ld0->dump();
  268. dbgs() << "Ld1:"; Ld1->dump();
  269. );
  270. VecMem.clear();
  271. VecMem.push_back(Ld0);
  272. VecMem.push_back(Ld1);
  273. return true;
  274. }
  275. // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
  276. // instructions, which is set to 16. So here we should collect all i8 and i16
  277. // narrow operations.
  278. // TODO: we currently only collect i16, and will support i8 later, so that's
  279. // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
  280. template<unsigned MaxBitWidth>
  281. bool ARMParallelDSP::IsNarrowSequence(Value *V) {
  282. if (auto *SExt = dyn_cast<SExtInst>(V)) {
  283. if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
  284. return false;
  285. if (auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) {
  286. // Check that this load could be paired.
  287. return LoadPairs.count(Ld) || OffsetLoads.count(Ld);
  288. }
  289. }
  290. return false;
  291. }
  292. /// Iterate through the block and record base, offset pairs of loads which can
  293. /// be widened into a single load.
  294. bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
  295. SmallVector<LoadInst*, 8> Loads;
  296. SmallVector<Instruction*, 8> Writes;
  297. LoadPairs.clear();
  298. WideLoads.clear();
  299. // Collect loads and instruction that may write to memory. For now we only
  300. // record loads which are simple, sign-extended and have a single user.
  301. // TODO: Allow zero-extended loads.
  302. for (auto &I : *BB) {
  303. if (I.mayWriteToMemory())
  304. Writes.push_back(&I);
  305. auto *Ld = dyn_cast<LoadInst>(&I);
  306. if (!Ld || !Ld->isSimple() ||
  307. !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
  308. continue;
  309. Loads.push_back(Ld);
  310. }
  311. if (Loads.empty() || Loads.size() > NumLoadLimit)
  312. return false;
  313. using InstSet = std::set<Instruction*>;
  314. using DepMap = std::map<Instruction*, InstSet>;
  315. DepMap RAWDeps;
  316. // Record any writes that may alias a load.
  317. const auto Size = LocationSize::beforeOrAfterPointer();
  318. for (auto Write : Writes) {
  319. for (auto Read : Loads) {
  320. MemoryLocation ReadLoc =
  321. MemoryLocation(Read->getPointerOperand(), Size);
  322. if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
  323. ModRefInfo::ModRef)))
  324. continue;
  325. if (Write->comesBefore(Read))
  326. RAWDeps[Read].insert(Write);
  327. }
  328. }
  329. // Check whether there's not a write between the two loads which would
  330. // prevent them from being safely merged.
  331. auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
  332. bool BaseFirst = Base->comesBefore(Offset);
  333. LoadInst *Dominator = BaseFirst ? Base : Offset;
  334. LoadInst *Dominated = BaseFirst ? Offset : Base;
  335. if (RAWDeps.count(Dominated)) {
  336. InstSet &WritesBefore = RAWDeps[Dominated];
  337. for (auto Before : WritesBefore) {
  338. // We can't move the second load backward, past a write, to merge
  339. // with the first load.
  340. if (Dominator->comesBefore(Before))
  341. return false;
  342. }
  343. }
  344. return true;
  345. };
  346. // Record base, offset load pairs.
  347. for (auto *Base : Loads) {
  348. for (auto *Offset : Loads) {
  349. if (Base == Offset || OffsetLoads.count(Offset))
  350. continue;
  351. if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&
  352. SafeToPair(Base, Offset)) {
  353. LoadPairs[Base] = Offset;
  354. OffsetLoads.insert(Offset);
  355. break;
  356. }
  357. }
  358. }
  359. LLVM_DEBUG(if (!LoadPairs.empty()) {
  360. dbgs() << "Consecutive load pairs:\n";
  361. for (auto &MapIt : LoadPairs) {
  362. LLVM_DEBUG(dbgs() << *MapIt.first << ", "
  363. << *MapIt.second << "\n");
  364. }
  365. });
  366. return LoadPairs.size() > 1;
  367. }
  368. // Search recursively back through the operands to find a tree of values that
  369. // form a multiply-accumulate chain. The search records the Add and Mul
  370. // instructions that form the reduction and allows us to find a single value
  371. // to be used as the initial input to the accumlator.
  372. bool ARMParallelDSP::Search(Value *V, BasicBlock *BB, Reduction &R) {
  373. // If we find a non-instruction, try to use it as the initial accumulator
  374. // value. This may have already been found during the search in which case
  375. // this function will return false, signaling a search fail.
  376. auto *I = dyn_cast<Instruction>(V);
  377. if (!I)
  378. return R.InsertAcc(V);
  379. if (I->getParent() != BB)
  380. return false;
  381. switch (I->getOpcode()) {
  382. default:
  383. break;
  384. case Instruction::PHI:
  385. // Could be the accumulator value.
  386. return R.InsertAcc(V);
  387. case Instruction::Add: {
  388. // Adds should be adding together two muls, or another add and a mul to
  389. // be within the mac chain. One of the operands may also be the
  390. // accumulator value at which point we should stop searching.
  391. R.InsertAdd(I);
  392. Value *LHS = I->getOperand(0);
  393. Value *RHS = I->getOperand(1);
  394. bool ValidLHS = Search(LHS, BB, R);
  395. bool ValidRHS = Search(RHS, BB, R);
  396. if (ValidLHS && ValidRHS)
  397. return true;
  398. return R.InsertAcc(I);
  399. }
  400. case Instruction::Mul: {
  401. Value *MulOp0 = I->getOperand(0);
  402. Value *MulOp1 = I->getOperand(1);
  403. return IsNarrowSequence<16>(MulOp0) && IsNarrowSequence<16>(MulOp1);
  404. }
  405. case Instruction::SExt:
  406. return Search(I->getOperand(0), BB, R);
  407. }
  408. return false;
  409. }
  410. // The pass needs to identify integer add/sub reductions of 16-bit vector
  411. // multiplications.
  412. // To use SMLAD:
  413. // 1) we first need to find integer add then look for this pattern:
  414. //
  415. // acc0 = ...
  416. // ld0 = load i16
  417. // sext0 = sext i16 %ld0 to i32
  418. // ld1 = load i16
  419. // sext1 = sext i16 %ld1 to i32
  420. // mul0 = mul %sext0, %sext1
  421. // ld2 = load i16
  422. // sext2 = sext i16 %ld2 to i32
  423. // ld3 = load i16
  424. // sext3 = sext i16 %ld3 to i32
  425. // mul1 = mul i32 %sext2, %sext3
  426. // add0 = add i32 %mul0, %acc0
  427. // acc1 = add i32 %add0, %mul1
  428. //
  429. // Which can be selected to:
  430. //
  431. // ldr r0
  432. // ldr r1
  433. // smlad r2, r0, r1, r2
  434. //
  435. // If constants are used instead of loads, these will need to be hoisted
  436. // out and into a register.
  437. //
  438. // If loop invariants are used instead of loads, these need to be packed
  439. // before the loop begins.
  440. //
  441. bool ARMParallelDSP::MatchSMLAD(Function &F) {
  442. bool Changed = false;
  443. for (auto &BB : F) {
  444. SmallPtrSet<Instruction*, 4> AllAdds;
  445. if (!RecordMemoryOps(&BB))
  446. continue;
  447. for (Instruction &I : reverse(BB)) {
  448. if (I.getOpcode() != Instruction::Add)
  449. continue;
  450. if (AllAdds.count(&I))
  451. continue;
  452. const auto *Ty = I.getType();
  453. if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
  454. continue;
  455. Reduction R(&I);
  456. if (!Search(&I, &BB, R))
  457. continue;
  458. R.InsertMuls();
  459. LLVM_DEBUG(dbgs() << "After search, Reduction:\n"; R.dump());
  460. if (!CreateParallelPairs(R))
  461. continue;
  462. InsertParallelMACs(R);
  463. Changed = true;
  464. AllAdds.insert(R.getAdds().begin(), R.getAdds().end());
  465. }
  466. }
  467. return Changed;
  468. }
  469. bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
  470. // Not enough mul operations to make a pair.
  471. if (R.getMuls().size() < 2)
  472. return false;
  473. // Check that the muls operate directly upon sign extended loads.
  474. for (auto &MulCand : R.getMuls()) {
  475. if (!MulCand->HasTwoLoadInputs())
  476. return false;
  477. }
  478. auto CanPair = [&](Reduction &R, MulCandidate *PMul0, MulCandidate *PMul1) {
  479. // The first elements of each vector should be loads with sexts. If we
  480. // find that its two pairs of consecutive loads, then these can be
  481. // transformed into two wider loads and the users can be replaced with
  482. // DSP intrinsics.
  483. auto Ld0 = static_cast<LoadInst*>(PMul0->LHS);
  484. auto Ld1 = static_cast<LoadInst*>(PMul1->LHS);
  485. auto Ld2 = static_cast<LoadInst*>(PMul0->RHS);
  486. auto Ld3 = static_cast<LoadInst*>(PMul1->RHS);
  487. // Check that each mul is operating on two different loads.
  488. if (Ld0 == Ld2 || Ld1 == Ld3)
  489. return false;
  490. if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
  491. if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
  492. LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
  493. R.AddMulPair(PMul0, PMul1);
  494. return true;
  495. } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
  496. LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
  497. LLVM_DEBUG(dbgs() << " exchanging Ld2 and Ld3\n");
  498. R.AddMulPair(PMul0, PMul1, true);
  499. return true;
  500. }
  501. } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
  502. AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
  503. LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
  504. LLVM_DEBUG(dbgs() << " exchanging Ld0 and Ld1\n");
  505. LLVM_DEBUG(dbgs() << " and swapping muls\n");
  506. // Only the second operand can be exchanged, so swap the muls.
  507. R.AddMulPair(PMul1, PMul0, true);
  508. return true;
  509. }
  510. return false;
  511. };
  512. MulCandList &Muls = R.getMuls();
  513. const unsigned Elems = Muls.size();
  514. for (unsigned i = 0; i < Elems; ++i) {
  515. MulCandidate *PMul0 = static_cast<MulCandidate*>(Muls[i].get());
  516. if (PMul0->Paired)
  517. continue;
  518. for (unsigned j = 0; j < Elems; ++j) {
  519. if (i == j)
  520. continue;
  521. MulCandidate *PMul1 = static_cast<MulCandidate*>(Muls[j].get());
  522. if (PMul1->Paired)
  523. continue;
  524. const Instruction *Mul0 = PMul0->Root;
  525. const Instruction *Mul1 = PMul1->Root;
  526. if (Mul0 == Mul1)
  527. continue;
  528. assert(PMul0 != PMul1 && "expected different chains");
  529. if (CanPair(R, PMul0, PMul1))
  530. break;
  531. }
  532. }
  533. return !R.getMulPairs().empty();
  534. }
  535. void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
  536. auto CreateSMLAD = [&](LoadInst* WideLd0, LoadInst *WideLd1,
  537. Value *Acc, bool Exchange,
  538. Instruction *InsertAfter) {
  539. // Replace the reduction chain with an intrinsic call
  540. Value* Args[] = { WideLd0, WideLd1, Acc };
  541. Function *SMLAD = nullptr;
  542. if (Exchange)
  543. SMLAD = Acc->getType()->isIntegerTy(32) ?
  544. Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
  545. Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
  546. else
  547. SMLAD = Acc->getType()->isIntegerTy(32) ?
  548. Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
  549. Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
  550. IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
  551. BasicBlock::iterator(InsertAfter));
  552. Instruction *Call = Builder.CreateCall(SMLAD, Args);
  553. NumSMLAD++;
  554. return Call;
  555. };
  556. // Return the instruction after the dominated instruction.
  557. auto GetInsertPoint = [this](Value *A, Value *B) {
  558. assert((isa<Instruction>(A) || isa<Instruction>(B)) &&
  559. "expected at least one instruction");
  560. Value *V = nullptr;
  561. if (!isa<Instruction>(A))
  562. V = B;
  563. else if (!isa<Instruction>(B))
  564. V = A;
  565. else
  566. V = DT->dominates(cast<Instruction>(A), cast<Instruction>(B)) ? B : A;
  567. return &*++BasicBlock::iterator(cast<Instruction>(V));
  568. };
  569. Value *Acc = R.getAccumulator();
  570. // For any muls that were discovered but not paired, accumulate their values
  571. // as before.
  572. IRBuilder<NoFolder> Builder(R.getRoot()->getParent());
  573. MulCandList &MulCands = R.getMuls();
  574. for (auto &MulCand : MulCands) {
  575. if (MulCand->Paired)
  576. continue;
  577. Instruction *Mul = cast<Instruction>(MulCand->Root);
  578. LLVM_DEBUG(dbgs() << "Accumulating unpaired mul: " << *Mul << "\n");
  579. if (R.getType() != Mul->getType()) {
  580. assert(R.is64Bit() && "expected 64-bit result");
  581. Builder.SetInsertPoint(&*++BasicBlock::iterator(Mul));
  582. Mul = cast<Instruction>(Builder.CreateSExt(Mul, R.getRoot()->getType()));
  583. }
  584. if (!Acc) {
  585. Acc = Mul;
  586. continue;
  587. }
  588. // If Acc is the original incoming value to the reduction, it could be a
  589. // phi. But the phi will dominate Mul, meaning that Mul will be the
  590. // insertion point.
  591. Builder.SetInsertPoint(GetInsertPoint(Mul, Acc));
  592. Acc = Builder.CreateAdd(Mul, Acc);
  593. }
  594. if (!Acc) {
  595. Acc = R.is64Bit() ?
  596. ConstantInt::get(IntegerType::get(M->getContext(), 64), 0) :
  597. ConstantInt::get(IntegerType::get(M->getContext(), 32), 0);
  598. } else if (Acc->getType() != R.getType()) {
  599. Builder.SetInsertPoint(R.getRoot());
  600. Acc = Builder.CreateSExt(Acc, R.getType());
  601. }
  602. // Roughly sort the mul pairs in their program order.
  603. llvm::sort(R.getMulPairs(), [](auto &PairA, auto &PairB) {
  604. const Instruction *A = PairA.first->Root;
  605. const Instruction *B = PairB.first->Root;
  606. return A->comesBefore(B);
  607. });
  608. IntegerType *Ty = IntegerType::get(M->getContext(), 32);
  609. for (auto &Pair : R.getMulPairs()) {
  610. MulCandidate *LHSMul = Pair.first;
  611. MulCandidate *RHSMul = Pair.second;
  612. LoadInst *BaseLHS = LHSMul->getBaseLoad();
  613. LoadInst *BaseRHS = RHSMul->getBaseLoad();
  614. LoadInst *WideLHS = WideLoads.count(BaseLHS) ?
  615. WideLoads[BaseLHS]->getLoad() : CreateWideLoad(LHSMul->VecLd, Ty);
  616. LoadInst *WideRHS = WideLoads.count(BaseRHS) ?
  617. WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty);
  618. Instruction *InsertAfter = GetInsertPoint(WideLHS, WideRHS);
  619. InsertAfter = GetInsertPoint(InsertAfter, Acc);
  620. Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter);
  621. }
  622. R.UpdateRoot(cast<Instruction>(Acc));
  623. }
  624. LoadInst* ARMParallelDSP::CreateWideLoad(MemInstList &Loads,
  625. IntegerType *LoadTy) {
  626. assert(Loads.size() == 2 && "currently only support widening two loads");
  627. LoadInst *Base = Loads[0];
  628. LoadInst *Offset = Loads[1];
  629. Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back());
  630. Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back());
  631. assert((BaseSExt && OffsetSExt)
  632. && "Loads should have a single, extending, user");
  633. std::function<void(Value*, Value*)> MoveBefore =
  634. [&](Value *A, Value *B) -> void {
  635. if (!isa<Instruction>(A) || !isa<Instruction>(B))
  636. return;
  637. auto *Source = cast<Instruction>(A);
  638. auto *Sink = cast<Instruction>(B);
  639. if (DT->dominates(Source, Sink) ||
  640. Source->getParent() != Sink->getParent() ||
  641. isa<PHINode>(Source) || isa<PHINode>(Sink))
  642. return;
  643. Source->moveBefore(Sink);
  644. for (auto &Op : Source->operands())
  645. MoveBefore(Op, Source);
  646. };
  647. // Insert the load at the point of the original dominating load.
  648. LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset;
  649. IRBuilder<NoFolder> IRB(DomLoad->getParent(),
  650. ++BasicBlock::iterator(DomLoad));
  651. // Bitcast the pointer to a wider type and create the wide load, while making
  652. // sure to maintain the original alignment as this prevents ldrd from being
  653. // generated when it could be illegal due to memory alignment.
  654. const unsigned AddrSpace = DomLoad->getPointerAddressSpace();
  655. Value *VecPtr = IRB.CreateBitCast(Base->getPointerOperand(),
  656. LoadTy->getPointerTo(AddrSpace));
  657. LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr, Base->getAlign());
  658. // Make sure everything is in the correct order in the basic block.
  659. MoveBefore(Base->getPointerOperand(), VecPtr);
  660. MoveBefore(VecPtr, WideLoad);
  661. // From the wide load, create two values that equal the original two loads.
  662. // Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
  663. // TODO: Support big-endian as well.
  664. Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType());
  665. Value *NewBaseSExt = IRB.CreateSExt(Bottom, BaseSExt->getType());
  666. BaseSExt->replaceAllUsesWith(NewBaseSExt);
  667. IntegerType *OffsetTy = cast<IntegerType>(Offset->getType());
  668. Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth());
  669. Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
  670. Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
  671. Value *NewOffsetSExt = IRB.CreateSExt(Trunc, OffsetSExt->getType());
  672. OffsetSExt->replaceAllUsesWith(NewOffsetSExt);
  673. LLVM_DEBUG(dbgs() << "From Base and Offset:\n"
  674. << *Base << "\n" << *Offset << "\n"
  675. << "Created Wide Load:\n"
  676. << *WideLoad << "\n"
  677. << *Bottom << "\n"
  678. << *NewBaseSExt << "\n"
  679. << *Top << "\n"
  680. << *Trunc << "\n"
  681. << *NewOffsetSExt << "\n");
  682. WideLoads.emplace(std::make_pair(Base,
  683. std::make_unique<WidenedLoad>(Loads, WideLoad)));
  684. return WideLoad;
  685. }
  686. Pass *llvm::createARMParallelDSPPass() {
  687. return new ARMParallelDSP();
  688. }
  689. char ARMParallelDSP::ID = 0;
  690. INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
  691. "Transform functions to use DSP intrinsics", false, false)
  692. INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
  693. "Transform functions to use DSP intrinsics", false, false)