MVELaneInterleavingPass.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. //===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass interleaves around sext/zext/trunc instructions. MVE does not have
  10. // a single sext/zext or trunc instruction that takes the bottom half of a
  11. // vector and extends to a full width, like NEON has with MOVL. Instead it is
  12. // expected that this happens through top/bottom instructions. So the MVE
  13. // equivalent VMOVLT/B instructions take either the even or odd elements of the
  14. // input and extend them to the larger type, producing a vector with half the
  15. // number of elements each of double the bitwidth. As there is no simple
  16. // instruction, we often have to turn sext/zext/trunc into a series of lane
  17. // moves (or stack loads/stores, which we do not do yet).
  18. //
  19. // This pass takes vector code that starts at truncs, looks for interconnected
  20. // blobs of operations that end with sext/zext (or constants/splats) of the
  21. // form:
  22. // %sa = sext v8i16 %a to v8i32
  23. // %sb = sext v8i16 %b to v8i32
  24. // %add = add v8i32 %sa, %sb
  25. // %r = trunc %add to v8i16
  26. // And adds shuffles to allow the use of VMOVL/VMOVN instrctions:
  27. // %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7>
  28. // %sa = sext v8i16 %sha to v8i32
  29. // %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7>
  30. // %sb = sext v8i16 %shb to v8i32
  31. // %add = add v8i32 %sa, %sb
  32. // %r = trunc %add to v8i16
  33. // %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7>
  34. // Which can then be split and lowered to MVE instructions efficiently:
  35. // %sa_b = VMOVLB.s16 %a
  36. // %sa_t = VMOVLT.s16 %a
  37. // %sb_b = VMOVLB.s16 %b
  38. // %sb_t = VMOVLT.s16 %b
  39. // %add_b = VADD.i32 %sa_b, %sb_b
  40. // %add_t = VADD.i32 %sa_t, %sb_t
  41. // %r = VMOVNT.i16 %add_b, %add_t
  42. //
  43. //===----------------------------------------------------------------------===//
  44. #include "ARM.h"
  45. #include "ARMBaseInstrInfo.h"
  46. #include "ARMSubtarget.h"
  47. #include "llvm/Analysis/TargetTransformInfo.h"
  48. #include "llvm/CodeGen/TargetLowering.h"
  49. #include "llvm/CodeGen/TargetPassConfig.h"
  50. #include "llvm/CodeGen/TargetSubtargetInfo.h"
  51. #include "llvm/IR/BasicBlock.h"
  52. #include "llvm/IR/Constant.h"
  53. #include "llvm/IR/Constants.h"
  54. #include "llvm/IR/DerivedTypes.h"
  55. #include "llvm/IR/Function.h"
  56. #include "llvm/IR/IRBuilder.h"
  57. #include "llvm/IR/InstIterator.h"
  58. #include "llvm/IR/InstrTypes.h"
  59. #include "llvm/IR/Instruction.h"
  60. #include "llvm/IR/Instructions.h"
  61. #include "llvm/IR/IntrinsicInst.h"
  62. #include "llvm/IR/Intrinsics.h"
  63. #include "llvm/IR/IntrinsicsARM.h"
  64. #include "llvm/IR/PatternMatch.h"
  65. #include "llvm/IR/Type.h"
  66. #include "llvm/IR/Value.h"
  67. #include "llvm/InitializePasses.h"
  68. #include "llvm/Pass.h"
  69. #include "llvm/Support/Casting.h"
  70. #include <algorithm>
  71. #include <cassert>
  72. using namespace llvm;
  73. #define DEBUG_TYPE "mve-laneinterleave"
  74. cl::opt<bool> EnableInterleave(
  75. "enable-mve-interleave", cl::Hidden, cl::init(true),
  76. cl::desc("Enable interleave MVE vector operation lowering"));
  77. namespace {
  78. class MVELaneInterleaving : public FunctionPass {
  79. public:
  80. static char ID; // Pass identification, replacement for typeid
  81. explicit MVELaneInterleaving() : FunctionPass(ID) {
  82. initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry());
  83. }
  84. bool runOnFunction(Function &F) override;
  85. StringRef getPassName() const override { return "MVE lane interleaving"; }
  86. void getAnalysisUsage(AnalysisUsage &AU) const override {
  87. AU.setPreservesCFG();
  88. AU.addRequired<TargetPassConfig>();
  89. FunctionPass::getAnalysisUsage(AU);
  90. }
  91. };
  92. } // end anonymous namespace
  93. char MVELaneInterleaving::ID = 0;
  94. INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving", false,
  95. false)
  96. Pass *llvm::createMVELaneInterleavingPass() {
  97. return new MVELaneInterleaving();
  98. }
  99. static bool isProfitableToInterleave(SmallSetVector<Instruction *, 4> &Exts,
  100. SmallSetVector<Instruction *, 4> &Truncs) {
  101. // This is not always beneficial to transform. Exts can be incorporated into
  102. // loads, Truncs can be folded into stores.
  103. // Truncs are usually the same number of instructions,
  104. // VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving
  105. // Exts are unfortunately more instructions in the general case:
  106. // A=VLDRH.32; B=VLDRH.32;
  107. // vs with interleaving:
  108. // T=VLDRH.16; A=VMOVNB T; B=VMOVNT T
  109. // But those VMOVL may be folded into a VMULL.
  110. // But expensive extends/truncs are always good to remove. FPExts always
  111. // involve extra VCVT's so are always considered to be beneficial to convert.
  112. for (auto *E : Exts) {
  113. if (isa<FPExtInst>(E) || !isa<LoadInst>(E->getOperand(0))) {
  114. LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n");
  115. return true;
  116. }
  117. }
  118. for (auto *T : Truncs) {
  119. if (T->hasOneUse() && !isa<StoreInst>(*T->user_begin())) {
  120. LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n");
  121. return true;
  122. }
  123. }
  124. // Otherwise, we know we have a load(ext), see if any of the Extends are a
  125. // vmull. This is a simple heuristic and certainly not perfect.
  126. for (auto *E : Exts) {
  127. if (!E->hasOneUse() ||
  128. cast<Instruction>(*E->user_begin())->getOpcode() != Instruction::Mul) {
  129. LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n");
  130. return false;
  131. }
  132. }
  133. return true;
  134. }
  135. static bool tryInterleave(Instruction *Start,
  136. SmallPtrSetImpl<Instruction *> &Visited) {
  137. LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n");
  138. auto *VT = cast<FixedVectorType>(Start->getType());
  139. if (!isa<Instruction>(Start->getOperand(0)))
  140. return false;
  141. // Look for connected operations starting from Ext's, terminating at Truncs.
  142. std::vector<Instruction *> Worklist;
  143. Worklist.push_back(Start);
  144. Worklist.push_back(cast<Instruction>(Start->getOperand(0)));
  145. SmallSetVector<Instruction *, 4> Truncs;
  146. SmallSetVector<Instruction *, 4> Exts;
  147. SmallSetVector<Use *, 4> OtherLeafs;
  148. SmallSetVector<Instruction *, 4> Ops;
  149. while (!Worklist.empty()) {
  150. Instruction *I = Worklist.back();
  151. Worklist.pop_back();
  152. switch (I->getOpcode()) {
  153. // Truncs
  154. case Instruction::Trunc:
  155. case Instruction::FPTrunc:
  156. if (Truncs.count(I))
  157. continue;
  158. Truncs.insert(I);
  159. Visited.insert(I);
  160. break;
  161. // Extend leafs
  162. case Instruction::SExt:
  163. case Instruction::ZExt:
  164. case Instruction::FPExt:
  165. if (Exts.count(I))
  166. continue;
  167. for (auto *Use : I->users())
  168. Worklist.push_back(cast<Instruction>(Use));
  169. Exts.insert(I);
  170. break;
  171. case Instruction::Call: {
  172. IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
  173. if (!II)
  174. return false;
  175. switch (II->getIntrinsicID()) {
  176. case Intrinsic::abs:
  177. case Intrinsic::smin:
  178. case Intrinsic::smax:
  179. case Intrinsic::umin:
  180. case Intrinsic::umax:
  181. case Intrinsic::sadd_sat:
  182. case Intrinsic::ssub_sat:
  183. case Intrinsic::uadd_sat:
  184. case Intrinsic::usub_sat:
  185. case Intrinsic::minnum:
  186. case Intrinsic::maxnum:
  187. case Intrinsic::fabs:
  188. case Intrinsic::fma:
  189. case Intrinsic::ceil:
  190. case Intrinsic::floor:
  191. case Intrinsic::rint:
  192. case Intrinsic::round:
  193. case Intrinsic::trunc:
  194. break;
  195. default:
  196. return false;
  197. }
  198. LLVM_FALLTHROUGH; // Fall through to treating these like an operator below.
  199. }
  200. // Binary/tertiary ops
  201. case Instruction::Add:
  202. case Instruction::Sub:
  203. case Instruction::Mul:
  204. case Instruction::AShr:
  205. case Instruction::LShr:
  206. case Instruction::Shl:
  207. case Instruction::ICmp:
  208. case Instruction::FCmp:
  209. case Instruction::FAdd:
  210. case Instruction::FMul:
  211. case Instruction::Select:
  212. if (Ops.count(I))
  213. continue;
  214. Ops.insert(I);
  215. for (Use &Op : I->operands()) {
  216. if (!isa<FixedVectorType>(Op->getType()))
  217. continue;
  218. if (isa<Instruction>(Op))
  219. Worklist.push_back(cast<Instruction>(&Op));
  220. else
  221. OtherLeafs.insert(&Op);
  222. }
  223. for (auto *Use : I->users())
  224. Worklist.push_back(cast<Instruction>(Use));
  225. break;
  226. case Instruction::ShuffleVector:
  227. // A shuffle of a splat is a splat.
  228. if (cast<ShuffleVectorInst>(I)->isZeroEltSplat())
  229. continue;
  230. LLVM_FALLTHROUGH;
  231. default:
  232. LLVM_DEBUG(dbgs() << " Unhandled instruction: " << *I << "\n");
  233. return false;
  234. }
  235. }
  236. if (Exts.empty() && OtherLeafs.empty())
  237. return false;
  238. LLVM_DEBUG({
  239. dbgs() << "Found group:\n Exts:";
  240. for (auto *I : Exts)
  241. dbgs() << " " << *I << "\n";
  242. dbgs() << " Ops:";
  243. for (auto *I : Ops)
  244. dbgs() << " " << *I << "\n";
  245. dbgs() << " OtherLeafs:";
  246. for (auto *I : OtherLeafs)
  247. dbgs() << " " << *I->get() << " of " << *I->getUser() << "\n";
  248. dbgs() << "Truncs:";
  249. for (auto *I : Truncs)
  250. dbgs() << " " << *I << "\n";
  251. });
  252. assert(!Truncs.empty() && "Expected some truncs");
  253. // Check types
  254. unsigned NumElts = VT->getNumElements();
  255. unsigned BaseElts = VT->getScalarSizeInBits() == 16
  256. ? 8
  257. : (VT->getScalarSizeInBits() == 8 ? 16 : 0);
  258. if (BaseElts == 0 || NumElts % BaseElts != 0) {
  259. LLVM_DEBUG(dbgs() << " Type is unsupported\n");
  260. return false;
  261. }
  262. if (Start->getOperand(0)->getType()->getScalarSizeInBits() !=
  263. VT->getScalarSizeInBits() * 2) {
  264. LLVM_DEBUG(dbgs() << " Type not double sized\n");
  265. return false;
  266. }
  267. for (Instruction *I : Exts)
  268. if (I->getOperand(0)->getType() != VT) {
  269. LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n");
  270. return false;
  271. }
  272. for (Instruction *I : Truncs)
  273. if (I->getType() != VT) {
  274. LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n");
  275. return false;
  276. }
  277. // Check that it looks beneficial
  278. if (!isProfitableToInterleave(Exts, Truncs))
  279. return false;
  280. // Create new shuffles around the extends / truncs / other leaves.
  281. IRBuilder<> Builder(Start);
  282. SmallVector<int, 16> LeafMask;
  283. SmallVector<int, 16> TruncMask;
  284. // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7 8, 10, 12, 14, 9, 11, 13, 15
  285. // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7 8, 12, 9, 13, 10, 14, 11, 15
  286. for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
  287. for (unsigned i = 0; i < BaseElts / 2; i++)
  288. LeafMask.push_back(Base + i * 2);
  289. for (unsigned i = 0; i < BaseElts / 2; i++)
  290. LeafMask.push_back(Base + i * 2 + 1);
  291. }
  292. for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
  293. for (unsigned i = 0; i < BaseElts / 2; i++) {
  294. TruncMask.push_back(Base + i);
  295. TruncMask.push_back(Base + i + BaseElts / 2);
  296. }
  297. }
  298. for (Instruction *I : Exts) {
  299. LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n");
  300. Builder.SetInsertPoint(I);
  301. Value *Shuffle = Builder.CreateShuffleVector(I->getOperand(0), LeafMask);
  302. bool FPext = isa<FPExtInst>(I);
  303. bool Sext = isa<SExtInst>(I);
  304. Value *Ext = FPext ? Builder.CreateFPExt(Shuffle, I->getType())
  305. : Sext ? Builder.CreateSExt(Shuffle, I->getType())
  306. : Builder.CreateZExt(Shuffle, I->getType());
  307. I->replaceAllUsesWith(Ext);
  308. LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n");
  309. }
  310. for (Use *I : OtherLeafs) {
  311. LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n");
  312. Builder.SetInsertPoint(cast<Instruction>(I->getUser()));
  313. Value *Shuffle = Builder.CreateShuffleVector(I->get(), LeafMask);
  314. I->getUser()->setOperand(I->getOperandNo(), Shuffle);
  315. LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n");
  316. }
  317. for (Instruction *I : Truncs) {
  318. LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n");
  319. Builder.SetInsertPoint(I->getParent(), ++I->getIterator());
  320. Value *Shuf = Builder.CreateShuffleVector(I, TruncMask);
  321. I->replaceAllUsesWith(Shuf);
  322. cast<Instruction>(Shuf)->setOperand(0, I);
  323. LLVM_DEBUG(dbgs() << " with " << *Shuf << "\n");
  324. }
  325. return true;
  326. }
  327. bool MVELaneInterleaving::runOnFunction(Function &F) {
  328. if (!EnableInterleave)
  329. return false;
  330. auto &TPC = getAnalysis<TargetPassConfig>();
  331. auto &TM = TPC.getTM<TargetMachine>();
  332. auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
  333. if (!ST->hasMVEIntegerOps())
  334. return false;
  335. bool Changed = false;
  336. SmallPtrSet<Instruction *, 16> Visited;
  337. for (Instruction &I : reverse(instructions(F))) {
  338. if (I.getType()->isVectorTy() &&
  339. (isa<TruncInst>(I) || isa<FPTruncInst>(I)) && !Visited.count(&I))
  340. Changed |= tryInterleave(&I, Visited);
  341. }
  342. return Changed;
  343. }