123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423 |
- //===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===//
- //
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
- // See https://llvm.org/LICENSE.txt for license information.
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- //
- //===----------------------------------------------------------------------===//
- //
- /// Insert tilecfg for each area of key AMX intrinsic.
- /// All the key AMX intrinsic's tile operand must come from tileload. And the
- /// def tile of key AMX intrinsic must be tilestored.
- /// take tdpbssd for example:
- /// --------------------------------------------------------------------------
- /// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...) key
- /// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...) |
- /// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...) amx
- /// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3) |
- /// call void @llvm.x86.tilestored64.internal(... td) area
- /// --------------------------------------------------------------------------
- /// This pass will insert tilecfg before every key-amx-area, some like:
- /// --------------------------------------------------------------------------
- /// %cfgmem = alloca <16 x i32>, align 4 * allocate mem
- /// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init
- /// ...
- /// ... pre-config shape of %t1 *
- /// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
- /// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
- /// ... *
- /// ... pre-config shape of %t2 * shapes
- /// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 *
- /// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
- /// ...
- /// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * tile config
- //
- //===----------------------------------------------------------------------===//
- //
- #include "X86.h"
- #include "llvm/ADT/SmallSet.h"
- #include "llvm/Analysis/TargetTransformInfo.h"
- #include "llvm/CodeGen/Passes.h"
- #include "llvm/CodeGen/TargetPassConfig.h"
- #include "llvm/CodeGen/ValueTypes.h"
- #include "llvm/IR/DataLayout.h"
- #include "llvm/IR/Function.h"
- #include "llvm/IR/IRBuilder.h"
- #include "llvm/IR/Instructions.h"
- #include "llvm/IR/IntrinsicInst.h"
- #include "llvm/IR/IntrinsicsX86.h"
- #include "llvm/IR/PatternMatch.h"
- #include "llvm/InitializePasses.h"
- #include "llvm/Pass.h"
- #include "llvm/Support/raw_ostream.h"
- #include "llvm/Target/TargetMachine.h"
- using namespace llvm;
- using namespace PatternMatch;
- #define DEBUG_TYPE "pre-amx-config"
- static bool isAMXIntrinsic(IntrinsicInst *II) {
- for (Value *Operand : II->operands())
- if (Operand->getType()->isX86_AMXTy())
- return true;
- return II->getType()->isX86_AMXTy();
- }
- static bool isTileLoad(IntrinsicInst *II) {
- return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal ||
- II->getIntrinsicID() == Intrinsic::x86_tileloaddt164_internal;
- }
- static bool isTileStore(IntrinsicInst *II) {
- return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal;
- }
- #ifndef NDEBUG
- static bool onlyTileDef(IntrinsicInst *II) {
- for (Value *Operand : II->operands())
- if (Operand->getType()->isX86_AMXTy())
- return false;
- return II->getType()->isX86_AMXTy();
- }
- static bool brokenVolatile(Instruction *I) {
- // Todo: it is weak to identify a normal call here.
- if ((isa<CallInst>(I) && !isa<IntrinsicInst>(I)) || I->isTerminator())
- return true;
- return false;
- }
- #endif
- namespace {
- class X86PreAMXConfig {
- Function &F;
- public:
- X86PreAMXConfig(Function &Func) : F(Func) {}
- bool preTileConfig();
- bool addTileConfig(Instruction *ModelStart, SmallVector<Value *, 8> &Shapes);
- bool findConfigShapes(
- DenseMap<Instruction *, SmallVector<Value *, 8>> &PosAndShapes);
- bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector<Value *, 8> &Shapes);
- bool preWriteTileCfg(Value *I8Ptr, Instruction *Pos,
- SmallVector<Value *, 8> &Shapes);
- BasicBlock::iterator
- getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
- SmallVector<Value *, 8> &Shapes);
- bool checkVolatileModel(SmallSet<Value *, 4> &Loads, IntrinsicInst *Store,
- IntrinsicInst *KeyAMX);
- };
- // Orderly write the shapes in tilecfg's mem. This maybe not right.
- // Because the first shape may not corresponding to the first tmm register,
- // so we need to handle at at X86FastTileConfig::materializeTileCfg()
- // after register allocation.
- // For example:
- // --------------------------------------------------------------------------
- // zeroinitialize tilecfg's mem (of ldtilecfg)
- // --------------------------------------------------------------------------
- // ... pre-config shape of %t1 *
- // %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48 *
- // %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 *
- // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
- // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
- // ... *
- // ... pre-config shape of %t2 *
- // %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49 *
- // %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 *
- // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes
- // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
- // ... *
- // ... pre-config shape of %t3 * of
- // %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50 *
- // %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 *
- // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 *
- // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 *
- // ... * tiles
- // ... pre-config shape of %td *
- // %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51 *
- // %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 *
- // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 *
- // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 *
- // --------------------------------------------------------------------------
- // call void @llvm.x86.ldtilecfg(i8* %mem) * tile config
- // --------------------------------------------------------------------------
- // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
- // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
- // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
- // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
- // call void @llvm.x86.tilestored64.internal(... td) area
- // --------------------------------------------------------------------------
- bool X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, Instruction *Pos,
- SmallVector<Value *, 8> &Shapes) {
- bool Write = false;
- LLVMContext &Ctx = Pos->getParent()->getContext();
- Type *I8Ty = Type::getInt8Ty(Ctx);
- Type *I16Ty = Type::getInt16Ty(Ctx);
- // TODO: Currently we defaultly set Palette = 1, it may be assigned to
- // other value in the future.
- Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0);
- Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
- Value *PalettePos =
- GetElementPtrInst::Create(I8Ty, I8Ptr, PaletteOffset, "", Pos);
- new StoreInst(PaletteValue, PalettePos, Pos);
- for (int I = 0, E = Shapes.size() / 2; I < E; I++) {
- Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I);
- Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2);
- const std::string ShapeName = "amx.tmm." + itostr(I);
- Value *RowPos = GetElementPtrInst::Create(I8Ty, I8Ptr, RowOffset,
- ShapeName + ".shape.row", Pos);
- Value *ColPos = GetElementPtrInst::Create(I8Ty, I8Ptr, ColOffset, "", Pos);
- ColPos = new BitCastInst(ColPos, PointerType::get(I16Ty, 0),
- ShapeName + ".shape.col", Pos);
- Value *Row = Shapes[I * 2];
- Value *Col = Shapes[I * 2 + 1];
- Row = new TruncInst(Row, I8Ty, "", Pos);
- new StoreInst(Row, RowPos, Pos);
- new StoreInst(Col, ColPos, Pos);
- Write = true;
- }
- return Write;
- }
- bool X86PreAMXConfig::addTileConfig(Instruction *ModelStart,
- SmallVector<Value *, 8> &Shapes) {
- Module *M = F.getParent();
- IRBuilder<> Builder(ModelStart);
- const DataLayout &DL = M->getDataLayout();
- unsigned AddrSpace = DL.getAllocaAddrSpace();
- LLVMContext &Ctx = Builder.getContext();
- Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false);
- Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx));
- AllocaInst *Addr =
- new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front());
- Addr->setAlignment(Alignment);
- Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy());
- std::array<Value *, 1> Args = {I8Ptr};
- Instruction *Cfg =
- Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, None, Args);
- Value *Val0 = Constant::getNullValue(V512Ty);
- Instruction *Init0 = new StoreInst(Val0, Addr, false, Alignment, Cfg);
- assert(Init0 && "Not Zero initilizate the cfg mem!");
- preWriteTileCfg(I8Ptr, Cfg, Shapes);
- return Init0;
- }
- // Todo: We may need to handle "more than one store" case in the future.
- bool X86PreAMXConfig::checkVolatileModel(SmallSet<Value *, 4> &Loads,
- IntrinsicInst *Store,
- IntrinsicInst *KeyAMX) {
- Value *ST = Store->getOperand(4);
- // Only has tileload and tilestore.
- if (!KeyAMX)
- return (Loads.size() == 1) && Loads.contains(ST);
- // All Loads should be operands of KeyAMX.
- // All tile operands of KeyAMX should come from Loads.
- for (Value *Op : KeyAMX->operands()) {
- if (Op->getType()->isX86_AMXTy())
- if (!Loads.erase(Op))
- return false;
- }
- // The def of KeyAMX should be stored into mem.
- // Todo: is it key amx can be no def?
- return Loads.empty() && (ST == cast<Value>(KeyAMX));
- }
- bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX,
- SmallVector<Value *, 8> &Shapes) {
- for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) {
- Value *Op = KeyAMX->getOperand(I);
- if (!Op->getType()->isX86_AMXTy())
- continue;
- IntrinsicInst *TileDef = dyn_cast<IntrinsicInst>(Op);
- assert((TileDef && isTileLoad(TileDef)) &&
- "All KeyAMX's tile definiation should comes from TileLoad!");
- Shapes.push_back(TileDef->getOperand(0));
- Shapes.push_back(TileDef->getOperand(1));
- }
- if (!isTileStore(KeyAMX)) {
- Shapes.push_back(KeyAMX->getOperand(0));
- Shapes.push_back(KeyAMX->getOperand(1));
- }
- return Shapes.size() != 0;
- }
- // Collect the shapes and skip the area of current key amx intrinsic.
- //
- // For example:
- // ...
- // --------------------------------------------------------------------------
- // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) record (m,k)
- // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) record (m,k)
- // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) record (m,k)
- // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)
- // call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k)
- // --------------------------------------------------------------------------
- BasicBlock::iterator
- X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
- SmallVector<Value *, 8> &Shapes) {
- IntrinsicInst *KeyAMX = nullptr;
- BasicBlock *BB = Iter->getParent();
- BasicBlock::iterator PosEnd = BB->end();
- SmallSet<Value *, 4> Loads;
- // See TileStore as "Config Position End" and check volatile model.
- for (auto I = Iter, E = BB->end(); I != E; ++I) {
- assert(!brokenVolatile(&*I) && "Not reach tile store!");
- IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
- if (!II || !isAMXIntrinsic(II))
- continue;
- if (isTileLoad(II)) {
- Loads.insert(II);
- } else if (isTileStore(II)) {
- if (!checkVolatileModel(Loads, II, KeyAMX))
- report_fatal_error("Not Volatile AMX Model!");
- PosEnd = I;
- break;
- } else {
- assert(!KeyAMX && "Too many key amx intrinsic!");
- KeyAMX = II;
- }
- }
- assert(PosEnd != BB->end() && "Not find TileStore!");
- // See KeyAMX as TileStore if only TileLoad and TileStore.
- if (!KeyAMX)
- KeyAMX = dyn_cast<IntrinsicInst>(&*PosEnd);
- // Get Shapes in order.
- assert(Shapes.empty() && "Shapes should be clean.");
- getKeyAMXShapes(KeyAMX, Shapes);
- return PosEnd;
- }
- // Record a key amx area's shapes with its position.
- // Use the first tileload as its position.
- // For example:
- // ...
- // --------------------------------------------------------------------------
- // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) <-- pos
- // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) /
- // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) shapes:
- // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) (m,k)(k,n)
- // call void @llvm.x86.tilestored64.internal(m, n,... td) (m,n)(m,n)
- // --------------------------------------------------------------------------
- bool X86PreAMXConfig::findConfigShapes(
- DenseMap<Instruction *, SmallVector<Value *, 8>> &PosAndShapes) {
- bool Find = false;
- for (BasicBlock &BB : F) {
- for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
- IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
- if (!II)
- continue;
- if (!isAMXIntrinsic(II))
- continue;
- assert(onlyTileDef(II) && "Not volatile model for AMX at O0!");
- I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]);
- Find = true;
- }
- }
- return Find;
- }
- // Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic.
- // e.g. (key amx = tdpbssd)
- // --------------------------------------------------------------------------
- // %cfgmem = alloca <16 x i32>, align 4 * allocate mem
- // store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init
- // ...
- // ... pre-config shape of %t1 *
- // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
- // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
- // ... *
- // ... pre-config shape of %t2 *
- // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes
- // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
- // ... *
- // ... pre-config shape of %t3 * of
- // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 *
- // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 *
- // ... * tiles
- // ... pre-config shape of %td *
- // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 *
- // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 *
- //
- // call void @llvm.x86.ldtilecfg(i8* %cfgmem) * pre-config
- // --------------------------------------------------------------------------
- // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
- // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
- // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
- // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
- // call void @llvm.x86.tilestored64.internal(... td) area
- // --------------------------------------------------------------------------
- bool X86PreAMXConfig::preTileConfig() {
- DenseMap<Instruction *, SmallVector<Value *, 8>> PosAndShapes;
- bool NeedCfg = findConfigShapes(PosAndShapes);
- if (!NeedCfg)
- return false;
- for (auto &IPAndShapes : PosAndShapes)
- addTileConfig(IPAndShapes.first, IPAndShapes.second);
- return true;
- }
- } // anonymous namespace
- namespace {
- class X86PreAMXConfigPass : public FunctionPass {
- public:
- static char ID;
- X86PreAMXConfigPass() : FunctionPass(ID) {
- initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry());
- }
- bool runOnFunction(Function &F) override {
- TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
- bool C = false;
- // Prepare for fast register allocation at O0.
- if (TM->getOptLevel() == CodeGenOpt::None) {
- // We pre-config each key AMX intrinsic at O0.
- // In theory, one tile config can cover several AMX intrinsics, but
- // it is very diffcult to classify the tile shapes at O0. So here we
- // let thing be easy, pre-config every key AMX intrinsic.
- X86PreAMXConfig PCFG(F);
- C = PCFG.preTileConfig();
- }
- return C;
- }
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesCFG();
- AU.addRequired<TargetPassConfig>();
- }
- };
- } // anonymous namespace
- static const char PassName[] = "Pre AMX Tile Config";
- char X86PreAMXConfigPass::ID = 0;
- INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
- INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
- INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
- FunctionPass *llvm::createX86PreAMXConfigPass() {
- return new X86PreAMXConfigPass();
- }
|