123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390 |
- //===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===//
- //
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
- // See https://llvm.org/LICENSE.txt for license information.
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- //
- //===----------------------------------------------------------------------===//
- //
- // This pass interleaves around sext/zext/trunc instructions. MVE does not have
- // a single sext/zext or trunc instruction that takes the bottom half of a
- // vector and extends to a full width, like NEON has with MOVL. Instead it is
- // expected that this happens through top/bottom instructions. So the MVE
- // equivalent VMOVLT/B instructions take either the even or odd elements of the
- // input and extend them to the larger type, producing a vector with half the
- // number of elements each of double the bitwidth. As there is no simple
- // instruction, we often have to turn sext/zext/trunc into a series of lane
- // moves (or stack loads/stores, which we do not do yet).
- //
- // This pass takes vector code that starts at truncs, looks for interconnected
- // blobs of operations that end with sext/zext (or constants/splats) of the
- // form:
- // %sa = sext v8i16 %a to v8i32
- // %sb = sext v8i16 %b to v8i32
- // %add = add v8i32 %sa, %sb
- // %r = trunc %add to v8i16
- // And adds shuffles to allow the use of VMOVL/VMOVN instrctions:
- // %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7>
- // %sa = sext v8i16 %sha to v8i32
- // %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7>
- // %sb = sext v8i16 %shb to v8i32
- // %add = add v8i32 %sa, %sb
- // %r = trunc %add to v8i16
- // %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7>
- // Which can then be split and lowered to MVE instructions efficiently:
- // %sa_b = VMOVLB.s16 %a
- // %sa_t = VMOVLT.s16 %a
- // %sb_b = VMOVLB.s16 %b
- // %sb_t = VMOVLT.s16 %b
- // %add_b = VADD.i32 %sa_b, %sb_b
- // %add_t = VADD.i32 %sa_t, %sb_t
- // %r = VMOVNT.i16 %add_b, %add_t
- //
- //===----------------------------------------------------------------------===//
- #include "ARM.h"
- #include "ARMBaseInstrInfo.h"
- #include "ARMSubtarget.h"
- #include "llvm/Analysis/TargetTransformInfo.h"
- #include "llvm/CodeGen/TargetLowering.h"
- #include "llvm/CodeGen/TargetPassConfig.h"
- #include "llvm/CodeGen/TargetSubtargetInfo.h"
- #include "llvm/IR/BasicBlock.h"
- #include "llvm/IR/Constant.h"
- #include "llvm/IR/Constants.h"
- #include "llvm/IR/DerivedTypes.h"
- #include "llvm/IR/Function.h"
- #include "llvm/IR/IRBuilder.h"
- #include "llvm/IR/InstIterator.h"
- #include "llvm/IR/InstrTypes.h"
- #include "llvm/IR/Instruction.h"
- #include "llvm/IR/Instructions.h"
- #include "llvm/IR/IntrinsicInst.h"
- #include "llvm/IR/Intrinsics.h"
- #include "llvm/IR/IntrinsicsARM.h"
- #include "llvm/IR/PatternMatch.h"
- #include "llvm/IR/Type.h"
- #include "llvm/IR/Value.h"
- #include "llvm/InitializePasses.h"
- #include "llvm/Pass.h"
- #include "llvm/Support/Casting.h"
- #include <algorithm>
- #include <cassert>
- using namespace llvm;
- #define DEBUG_TYPE "mve-laneinterleave"
- cl::opt<bool> EnableInterleave(
- "enable-mve-interleave", cl::Hidden, cl::init(true),
- cl::desc("Enable interleave MVE vector operation lowering"));
- namespace {
- class MVELaneInterleaving : public FunctionPass {
- public:
- static char ID; // Pass identification, replacement for typeid
- explicit MVELaneInterleaving() : FunctionPass(ID) {
- initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry());
- }
- bool runOnFunction(Function &F) override;
- StringRef getPassName() const override { return "MVE lane interleaving"; }
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- AU.addRequired<TargetPassConfig>();
- FunctionPass::getAnalysisUsage(AU);
- }
- };
- } // end anonymous namespace
- char MVELaneInterleaving::ID = 0;
- INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving", false,
- false)
- Pass *llvm::createMVELaneInterleavingPass() {
- return new MVELaneInterleaving();
- }
- static bool isProfitableToInterleave(SmallSetVector<Instruction *, 4> &Exts,
- SmallSetVector<Instruction *, 4> &Truncs) {
- // This is not always beneficial to transform. Exts can be incorporated into
- // loads, Truncs can be folded into stores.
- // Truncs are usually the same number of instructions,
- // VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving
- // Exts are unfortunately more instructions in the general case:
- // A=VLDRH.32; B=VLDRH.32;
- // vs with interleaving:
- // T=VLDRH.16; A=VMOVNB T; B=VMOVNT T
- // But those VMOVL may be folded into a VMULL.
- // But expensive extends/truncs are always good to remove. FPExts always
- // involve extra VCVT's so are always considered to be beneficial to convert.
- for (auto *E : Exts) {
- if (isa<FPExtInst>(E) || !isa<LoadInst>(E->getOperand(0))) {
- LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n");
- return true;
- }
- }
- for (auto *T : Truncs) {
- if (T->hasOneUse() && !isa<StoreInst>(*T->user_begin())) {
- LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n");
- return true;
- }
- }
- // Otherwise, we know we have a load(ext), see if any of the Extends are a
- // vmull. This is a simple heuristic and certainly not perfect.
- for (auto *E : Exts) {
- if (!E->hasOneUse() ||
- cast<Instruction>(*E->user_begin())->getOpcode() != Instruction::Mul) {
- LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n");
- return false;
- }
- }
- return true;
- }
- static bool tryInterleave(Instruction *Start,
- SmallPtrSetImpl<Instruction *> &Visited) {
- LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n");
- auto *VT = cast<FixedVectorType>(Start->getType());
- if (!isa<Instruction>(Start->getOperand(0)))
- return false;
- // Look for connected operations starting from Ext's, terminating at Truncs.
- std::vector<Instruction *> Worklist;
- Worklist.push_back(Start);
- Worklist.push_back(cast<Instruction>(Start->getOperand(0)));
- SmallSetVector<Instruction *, 4> Truncs;
- SmallSetVector<Instruction *, 4> Exts;
- SmallSetVector<Use *, 4> OtherLeafs;
- SmallSetVector<Instruction *, 4> Ops;
- while (!Worklist.empty()) {
- Instruction *I = Worklist.back();
- Worklist.pop_back();
- switch (I->getOpcode()) {
- // Truncs
- case Instruction::Trunc:
- case Instruction::FPTrunc:
- if (Truncs.count(I))
- continue;
- Truncs.insert(I);
- Visited.insert(I);
- break;
- // Extend leafs
- case Instruction::SExt:
- case Instruction::ZExt:
- case Instruction::FPExt:
- if (Exts.count(I))
- continue;
- for (auto *Use : I->users())
- Worklist.push_back(cast<Instruction>(Use));
- Exts.insert(I);
- break;
- case Instruction::Call: {
- IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
- if (!II)
- return false;
- switch (II->getIntrinsicID()) {
- case Intrinsic::abs:
- case Intrinsic::smin:
- case Intrinsic::smax:
- case Intrinsic::umin:
- case Intrinsic::umax:
- case Intrinsic::sadd_sat:
- case Intrinsic::ssub_sat:
- case Intrinsic::uadd_sat:
- case Intrinsic::usub_sat:
- case Intrinsic::minnum:
- case Intrinsic::maxnum:
- case Intrinsic::fabs:
- case Intrinsic::fma:
- case Intrinsic::ceil:
- case Intrinsic::floor:
- case Intrinsic::rint:
- case Intrinsic::round:
- case Intrinsic::trunc:
- break;
- default:
- return false;
- }
- LLVM_FALLTHROUGH; // Fall through to treating these like an operator below.
- }
- // Binary/tertiary ops
- case Instruction::Add:
- case Instruction::Sub:
- case Instruction::Mul:
- case Instruction::AShr:
- case Instruction::LShr:
- case Instruction::Shl:
- case Instruction::ICmp:
- case Instruction::FCmp:
- case Instruction::FAdd:
- case Instruction::FMul:
- case Instruction::Select:
- if (Ops.count(I))
- continue;
- Ops.insert(I);
- for (Use &Op : I->operands()) {
- if (!isa<FixedVectorType>(Op->getType()))
- continue;
- if (isa<Instruction>(Op))
- Worklist.push_back(cast<Instruction>(&Op));
- else
- OtherLeafs.insert(&Op);
- }
- for (auto *Use : I->users())
- Worklist.push_back(cast<Instruction>(Use));
- break;
- case Instruction::ShuffleVector:
- // A shuffle of a splat is a splat.
- if (cast<ShuffleVectorInst>(I)->isZeroEltSplat())
- continue;
- LLVM_FALLTHROUGH;
- default:
- LLVM_DEBUG(dbgs() << " Unhandled instruction: " << *I << "\n");
- return false;
- }
- }
- if (Exts.empty() && OtherLeafs.empty())
- return false;
- LLVM_DEBUG({
- dbgs() << "Found group:\n Exts:";
- for (auto *I : Exts)
- dbgs() << " " << *I << "\n";
- dbgs() << " Ops:";
- for (auto *I : Ops)
- dbgs() << " " << *I << "\n";
- dbgs() << " OtherLeafs:";
- for (auto *I : OtherLeafs)
- dbgs() << " " << *I->get() << " of " << *I->getUser() << "\n";
- dbgs() << "Truncs:";
- for (auto *I : Truncs)
- dbgs() << " " << *I << "\n";
- });
- assert(!Truncs.empty() && "Expected some truncs");
- // Check types
- unsigned NumElts = VT->getNumElements();
- unsigned BaseElts = VT->getScalarSizeInBits() == 16
- ? 8
- : (VT->getScalarSizeInBits() == 8 ? 16 : 0);
- if (BaseElts == 0 || NumElts % BaseElts != 0) {
- LLVM_DEBUG(dbgs() << " Type is unsupported\n");
- return false;
- }
- if (Start->getOperand(0)->getType()->getScalarSizeInBits() !=
- VT->getScalarSizeInBits() * 2) {
- LLVM_DEBUG(dbgs() << " Type not double sized\n");
- return false;
- }
- for (Instruction *I : Exts)
- if (I->getOperand(0)->getType() != VT) {
- LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n");
- return false;
- }
- for (Instruction *I : Truncs)
- if (I->getType() != VT) {
- LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n");
- return false;
- }
- // Check that it looks beneficial
- if (!isProfitableToInterleave(Exts, Truncs))
- return false;
- // Create new shuffles around the extends / truncs / other leaves.
- IRBuilder<> Builder(Start);
- SmallVector<int, 16> LeafMask;
- SmallVector<int, 16> TruncMask;
- // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7 8, 10, 12, 14, 9, 11, 13, 15
- // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7 8, 12, 9, 13, 10, 14, 11, 15
- for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
- for (unsigned i = 0; i < BaseElts / 2; i++)
- LeafMask.push_back(Base + i * 2);
- for (unsigned i = 0; i < BaseElts / 2; i++)
- LeafMask.push_back(Base + i * 2 + 1);
- }
- for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
- for (unsigned i = 0; i < BaseElts / 2; i++) {
- TruncMask.push_back(Base + i);
- TruncMask.push_back(Base + i + BaseElts / 2);
- }
- }
- for (Instruction *I : Exts) {
- LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n");
- Builder.SetInsertPoint(I);
- Value *Shuffle = Builder.CreateShuffleVector(I->getOperand(0), LeafMask);
- bool FPext = isa<FPExtInst>(I);
- bool Sext = isa<SExtInst>(I);
- Value *Ext = FPext ? Builder.CreateFPExt(Shuffle, I->getType())
- : Sext ? Builder.CreateSExt(Shuffle, I->getType())
- : Builder.CreateZExt(Shuffle, I->getType());
- I->replaceAllUsesWith(Ext);
- LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n");
- }
- for (Use *I : OtherLeafs) {
- LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n");
- Builder.SetInsertPoint(cast<Instruction>(I->getUser()));
- Value *Shuffle = Builder.CreateShuffleVector(I->get(), LeafMask);
- I->getUser()->setOperand(I->getOperandNo(), Shuffle);
- LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n");
- }
- for (Instruction *I : Truncs) {
- LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n");
- Builder.SetInsertPoint(I->getParent(), ++I->getIterator());
- Value *Shuf = Builder.CreateShuffleVector(I, TruncMask);
- I->replaceAllUsesWith(Shuf);
- cast<Instruction>(Shuf)->setOperand(0, I);
- LLVM_DEBUG(dbgs() << " with " << *Shuf << "\n");
- }
- return true;
- }
- bool MVELaneInterleaving::runOnFunction(Function &F) {
- if (!EnableInterleave)
- return false;
- auto &TPC = getAnalysis<TargetPassConfig>();
- auto &TM = TPC.getTM<TargetMachine>();
- auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
- if (!ST->hasMVEIntegerOps())
- return false;
- bool Changed = false;
- SmallPtrSet<Instruction *, 16> Visited;
- for (Instruction &I : reverse(instructions(F))) {
- if (I.getType()->isVectorTy() &&
- (isa<TruncInst>(I) || isa<FPTruncInst>(I)) && !Visited.count(&I))
- Changed |= tryInterleave(&I, Visited);
- }
- return Changed;
- }
|