X86PreTileConfig.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. /// \file Pass to pre-config the shapes of AMX registers
  10. /// AMX register needs to be configured before use. The shapes of AMX register
  11. /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
  12. ///
  13. /// The instruction ldtilecfg is used to config the shapes. It must be reachable
  14. /// for all variable shapes. ldtilecfg will be inserted more than once if we
  15. /// cannot find a dominating point for all AMX instructions.
  16. ///
  17. /// The configure register is caller saved according to ABI. We need to insert
  18. /// ldtilecfg again after the call instruction if callee clobbers any AMX
  19. /// registers.
  20. ///
  21. /// This pass calculates all points that ldtilecfg need to be inserted to and
  22. /// insert them. It reports error if the reachability conditions aren't met.
  23. //
  24. //===----------------------------------------------------------------------===//
  25. #include "X86.h"
  26. #include "X86InstrBuilder.h"
  27. #include "X86MachineFunctionInfo.h"
  28. #include "X86RegisterInfo.h"
  29. #include "X86Subtarget.h"
  30. #include "llvm/CodeGen/MachineFunctionPass.h"
  31. #include "llvm/CodeGen/MachineInstr.h"
  32. #include "llvm/CodeGen/MachineLoopInfo.h"
  33. #include "llvm/CodeGen/MachineModuleInfo.h"
  34. #include "llvm/CodeGen/MachineRegisterInfo.h"
  35. #include "llvm/CodeGen/Passes.h"
  36. #include "llvm/CodeGen/TargetInstrInfo.h"
  37. #include "llvm/CodeGen/TargetRegisterInfo.h"
  38. #include "llvm/InitializePasses.h"
  39. using namespace llvm;
  40. #define DEBUG_TYPE "tile-pre-config"
  41. static void emitErrorMsg(MachineFunction &MF) {
  42. SmallString<32> Str;
  43. Twine ErrorMsg =
  44. MF.getName() +
  45. ": Failed to config tile register, please define the shape earlier";
  46. LLVMContext &Context = MF.getMMI().getModule()->getContext();
  47. Context.emitError(ErrorMsg);
  48. }
  49. namespace {
  50. struct MIRef {
  51. MachineInstr *MI = nullptr;
  52. MachineBasicBlock *MBB = nullptr;
  53. // A virtual position for instruction that will be inserted after MI.
  54. size_t Pos = 0;
  55. MIRef() = default;
  56. MIRef(MachineBasicBlock *MBB) : MBB(MBB) {
  57. for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI();
  58. ++I, ++Pos)
  59. MI = &*I;
  60. }
  61. MIRef(MachineInstr *MI)
  62. : MI(MI), MBB(MI->getParent()),
  63. Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
  64. MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
  65. : MI(MI), MBB(MBB),
  66. Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
  67. MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos)
  68. : MI(MI), MBB(MBB), Pos(Pos) {}
  69. operator bool() const { return MBB != nullptr; }
  70. bool operator==(const MIRef &RHS) const {
  71. return MI == RHS.MI && MBB == RHS.MBB;
  72. }
  73. bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
  74. bool operator<(const MIRef &RHS) const {
  75. // Comparison between different BBs happens when inserting a MIRef into set.
  76. // So we compare MBB first to make the insertion happy.
  77. return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
  78. }
  79. bool operator>(const MIRef &RHS) const {
  80. // Comparison between different BBs happens when inserting a MIRef into set.
  81. // So we compare MBB first to make the insertion happy.
  82. return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos);
  83. }
  84. };
  85. struct BBInfo {
  86. MIRef FirstAMX;
  87. MIRef LastCall;
  88. bool HasAMXRegLiveIn = false;
  89. bool TileCfgForbidden = false;
  90. bool NeedTileCfgLiveIn = false;
  91. };
  92. class X86PreTileConfig : public MachineFunctionPass {
  93. MachineRegisterInfo *MRI;
  94. const MachineLoopInfo *MLI;
  95. SmallSet<MachineInstr *, 8> DefVisited;
  96. DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
  97. DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
  98. /// Check if the callee will clobber AMX registers.
  99. bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
  100. auto Iter = llvm::find_if(
  101. MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); });
  102. if (Iter == MI.operands_end())
  103. return false;
  104. UsableRegs.clearBitsInMask(Iter->getRegMask());
  105. return !UsableRegs.none();
  106. }
  107. /// Check if MI is AMX pseudo instruction.
  108. bool isAMXInstruction(MachineInstr &MI) {
  109. if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3)
  110. return false;
  111. MachineOperand &MO = MI.getOperand(0);
  112. // We can simply check if it is AMX instruction by its def.
  113. // But we should exclude old API which uses physical registers.
  114. if (MO.isReg() && MO.getReg().isVirtual() &&
  115. MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) {
  116. collectShapeInfo(MI);
  117. return true;
  118. }
  119. // PTILESTOREDV is the only exception that doesn't def a AMX register.
  120. return MI.getOpcode() == X86::PTILESTOREDV;
  121. }
  122. /// Check if it is an edge from loop bottom to loop head.
  123. bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) {
  124. if (!MLI->isLoopHeader(Header))
  125. return false;
  126. auto *ML = MLI->getLoopFor(Header);
  127. if (ML->contains(Bottom) && ML->isLoopLatch(Bottom))
  128. return true;
  129. return false;
  130. }
  131. /// Collect the shape def information for later use.
  132. void collectShapeInfo(MachineInstr &MI);
  133. /// Try to hoist shapes definded below AMX instructions.
  134. bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
  135. MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX;
  136. auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX);
  137. auto InsertPoint = FirstAMX.MI->getIterator();
  138. for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) {
  139. // Do not hoist instructions that access memory.
  140. if (I->MI->mayLoadOrStore())
  141. return false;
  142. for (auto &MO : I->MI->operands()) {
  143. if (MO.isDef())
  144. continue;
  145. // Do not hoist instructions if the sources' def under AMX instruction.
  146. // TODO: We can handle isMoveImmediate MI here.
  147. if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX)
  148. return false;
  149. // TODO: Maybe need more checks here.
  150. }
  151. MBB->insert(InsertPoint, I->MI->removeFromParent());
  152. }
  153. // We only need to mark the last shape in the BB now.
  154. Shapes.clear();
  155. Shapes.push_back(MIRef(&*--InsertPoint, MBB));
  156. return true;
  157. }
  158. public:
  159. X86PreTileConfig() : MachineFunctionPass(ID) {}
  160. /// Return the pass name.
  161. StringRef getPassName() const override {
  162. return "Tile Register Pre-configure";
  163. }
  164. /// X86PreTileConfig analysis usage.
  165. void getAnalysisUsage(AnalysisUsage &AU) const override {
  166. AU.setPreservesAll();
  167. AU.addRequired<MachineLoopInfo>();
  168. MachineFunctionPass::getAnalysisUsage(AU);
  169. }
  170. /// Clear MF related structures.
  171. void releaseMemory() override {
  172. ShapeBBs.clear();
  173. DefVisited.clear();
  174. BBVisitedInfo.clear();
  175. }
  176. /// Perform ldtilecfg instructions inserting.
  177. bool runOnMachineFunction(MachineFunction &MF) override;
  178. static char ID;
  179. };
  180. } // end anonymous namespace
  181. char X86PreTileConfig::ID = 0;
  182. INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
  183. "Tile Register Pre-configure", false, false)
  184. INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
  185. INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
  186. "Tile Register Pre-configure", false, false)
  187. void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
  188. auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
  189. MIRef MIR(MI, MBB);
  190. auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
  191. if (I == ShapeBBs[MBB].end() || *I != MIR)
  192. ShapeBBs[MBB].insert(I, MIR);
  193. };
  194. SmallVector<Register, 8> WorkList(
  195. {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});
  196. while (!WorkList.empty()) {
  197. Register R = WorkList.pop_back_val();
  198. MachineInstr *DefMI = MRI->getVRegDef(R);
  199. assert(DefMI && "R must has one define instruction");
  200. MachineBasicBlock *DefMBB = DefMI->getParent();
  201. if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second)
  202. continue;
  203. if (DefMI->isPHI()) {
  204. for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)
  205. if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB()))
  206. RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def.
  207. else
  208. WorkList.push_back(DefMI->getOperand(I).getReg());
  209. } else {
  210. RecordShape(DefMI, DefMBB);
  211. }
  212. }
  213. }
  214. bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
  215. const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
  216. const TargetInstrInfo *TII = ST.getInstrInfo();
  217. const TargetRegisterInfo *TRI = ST.getRegisterInfo();
  218. const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID);
  219. X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();
  220. BitVector AMXRegs(TRI->getNumRegs());
  221. for (unsigned I = 0; I < RC->getNumRegs(); I++)
  222. AMXRegs.set(X86::TMM0 + I);
  223. // Iterate MF to collect information.
  224. MRI = &MF.getRegInfo();
  225. MLI = &getAnalysis<MachineLoopInfo>();
  226. SmallSet<MIRef, 8> CfgNeedInsert;
  227. SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs;
  228. for (auto &MBB : MF) {
  229. size_t Pos = 0;
  230. for (auto &MI : MBB) {
  231. ++Pos;
  232. if (isAMXInstruction(MI)) {
  233. // If there's call before the AMX, we need to reload tile config.
  234. if (BBVisitedInfo[&MBB].LastCall)
  235. CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall);
  236. else // Otherwise, we need tile config to live in this BB.
  237. BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true;
  238. // Always record the first AMX in case there's shape def after it.
  239. if (!BBVisitedInfo[&MBB].FirstAMX)
  240. BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos);
  241. } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) {
  242. // Record the call only if the callee clobbers all AMX registers.
  243. BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos);
  244. }
  245. }
  246. if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) {
  247. if (&MBB == &MF.front())
  248. CfgNeedInsert.insert(MIRef(&MBB));
  249. else
  250. CfgLiveInBBs.push_back(&MBB);
  251. }
  252. if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
  253. for (auto *Succ : MBB.successors())
  254. if (!isLoopBackEdge(Succ, &MBB))
  255. BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
  256. }
  257. // Update NeedTileCfgLiveIn for predecessors.
  258. while (!CfgLiveInBBs.empty()) {
  259. MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val();
  260. for (auto *Pred : MBB->predecessors()) {
  261. if (BBVisitedInfo[Pred].LastCall) {
  262. CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall);
  263. } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) {
  264. BBVisitedInfo[Pred].NeedTileCfgLiveIn = true;
  265. if (Pred == &MF.front())
  266. CfgNeedInsert.insert(MIRef(Pred));
  267. else
  268. CfgLiveInBBs.push_back(Pred);
  269. }
  270. }
  271. }
  272. // There's no AMX instruction if we didn't find a tile config live in point.
  273. if (CfgNeedInsert.empty())
  274. return false;
  275. X86FI->setHasVirtualTileReg(true);
  276. // Avoid to insert ldtilecfg before any shape defs.
  277. SmallVector<MachineBasicBlock *, 8> WorkList;
  278. for (auto &I : ShapeBBs) {
  279. // TODO: We can hoist shapes across BBs here.
  280. if (BBVisitedInfo[I.first].HasAMXRegLiveIn) {
  281. // We are not able to config tile registers since the shape to config
  282. // is not defined yet. Emit error message and continue. The function
  283. // would not config tile registers.
  284. emitErrorMsg(MF);
  285. return false;
  286. }
  287. if (BBVisitedInfo[I.first].FirstAMX &&
  288. BBVisitedInfo[I.first].FirstAMX < I.second.back() &&
  289. !hoistShapesInBB(I.first, I.second)) {
  290. emitErrorMsg(MF);
  291. return false;
  292. }
  293. WorkList.push_back(I.first);
  294. }
  295. while (!WorkList.empty()) {
  296. MachineBasicBlock *MBB = WorkList.pop_back_val();
  297. for (auto *Pred : MBB->predecessors()) {
  298. if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) {
  299. BBVisitedInfo[Pred].TileCfgForbidden = true;
  300. WorkList.push_back(Pred);
  301. }
  302. }
  303. }
  304. DebugLoc DL;
  305. SmallSet<MIRef, 8> VisitedOrInserted;
  306. int SS = MF.getFrameInfo().CreateStackObject(
  307. ST.getTileConfigSize(), ST.getTileConfigAlignment(), false);
  308. // Try to insert for the tile config live in points.
  309. for (const auto &I : CfgNeedInsert) {
  310. SmallSet<MIRef, 8> InsertPoints;
  311. SmallVector<MIRef, 8> WorkList({I});
  312. while (!WorkList.empty()) {
  313. MIRef I = WorkList.pop_back_val();
  314. if (!VisitedOrInserted.count(I)) {
  315. if (!BBVisitedInfo[I.MBB].TileCfgForbidden) {
  316. // If the BB is all shapes reachable, stop sink and try to insert.
  317. InsertPoints.insert(I);
  318. } else {
  319. // Avoid the BB to be multi visited.
  320. VisitedOrInserted.insert(I);
  321. // Sink the inserting point along the chain with NeedTileCfgLiveIn =
  322. // true when MBB isn't all shapes reachable.
  323. for (auto *Succ : I.MBB->successors())
  324. if (BBVisitedInfo[Succ].NeedTileCfgLiveIn)
  325. WorkList.push_back(MIRef(Succ));
  326. }
  327. }
  328. }
  329. // A given point might be forked due to shape conditions are not met.
  330. for (MIRef I : InsertPoints) {
  331. // Make sure we insert ldtilecfg after the last shape def in MBB.
  332. if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
  333. I = ShapeBBs[I.MBB].back();
  334. // There're chances the MBB is sunk more than once. Record it to avoid
  335. // multi insert.
  336. if (VisitedOrInserted.insert(I).second) {
  337. auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin();
  338. addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::PLDTILECFGV)),
  339. SS);
  340. }
  341. }
  342. }
  343. // Zero stack slot.
  344. MachineBasicBlock &MBB = MF.front();
  345. MachineInstr *MI = &*MBB.begin();
  346. if (ST.hasAVX512()) {
  347. Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
  348. BuildMI(MBB, MI, DL, TII->get(X86::AVX512_512_SET0), Zmm);
  349. addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS)
  350. .addReg(Zmm);
  351. } else if (ST.hasAVX2()) {
  352. Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
  353. BuildMI(MBB, MI, DL, TII->get(X86::AVX_SET0), Ymm);
  354. addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS)
  355. .addReg(Ymm);
  356. addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32)
  357. .addReg(Ymm);
  358. } else {
  359. assert(ST.hasSSE2() && "AMX should assume SSE2 enabled");
  360. unsigned StoreOpc = ST.hasAVX() ? X86::VMOVUPSmr : X86::MOVUPSmr;
  361. Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
  362. BuildMI(MBB, MI, DL, TII->get(X86::V_SET0), Xmm);
  363. addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS).addReg(Xmm);
  364. addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 16)
  365. .addReg(Xmm);
  366. addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 32)
  367. .addReg(Xmm);
  368. addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 48)
  369. .addReg(Xmm);
  370. }
  371. // Fill in the palette first.
  372. addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1);
  373. return true;
  374. }
  375. FunctionPass *llvm::createX86PreTileConfigPass() {
  376. return new X86PreTileConfig();
  377. }