X86PreAMXConfig.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. //===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. /// Insert tilecfg for each area of key AMX intrinsic.
  10. /// All the key AMX intrinsic's tile operand must come from tileload. And the
  11. /// def tile of key AMX intrinsic must be tilestored.
  12. /// take tdpbssd for example:
  13. /// --------------------------------------------------------------------------
  14. /// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...) key
  15. /// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...) |
  16. /// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...) amx
  17. /// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3) |
  18. /// call void @llvm.x86.tilestored64.internal(... td) area
  19. /// --------------------------------------------------------------------------
  20. /// This pass will insert tilecfg before every key-amx-area, some like:
  21. /// --------------------------------------------------------------------------
  22. /// %cfgmem = alloca <16 x i32>, align 4 * allocate mem
  23. /// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init
  24. /// ...
  25. /// ... pre-config shape of %t1 *
  26. /// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
  27. /// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
  28. /// ... *
  29. /// ... pre-config shape of %t2 * shapes
  30. /// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 *
  31. /// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
  32. /// ...
  33. /// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * tile config
  34. //
  35. //===----------------------------------------------------------------------===//
  36. //
  37. #include "X86.h"
  38. #include "llvm/ADT/SmallSet.h"
  39. #include "llvm/Analysis/TargetTransformInfo.h"
  40. #include "llvm/CodeGen/Passes.h"
  41. #include "llvm/CodeGen/TargetPassConfig.h"
  42. #include "llvm/CodeGen/ValueTypes.h"
  43. #include "llvm/IR/DataLayout.h"
  44. #include "llvm/IR/Function.h"
  45. #include "llvm/IR/IRBuilder.h"
  46. #include "llvm/IR/Instructions.h"
  47. #include "llvm/IR/IntrinsicInst.h"
  48. #include "llvm/IR/IntrinsicsX86.h"
  49. #include "llvm/IR/PatternMatch.h"
  50. #include "llvm/InitializePasses.h"
  51. #include "llvm/Pass.h"
  52. #include "llvm/Support/raw_ostream.h"
  53. #include "llvm/Target/TargetMachine.h"
  54. using namespace llvm;
  55. using namespace PatternMatch;
  56. #define DEBUG_TYPE "pre-amx-config"
  57. static bool isAMXIntrinsic(IntrinsicInst *II) {
  58. for (Value *Operand : II->operands())
  59. if (Operand->getType()->isX86_AMXTy())
  60. return true;
  61. return II->getType()->isX86_AMXTy();
  62. }
  63. static bool isTileLoad(IntrinsicInst *II) {
  64. return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal ||
  65. II->getIntrinsicID() == Intrinsic::x86_tileloaddt164_internal;
  66. }
  67. static bool isTileStore(IntrinsicInst *II) {
  68. return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal;
  69. }
  70. #ifndef NDEBUG
  71. static bool onlyTileDef(IntrinsicInst *II) {
  72. for (Value *Operand : II->operands())
  73. if (Operand->getType()->isX86_AMXTy())
  74. return false;
  75. return II->getType()->isX86_AMXTy();
  76. }
  77. static bool brokenVolatile(Instruction *I) {
  78. // Todo: it is weak to identify a normal call here.
  79. if ((isa<CallInst>(I) && !isa<IntrinsicInst>(I)) || I->isTerminator())
  80. return true;
  81. return false;
  82. }
  83. #endif
  84. namespace {
  85. class X86PreAMXConfig {
  86. Function &F;
  87. public:
  88. X86PreAMXConfig(Function &Func) : F(Func) {}
  89. bool preTileConfig();
  90. bool addTileConfig(Instruction *ModelStart, SmallVector<Value *, 8> &Shapes);
  91. bool findConfigShapes(
  92. DenseMap<Instruction *, SmallVector<Value *, 8>> &PosAndShapes);
  93. bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector<Value *, 8> &Shapes);
  94. bool preWriteTileCfg(Value *I8Ptr, Instruction *Pos,
  95. SmallVector<Value *, 8> &Shapes);
  96. BasicBlock::iterator
  97. getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
  98. SmallVector<Value *, 8> &Shapes);
  99. bool checkVolatileModel(SmallSet<Value *, 4> &Loads, IntrinsicInst *Store,
  100. IntrinsicInst *KeyAMX);
  101. };
  102. // Orderly write the shapes in tilecfg's mem. This maybe not right.
  103. // Because the first shape may not corresponding to the first tmm register,
  104. // so we need to handle at at X86FastTileConfig::materializeTileCfg()
  105. // after register allocation.
  106. // For example:
  107. // --------------------------------------------------------------------------
  108. // zeroinitialize tilecfg's mem (of ldtilecfg)
  109. // --------------------------------------------------------------------------
  110. // ... pre-config shape of %t1 *
  111. // %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48 *
  112. // %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 *
  113. // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
  114. // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
  115. // ... *
  116. // ... pre-config shape of %t2 *
  117. // %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49 *
  118. // %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 *
  119. // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes
  120. // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
  121. // ... *
  122. // ... pre-config shape of %t3 * of
  123. // %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50 *
  124. // %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 *
  125. // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 *
  126. // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 *
  127. // ... * tiles
  128. // ... pre-config shape of %td *
  129. // %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51 *
  130. // %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 *
  131. // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 *
  132. // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 *
  133. // --------------------------------------------------------------------------
  134. // call void @llvm.x86.ldtilecfg(i8* %mem) * tile config
  135. // --------------------------------------------------------------------------
  136. // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
  137. // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
  138. // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
  139. // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
  140. // call void @llvm.x86.tilestored64.internal(... td) area
  141. // --------------------------------------------------------------------------
  142. bool X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, Instruction *Pos,
  143. SmallVector<Value *, 8> &Shapes) {
  144. bool Write = false;
  145. LLVMContext &Ctx = Pos->getParent()->getContext();
  146. Type *I8Ty = Type::getInt8Ty(Ctx);
  147. Type *I16Ty = Type::getInt16Ty(Ctx);
  148. // TODO: Currently we defaultly set Palette = 1, it may be assigned to
  149. // other value in the future.
  150. Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0);
  151. Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
  152. Value *PalettePos =
  153. GetElementPtrInst::Create(I8Ty, I8Ptr, PaletteOffset, "", Pos);
  154. new StoreInst(PaletteValue, PalettePos, Pos);
  155. for (int I = 0, E = Shapes.size() / 2; I < E; I++) {
  156. Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I);
  157. Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2);
  158. const std::string ShapeName = "amx.tmm." + itostr(I);
  159. Value *RowPos = GetElementPtrInst::Create(I8Ty, I8Ptr, RowOffset,
  160. ShapeName + ".shape.row", Pos);
  161. Value *ColPos = GetElementPtrInst::Create(I8Ty, I8Ptr, ColOffset, "", Pos);
  162. ColPos = new BitCastInst(ColPos, PointerType::get(I16Ty, 0),
  163. ShapeName + ".shape.col", Pos);
  164. Value *Row = Shapes[I * 2];
  165. Value *Col = Shapes[I * 2 + 1];
  166. Row = new TruncInst(Row, I8Ty, "", Pos);
  167. new StoreInst(Row, RowPos, Pos);
  168. new StoreInst(Col, ColPos, Pos);
  169. Write = true;
  170. }
  171. return Write;
  172. }
  173. bool X86PreAMXConfig::addTileConfig(Instruction *ModelStart,
  174. SmallVector<Value *, 8> &Shapes) {
  175. Module *M = F.getParent();
  176. IRBuilder<> Builder(ModelStart);
  177. const DataLayout &DL = M->getDataLayout();
  178. unsigned AddrSpace = DL.getAllocaAddrSpace();
  179. LLVMContext &Ctx = Builder.getContext();
  180. Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false);
  181. Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx));
  182. AllocaInst *Addr =
  183. new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front());
  184. Addr->setAlignment(Alignment);
  185. Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy());
  186. std::array<Value *, 1> Args = {I8Ptr};
  187. Instruction *Cfg =
  188. Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, None, Args);
  189. Value *Val0 = Constant::getNullValue(V512Ty);
  190. Instruction *Init0 = new StoreInst(Val0, Addr, false, Alignment, Cfg);
  191. assert(Init0 && "Not Zero initilizate the cfg mem!");
  192. preWriteTileCfg(I8Ptr, Cfg, Shapes);
  193. return Init0;
  194. }
  195. // Todo: We may need to handle "more than one store" case in the future.
  196. bool X86PreAMXConfig::checkVolatileModel(SmallSet<Value *, 4> &Loads,
  197. IntrinsicInst *Store,
  198. IntrinsicInst *KeyAMX) {
  199. Value *ST = Store->getOperand(4);
  200. // Only has tileload and tilestore.
  201. if (!KeyAMX)
  202. return (Loads.size() == 1) && Loads.contains(ST);
  203. // All Loads should be operands of KeyAMX.
  204. // All tile operands of KeyAMX should come from Loads.
  205. for (Value *Op : KeyAMX->operands()) {
  206. if (Op->getType()->isX86_AMXTy())
  207. if (!Loads.erase(Op))
  208. return false;
  209. }
  210. // The def of KeyAMX should be stored into mem.
  211. // Todo: is it key amx can be no def?
  212. return Loads.empty() && (ST == cast<Value>(KeyAMX));
  213. }
  214. bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX,
  215. SmallVector<Value *, 8> &Shapes) {
  216. for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) {
  217. Value *Op = KeyAMX->getOperand(I);
  218. if (!Op->getType()->isX86_AMXTy())
  219. continue;
  220. IntrinsicInst *TileDef = dyn_cast<IntrinsicInst>(Op);
  221. assert((TileDef && isTileLoad(TileDef)) &&
  222. "All KeyAMX's tile definiation should comes from TileLoad!");
  223. Shapes.push_back(TileDef->getOperand(0));
  224. Shapes.push_back(TileDef->getOperand(1));
  225. }
  226. if (!isTileStore(KeyAMX)) {
  227. Shapes.push_back(KeyAMX->getOperand(0));
  228. Shapes.push_back(KeyAMX->getOperand(1));
  229. }
  230. return Shapes.size() != 0;
  231. }
  232. // Collect the shapes and skip the area of current key amx intrinsic.
  233. //
  234. // For example:
  235. // ...
  236. // --------------------------------------------------------------------------
  237. // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) record (m,k)
  238. // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) record (m,k)
  239. // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) record (m,k)
  240. // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)
  241. // call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k)
  242. // --------------------------------------------------------------------------
  243. BasicBlock::iterator
  244. X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
  245. SmallVector<Value *, 8> &Shapes) {
  246. IntrinsicInst *KeyAMX = nullptr;
  247. BasicBlock *BB = Iter->getParent();
  248. BasicBlock::iterator PosEnd = BB->end();
  249. SmallSet<Value *, 4> Loads;
  250. // See TileStore as "Config Position End" and check volatile model.
  251. for (auto I = Iter, E = BB->end(); I != E; ++I) {
  252. assert(!brokenVolatile(&*I) && "Not reach tile store!");
  253. IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
  254. if (!II || !isAMXIntrinsic(II))
  255. continue;
  256. if (isTileLoad(II)) {
  257. Loads.insert(II);
  258. } else if (isTileStore(II)) {
  259. if (!checkVolatileModel(Loads, II, KeyAMX))
  260. report_fatal_error("Not Volatile AMX Model!");
  261. PosEnd = I;
  262. break;
  263. } else {
  264. assert(!KeyAMX && "Too many key amx intrinsic!");
  265. KeyAMX = II;
  266. }
  267. }
  268. assert(PosEnd != BB->end() && "Not find TileStore!");
  269. // See KeyAMX as TileStore if only TileLoad and TileStore.
  270. if (!KeyAMX)
  271. KeyAMX = dyn_cast<IntrinsicInst>(&*PosEnd);
  272. // Get Shapes in order.
  273. assert(Shapes.empty() && "Shapes should be clean.");
  274. getKeyAMXShapes(KeyAMX, Shapes);
  275. return PosEnd;
  276. }
  277. // Record a key amx area's shapes with its position.
  278. // Use the first tileload as its position.
  279. // For example:
  280. // ...
  281. // --------------------------------------------------------------------------
  282. // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) <-- pos
  283. // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) /
  284. // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) shapes:
  285. // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) (m,k)(k,n)
  286. // call void @llvm.x86.tilestored64.internal(m, n,... td) (m,n)(m,n)
  287. // --------------------------------------------------------------------------
  288. bool X86PreAMXConfig::findConfigShapes(
  289. DenseMap<Instruction *, SmallVector<Value *, 8>> &PosAndShapes) {
  290. bool Find = false;
  291. for (BasicBlock &BB : F) {
  292. for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
  293. IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
  294. if (!II)
  295. continue;
  296. if (!isAMXIntrinsic(II))
  297. continue;
  298. assert(onlyTileDef(II) && "Not volatile model for AMX at O0!");
  299. I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]);
  300. Find = true;
  301. }
  302. }
  303. return Find;
  304. }
  305. // Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic.
  306. // e.g. (key amx = tdpbssd)
  307. // --------------------------------------------------------------------------
  308. // %cfgmem = alloca <16 x i32>, align 4 * allocate mem
  309. // store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init
  310. // ...
  311. // ... pre-config shape of %t1 *
  312. // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
  313. // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
  314. // ... *
  315. // ... pre-config shape of %t2 *
  316. // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes
  317. // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
  318. // ... *
  319. // ... pre-config shape of %t3 * of
  320. // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 *
  321. // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 *
  322. // ... * tiles
  323. // ... pre-config shape of %td *
  324. // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 *
  325. // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 *
  326. //
  327. // call void @llvm.x86.ldtilecfg(i8* %cfgmem) * pre-config
  328. // --------------------------------------------------------------------------
  329. // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
  330. // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
  331. // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
  332. // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
  333. // call void @llvm.x86.tilestored64.internal(... td) area
  334. // --------------------------------------------------------------------------
  335. bool X86PreAMXConfig::preTileConfig() {
  336. DenseMap<Instruction *, SmallVector<Value *, 8>> PosAndShapes;
  337. bool NeedCfg = findConfigShapes(PosAndShapes);
  338. if (!NeedCfg)
  339. return false;
  340. for (auto &IPAndShapes : PosAndShapes)
  341. addTileConfig(IPAndShapes.first, IPAndShapes.second);
  342. return true;
  343. }
  344. } // anonymous namespace
  345. namespace {
  346. class X86PreAMXConfigPass : public FunctionPass {
  347. public:
  348. static char ID;
  349. X86PreAMXConfigPass() : FunctionPass(ID) {
  350. initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry());
  351. }
  352. bool runOnFunction(Function &F) override {
  353. TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
  354. bool C = false;
  355. // Prepare for fast register allocation at O0.
  356. if (TM->getOptLevel() == CodeGenOpt::None) {
  357. // We pre-config each key AMX intrinsic at O0.
  358. // In theory, one tile config can cover several AMX intrinsics, but
  359. // it is very diffcult to classify the tile shapes at O0. So here we
  360. // let thing be easy, pre-config every key AMX intrinsic.
  361. X86PreAMXConfig PCFG(F);
  362. C = PCFG.preTileConfig();
  363. }
  364. return C;
  365. }
  366. void getAnalysisUsage(AnalysisUsage &AU) const override {
  367. AU.setPreservesCFG();
  368. AU.addRequired<TargetPassConfig>();
  369. }
  370. };
  371. } // anonymous namespace
  372. static const char PassName[] = "Pre AMX Tile Config";
  373. char X86PreAMXConfigPass::ID = 0;
  374. INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
  375. INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
  376. INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
  377. FunctionPass *llvm::createX86PreAMXConfigPass() {
  378. return new X86PreAMXConfigPass();
  379. }