AArch64CondBrTuning.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. //===-- AArch64CondBrTuning.cpp --- Conditional branch tuning for AArch64 -===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. /// \file
  9. /// This file contains a pass that transforms CBZ/CBNZ/TBZ/TBNZ instructions
  10. /// into a conditional branch (B.cond), when the NZCV flags can be set for
  11. /// "free". This is preferred on targets that have more flexibility when
  12. /// scheduling B.cond instructions as compared to CBZ/CBNZ/TBZ/TBNZ (assuming
  13. /// all other variables are equal). This can also reduce register pressure.
  14. ///
  15. /// A few examples:
  16. ///
  17. /// 1) add w8, w0, w1 -> cmn w0, w1 ; CMN is an alias of ADDS.
  18. /// cbz w8, .LBB_2 -> b.eq .LBB0_2
  19. ///
  20. /// 2) add w8, w0, w1 -> adds w8, w0, w1 ; w8 has multiple uses.
  21. /// cbz w8, .LBB1_2 -> b.eq .LBB1_2
  22. ///
  23. /// 3) sub w8, w0, w1 -> subs w8, w0, w1 ; w8 has multiple uses.
  24. /// tbz w8, #31, .LBB6_2 -> b.pl .LBB6_2
  25. ///
  26. //===----------------------------------------------------------------------===//
  27. #include "AArch64.h"
  28. #include "AArch64Subtarget.h"
  29. #include "llvm/CodeGen/MachineFunction.h"
  30. #include "llvm/CodeGen/MachineFunctionPass.h"
  31. #include "llvm/CodeGen/MachineInstrBuilder.h"
  32. #include "llvm/CodeGen/MachineRegisterInfo.h"
  33. #include "llvm/CodeGen/Passes.h"
  34. #include "llvm/CodeGen/TargetInstrInfo.h"
  35. #include "llvm/CodeGen/TargetRegisterInfo.h"
  36. #include "llvm/CodeGen/TargetSubtargetInfo.h"
  37. #include "llvm/Support/Debug.h"
  38. #include "llvm/Support/raw_ostream.h"
  39. using namespace llvm;
  40. #define DEBUG_TYPE "aarch64-cond-br-tuning"
  41. #define AARCH64_CONDBR_TUNING_NAME "AArch64 Conditional Branch Tuning"
  42. namespace {
  43. class AArch64CondBrTuning : public MachineFunctionPass {
  44. const AArch64InstrInfo *TII;
  45. const TargetRegisterInfo *TRI;
  46. MachineRegisterInfo *MRI;
  47. public:
  48. static char ID;
  49. AArch64CondBrTuning() : MachineFunctionPass(ID) {
  50. initializeAArch64CondBrTuningPass(*PassRegistry::getPassRegistry());
  51. }
  52. void getAnalysisUsage(AnalysisUsage &AU) const override;
  53. bool runOnMachineFunction(MachineFunction &MF) override;
  54. StringRef getPassName() const override { return AARCH64_CONDBR_TUNING_NAME; }
  55. private:
  56. MachineInstr *getOperandDef(const MachineOperand &MO);
  57. MachineInstr *convertToFlagSetting(MachineInstr &MI, bool IsFlagSetting,
  58. bool Is64Bit);
  59. MachineInstr *convertToCondBr(MachineInstr &MI);
  60. bool tryToTuneBranch(MachineInstr &MI, MachineInstr &DefMI);
  61. };
  62. } // end anonymous namespace
  63. char AArch64CondBrTuning::ID = 0;
  64. INITIALIZE_PASS(AArch64CondBrTuning, "aarch64-cond-br-tuning",
  65. AARCH64_CONDBR_TUNING_NAME, false, false)
  66. void AArch64CondBrTuning::getAnalysisUsage(AnalysisUsage &AU) const {
  67. AU.setPreservesCFG();
  68. MachineFunctionPass::getAnalysisUsage(AU);
  69. }
  70. MachineInstr *AArch64CondBrTuning::getOperandDef(const MachineOperand &MO) {
  71. if (!MO.getReg().isVirtual())
  72. return nullptr;
  73. return MRI->getUniqueVRegDef(MO.getReg());
  74. }
  75. MachineInstr *AArch64CondBrTuning::convertToFlagSetting(MachineInstr &MI,
  76. bool IsFlagSetting,
  77. bool Is64Bit) {
  78. // If this is already the flag setting version of the instruction (e.g., SUBS)
  79. // just make sure the implicit-def of NZCV isn't marked dead.
  80. if (IsFlagSetting) {
  81. for (MachineOperand &MO : MI.implicit_operands())
  82. if (MO.isReg() && MO.isDead() && MO.getReg() == AArch64::NZCV)
  83. MO.setIsDead(false);
  84. return &MI;
  85. }
  86. unsigned NewOpc = TII->convertToFlagSettingOpc(MI.getOpcode());
  87. Register NewDestReg = MI.getOperand(0).getReg();
  88. if (MRI->hasOneNonDBGUse(MI.getOperand(0).getReg()))
  89. NewDestReg = Is64Bit ? AArch64::XZR : AArch64::WZR;
  90. MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(),
  91. TII->get(NewOpc), NewDestReg);
  92. for (const MachineOperand &MO : llvm::drop_begin(MI.operands()))
  93. MIB.add(MO);
  94. return MIB;
  95. }
  96. MachineInstr *AArch64CondBrTuning::convertToCondBr(MachineInstr &MI) {
  97. AArch64CC::CondCode CC;
  98. MachineBasicBlock *TargetMBB = TII->getBranchDestBlock(MI);
  99. switch (MI.getOpcode()) {
  100. default:
  101. llvm_unreachable("Unexpected opcode!");
  102. case AArch64::CBZW:
  103. case AArch64::CBZX:
  104. CC = AArch64CC::EQ;
  105. break;
  106. case AArch64::CBNZW:
  107. case AArch64::CBNZX:
  108. CC = AArch64CC::NE;
  109. break;
  110. case AArch64::TBZW:
  111. case AArch64::TBZX:
  112. CC = AArch64CC::PL;
  113. break;
  114. case AArch64::TBNZW:
  115. case AArch64::TBNZX:
  116. CC = AArch64CC::MI;
  117. break;
  118. }
  119. return BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), TII->get(AArch64::Bcc))
  120. .addImm(CC)
  121. .addMBB(TargetMBB);
  122. }
  123. bool AArch64CondBrTuning::tryToTuneBranch(MachineInstr &MI,
  124. MachineInstr &DefMI) {
  125. // We don't want NZCV bits live across blocks.
  126. if (MI.getParent() != DefMI.getParent())
  127. return false;
  128. bool IsFlagSetting = true;
  129. unsigned MIOpc = MI.getOpcode();
  130. MachineInstr *NewCmp = nullptr, *NewBr = nullptr;
  131. switch (DefMI.getOpcode()) {
  132. default:
  133. return false;
  134. case AArch64::ADDWri:
  135. case AArch64::ADDWrr:
  136. case AArch64::ADDWrs:
  137. case AArch64::ADDWrx:
  138. case AArch64::ANDWri:
  139. case AArch64::ANDWrr:
  140. case AArch64::ANDWrs:
  141. case AArch64::BICWrr:
  142. case AArch64::BICWrs:
  143. case AArch64::SUBWri:
  144. case AArch64::SUBWrr:
  145. case AArch64::SUBWrs:
  146. case AArch64::SUBWrx:
  147. IsFlagSetting = false;
  148. [[fallthrough]];
  149. case AArch64::ADDSWri:
  150. case AArch64::ADDSWrr:
  151. case AArch64::ADDSWrs:
  152. case AArch64::ADDSWrx:
  153. case AArch64::ANDSWri:
  154. case AArch64::ANDSWrr:
  155. case AArch64::ANDSWrs:
  156. case AArch64::BICSWrr:
  157. case AArch64::BICSWrs:
  158. case AArch64::SUBSWri:
  159. case AArch64::SUBSWrr:
  160. case AArch64::SUBSWrs:
  161. case AArch64::SUBSWrx:
  162. switch (MIOpc) {
  163. default:
  164. llvm_unreachable("Unexpected opcode!");
  165. case AArch64::CBZW:
  166. case AArch64::CBNZW:
  167. case AArch64::TBZW:
  168. case AArch64::TBNZW:
  169. // Check to see if the TBZ/TBNZ is checking the sign bit.
  170. if ((MIOpc == AArch64::TBZW || MIOpc == AArch64::TBNZW) &&
  171. MI.getOperand(1).getImm() != 31)
  172. return false;
  173. // There must not be any instruction between DefMI and MI that clobbers or
  174. // reads NZCV.
  175. if (isNZCVTouchedInInstructionRange(DefMI, MI, TRI))
  176. return false;
  177. LLVM_DEBUG(dbgs() << " Replacing instructions:\n ");
  178. LLVM_DEBUG(DefMI.print(dbgs()));
  179. LLVM_DEBUG(dbgs() << " ");
  180. LLVM_DEBUG(MI.print(dbgs()));
  181. NewCmp = convertToFlagSetting(DefMI, IsFlagSetting, /*Is64Bit=*/false);
  182. NewBr = convertToCondBr(MI);
  183. break;
  184. }
  185. break;
  186. case AArch64::ADDXri:
  187. case AArch64::ADDXrr:
  188. case AArch64::ADDXrs:
  189. case AArch64::ADDXrx:
  190. case AArch64::ANDXri:
  191. case AArch64::ANDXrr:
  192. case AArch64::ANDXrs:
  193. case AArch64::BICXrr:
  194. case AArch64::BICXrs:
  195. case AArch64::SUBXri:
  196. case AArch64::SUBXrr:
  197. case AArch64::SUBXrs:
  198. case AArch64::SUBXrx:
  199. IsFlagSetting = false;
  200. [[fallthrough]];
  201. case AArch64::ADDSXri:
  202. case AArch64::ADDSXrr:
  203. case AArch64::ADDSXrs:
  204. case AArch64::ADDSXrx:
  205. case AArch64::ANDSXri:
  206. case AArch64::ANDSXrr:
  207. case AArch64::ANDSXrs:
  208. case AArch64::BICSXrr:
  209. case AArch64::BICSXrs:
  210. case AArch64::SUBSXri:
  211. case AArch64::SUBSXrr:
  212. case AArch64::SUBSXrs:
  213. case AArch64::SUBSXrx:
  214. switch (MIOpc) {
  215. default:
  216. llvm_unreachable("Unexpected opcode!");
  217. case AArch64::CBZX:
  218. case AArch64::CBNZX:
  219. case AArch64::TBZX:
  220. case AArch64::TBNZX: {
  221. // Check to see if the TBZ/TBNZ is checking the sign bit.
  222. if ((MIOpc == AArch64::TBZX || MIOpc == AArch64::TBNZX) &&
  223. MI.getOperand(1).getImm() != 63)
  224. return false;
  225. // There must not be any instruction between DefMI and MI that clobbers or
  226. // reads NZCV.
  227. if (isNZCVTouchedInInstructionRange(DefMI, MI, TRI))
  228. return false;
  229. LLVM_DEBUG(dbgs() << " Replacing instructions:\n ");
  230. LLVM_DEBUG(DefMI.print(dbgs()));
  231. LLVM_DEBUG(dbgs() << " ");
  232. LLVM_DEBUG(MI.print(dbgs()));
  233. NewCmp = convertToFlagSetting(DefMI, IsFlagSetting, /*Is64Bit=*/true);
  234. NewBr = convertToCondBr(MI);
  235. break;
  236. }
  237. }
  238. break;
  239. }
  240. (void)NewCmp; (void)NewBr;
  241. assert(NewCmp && NewBr && "Expected new instructions.");
  242. LLVM_DEBUG(dbgs() << " with instruction:\n ");
  243. LLVM_DEBUG(NewCmp->print(dbgs()));
  244. LLVM_DEBUG(dbgs() << " ");
  245. LLVM_DEBUG(NewBr->print(dbgs()));
  246. // If this was a flag setting version of the instruction, we use the original
  247. // instruction by just clearing the dead marked on the implicit-def of NCZV.
  248. // Therefore, we should not erase this instruction.
  249. if (!IsFlagSetting)
  250. DefMI.eraseFromParent();
  251. MI.eraseFromParent();
  252. return true;
  253. }
  254. bool AArch64CondBrTuning::runOnMachineFunction(MachineFunction &MF) {
  255. if (skipFunction(MF.getFunction()))
  256. return false;
  257. LLVM_DEBUG(
  258. dbgs() << "********** AArch64 Conditional Branch Tuning **********\n"
  259. << "********** Function: " << MF.getName() << '\n');
  260. TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());
  261. TRI = MF.getSubtarget().getRegisterInfo();
  262. MRI = &MF.getRegInfo();
  263. bool Changed = false;
  264. for (MachineBasicBlock &MBB : MF) {
  265. bool LocalChange = false;
  266. for (MachineInstr &MI : MBB.terminators()) {
  267. switch (MI.getOpcode()) {
  268. default:
  269. break;
  270. case AArch64::CBZW:
  271. case AArch64::CBZX:
  272. case AArch64::CBNZW:
  273. case AArch64::CBNZX:
  274. case AArch64::TBZW:
  275. case AArch64::TBZX:
  276. case AArch64::TBNZW:
  277. case AArch64::TBNZX:
  278. MachineInstr *DefMI = getOperandDef(MI.getOperand(0));
  279. LocalChange = (DefMI && tryToTuneBranch(MI, *DefMI));
  280. break;
  281. }
  282. // If the optimization was successful, we can't optimize any other
  283. // branches because doing so would clobber the NZCV flags.
  284. if (LocalChange) {
  285. Changed = true;
  286. break;
  287. }
  288. }
  289. }
  290. return Changed;
  291. }
  292. FunctionPass *llvm::createAArch64CondBrTuning() {
  293. return new AArch64CondBrTuning();
  294. }