X86PreAMXConfig.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. //===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. /// Insert tilecfg for each area of key AMX intrinsic.
  10. /// All the key AMX intrinsic's tile operand must come from tileload. And the
  11. /// def tile of key AMX intrinsic must be tilestored.
  12. /// take tdpbssd for example:
  13. /// --------------------------------------------------------------------------
  14. /// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...) key
  15. /// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...) |
  16. /// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...) amx
  17. /// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3) |
  18. /// call void @llvm.x86.tilestored64.internal(... td) area
  19. /// --------------------------------------------------------------------------
  20. /// This pass will insert tilecfg before every key-amx-area, some like:
  21. /// --------------------------------------------------------------------------
  22. /// %cfgmem = alloca <16 x i32>, align 4 * allocate mem
  23. /// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init
  24. /// ...
  25. /// ... pre-config shape of %t1 *
  26. /// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
  27. /// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
  28. /// ... *
  29. /// ... pre-config shape of %t2 * shapes
  30. /// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 *
  31. /// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
  32. /// ...
  33. /// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * tile config
  34. //
  35. //===----------------------------------------------------------------------===//
  36. //
  37. #include "X86.h"
  38. #include "llvm/ADT/SmallSet.h"
  39. #include "llvm/Analysis/TargetTransformInfo.h"
  40. #include "llvm/CodeGen/Passes.h"
  41. #include "llvm/CodeGen/TargetPassConfig.h"
  42. #include "llvm/CodeGen/ValueTypes.h"
  43. #include "llvm/IR/DataLayout.h"
  44. #include "llvm/IR/Function.h"
  45. #include "llvm/IR/IRBuilder.h"
  46. #include "llvm/IR/Instructions.h"
  47. #include "llvm/IR/IntrinsicInst.h"
  48. #include "llvm/IR/IntrinsicsX86.h"
  49. #include "llvm/IR/PatternMatch.h"
  50. #include "llvm/InitializePasses.h"
  51. #include "llvm/Pass.h"
  52. #include "llvm/Support/raw_ostream.h"
  53. #include "llvm/Target/TargetMachine.h"
  54. using namespace llvm;
  55. using namespace PatternMatch;
  56. #define DEBUG_TYPE "pre-amx-config"
  57. static bool isAMXIntrinsic(IntrinsicInst *II) {
  58. for (Value *Operand : II->operands())
  59. if (Operand->getType()->isX86_AMXTy())
  60. return true;
  61. return II->getType()->isX86_AMXTy();
  62. }
  63. static bool isTileLoad(IntrinsicInst *II) {
  64. return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal ||
  65. II->getIntrinsicID() == Intrinsic::x86_tileloaddt164_internal;
  66. }
  67. static bool isTileStore(IntrinsicInst *II) {
  68. return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal;
  69. }
  70. #ifndef NDEBUG
  71. static bool onlyTileDef(IntrinsicInst *II) {
  72. for (Value *Operand : II->operands())
  73. if (Operand->getType()->isX86_AMXTy())
  74. return false;
  75. return II->getType()->isX86_AMXTy();
  76. }
  77. static bool brokenVolatile(Instruction *I) {
  78. // Todo: it is weak to identify a normal call here.
  79. if ((isa<CallInst>(I) && !isa<IntrinsicInst>(I)) || I->isTerminator())
  80. return true;
  81. return false;
  82. }
  83. #endif
  84. namespace {
  85. class X86PreAMXConfig {
  86. using PosAndShapesMap = MapVector<Instruction *, SmallVector<Value *, 8>>;
  87. Function &F;
  88. public:
  89. X86PreAMXConfig(Function &Func) : F(Func) {}
  90. bool preTileConfig();
  91. void addTileConfig(Instruction *ModelStart, SmallVector<Value *, 8> &Shapes);
  92. bool findConfigShapes(PosAndShapesMap &PosAndShapes);
  93. bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector<Value *, 8> &Shapes);
  94. void preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder,
  95. SmallVector<Value *, 8> &Shapes);
  96. BasicBlock::iterator
  97. getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
  98. SmallVector<Value *, 8> &Shapes);
  99. bool checkVolatileModel(SmallSet<Value *, 4> &Loads, IntrinsicInst *Store,
  100. IntrinsicInst *KeyAMX);
  101. };
  102. // Orderly write the shapes in tilecfg's mem. This maybe not right.
  103. // Because the first shape may not corresponding to the first tmm register,
  104. // so we need to handle at at X86FastTileConfig::materializeTileCfg()
  105. // after register allocation.
  106. // For example:
  107. // --------------------------------------------------------------------------
  108. // zeroinitialize tilecfg's mem (of ldtilecfg)
  109. // --------------------------------------------------------------------------
  110. // ... pre-config shape of %t1 *
  111. // %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48 *
  112. // %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 *
  113. // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
  114. // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
  115. // ... *
  116. // ... pre-config shape of %t2 *
  117. // %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49 *
  118. // %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 *
  119. // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes
  120. // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
  121. // ... *
  122. // ... pre-config shape of %t3 * of
  123. // %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50 *
  124. // %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 *
  125. // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 *
  126. // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 *
  127. // ... * tiles
  128. // ... pre-config shape of %td *
  129. // %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51 *
  130. // %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 *
  131. // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 *
  132. // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 *
  133. // --------------------------------------------------------------------------
  134. // call void @llvm.x86.ldtilecfg(i8* %mem) * tile config
  135. // --------------------------------------------------------------------------
  136. // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
  137. // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
  138. // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
  139. // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
  140. // call void @llvm.x86.tilestored64.internal(... td) area
  141. // --------------------------------------------------------------------------
  142. void X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder,
  143. SmallVector<Value *, 8> &Shapes) {
  144. LLVMContext &Ctx = Builder.getContext();
  145. Type *I8Ty = Type::getInt8Ty(Ctx);
  146. Type *I16Ty = Type::getInt16Ty(Ctx);
  147. // TODO: Currently we defaultly set Palette = 1, it may be assigned to
  148. // other value in the future.
  149. Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0);
  150. Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
  151. Value *PalettePos = Builder.CreateGEP(I8Ty, I8Ptr, PaletteOffset);
  152. Builder.CreateStore(PaletteValue, PalettePos);
  153. for (int I = 0, E = Shapes.size() / 2; I < E; I++) {
  154. Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I);
  155. Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2);
  156. const std::string ShapeName = "amx.tmm." + itostr(I);
  157. Value *RowPos = Builder.CreateGEP(I8Ty, I8Ptr, RowOffset,
  158. ShapeName + ".shape.row");
  159. Value *ColPos = Builder.CreateGEP(I8Ty, I8Ptr, ColOffset);
  160. ColPos = Builder.CreateBitCast(ColPos, PointerType::get(I16Ty, 0),
  161. ShapeName + ".shape.col");
  162. Value *Row = Shapes[I * 2];
  163. Value *Col = Shapes[I * 2 + 1];
  164. Row = Builder.CreateTrunc(Row, I8Ty);
  165. Builder.CreateStore(Row, RowPos);
  166. Builder.CreateStore(Col, ColPos);
  167. }
  168. }
  169. void X86PreAMXConfig::addTileConfig(Instruction *ModelStart,
  170. SmallVector<Value *, 8> &Shapes) {
  171. Module *M = F.getParent();
  172. IRBuilder<> Builder(ModelStart);
  173. const DataLayout &DL = M->getDataLayout();
  174. unsigned AddrSpace = DL.getAllocaAddrSpace();
  175. LLVMContext &Ctx = Builder.getContext();
  176. Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false);
  177. Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx));
  178. AllocaInst *Addr =
  179. new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front());
  180. Addr->setAlignment(Alignment);
  181. Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy());
  182. Builder.CreateAlignedStore(Constant::getNullValue(V512Ty), Addr, Alignment);
  183. preWriteTileCfg(I8Ptr, Builder, Shapes);
  184. Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, std::nullopt,
  185. {I8Ptr});
  186. }
  187. // Todo: We may need to handle "more than one store" case in the future.
  188. bool X86PreAMXConfig::checkVolatileModel(SmallSet<Value *, 4> &Loads,
  189. IntrinsicInst *Store,
  190. IntrinsicInst *KeyAMX) {
  191. Value *ST = Store->getOperand(4);
  192. // Only has tileload and tilestore.
  193. if (!KeyAMX)
  194. return (Loads.size() == 1) && Loads.contains(ST);
  195. // All Loads should be operands of KeyAMX.
  196. // All tile operands of KeyAMX should come from Loads.
  197. for (Value *Op : KeyAMX->operands()) {
  198. if (Op->getType()->isX86_AMXTy())
  199. if (!Loads.erase(Op))
  200. return false;
  201. }
  202. // The def of KeyAMX should be stored into mem.
  203. // Todo: is it key amx can be no def?
  204. return Loads.empty() && (ST == cast<Value>(KeyAMX));
  205. }
  206. bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX,
  207. SmallVector<Value *, 8> &Shapes) {
  208. for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) {
  209. Value *Op = KeyAMX->getOperand(I);
  210. if (!Op->getType()->isX86_AMXTy())
  211. continue;
  212. IntrinsicInst *TileDef = dyn_cast<IntrinsicInst>(Op);
  213. assert((TileDef && isTileLoad(TileDef)) &&
  214. "All KeyAMX's tile definiation should comes from TileLoad!");
  215. Shapes.push_back(TileDef->getOperand(0));
  216. Shapes.push_back(TileDef->getOperand(1));
  217. }
  218. if (!isTileStore(KeyAMX)) {
  219. Shapes.push_back(KeyAMX->getOperand(0));
  220. Shapes.push_back(KeyAMX->getOperand(1));
  221. }
  222. return Shapes.size() != 0;
  223. }
  224. // Collect the shapes and skip the area of current key amx intrinsic.
  225. //
  226. // For example:
  227. // ...
  228. // --------------------------------------------------------------------------
  229. // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) record (m,k)
  230. // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) record (m,k)
  231. // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) record (m,k)
  232. // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)
  233. // call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k)
  234. // --------------------------------------------------------------------------
  235. BasicBlock::iterator
  236. X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
  237. SmallVector<Value *, 8> &Shapes) {
  238. IntrinsicInst *KeyAMX = nullptr;
  239. BasicBlock *BB = Iter->getParent();
  240. BasicBlock::iterator PosEnd = BB->end();
  241. SmallSet<Value *, 4> Loads;
  242. // See TileStore as "Config Position End" and check volatile model.
  243. for (auto I = Iter, E = BB->end(); I != E; ++I) {
  244. assert(!brokenVolatile(&*I) && "Not reach tile store!");
  245. IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
  246. if (!II || !isAMXIntrinsic(II))
  247. continue;
  248. if (isTileLoad(II)) {
  249. Loads.insert(II);
  250. } else if (isTileStore(II)) {
  251. if (!checkVolatileModel(Loads, II, KeyAMX))
  252. report_fatal_error("Not Volatile AMX Model!");
  253. PosEnd = I;
  254. break;
  255. } else {
  256. assert(!KeyAMX && "Too many key amx intrinsic!");
  257. KeyAMX = II;
  258. }
  259. }
  260. assert(PosEnd != BB->end() && "Not find TileStore!");
  261. // See KeyAMX as TileStore if only TileLoad and TileStore.
  262. if (!KeyAMX)
  263. KeyAMX = dyn_cast<IntrinsicInst>(&*PosEnd);
  264. // Get Shapes in order.
  265. assert(Shapes.empty() && "Shapes should be clean.");
  266. getKeyAMXShapes(KeyAMX, Shapes);
  267. return PosEnd;
  268. }
  269. // Record a key amx area's shapes with its position.
  270. // Use the first tileload as its position.
  271. // For example:
  272. // ...
  273. // --------------------------------------------------------------------------
  274. // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) <-- pos
  275. // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) /
  276. // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) shapes:
  277. // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) (m,k)(k,n)
  278. // call void @llvm.x86.tilestored64.internal(m, n,... td) (m,n)(m,n)
  279. // --------------------------------------------------------------------------
  280. bool X86PreAMXConfig::findConfigShapes(PosAndShapesMap &PosAndShapes) {
  281. bool Find = false;
  282. for (BasicBlock &BB : F) {
  283. for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
  284. IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
  285. if (!II)
  286. continue;
  287. if (!isAMXIntrinsic(II))
  288. continue;
  289. assert(onlyTileDef(II) && "Not volatile model for AMX at O0!");
  290. I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]);
  291. Find = true;
  292. }
  293. }
  294. return Find;
  295. }
  296. // Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic.
  297. // e.g. (key amx = tdpbssd)
  298. // --------------------------------------------------------------------------
  299. // %cfgmem = alloca <16 x i32>, align 4 * allocate mem
  300. // store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init
  301. // ...
  302. // ... pre-config shape of %t1 *
  303. // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
  304. // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
  305. // ... *
  306. // ... pre-config shape of %t2 *
  307. // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes
  308. // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
  309. // ... *
  310. // ... pre-config shape of %t3 * of
  311. // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 *
  312. // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 *
  313. // ... * tiles
  314. // ... pre-config shape of %td *
  315. // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 *
  316. // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 *
  317. //
  318. // call void @llvm.x86.ldtilecfg(i8* %cfgmem) * pre-config
  319. // --------------------------------------------------------------------------
  320. // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
  321. // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
  322. // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
  323. // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
  324. // call void @llvm.x86.tilestored64.internal(... td) area
  325. // --------------------------------------------------------------------------
  326. bool X86PreAMXConfig::preTileConfig() {
  327. PosAndShapesMap PosAndShapes;
  328. bool NeedCfg = findConfigShapes(PosAndShapes);
  329. if (!NeedCfg)
  330. return false;
  331. for (auto &IPAndShapes : PosAndShapes)
  332. addTileConfig(IPAndShapes.first, IPAndShapes.second);
  333. return true;
  334. }
  335. } // anonymous namespace
  336. namespace {
  337. class X86PreAMXConfigPass : public FunctionPass {
  338. public:
  339. static char ID;
  340. X86PreAMXConfigPass() : FunctionPass(ID) {
  341. initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry());
  342. }
  343. bool runOnFunction(Function &F) override {
  344. TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
  345. bool C = false;
  346. // Prepare for fast register allocation at O0.
  347. if (TM->getOptLevel() == CodeGenOpt::None) {
  348. // We pre-config each key AMX intrinsic at O0.
  349. // In theory, one tile config can cover several AMX intrinsics, but
  350. // it is very diffcult to classify the tile shapes at O0. So here we
  351. // let thing be easy, pre-config every key AMX intrinsic.
  352. X86PreAMXConfig PCFG(F);
  353. C = PCFG.preTileConfig();
  354. }
  355. return C;
  356. }
  357. void getAnalysisUsage(AnalysisUsage &AU) const override {
  358. AU.setPreservesCFG();
  359. AU.addRequired<TargetPassConfig>();
  360. }
  361. };
  362. } // anonymous namespace
  363. static const char PassName[] = "Pre AMX Tile Config";
  364. char X86PreAMXConfigPass::ID = 0;
  365. INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
  366. INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
  367. INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
  368. FunctionPass *llvm::createX86PreAMXConfigPass() {
  369. return new X86PreAMXConfigPass();
  370. }