MVELaneInterleavingPass.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. //===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass interleaves around sext/zext/trunc instructions. MVE does not have
  10. // a single sext/zext or trunc instruction that takes the bottom half of a
  11. // vector and extends to a full width, like NEON has with MOVL. Instead it is
  12. // expected that this happens through top/bottom instructions. So the MVE
  13. // equivalent VMOVLT/B instructions take either the even or odd elements of the
  14. // input and extend them to the larger type, producing a vector with half the
  15. // number of elements each of double the bitwidth. As there is no simple
  16. // instruction, we often have to turn sext/zext/trunc into a series of lane
  17. // moves (or stack loads/stores, which we do not do yet).
  18. //
  19. // This pass takes vector code that starts at truncs, looks for interconnected
  20. // blobs of operations that end with sext/zext (or constants/splats) of the
  21. // form:
  22. // %sa = sext v8i16 %a to v8i32
  23. // %sb = sext v8i16 %b to v8i32
  24. // %add = add v8i32 %sa, %sb
  25. // %r = trunc %add to v8i16
  26. // And adds shuffles to allow the use of VMOVL/VMOVN instrctions:
  27. // %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7>
  28. // %sa = sext v8i16 %sha to v8i32
  29. // %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7>
  30. // %sb = sext v8i16 %shb to v8i32
  31. // %add = add v8i32 %sa, %sb
  32. // %r = trunc %add to v8i16
  33. // %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7>
  34. // Which can then be split and lowered to MVE instructions efficiently:
  35. // %sa_b = VMOVLB.s16 %a
  36. // %sa_t = VMOVLT.s16 %a
  37. // %sb_b = VMOVLB.s16 %b
  38. // %sb_t = VMOVLT.s16 %b
  39. // %add_b = VADD.i32 %sa_b, %sb_b
  40. // %add_t = VADD.i32 %sa_t, %sb_t
  41. // %r = VMOVNT.i16 %add_b, %add_t
  42. //
  43. //===----------------------------------------------------------------------===//
  44. #include "ARM.h"
  45. #include "ARMBaseInstrInfo.h"
  46. #include "ARMSubtarget.h"
  47. #include "llvm/ADT/SetVector.h"
  48. #include "llvm/Analysis/TargetTransformInfo.h"
  49. #include "llvm/CodeGen/TargetLowering.h"
  50. #include "llvm/CodeGen/TargetPassConfig.h"
  51. #include "llvm/CodeGen/TargetSubtargetInfo.h"
  52. #include "llvm/IR/BasicBlock.h"
  53. #include "llvm/IR/Constant.h"
  54. #include "llvm/IR/Constants.h"
  55. #include "llvm/IR/DerivedTypes.h"
  56. #include "llvm/IR/Function.h"
  57. #include "llvm/IR/IRBuilder.h"
  58. #include "llvm/IR/InstIterator.h"
  59. #include "llvm/IR/InstrTypes.h"
  60. #include "llvm/IR/Instruction.h"
  61. #include "llvm/IR/Instructions.h"
  62. #include "llvm/IR/IntrinsicInst.h"
  63. #include "llvm/IR/Intrinsics.h"
  64. #include "llvm/IR/IntrinsicsARM.h"
  65. #include "llvm/IR/PatternMatch.h"
  66. #include "llvm/IR/Type.h"
  67. #include "llvm/IR/Value.h"
  68. #include "llvm/InitializePasses.h"
  69. #include "llvm/Pass.h"
  70. #include "llvm/Support/Casting.h"
  71. #include <algorithm>
  72. #include <cassert>
  73. using namespace llvm;
  74. #define DEBUG_TYPE "mve-laneinterleave"
  75. cl::opt<bool> EnableInterleave(
  76. "enable-mve-interleave", cl::Hidden, cl::init(true),
  77. cl::desc("Enable interleave MVE vector operation lowering"));
  78. namespace {
  79. class MVELaneInterleaving : public FunctionPass {
  80. public:
  81. static char ID; // Pass identification, replacement for typeid
  82. explicit MVELaneInterleaving() : FunctionPass(ID) {
  83. initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry());
  84. }
  85. bool runOnFunction(Function &F) override;
  86. StringRef getPassName() const override { return "MVE lane interleaving"; }
  87. void getAnalysisUsage(AnalysisUsage &AU) const override {
  88. AU.setPreservesCFG();
  89. AU.addRequired<TargetPassConfig>();
  90. FunctionPass::getAnalysisUsage(AU);
  91. }
  92. };
  93. } // end anonymous namespace
  94. char MVELaneInterleaving::ID = 0;
  95. INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving", false,
  96. false)
  97. Pass *llvm::createMVELaneInterleavingPass() {
  98. return new MVELaneInterleaving();
  99. }
  100. static bool isProfitableToInterleave(SmallSetVector<Instruction *, 4> &Exts,
  101. SmallSetVector<Instruction *, 4> &Truncs) {
  102. // This is not always beneficial to transform. Exts can be incorporated into
  103. // loads, Truncs can be folded into stores.
  104. // Truncs are usually the same number of instructions,
  105. // VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving
  106. // Exts are unfortunately more instructions in the general case:
  107. // A=VLDRH.32; B=VLDRH.32;
  108. // vs with interleaving:
  109. // T=VLDRH.16; A=VMOVNB T; B=VMOVNT T
  110. // But those VMOVL may be folded into a VMULL.
  111. // But expensive extends/truncs are always good to remove. FPExts always
  112. // involve extra VCVT's so are always considered to be beneficial to convert.
  113. for (auto *E : Exts) {
  114. if (isa<FPExtInst>(E) || !isa<LoadInst>(E->getOperand(0))) {
  115. LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n");
  116. return true;
  117. }
  118. }
  119. for (auto *T : Truncs) {
  120. if (T->hasOneUse() && !isa<StoreInst>(*T->user_begin())) {
  121. LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n");
  122. return true;
  123. }
  124. }
  125. // Otherwise, we know we have a load(ext), see if any of the Extends are a
  126. // vmull. This is a simple heuristic and certainly not perfect.
  127. for (auto *E : Exts) {
  128. if (!E->hasOneUse() ||
  129. cast<Instruction>(*E->user_begin())->getOpcode() != Instruction::Mul) {
  130. LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n");
  131. return false;
  132. }
  133. }
  134. return true;
  135. }
  136. static bool tryInterleave(Instruction *Start,
  137. SmallPtrSetImpl<Instruction *> &Visited) {
  138. LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n");
  139. auto *VT = cast<FixedVectorType>(Start->getType());
  140. if (!isa<Instruction>(Start->getOperand(0)))
  141. return false;
  142. // Look for connected operations starting from Ext's, terminating at Truncs.
  143. std::vector<Instruction *> Worklist;
  144. Worklist.push_back(Start);
  145. Worklist.push_back(cast<Instruction>(Start->getOperand(0)));
  146. SmallSetVector<Instruction *, 4> Truncs;
  147. SmallSetVector<Instruction *, 4> Exts;
  148. SmallSetVector<Use *, 4> OtherLeafs;
  149. SmallSetVector<Instruction *, 4> Ops;
  150. while (!Worklist.empty()) {
  151. Instruction *I = Worklist.back();
  152. Worklist.pop_back();
  153. switch (I->getOpcode()) {
  154. // Truncs
  155. case Instruction::Trunc:
  156. case Instruction::FPTrunc:
  157. if (!Truncs.insert(I))
  158. continue;
  159. Visited.insert(I);
  160. break;
  161. // Extend leafs
  162. case Instruction::SExt:
  163. case Instruction::ZExt:
  164. case Instruction::FPExt:
  165. if (Exts.count(I))
  166. continue;
  167. for (auto *Use : I->users())
  168. Worklist.push_back(cast<Instruction>(Use));
  169. Exts.insert(I);
  170. break;
  171. case Instruction::Call: {
  172. IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
  173. if (!II)
  174. return false;
  175. switch (II->getIntrinsicID()) {
  176. case Intrinsic::abs:
  177. case Intrinsic::smin:
  178. case Intrinsic::smax:
  179. case Intrinsic::umin:
  180. case Intrinsic::umax:
  181. case Intrinsic::sadd_sat:
  182. case Intrinsic::ssub_sat:
  183. case Intrinsic::uadd_sat:
  184. case Intrinsic::usub_sat:
  185. case Intrinsic::minnum:
  186. case Intrinsic::maxnum:
  187. case Intrinsic::fabs:
  188. case Intrinsic::fma:
  189. case Intrinsic::ceil:
  190. case Intrinsic::floor:
  191. case Intrinsic::rint:
  192. case Intrinsic::round:
  193. case Intrinsic::trunc:
  194. break;
  195. default:
  196. return false;
  197. }
  198. [[fallthrough]]; // Fall through to treating these like an operator below.
  199. }
  200. // Binary/tertiary ops
  201. case Instruction::Add:
  202. case Instruction::Sub:
  203. case Instruction::Mul:
  204. case Instruction::AShr:
  205. case Instruction::LShr:
  206. case Instruction::Shl:
  207. case Instruction::ICmp:
  208. case Instruction::FCmp:
  209. case Instruction::FAdd:
  210. case Instruction::FMul:
  211. case Instruction::Select:
  212. if (!Ops.insert(I))
  213. continue;
  214. for (Use &Op : I->operands()) {
  215. if (!isa<FixedVectorType>(Op->getType()))
  216. continue;
  217. if (isa<Instruction>(Op))
  218. Worklist.push_back(cast<Instruction>(&Op));
  219. else
  220. OtherLeafs.insert(&Op);
  221. }
  222. for (auto *Use : I->users())
  223. Worklist.push_back(cast<Instruction>(Use));
  224. break;
  225. case Instruction::ShuffleVector:
  226. // A shuffle of a splat is a splat.
  227. if (cast<ShuffleVectorInst>(I)->isZeroEltSplat())
  228. continue;
  229. [[fallthrough]];
  230. default:
  231. LLVM_DEBUG(dbgs() << " Unhandled instruction: " << *I << "\n");
  232. return false;
  233. }
  234. }
  235. if (Exts.empty() && OtherLeafs.empty())
  236. return false;
  237. LLVM_DEBUG({
  238. dbgs() << "Found group:\n Exts:";
  239. for (auto *I : Exts)
  240. dbgs() << " " << *I << "\n";
  241. dbgs() << " Ops:";
  242. for (auto *I : Ops)
  243. dbgs() << " " << *I << "\n";
  244. dbgs() << " OtherLeafs:";
  245. for (auto *I : OtherLeafs)
  246. dbgs() << " " << *I->get() << " of " << *I->getUser() << "\n";
  247. dbgs() << "Truncs:";
  248. for (auto *I : Truncs)
  249. dbgs() << " " << *I << "\n";
  250. });
  251. assert(!Truncs.empty() && "Expected some truncs");
  252. // Check types
  253. unsigned NumElts = VT->getNumElements();
  254. unsigned BaseElts = VT->getScalarSizeInBits() == 16
  255. ? 8
  256. : (VT->getScalarSizeInBits() == 8 ? 16 : 0);
  257. if (BaseElts == 0 || NumElts % BaseElts != 0) {
  258. LLVM_DEBUG(dbgs() << " Type is unsupported\n");
  259. return false;
  260. }
  261. if (Start->getOperand(0)->getType()->getScalarSizeInBits() !=
  262. VT->getScalarSizeInBits() * 2) {
  263. LLVM_DEBUG(dbgs() << " Type not double sized\n");
  264. return false;
  265. }
  266. for (Instruction *I : Exts)
  267. if (I->getOperand(0)->getType() != VT) {
  268. LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n");
  269. return false;
  270. }
  271. for (Instruction *I : Truncs)
  272. if (I->getType() != VT) {
  273. LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n");
  274. return false;
  275. }
  276. // Check that it looks beneficial
  277. if (!isProfitableToInterleave(Exts, Truncs))
  278. return false;
  279. // Create new shuffles around the extends / truncs / other leaves.
  280. IRBuilder<> Builder(Start);
  281. SmallVector<int, 16> LeafMask;
  282. SmallVector<int, 16> TruncMask;
  283. // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7 8, 10, 12, 14, 9, 11, 13, 15
  284. // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7 8, 12, 9, 13, 10, 14, 11, 15
  285. for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
  286. for (unsigned i = 0; i < BaseElts / 2; i++)
  287. LeafMask.push_back(Base + i * 2);
  288. for (unsigned i = 0; i < BaseElts / 2; i++)
  289. LeafMask.push_back(Base + i * 2 + 1);
  290. }
  291. for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
  292. for (unsigned i = 0; i < BaseElts / 2; i++) {
  293. TruncMask.push_back(Base + i);
  294. TruncMask.push_back(Base + i + BaseElts / 2);
  295. }
  296. }
  297. for (Instruction *I : Exts) {
  298. LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n");
  299. Builder.SetInsertPoint(I);
  300. Value *Shuffle = Builder.CreateShuffleVector(I->getOperand(0), LeafMask);
  301. bool FPext = isa<FPExtInst>(I);
  302. bool Sext = isa<SExtInst>(I);
  303. Value *Ext = FPext ? Builder.CreateFPExt(Shuffle, I->getType())
  304. : Sext ? Builder.CreateSExt(Shuffle, I->getType())
  305. : Builder.CreateZExt(Shuffle, I->getType());
  306. I->replaceAllUsesWith(Ext);
  307. LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n");
  308. }
  309. for (Use *I : OtherLeafs) {
  310. LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n");
  311. Builder.SetInsertPoint(cast<Instruction>(I->getUser()));
  312. Value *Shuffle = Builder.CreateShuffleVector(I->get(), LeafMask);
  313. I->getUser()->setOperand(I->getOperandNo(), Shuffle);
  314. LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n");
  315. }
  316. for (Instruction *I : Truncs) {
  317. LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n");
  318. Builder.SetInsertPoint(I->getParent(), ++I->getIterator());
  319. Value *Shuf = Builder.CreateShuffleVector(I, TruncMask);
  320. I->replaceAllUsesWith(Shuf);
  321. cast<Instruction>(Shuf)->setOperand(0, I);
  322. LLVM_DEBUG(dbgs() << " with " << *Shuf << "\n");
  323. }
  324. return true;
  325. }
  326. bool MVELaneInterleaving::runOnFunction(Function &F) {
  327. if (!EnableInterleave)
  328. return false;
  329. auto &TPC = getAnalysis<TargetPassConfig>();
  330. auto &TM = TPC.getTM<TargetMachine>();
  331. auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
  332. if (!ST->hasMVEIntegerOps())
  333. return false;
  334. bool Changed = false;
  335. SmallPtrSet<Instruction *, 16> Visited;
  336. for (Instruction &I : reverse(instructions(F))) {
  337. if (I.getType()->isVectorTy() &&
  338. (isa<TruncInst>(I) || isa<FPTruncInst>(I)) && !Visited.count(&I))
  339. Changed |= tryInterleave(&I, Visited);
  340. }
  341. return Changed;
  342. }