X86LowerAMXIntrinsics.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681
  1. //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. /// \file Pass to transform amx intrinsics to scalar operations.
  10. /// This pass is always enabled and it skips when it is not -O0 and has no
  11. /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
  12. /// intrinsics is near the amx intrinsics code. We are not able to find a
  13. /// point which post-dominate all the shape and dominate all amx intrinsics.
  14. /// To decouple the dependency of the shape, we transform amx intrinsics
  15. /// to scalar operation, so that compiling doesn't fail. In long term, we
  16. /// should improve fast register allocation to allocate amx register.
  17. //===----------------------------------------------------------------------===//
  18. //
  19. #include "X86.h"
  20. #include "llvm/ADT/DenseSet.h"
  21. #include "llvm/ADT/PostOrderIterator.h"
  22. #include "llvm/Analysis/DomTreeUpdater.h"
  23. #include "llvm/Analysis/LoopInfo.h"
  24. #include "llvm/Analysis/OptimizationRemarkEmitter.h"
  25. #include "llvm/Analysis/TargetTransformInfo.h"
  26. #include "llvm/CodeGen/Passes.h"
  27. #include "llvm/CodeGen/TargetPassConfig.h"
  28. #include "llvm/CodeGen/ValueTypes.h"
  29. #include "llvm/IR/DataLayout.h"
  30. #include "llvm/IR/Function.h"
  31. #include "llvm/IR/IRBuilder.h"
  32. #include "llvm/IR/Instructions.h"
  33. #include "llvm/IR/IntrinsicInst.h"
  34. #include "llvm/IR/IntrinsicsX86.h"
  35. #include "llvm/IR/PatternMatch.h"
  36. #include "llvm/InitializePasses.h"
  37. #include "llvm/Pass.h"
  38. #include "llvm/Support/CommandLine.h"
  39. #include "llvm/Target/TargetMachine.h"
  40. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  41. #include "llvm/Transforms/Utils/LoopUtils.h"
  42. using namespace llvm;
  43. using namespace PatternMatch;
  44. #define DEBUG_TYPE "lower-amx-intrinsics"
  45. #ifndef NDEBUG
  46. static bool isV256I32Ty(Type *Ty) {
  47. if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
  48. return FVT->getNumElements() == 256 &&
  49. FVT->getElementType()->isIntegerTy(32);
  50. return false;
  51. }
  52. #endif
  53. static cl::opt<bool>
  54. X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden,
  55. cl::desc("X86: enable AMX scalarizition."));
  56. namespace {
  57. class X86LowerAMXIntrinsics {
  58. Function &Func;
  59. public:
  60. X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI)
  61. : Func(F), DTU(DomTU), LI(LoopI) {}
  62. bool visit();
  63. private:
  64. DomTreeUpdater &DTU;
  65. LoopInfo *LI;
  66. BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound,
  67. Value *Step, StringRef Name, IRBuilderBase &B,
  68. Loop *L);
  69. template <bool IsTileLoad>
  70. Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
  71. IRBuilderBase &B, Value *Row, Value *Col,
  72. Value *Ptr, Value *Stride, Value *Tile);
  73. template <Intrinsic::ID IntrID>
  74. std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
  75. IntrID == Intrinsic::x86_tdpbsud_internal ||
  76. IntrID == Intrinsic::x86_tdpbusd_internal ||
  77. IntrID == Intrinsic::x86_tdpbuud_internal ||
  78. IntrID == Intrinsic::x86_tdpbf16ps_internal,
  79. Value *>
  80. createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
  81. Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
  82. Value *RHS);
  83. template <bool IsTileLoad>
  84. bool lowerTileLoadStore(Instruction *TileLoadStore);
  85. template <Intrinsic::ID IntrID>
  86. std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
  87. IntrID == Intrinsic::x86_tdpbsud_internal ||
  88. IntrID == Intrinsic::x86_tdpbusd_internal ||
  89. IntrID == Intrinsic::x86_tdpbuud_internal ||
  90. IntrID == Intrinsic::x86_tdpbf16ps_internal,
  91. bool>
  92. lowerTileDP(Instruction *TileDP);
  93. bool lowerTileZero(Instruction *TileZero);
  94. };
  95. } // anonymous namespace
  96. BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader,
  97. BasicBlock *Exit, Value *Bound,
  98. Value *Step, StringRef Name,
  99. IRBuilderBase &B, Loop *L) {
  100. LLVMContext &Ctx = Preheader->getContext();
  101. BasicBlock *Header =
  102. BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
  103. BasicBlock *Body =
  104. BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
  105. BasicBlock *Latch =
  106. BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
  107. Type *I16Ty = Type::getInt16Ty(Ctx);
  108. BranchInst::Create(Body, Header);
  109. BranchInst::Create(Latch, Body);
  110. PHINode *IV =
  111. PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator());
  112. IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
  113. B.SetInsertPoint(Latch);
  114. Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
  115. Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
  116. BranchInst::Create(Header, Exit, Cond, Latch);
  117. IV->addIncoming(Inc, Latch);
  118. BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
  119. BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
  120. PreheaderBr->setSuccessor(0, Header);
  121. DTU.applyUpdatesPermissive({
  122. {DominatorTree::Delete, Preheader, Tmp},
  123. {DominatorTree::Insert, Header, Body},
  124. {DominatorTree::Insert, Body, Latch},
  125. {DominatorTree::Insert, Latch, Header},
  126. {DominatorTree::Insert, Latch, Exit},
  127. {DominatorTree::Insert, Preheader, Header},
  128. });
  129. if (LI) {
  130. L->addBasicBlockToLoop(Header, *LI);
  131. L->addBasicBlockToLoop(Body, *LI);
  132. L->addBasicBlockToLoop(Latch, *LI);
  133. }
  134. return Body;
  135. }
  136. template <bool IsTileLoad>
  137. Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
  138. BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
  139. Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
  140. std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
  141. Loop *RowLoop = nullptr;
  142. Loop *ColLoop = nullptr;
  143. if (LI) {
  144. RowLoop = LI->AllocateLoop();
  145. ColLoop = LI->AllocateLoop();
  146. RowLoop->addChildLoop(ColLoop);
  147. if (Loop *ParentL = LI->getLoopFor(Start))
  148. ParentL->addChildLoop(RowLoop);
  149. else
  150. LI->addTopLevelLoop(RowLoop);
  151. }
  152. BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
  153. IntrinName + ".scalarize.rows", B, RowLoop);
  154. BasicBlock *RowLatch = RowBody->getSingleSuccessor();
  155. BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
  156. IntrinName + ".scalarize.cols", B, ColLoop);
  157. BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
  158. BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
  159. BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
  160. Value *CurrentRow = &*RowLoopHeader->begin();
  161. Value *CurrentCol = &*ColLoopHeader->begin();
  162. Type *EltTy = B.getInt32Ty();
  163. FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256);
  164. // Common part for tileload and tilestore
  165. // *.scalarize.cols.body:
  166. // Calculate %idxmem and %idxvec
  167. B.SetInsertPoint(ColBody->getTerminator());
  168. Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
  169. Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
  170. Value *Offset =
  171. B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
  172. unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace();
  173. Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS));
  174. Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset);
  175. Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
  176. if (IsTileLoad) {
  177. // tileload.scalarize.rows.header:
  178. // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
  179. // %tileload.scalarize.rows.latch ]
  180. B.SetInsertPoint(RowLoopHeader->getTerminator());
  181. Value *VecZero = Constant::getNullValue(V256I32Ty);
  182. PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
  183. VecCPhiRowLoop->addIncoming(VecZero, Start);
  184. // tileload.scalarize.cols.header:
  185. // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
  186. // ], [ %ResVec, %tileload.scalarize.cols.latch ]
  187. B.SetInsertPoint(ColLoopHeader->getTerminator());
  188. PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
  189. VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
  190. // tileload.scalarize.cols.body:
  191. // Calculate %idxmem and %idxvec
  192. // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
  193. // %elt = load i32, i32* %ptr
  194. // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
  195. B.SetInsertPoint(ColBody->getTerminator());
  196. Value *Elt = B.CreateLoad(EltTy, EltPtr);
  197. Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
  198. VecPhi->addIncoming(ResVec, ColLoopLatch);
  199. VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
  200. return ResVec;
  201. } else {
  202. auto *BitCast = cast<BitCastInst>(Tile);
  203. Value *Vec = BitCast->getOperand(0);
  204. assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
  205. // tilestore.scalarize.cols.body:
  206. // %mul = mul i16 %row.iv, i16 16
  207. // %idx = add i16 %mul, i16 %col.iv
  208. // %vec = extractelement <16 x i32> %vec, i16 %idx
  209. // store i32 %vec, i32* %ptr
  210. B.SetInsertPoint(ColBody->getTerminator());
  211. Value *Elt = B.CreateExtractElement(Vec, Idx);
  212. B.CreateStore(Elt, EltPtr);
  213. return nullptr;
  214. }
  215. }
  216. template <Intrinsic::ID IntrID>
  217. std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
  218. IntrID == Intrinsic::x86_tdpbsud_internal ||
  219. IntrID == Intrinsic::x86_tdpbusd_internal ||
  220. IntrID == Intrinsic::x86_tdpbuud_internal ||
  221. IntrID == Intrinsic::x86_tdpbf16ps_internal,
  222. Value *>
  223. X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
  224. IRBuilderBase &B, Value *Row,
  225. Value *Col, Value *K, Value *Acc,
  226. Value *LHS, Value *RHS) {
  227. std::string IntrinName;
  228. switch (IntrID) {
  229. case Intrinsic::x86_tdpbssd_internal:
  230. IntrinName = "tiledpbssd";
  231. break;
  232. case Intrinsic::x86_tdpbsud_internal:
  233. IntrinName = "tiledpbsud";
  234. break;
  235. case Intrinsic::x86_tdpbusd_internal:
  236. IntrinName = "tiledpbusd";
  237. break;
  238. case Intrinsic::x86_tdpbuud_internal:
  239. IntrinName = "tiledpbuud";
  240. break;
  241. case Intrinsic::x86_tdpbf16ps_internal:
  242. IntrinName = "tiledpbf16ps";
  243. break;
  244. }
  245. Loop *RowLoop = nullptr;
  246. Loop *ColLoop = nullptr;
  247. Loop *InnerLoop = nullptr;
  248. if (LI) {
  249. RowLoop = LI->AllocateLoop();
  250. ColLoop = LI->AllocateLoop();
  251. InnerLoop = LI->AllocateLoop();
  252. ColLoop->addChildLoop(InnerLoop);
  253. RowLoop->addChildLoop(ColLoop);
  254. if (Loop *ParentL = LI->getLoopFor(Start))
  255. ParentL->addChildLoop(RowLoop);
  256. else
  257. LI->addTopLevelLoop(RowLoop);
  258. }
  259. BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
  260. IntrinName + ".scalarize.rows", B, RowLoop);
  261. BasicBlock *RowLatch = RowBody->getSingleSuccessor();
  262. BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
  263. IntrinName + ".scalarize.cols", B, ColLoop);
  264. BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
  265. B.SetInsertPoint(ColBody->getTerminator());
  266. BasicBlock *InnerBody =
  267. createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
  268. IntrinName + ".scalarize.inner", B, InnerLoop);
  269. BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
  270. BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
  271. BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
  272. BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
  273. Value *CurrentRow = &*RowLoopHeader->begin();
  274. Value *CurrentCol = &*ColLoopHeader->begin();
  275. Value *CurrentInner = &*InnerLoopHeader->begin();
  276. FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
  277. auto *BitCastAcc = cast<BitCastInst>(Acc);
  278. Value *VecC = BitCastAcc->getOperand(0);
  279. assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
  280. // TODO else create BitCast from x86amx to v256i32.
  281. // Store x86amx to memory, and reload from memory
  282. // to vector. However with -O0, it doesn't happen.
  283. auto *BitCastLHS = cast<BitCastInst>(LHS);
  284. Value *VecA = BitCastLHS->getOperand(0);
  285. assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
  286. auto *BitCastRHS = cast<BitCastInst>(RHS);
  287. Value *VecB = BitCastRHS->getOperand(0);
  288. assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
  289. // tiledpbssd.scalarize.rows.header:
  290. // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
  291. // %tiledpbssd.scalarize.rows.latch ]
  292. // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
  293. // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
  294. B.SetInsertPoint(RowLoopHeader->getTerminator());
  295. PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
  296. VecCPhiRowLoop->addIncoming(VecC, Start);
  297. Value *VecZero = Constant::getNullValue(V256I32Ty);
  298. PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
  299. VecDPhiRowLoop->addIncoming(VecZero, Start);
  300. // tiledpbssd.scalarize.cols.header:
  301. // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
  302. // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
  303. // %tiledpbssd.scalarize.cols.latch ]
  304. // %vec.d.phi.col = phi <256 x i32> [
  305. // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
  306. // %tiledpbssd.scalarize.cols.latch ]
  307. // calculate idxc.
  308. B.SetInsertPoint(ColLoopHeader->getTerminator());
  309. PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col");
  310. VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
  311. PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col");
  312. VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
  313. Value *IdxC =
  314. B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
  315. // tiledpbssd.scalarize.inner.header:
  316. // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
  317. // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
  318. // %tiledpbssd.scalarize.inner.latch ]
  319. B.SetInsertPoint(InnerLoopHeader->getTerminator());
  320. PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
  321. VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
  322. B.SetInsertPoint(InnerBody->getTerminator());
  323. Value *IdxA =
  324. B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
  325. Value *IdxB =
  326. B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
  327. Value *NewVecC = nullptr;
  328. if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
  329. // tiledpbssd.scalarize.inner.body:
  330. // calculate idxa, idxb
  331. // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
  332. // %elta = extractelement <256 x i32> %veca, i16 %idxa
  333. // %eltav4i8 = bitcast i32 %elta to <4 x i8>
  334. // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
  335. // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
  336. // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
  337. // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
  338. // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
  339. // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
  340. // %neweltc = add i32 %elt, %acc
  341. // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
  342. // i16 %idxc
  343. FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
  344. FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
  345. Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
  346. Value *EltA = B.CreateExtractElement(VecA, IdxA);
  347. Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
  348. Value *EltB = B.CreateExtractElement(VecB, IdxB);
  349. Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
  350. Value *SEXTSubVecB = nullptr;
  351. Value *SEXTSubVecA = nullptr;
  352. switch (IntrID) {
  353. case Intrinsic::x86_tdpbssd_internal:
  354. SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
  355. SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
  356. break;
  357. case Intrinsic::x86_tdpbsud_internal:
  358. SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
  359. SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
  360. break;
  361. case Intrinsic::x86_tdpbusd_internal:
  362. SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
  363. SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
  364. break;
  365. case Intrinsic::x86_tdpbuud_internal:
  366. SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
  367. SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
  368. break;
  369. default:
  370. llvm_unreachable("Invalid intrinsic ID!");
  371. }
  372. Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
  373. Value *ResElt = B.CreateAdd(EltC, SubVecR);
  374. NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
  375. } else {
  376. // tiledpbf16ps.scalarize.inner.body:
  377. // calculate idxa, idxb, idxc
  378. // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
  379. // %eltcf32 = bitcast i32 %eltc to float
  380. // %elta = extractelement <256 x i32> %veca, i16 %idxa
  381. // %eltav2i16 = bitcast i32 %elta to <2 x i16>
  382. // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
  383. // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
  384. // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
  385. // x i32> <i32 2, i32 0, i32 3, i32 1>
  386. // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
  387. // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
  388. // i32> <i32 2, i32 0, i32 3, i32 1>
  389. // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
  390. // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
  391. // %acc = call float
  392. // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
  393. // %neweltc = bitcast float %acc to i32
  394. // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
  395. // i16 %idxc
  396. // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
  397. // i16 %idxc
  398. FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
  399. FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
  400. Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
  401. Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy());
  402. Value *EltA = B.CreateExtractElement(VecA, IdxA);
  403. Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
  404. Value *EltB = B.CreateExtractElement(VecB, IdxB);
  405. Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
  406. Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
  407. int ShuffleMask[4] = {2, 0, 3, 1};
  408. auto ShuffleArray = ArrayRef(ShuffleMask);
  409. Value *AV2F32 = B.CreateBitCast(
  410. B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
  411. Value *BV2F32 = B.CreateBitCast(
  412. B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
  413. Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32));
  414. Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
  415. NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
  416. }
  417. // tiledpbssd.scalarize.cols.latch:
  418. // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
  419. // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
  420. // i16 %idxc
  421. B.SetInsertPoint(ColLoopLatch->getTerminator());
  422. Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC);
  423. Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
  424. VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
  425. VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
  426. VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
  427. VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
  428. VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
  429. return NewVecD;
  430. }
  431. template <Intrinsic::ID IntrID>
  432. std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
  433. IntrID == Intrinsic::x86_tdpbsud_internal ||
  434. IntrID == Intrinsic::x86_tdpbusd_internal ||
  435. IntrID == Intrinsic::x86_tdpbuud_internal ||
  436. IntrID == Intrinsic::x86_tdpbf16ps_internal,
  437. bool>
  438. X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
  439. Value *M, *N, *K, *C, *A, *B;
  440. match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K),
  441. m_Value(C), m_Value(A), m_Value(B)));
  442. Instruction *InsertI = TileDP;
  443. IRBuilder<> PreBuilder(TileDP);
  444. PreBuilder.SetInsertPoint(TileDP);
  445. // We visit the loop with (m, n/4, k/4):
  446. // %n_dword = lshr i16 %n, 2
  447. // %k_dword = lshr i16 %k, 2
  448. Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
  449. Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
  450. BasicBlock *Start = InsertI->getParent();
  451. BasicBlock *End =
  452. SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
  453. IRBuilder<> Builder(TileDP);
  454. Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
  455. KDWord, C, A, B);
  456. // we cannot assume there always be bitcast after tiledpbssd. So we need to
  457. // insert one bitcast as required
  458. Builder.SetInsertPoint(End->getFirstNonPHI());
  459. Value *ResAMX =
  460. Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
  461. // Delete TileDP intrinsic and do some clean-up.
  462. for (Use &U : llvm::make_early_inc_range(TileDP->uses())) {
  463. Instruction *I = cast<Instruction>(U.getUser());
  464. Value *Vec;
  465. if (match(I, m_BitCast(m_Value(Vec)))) {
  466. I->replaceAllUsesWith(ResVec);
  467. I->eraseFromParent();
  468. }
  469. }
  470. TileDP->replaceAllUsesWith(ResAMX);
  471. TileDP->eraseFromParent();
  472. return true;
  473. }
  474. template <bool IsTileLoad>
  475. bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
  476. Value *M, *N, *Ptr, *Stride, *Tile;
  477. if (IsTileLoad)
  478. match(TileLoadStore,
  479. m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
  480. m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
  481. else
  482. match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
  483. m_Value(M), m_Value(N), m_Value(Ptr),
  484. m_Value(Stride), m_Value(Tile)));
  485. Instruction *InsertI = TileLoadStore;
  486. IRBuilder<> PreBuilder(TileLoadStore);
  487. PreBuilder.SetInsertPoint(TileLoadStore);
  488. Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
  489. Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
  490. BasicBlock *Start = InsertI->getParent();
  491. BasicBlock *End =
  492. SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
  493. IRBuilder<> Builder(TileLoadStore);
  494. Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
  495. Start, End, Builder, M, NDWord, Ptr, StrideDWord,
  496. IsTileLoad ? nullptr : Tile);
  497. if (IsTileLoad) {
  498. // we cannot assume there always be bitcast after tileload. So we need to
  499. // insert one bitcast as required
  500. Builder.SetInsertPoint(End->getFirstNonPHI());
  501. Value *ResAMX =
  502. Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
  503. // Delete tileloadd6 intrinsic and do some clean-up
  504. for (Use &U : llvm::make_early_inc_range(TileLoadStore->uses())) {
  505. Instruction *I = cast<Instruction>(U.getUser());
  506. Value *Vec;
  507. if (match(I, m_BitCast(m_Value(Vec)))) {
  508. I->replaceAllUsesWith(ResVec);
  509. I->eraseFromParent();
  510. }
  511. }
  512. TileLoadStore->replaceAllUsesWith(ResAMX);
  513. }
  514. TileLoadStore->eraseFromParent();
  515. return true;
  516. }
  517. bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
  518. IRBuilder<> Builder(TileZero);
  519. FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
  520. Value *VecZero = Constant::getNullValue(V256I32Ty);
  521. for (Use &U : llvm::make_early_inc_range(TileZero->uses())) {
  522. Instruction *I = cast<Instruction>(U.getUser());
  523. Value *Vec;
  524. if (match(I, m_BitCast(m_Value(Vec)))) {
  525. I->replaceAllUsesWith(VecZero);
  526. I->eraseFromParent();
  527. }
  528. }
  529. TileZero->eraseFromParent();
  530. return true;
  531. }
  532. bool X86LowerAMXIntrinsics::visit() {
  533. bool C = false;
  534. SmallVector<IntrinsicInst *, 8> WorkList;
  535. for (BasicBlock *BB : depth_first(&Func)) {
  536. for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
  537. if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
  538. switch (Inst->getIntrinsicID()) {
  539. case Intrinsic::x86_tdpbssd_internal:
  540. case Intrinsic::x86_tdpbsud_internal:
  541. case Intrinsic::x86_tdpbusd_internal:
  542. case Intrinsic::x86_tdpbuud_internal:
  543. case Intrinsic::x86_tileloadd64_internal:
  544. case Intrinsic::x86_tilestored64_internal:
  545. case Intrinsic::x86_tilezero_internal:
  546. case Intrinsic::x86_tdpbf16ps_internal:
  547. WorkList.push_back(Inst);
  548. break;
  549. default:
  550. break;
  551. }
  552. }
  553. }
  554. }
  555. for (auto *Inst : WorkList) {
  556. switch (Inst->getIntrinsicID()) {
  557. case Intrinsic::x86_tdpbssd_internal:
  558. C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
  559. break;
  560. case Intrinsic::x86_tdpbsud_internal:
  561. C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C;
  562. break;
  563. case Intrinsic::x86_tdpbusd_internal:
  564. C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C;
  565. break;
  566. case Intrinsic::x86_tdpbuud_internal:
  567. C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C;
  568. break;
  569. case Intrinsic::x86_tdpbf16ps_internal:
  570. C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
  571. break;
  572. case Intrinsic::x86_tileloadd64_internal:
  573. C = lowerTileLoadStore<true>(Inst) || C;
  574. break;
  575. case Intrinsic::x86_tilestored64_internal:
  576. C = lowerTileLoadStore<false>(Inst) || C;
  577. break;
  578. case Intrinsic::x86_tilezero_internal:
  579. C = lowerTileZero(Inst) || C;
  580. break;
  581. default:
  582. llvm_unreachable("invalid amx intrinsics!");
  583. }
  584. }
  585. return C;
  586. }
  587. namespace {
  588. class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
  589. public:
  590. static char ID;
  591. X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
  592. initializeX86LowerAMXIntrinsicsLegacyPassPass(
  593. *PassRegistry::getPassRegistry());
  594. }
  595. bool runOnFunction(Function &F) override {
  596. if (!X86ScalarizeAMX)
  597. return false;
  598. TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
  599. if (!F.hasFnAttribute(Attribute::OptimizeNone) &&
  600. TM->getOptLevel() != CodeGenOpt::None)
  601. return false;
  602. auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
  603. auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
  604. auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
  605. auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
  606. DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
  607. X86LowerAMXIntrinsics LAT(F, DTU, LI);
  608. return LAT.visit();
  609. }
  610. StringRef getPassName() const override { return "Lower AMX intrinsics"; }
  611. void getAnalysisUsage(AnalysisUsage &AU) const override {
  612. AU.addPreserved<DominatorTreeWrapperPass>();
  613. AU.addPreserved<LoopInfoWrapperPass>();
  614. AU.addRequired<TargetPassConfig>();
  615. }
  616. };
  617. } // namespace
  618. static const char PassName[] = "Lower AMX intrinsics";
  619. char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
  620. INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
  621. false, false)
  622. INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
  623. INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
  624. false, false)
  625. FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
  626. return new X86LowerAMXIntrinsicsLegacyPass();
  627. }