X86VZeroUpper.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. //===- X86VZeroUpper.cpp - AVX vzeroupper instruction inserter ------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file defines the pass which inserts x86 AVX vzeroupper instructions
  10. // before calls to SSE encoded functions. This avoids transition latency
  11. // penalty when transferring control between AVX encoded instructions and old
  12. // SSE encoding mode.
  13. //
  14. //===----------------------------------------------------------------------===//
  15. #include "X86.h"
  16. #include "X86InstrInfo.h"
  17. #include "X86Subtarget.h"
  18. #include "llvm/ADT/SmallVector.h"
  19. #include "llvm/ADT/Statistic.h"
  20. #include "llvm/CodeGen/MachineBasicBlock.h"
  21. #include "llvm/CodeGen/MachineFunction.h"
  22. #include "llvm/CodeGen/MachineFunctionPass.h"
  23. #include "llvm/CodeGen/MachineInstr.h"
  24. #include "llvm/CodeGen/MachineInstrBuilder.h"
  25. #include "llvm/CodeGen/MachineOperand.h"
  26. #include "llvm/CodeGen/MachineRegisterInfo.h"
  27. #include "llvm/CodeGen/TargetInstrInfo.h"
  28. #include "llvm/CodeGen/TargetRegisterInfo.h"
  29. #include "llvm/IR/CallingConv.h"
  30. #include "llvm/IR/DebugLoc.h"
  31. #include "llvm/IR/Function.h"
  32. #include "llvm/Support/Debug.h"
  33. #include "llvm/Support/ErrorHandling.h"
  34. #include "llvm/Support/raw_ostream.h"
  35. #include <cassert>
  36. using namespace llvm;
  37. #define DEBUG_TYPE "x86-vzeroupper"
  38. static cl::opt<bool>
  39. UseVZeroUpper("x86-use-vzeroupper", cl::Hidden,
  40. cl::desc("Minimize AVX to SSE transition penalty"),
  41. cl::init(true));
  42. STATISTIC(NumVZU, "Number of vzeroupper instructions inserted");
  43. namespace {
  44. class VZeroUpperInserter : public MachineFunctionPass {
  45. public:
  46. VZeroUpperInserter() : MachineFunctionPass(ID) {}
  47. bool runOnMachineFunction(MachineFunction &MF) override;
  48. MachineFunctionProperties getRequiredProperties() const override {
  49. return MachineFunctionProperties().set(
  50. MachineFunctionProperties::Property::NoVRegs);
  51. }
  52. StringRef getPassName() const override { return "X86 vzeroupper inserter"; }
  53. private:
  54. void processBasicBlock(MachineBasicBlock &MBB);
  55. void insertVZeroUpper(MachineBasicBlock::iterator I,
  56. MachineBasicBlock &MBB);
  57. void addDirtySuccessor(MachineBasicBlock &MBB);
  58. using BlockExitState = enum { PASS_THROUGH, EXITS_CLEAN, EXITS_DIRTY };
  59. static const char* getBlockExitStateName(BlockExitState ST);
  60. // Core algorithm state:
  61. // BlockState - Each block is either:
  62. // - PASS_THROUGH: There are neither YMM/ZMM dirtying instructions nor
  63. // vzeroupper instructions in this block.
  64. // - EXITS_CLEAN: There is (or will be) a vzeroupper instruction in this
  65. // block that will ensure that YMM/ZMM is clean on exit.
  66. // - EXITS_DIRTY: An instruction in the block dirties YMM/ZMM and no
  67. // subsequent vzeroupper in the block clears it.
  68. //
  69. // AddedToDirtySuccessors - This flag is raised when a block is added to the
  70. // DirtySuccessors list to ensure that it's not
  71. // added multiple times.
  72. //
  73. // FirstUnguardedCall - Records the location of the first unguarded call in
  74. // each basic block that may need to be guarded by a
  75. // vzeroupper. We won't know whether it actually needs
  76. // to be guarded until we discover a predecessor that
  77. // is DIRTY_OUT.
  78. struct BlockState {
  79. BlockExitState ExitState = PASS_THROUGH;
  80. bool AddedToDirtySuccessors = false;
  81. MachineBasicBlock::iterator FirstUnguardedCall;
  82. BlockState() = default;
  83. };
  84. using BlockStateMap = SmallVector<BlockState, 8>;
  85. using DirtySuccessorsWorkList = SmallVector<MachineBasicBlock *, 8>;
  86. BlockStateMap BlockStates;
  87. DirtySuccessorsWorkList DirtySuccessors;
  88. bool EverMadeChange;
  89. bool IsX86INTR;
  90. const TargetInstrInfo *TII;
  91. static char ID;
  92. };
  93. } // end anonymous namespace
  94. char VZeroUpperInserter::ID = 0;
  95. FunctionPass *llvm::createX86IssueVZeroUpperPass() {
  96. return new VZeroUpperInserter();
  97. }
  98. #ifndef NDEBUG
  99. const char* VZeroUpperInserter::getBlockExitStateName(BlockExitState ST) {
  100. switch (ST) {
  101. case PASS_THROUGH: return "Pass-through";
  102. case EXITS_DIRTY: return "Exits-dirty";
  103. case EXITS_CLEAN: return "Exits-clean";
  104. }
  105. llvm_unreachable("Invalid block exit state.");
  106. }
  107. #endif
  108. /// VZEROUPPER cleans state that is related to Y/ZMM0-15 only.
  109. /// Thus, there is no need to check for Y/ZMM16 and above.
  110. static bool isYmmOrZmmReg(unsigned Reg) {
  111. return (Reg >= X86::YMM0 && Reg <= X86::YMM15) ||
  112. (Reg >= X86::ZMM0 && Reg <= X86::ZMM15);
  113. }
  114. static bool checkFnHasLiveInYmmOrZmm(MachineRegisterInfo &MRI) {
  115. for (std::pair<unsigned, unsigned> LI : MRI.liveins())
  116. if (isYmmOrZmmReg(LI.first))
  117. return true;
  118. return false;
  119. }
  120. static bool clobbersAllYmmAndZmmRegs(const MachineOperand &MO) {
  121. for (unsigned reg = X86::YMM0; reg <= X86::YMM15; ++reg) {
  122. if (!MO.clobbersPhysReg(reg))
  123. return false;
  124. }
  125. for (unsigned reg = X86::ZMM0; reg <= X86::ZMM15; ++reg) {
  126. if (!MO.clobbersPhysReg(reg))
  127. return false;
  128. }
  129. return true;
  130. }
  131. static bool hasYmmOrZmmReg(MachineInstr &MI) {
  132. for (const MachineOperand &MO : MI.operands()) {
  133. if (MI.isCall() && MO.isRegMask() && !clobbersAllYmmAndZmmRegs(MO))
  134. return true;
  135. if (!MO.isReg())
  136. continue;
  137. if (MO.isDebug())
  138. continue;
  139. if (isYmmOrZmmReg(MO.getReg()))
  140. return true;
  141. }
  142. return false;
  143. }
  144. /// Check if given call instruction has a RegMask operand.
  145. static bool callHasRegMask(MachineInstr &MI) {
  146. assert(MI.isCall() && "Can only be called on call instructions.");
  147. for (const MachineOperand &MO : MI.operands()) {
  148. if (MO.isRegMask())
  149. return true;
  150. }
  151. return false;
  152. }
  153. /// Insert a vzeroupper instruction before I.
  154. void VZeroUpperInserter::insertVZeroUpper(MachineBasicBlock::iterator I,
  155. MachineBasicBlock &MBB) {
  156. BuildMI(MBB, I, I->getDebugLoc(), TII->get(X86::VZEROUPPER));
  157. ++NumVZU;
  158. EverMadeChange = true;
  159. }
  160. /// Add MBB to the DirtySuccessors list if it hasn't already been added.
  161. void VZeroUpperInserter::addDirtySuccessor(MachineBasicBlock &MBB) {
  162. if (!BlockStates[MBB.getNumber()].AddedToDirtySuccessors) {
  163. DirtySuccessors.push_back(&MBB);
  164. BlockStates[MBB.getNumber()].AddedToDirtySuccessors = true;
  165. }
  166. }
  167. /// Loop over all of the instructions in the basic block, inserting vzeroupper
  168. /// instructions before function calls.
  169. void VZeroUpperInserter::processBasicBlock(MachineBasicBlock &MBB) {
  170. // Start by assuming that the block is PASS_THROUGH which implies no unguarded
  171. // calls.
  172. BlockExitState CurState = PASS_THROUGH;
  173. BlockStates[MBB.getNumber()].FirstUnguardedCall = MBB.end();
  174. for (MachineInstr &MI : MBB) {
  175. bool IsCall = MI.isCall();
  176. bool IsReturn = MI.isReturn();
  177. bool IsControlFlow = IsCall || IsReturn;
  178. // No need for vzeroupper before iret in interrupt handler function,
  179. // epilogue will restore YMM/ZMM registers if needed.
  180. if (IsX86INTR && IsReturn)
  181. continue;
  182. // An existing VZERO* instruction resets the state.
  183. if (MI.getOpcode() == X86::VZEROALL || MI.getOpcode() == X86::VZEROUPPER) {
  184. CurState = EXITS_CLEAN;
  185. continue;
  186. }
  187. // Shortcut: don't need to check regular instructions in dirty state.
  188. if (!IsControlFlow && CurState == EXITS_DIRTY)
  189. continue;
  190. if (hasYmmOrZmmReg(MI)) {
  191. // We found a ymm/zmm-using instruction; this could be an AVX/AVX512
  192. // instruction, or it could be control flow.
  193. CurState = EXITS_DIRTY;
  194. continue;
  195. }
  196. // Check for control-flow out of the current function (which might
  197. // indirectly execute SSE instructions).
  198. if (!IsControlFlow)
  199. continue;
  200. // If the call has no RegMask, skip it as well. It usually happens on
  201. // helper function calls (such as '_chkstk', '_ftol2') where standard
  202. // calling convention is not used (RegMask is not used to mark register
  203. // clobbered and register usage (def/implicit-def/use) is well-defined and
  204. // explicitly specified.
  205. if (IsCall && !callHasRegMask(MI))
  206. continue;
  207. // The VZEROUPPER instruction resets the upper 128 bits of YMM0-YMM15
  208. // registers. In addition, the processor changes back to Clean state, after
  209. // which execution of SSE instructions or AVX instructions has no transition
  210. // penalty. Add the VZEROUPPER instruction before any function call/return
  211. // that might execute SSE code.
  212. // FIXME: In some cases, we may want to move the VZEROUPPER into a
  213. // predecessor block.
  214. if (CurState == EXITS_DIRTY) {
  215. // After the inserted VZEROUPPER the state becomes clean again, but
  216. // other YMM/ZMM may appear before other subsequent calls or even before
  217. // the end of the BB.
  218. insertVZeroUpper(MI, MBB);
  219. CurState = EXITS_CLEAN;
  220. } else if (CurState == PASS_THROUGH) {
  221. // If this block is currently in pass-through state and we encounter a
  222. // call then whether we need a vzeroupper or not depends on whether this
  223. // block has successors that exit dirty. Record the location of the call,
  224. // and set the state to EXITS_CLEAN, but do not insert the vzeroupper yet.
  225. // It will be inserted later if necessary.
  226. BlockStates[MBB.getNumber()].FirstUnguardedCall = MI;
  227. CurState = EXITS_CLEAN;
  228. }
  229. }
  230. LLVM_DEBUG(dbgs() << "MBB #" << MBB.getNumber() << " exit state: "
  231. << getBlockExitStateName(CurState) << '\n');
  232. if (CurState == EXITS_DIRTY)
  233. for (MachineBasicBlock *Succ : MBB.successors())
  234. addDirtySuccessor(*Succ);
  235. BlockStates[MBB.getNumber()].ExitState = CurState;
  236. }
  237. /// Loop over all of the basic blocks, inserting vzeroupper instructions before
  238. /// function calls.
  239. bool VZeroUpperInserter::runOnMachineFunction(MachineFunction &MF) {
  240. if (!UseVZeroUpper)
  241. return false;
  242. const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
  243. if (!ST.hasAVX() || !ST.insertVZEROUPPER())
  244. return false;
  245. TII = ST.getInstrInfo();
  246. MachineRegisterInfo &MRI = MF.getRegInfo();
  247. EverMadeChange = false;
  248. IsX86INTR = MF.getFunction().getCallingConv() == CallingConv::X86_INTR;
  249. bool FnHasLiveInYmmOrZmm = checkFnHasLiveInYmmOrZmm(MRI);
  250. // Fast check: if the function doesn't use any ymm/zmm registers, we don't
  251. // need to insert any VZEROUPPER instructions. This is constant-time, so it
  252. // is cheap in the common case of no ymm/zmm use.
  253. bool YmmOrZmmUsed = FnHasLiveInYmmOrZmm;
  254. for (const auto *RC : {&X86::VR256RegClass, &X86::VR512_0_15RegClass}) {
  255. if (!YmmOrZmmUsed) {
  256. for (MCPhysReg R : *RC) {
  257. if (!MRI.reg_nodbg_empty(R)) {
  258. YmmOrZmmUsed = true;
  259. break;
  260. }
  261. }
  262. }
  263. }
  264. if (!YmmOrZmmUsed)
  265. return false;
  266. assert(BlockStates.empty() && DirtySuccessors.empty() &&
  267. "X86VZeroUpper state should be clear");
  268. BlockStates.resize(MF.getNumBlockIDs());
  269. // Process all blocks. This will compute block exit states, record the first
  270. // unguarded call in each block, and add successors of dirty blocks to the
  271. // DirtySuccessors list.
  272. for (MachineBasicBlock &MBB : MF)
  273. processBasicBlock(MBB);
  274. // If any YMM/ZMM regs are live-in to this function, add the entry block to
  275. // the DirtySuccessors list
  276. if (FnHasLiveInYmmOrZmm)
  277. addDirtySuccessor(MF.front());
  278. // Re-visit all blocks that are successors of EXITS_DIRTY blocks. Add
  279. // vzeroupper instructions to unguarded calls, and propagate EXITS_DIRTY
  280. // through PASS_THROUGH blocks.
  281. while (!DirtySuccessors.empty()) {
  282. MachineBasicBlock &MBB = *DirtySuccessors.back();
  283. DirtySuccessors.pop_back();
  284. BlockState &BBState = BlockStates[MBB.getNumber()];
  285. // MBB is a successor of a dirty block, so its first call needs to be
  286. // guarded.
  287. if (BBState.FirstUnguardedCall != MBB.end())
  288. insertVZeroUpper(BBState.FirstUnguardedCall, MBB);
  289. // If this successor was a pass-through block, then it is now dirty. Its
  290. // successors need to be added to the worklist (if they haven't been
  291. // already).
  292. if (BBState.ExitState == PASS_THROUGH) {
  293. LLVM_DEBUG(dbgs() << "MBB #" << MBB.getNumber()
  294. << " was Pass-through, is now Dirty-out.\n");
  295. for (MachineBasicBlock *Succ : MBB.successors())
  296. addDirtySuccessor(*Succ);
  297. }
  298. }
  299. BlockStates.clear();
  300. return EverMadeChange;
  301. }