X86TileConfig.cpp 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. //===-- X86TileConfig.cpp - Tile Register Configure----------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. /// \file Pass to config the shape of AMX physical registers
  10. /// AMX register need to be configured before use. In X86PreTileConfig pass
  11. /// the pldtilecfg instruction is inserted, however at that time we don't
  12. /// know the shape of each physical tile registers, because the register
  13. /// allocation is not done yet. This pass runs after egister allocation
  14. /// pass. It collects the shape information of each physical tile register
  15. /// and store the shape in the stack slot that is allocated for load config
  16. /// to tile config register.
  17. //
  18. //===----------------------------------------------------------------------===//
  19. #include "X86.h"
  20. #include "X86InstrBuilder.h"
  21. #include "X86MachineFunctionInfo.h"
  22. #include "X86RegisterInfo.h"
  23. #include "X86Subtarget.h"
  24. #include "llvm/CodeGen/LiveIntervals.h"
  25. #include "llvm/CodeGen/MachineFrameInfo.h"
  26. #include "llvm/CodeGen/MachineFunctionPass.h"
  27. #include "llvm/CodeGen/MachineInstr.h"
  28. #include "llvm/CodeGen/MachineRegisterInfo.h"
  29. #include "llvm/CodeGen/Passes.h"
  30. #include "llvm/CodeGen/TargetInstrInfo.h"
  31. #include "llvm/CodeGen/TargetRegisterInfo.h"
  32. #include "llvm/CodeGen/TileShapeInfo.h"
  33. #include "llvm/CodeGen/VirtRegMap.h"
  34. #include "llvm/InitializePasses.h"
  35. using namespace llvm;
  36. #define DEBUG_TYPE "tileconfig"
  37. namespace {
  38. struct X86TileConfig : public MachineFunctionPass {
  39. X86TileConfig() : MachineFunctionPass(ID) {}
  40. /// Return the pass name.
  41. StringRef getPassName() const override { return "Tile Register Configure"; }
  42. /// X86TileConfig analysis usage.
  43. void getAnalysisUsage(AnalysisUsage &AU) const override {
  44. AU.setPreservesAll();
  45. AU.addRequired<VirtRegMap>();
  46. AU.addRequired<LiveIntervals>();
  47. MachineFunctionPass::getAnalysisUsage(AU);
  48. }
  49. /// Perform register allocation.
  50. bool runOnMachineFunction(MachineFunction &mf) override;
  51. MachineFunctionProperties getRequiredProperties() const override {
  52. return MachineFunctionProperties().set(
  53. MachineFunctionProperties::Property::NoPHIs);
  54. }
  55. static char ID;
  56. };
  57. } // end anonymous namespace
  58. char X86TileConfig::ID = 0;
  59. INITIALIZE_PASS_BEGIN(X86TileConfig, DEBUG_TYPE, "Tile Register Configure",
  60. false, false)
  61. INITIALIZE_PASS_DEPENDENCY(VirtRegMap)
  62. INITIALIZE_PASS_END(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", false,
  63. false)
  64. bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {
  65. const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
  66. const TargetRegisterInfo *TRI = ST.getRegisterInfo();
  67. const TargetInstrInfo *TII = ST.getInstrInfo();
  68. MachineRegisterInfo &MRI = MF.getRegInfo();
  69. LiveIntervals &LIS = getAnalysis<LiveIntervals>();
  70. VirtRegMap &VRM = getAnalysis<VirtRegMap>();
  71. if (VRM.isShapeMapEmpty())
  72. return false;
  73. int SS = INT_MAX;
  74. for (MachineBasicBlock &MBB : MF) {
  75. for (MachineInstr &MI : MBB) {
  76. if (MI.getOpcode() == X86::PLDTILECFGV) {
  77. SS = MI.getOperand(0).getIndex();
  78. break;
  79. }
  80. }
  81. if (SS != INT_MAX)
  82. break;
  83. }
  84. // Didn't find PLDTILECFGV, just return false;
  85. if (SS == INT_MAX)
  86. return false;
  87. // Try to find a point to insert MIs for constant shapes.
  88. // Here we are leveraging the palette id inserted in PreRA pass.
  89. unsigned ConstPos = 0;
  90. MachineInstr *ConstMI = nullptr;
  91. for (MachineInstr &MI : MF.front()) {
  92. if (MI.getOpcode() == X86::MOV8mi && SS == MI.getOperand(0).getIndex()) {
  93. ConstMI = &MI;
  94. break;
  95. }
  96. ++ConstPos;
  97. }
  98. assert(ConstMI && "Cannot find an insertion point");
  99. unsigned AMXRegNum = TRI->getRegClass(X86::TILERegClassID)->getNumRegs();
  100. SmallVector<Register, 8> Phys2Virt(AMXRegNum, 0);
  101. for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
  102. Register VirtReg = Register::index2VirtReg(I);
  103. if (MRI.reg_nodbg_empty(VirtReg))
  104. continue;
  105. if (MRI.getRegClass(VirtReg)->getID() != X86::TILERegClassID)
  106. continue;
  107. if (VRM.getPhys(VirtReg) == VirtRegMap::NO_PHYS_REG)
  108. continue;
  109. unsigned Index = VRM.getPhys(VirtReg) - X86::TMM0;
  110. if (!Phys2Virt[Index])
  111. Phys2Virt[Index] = VirtReg;
  112. }
  113. // Fill in the shape of each tile physical register.
  114. for (unsigned I = 0; I < AMXRegNum; ++I) {
  115. if (!Phys2Virt[I])
  116. continue;
  117. DebugLoc DL;
  118. bool IsRow = true;
  119. MachineInstr *NewMI = nullptr;
  120. ShapeT Shape = VRM.getShape(Phys2Virt[I]);
  121. for (auto &R : {Shape.getRow()->getReg(), Shape.getCol()->getReg()}) {
  122. // Here is the data format for the tile config.
  123. // 0 palette
  124. // 1 start_row
  125. // 2-15 reserved, must be zero
  126. // 16-17 tile0.colsb Tile 0 bytes per row.
  127. // 18-19 tile1.colsb Tile 1 bytes per row.
  128. // 20-21 tile2.colsb Tile 2 bytes per row.
  129. // ... (sequence continues)
  130. // 30-31 tile7.colsb Tile 7 bytes per row.
  131. // 32-47 reserved, must be zero
  132. // 48 tile0.rows Tile 0 rows.
  133. // 49 tile1.rows Tile 1 rows.
  134. // 50 tile2.rows Tile 2 rows.
  135. // ... (sequence continues)
  136. // 55 tile7.rows Tile 7 rows.
  137. // 56-63 reserved, must be zero
  138. int64_t Imm = INT64_MAX;
  139. int Offset = IsRow ? 48 + I : 16 + I * 2;
  140. for (auto &DefMI : MRI.def_instructions(R)) {
  141. MachineBasicBlock &MBB = *DefMI.getParent();
  142. if (DefMI.isMoveImmediate()) {
  143. if (Imm != INT64_MAX) {
  144. // FIXME: We should handle this case in future.
  145. assert(Imm == DefMI.getOperand(1).getImm() &&
  146. "Cannot initialize with different shapes");
  147. continue;
  148. }
  149. Imm = DefMI.getOperand(1).getImm();
  150. NewMI = addFrameReference(
  151. BuildMI(MF.front(), ++ConstMI->getIterator(), DL,
  152. TII->get(IsRow ? X86::MOV8mi : X86::MOV16mi)),
  153. SS, Offset)
  154. .addImm(Imm);
  155. ConstMI = NewMI;
  156. LIS.InsertMachineInstrInMaps(*NewMI);
  157. } else {
  158. unsigned SubIdx = IsRow ? X86::sub_8bit : X86::sub_16bit;
  159. unsigned RegSize = TRI->getRegSizeInBits(*MRI.getRegClass(R));
  160. if ((IsRow && RegSize == 8) || (!IsRow && RegSize == 16))
  161. SubIdx = 0;
  162. auto Iter = DefMI.getIterator();
  163. if (&MBB == &MF.front() &&
  164. (unsigned)std::distance(MBB.instr_begin(), Iter) < ConstPos)
  165. Iter = ConstMI->getIterator();
  166. NewMI = addFrameReference(
  167. BuildMI(MBB, ++Iter, DL,
  168. TII->get(IsRow ? X86::MOV8mr : X86::MOV16mr)),
  169. SS, Offset)
  170. .addReg(R, 0, SubIdx);
  171. SlotIndex SIdx = LIS.InsertMachineInstrInMaps(*NewMI);
  172. LIS.extendToIndices(LIS.getInterval(R), {SIdx.getRegSlot()});
  173. }
  174. }
  175. IsRow = false;
  176. }
  177. }
  178. return true;
  179. }
  180. FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); }