123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407 |
- //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
- //
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
- // See https://llvm.org/LICENSE.txt for license information.
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- //
- //===----------------------------------------------------------------------===//
- //
- /// \file Pass to pre-config the shapes of AMX registers
- /// AMX register needs to be configured before use. The shapes of AMX register
- /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
- ///
- /// The instruction ldtilecfg is used to config the shapes. It must be reachable
- /// for all variable shapes. ldtilecfg will be inserted more than once if we
- /// cannot find a dominating point for all AMX instructions.
- ///
- /// The configure register is caller saved according to ABI. We need to insert
- /// ldtilecfg again after the call instruction if callee clobbers any AMX
- /// registers.
- ///
- /// This pass calculates all points that ldtilecfg need to be inserted to and
- /// insert them. It reports error if the reachability conditions aren't met.
- //
- //===----------------------------------------------------------------------===//
- #include "X86.h"
- #include "X86InstrBuilder.h"
- #include "X86MachineFunctionInfo.h"
- #include "X86RegisterInfo.h"
- #include "X86Subtarget.h"
- #include "llvm/CodeGen/MachineFunctionPass.h"
- #include "llvm/CodeGen/MachineInstr.h"
- #include "llvm/CodeGen/MachineLoopInfo.h"
- #include "llvm/CodeGen/MachineRegisterInfo.h"
- #include "llvm/CodeGen/Passes.h"
- #include "llvm/CodeGen/TargetInstrInfo.h"
- #include "llvm/CodeGen/TargetRegisterInfo.h"
- #include "llvm/InitializePasses.h"
- using namespace llvm;
- #define DEBUG_TYPE "tile-pre-config"
- #define REPORT_CONFIG_FAIL \
- report_fatal_error( \
- MF.getName() + \
- ": Failed to config tile register, please define the shape earlier");
- namespace {
- struct MIRef {
- MachineInstr *MI = nullptr;
- MachineBasicBlock *MBB = nullptr;
- // A virtual position for instruction that will be inserted after MI.
- size_t Pos = 0;
- MIRef() = default;
- MIRef(MachineBasicBlock *MBB) : MBB(MBB) {
- for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI();
- ++I, ++Pos)
- MI = &*I;
- }
- MIRef(MachineInstr *MI)
- : MI(MI), MBB(MI->getParent()),
- Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
- MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
- : MI(MI), MBB(MBB),
- Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
- MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos)
- : MI(MI), MBB(MBB), Pos(Pos) {}
- operator bool() const { return MBB != nullptr; }
- bool operator==(const MIRef &RHS) const {
- return MI == RHS.MI && MBB == RHS.MBB;
- }
- bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
- bool operator<(const MIRef &RHS) const {
- // Comparison between different BBs happens when inserting a MIRef into set.
- // So we compare MBB first to make the insertion happy.
- return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
- }
- bool operator>(const MIRef &RHS) const {
- // Comparison between different BBs happens when inserting a MIRef into set.
- // So we compare MBB first to make the insertion happy.
- return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos);
- }
- };
- struct BBInfo {
- MIRef FirstAMX;
- MIRef LastCall;
- bool HasAMXRegLiveIn = false;
- bool TileCfgForbidden = false;
- bool NeedTileCfgLiveIn = false;
- };
- class X86PreTileConfig : public MachineFunctionPass {
- MachineRegisterInfo *MRI;
- const MachineLoopInfo *MLI;
- SmallSet<MachineInstr *, 8> DefVisited;
- DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
- DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
- /// Check if the callee will clobber AMX registers.
- bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
- auto Iter = llvm::find_if(
- MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); });
- if (Iter == MI.operands_end())
- return false;
- UsableRegs.clearBitsInMask(Iter->getRegMask());
- return !UsableRegs.none();
- }
- /// Check if MI is AMX pseudo instruction.
- bool isAMXInstruction(MachineInstr &MI) {
- if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3)
- return false;
- MachineOperand &MO = MI.getOperand(0);
- // We can simply check if it is AMX instruction by its def.
- // But we should exclude old API which uses physical registers.
- if (MO.isReg() && MO.getReg().isVirtual() &&
- MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) {
- collectShapeInfo(MI);
- return true;
- }
- // PTILESTOREDV is the only exception that doesn't def a AMX register.
- return MI.getOpcode() == X86::PTILESTOREDV;
- }
- /// Check if it is an edge from loop bottom to loop head.
- bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) {
- if (!MLI->isLoopHeader(Header))
- return false;
- auto *ML = MLI->getLoopFor(Header);
- if (ML->contains(Bottom) && ML->isLoopLatch(Bottom))
- return true;
- return false;
- }
- /// Collect the shape def information for later use.
- void collectShapeInfo(MachineInstr &MI);
- /// Try to hoist shapes definded below AMX instructions.
- bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
- MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX;
- auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX);
- auto InsertPoint = FirstAMX.MI->getIterator();
- for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) {
- // Do not hoist instructions that access memory.
- if (I->MI->mayLoadOrStore())
- return false;
- for (auto &MO : I->MI->operands()) {
- if (MO.isDef())
- continue;
- // Do not hoist instructions if the sources' def under AMX instruction.
- // TODO: We can handle isMoveImmediate MI here.
- if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX)
- return false;
- // TODO: Maybe need more checks here.
- }
- MBB->insert(InsertPoint, I->MI->removeFromParent());
- }
- // We only need to mark the last shape in the BB now.
- Shapes.clear();
- Shapes.push_back(MIRef(&*--InsertPoint, MBB));
- return true;
- }
- public:
- X86PreTileConfig() : MachineFunctionPass(ID) {}
- /// Return the pass name.
- StringRef getPassName() const override {
- return "Tile Register Pre-configure";
- }
- /// X86PreTileConfig analysis usage.
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesAll();
- AU.addRequired<MachineLoopInfo>();
- MachineFunctionPass::getAnalysisUsage(AU);
- }
- /// Clear MF related structures.
- void releaseMemory() override {
- ShapeBBs.clear();
- DefVisited.clear();
- BBVisitedInfo.clear();
- }
- /// Perform ldtilecfg instructions inserting.
- bool runOnMachineFunction(MachineFunction &MF) override;
- static char ID;
- };
- } // end anonymous namespace
- char X86PreTileConfig::ID = 0;
- INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
- "Tile Register Pre-configure", false, false)
- INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
- INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
- "Tile Register Pre-configure", false, false)
- void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
- auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
- MIRef MIR(MI, MBB);
- auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
- if (I == ShapeBBs[MBB].end() || *I != MIR)
- ShapeBBs[MBB].insert(I, MIR);
- };
- SmallVector<Register, 8> WorkList(
- {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});
- while (!WorkList.empty()) {
- Register R = WorkList.pop_back_val();
- MachineInstr *DefMI = MRI->getVRegDef(R);
- assert(DefMI && "R must has one define instruction");
- MachineBasicBlock *DefMBB = DefMI->getParent();
- if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second)
- continue;
- if (DefMI->isPHI()) {
- for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)
- if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB()))
- RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def.
- else
- WorkList.push_back(DefMI->getOperand(I).getReg());
- } else {
- RecordShape(DefMI, DefMBB);
- }
- }
- }
- bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
- const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
- const TargetInstrInfo *TII = ST.getInstrInfo();
- const TargetRegisterInfo *TRI = ST.getRegisterInfo();
- const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID);
- X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();
- BitVector AMXRegs(TRI->getNumRegs());
- for (unsigned I = 0; I < RC->getNumRegs(); I++)
- AMXRegs.set(X86::TMM0 + I);
- // Iterate MF to collect information.
- MRI = &MF.getRegInfo();
- MLI = &getAnalysis<MachineLoopInfo>();
- SmallSet<MIRef, 8> CfgNeedInsert;
- SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs;
- for (auto &MBB : MF) {
- size_t Pos = 0;
- for (auto &MI : MBB) {
- ++Pos;
- if (isAMXInstruction(MI)) {
- // If there's call before the AMX, we need to reload tile config.
- if (BBVisitedInfo[&MBB].LastCall)
- CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall);
- else // Otherwise, we need tile config to live in this BB.
- BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true;
- // Always record the first AMX in case there's shape def after it.
- if (!BBVisitedInfo[&MBB].FirstAMX)
- BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos);
- } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) {
- // Record the call only if the callee clobbers all AMX registers.
- BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos);
- }
- }
- if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) {
- if (&MBB == &MF.front())
- CfgNeedInsert.insert(MIRef(&MBB));
- else
- CfgLiveInBBs.push_back(&MBB);
- }
- if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
- for (auto *Succ : MBB.successors())
- if (!isLoopBackEdge(Succ, &MBB))
- BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
- }
- // Update NeedTileCfgLiveIn for predecessors.
- while (!CfgLiveInBBs.empty()) {
- MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val();
- for (auto *Pred : MBB->predecessors()) {
- if (BBVisitedInfo[Pred].LastCall) {
- CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall);
- } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) {
- BBVisitedInfo[Pred].NeedTileCfgLiveIn = true;
- if (Pred == &MF.front())
- CfgNeedInsert.insert(MIRef(Pred));
- else
- CfgLiveInBBs.push_back(Pred);
- }
- }
- }
- // There's no AMX instruction if we didn't find a tile config live in point.
- if (CfgNeedInsert.empty())
- return false;
- X86FI->setHasVirtualTileReg(true);
- // Avoid to insert ldtilecfg before any shape defs.
- SmallVector<MachineBasicBlock *, 8> WorkList;
- for (auto &I : ShapeBBs) {
- // TODO: We can hoist shapes across BBs here.
- if (BBVisitedInfo[I.first].HasAMXRegLiveIn)
- REPORT_CONFIG_FAIL
- if (BBVisitedInfo[I.first].FirstAMX &&
- BBVisitedInfo[I.first].FirstAMX < I.second.back() &&
- !hoistShapesInBB(I.first, I.second))
- REPORT_CONFIG_FAIL
- WorkList.push_back(I.first);
- }
- while (!WorkList.empty()) {
- MachineBasicBlock *MBB = WorkList.pop_back_val();
- for (auto *Pred : MBB->predecessors()) {
- if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) {
- BBVisitedInfo[Pred].TileCfgForbidden = true;
- WorkList.push_back(Pred);
- }
- }
- }
- DebugLoc DL;
- SmallSet<MIRef, 8> VisitedOrInserted;
- int SS = MF.getFrameInfo().CreateStackObject(
- ST.getTileConfigSize(), ST.getTileConfigAlignment(), false);
- // Try to insert for the tile config live in points.
- for (const auto &I : CfgNeedInsert) {
- SmallSet<MIRef, 8> InsertPoints;
- SmallVector<MIRef, 8> WorkList({I});
- while (!WorkList.empty()) {
- MIRef I = WorkList.pop_back_val();
- if (!VisitedOrInserted.count(I)) {
- if (!BBVisitedInfo[I.MBB].TileCfgForbidden) {
- // If the BB is all shapes reachable, stop sink and try to insert.
- InsertPoints.insert(I);
- } else {
- // Avoid the BB to be multi visited.
- VisitedOrInserted.insert(I);
- // Sink the inserting point along the chain with NeedTileCfgLiveIn =
- // true when MBB isn't all shapes reachable.
- for (auto *Succ : I.MBB->successors())
- if (BBVisitedInfo[Succ].NeedTileCfgLiveIn)
- WorkList.push_back(MIRef(Succ));
- }
- }
- }
- // A given point might be forked due to shape conditions are not met.
- for (MIRef I : InsertPoints) {
- // Make sure we insert ldtilecfg after the last shape def in MBB.
- if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
- I = ShapeBBs[I.MBB].back();
- // There're chances the MBB is sunk more than once. Record it to avoid
- // multi insert.
- if (VisitedOrInserted.insert(I).second) {
- auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin();
- addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::LDTILECFG)),
- SS);
- }
- }
- }
- // Zero stack slot.
- MachineBasicBlock &MBB = MF.front();
- MachineInstr *MI = &*MBB.begin();
- if (ST.hasAVX512()) {
- Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
- BuildMI(MBB, MI, DL, TII->get(X86::VPXORDZrr), Zmm)
- .addReg(Zmm, RegState::Undef)
- .addReg(Zmm, RegState::Undef);
- addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS)
- .addReg(Zmm);
- } else if (ST.hasAVX2()) {
- Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
- BuildMI(MBB, MI, DL, TII->get(X86::VPXORYrr), Ymm)
- .addReg(Ymm, RegState::Undef)
- .addReg(Ymm, RegState::Undef);
- addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS)
- .addReg(Ymm);
- addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32)
- .addReg(Ymm);
- } else {
- assert(ST.hasSSE2() && "AMX should assume SSE2 enabled");
- Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
- BuildMI(MBB, MI, DL, TII->get(X86::PXORrr), Xmm)
- .addReg(Xmm, RegState::Undef)
- .addReg(Xmm, RegState::Undef);
- addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS)
- .addReg(Xmm);
- addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 16)
- .addReg(Xmm);
- addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 32)
- .addReg(Xmm);
- addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48)
- .addReg(Xmm);
- }
- // Fill in the palette first.
- addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1);
- return true;
- }
- FunctionPass *llvm::createX86PreTileConfigPass() {
- return new X86PreTileConfig();
- }
|