SwitchLoweringUtils.h 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
  14. #define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
  15. #include "llvm/ADT/SmallVector.h"
  16. #include "llvm/CodeGen/ISDOpcodes.h"
  17. #include "llvm/CodeGen/SelectionDAGNodes.h"
  18. #include "llvm/IR/InstrTypes.h"
  19. #include "llvm/Support/BranchProbability.h"
  20. #include <vector>
  21. namespace llvm {
  22. class BlockFrequencyInfo;
  23. class ConstantInt;
  24. class FunctionLoweringInfo;
  25. class MachineBasicBlock;
  26. class ProfileSummaryInfo;
  27. class TargetLowering;
  28. class TargetMachine;
  29. namespace SwitchCG {
  30. enum CaseClusterKind {
  31. /// A cluster of adjacent case labels with the same destination, or just one
  32. /// case.
  33. CC_Range,
  34. /// A cluster of cases suitable for jump table lowering.
  35. CC_JumpTable,
  36. /// A cluster of cases suitable for bit test lowering.
  37. CC_BitTests
  38. };
  39. /// A cluster of case labels.
  40. struct CaseCluster {
  41. CaseClusterKind Kind;
  42. const ConstantInt *Low, *High;
  43. union {
  44. MachineBasicBlock *MBB;
  45. unsigned JTCasesIndex;
  46. unsigned BTCasesIndex;
  47. };
  48. BranchProbability Prob;
  49. static CaseCluster range(const ConstantInt *Low, const ConstantInt *High,
  50. MachineBasicBlock *MBB, BranchProbability Prob) {
  51. CaseCluster C;
  52. C.Kind = CC_Range;
  53. C.Low = Low;
  54. C.High = High;
  55. C.MBB = MBB;
  56. C.Prob = Prob;
  57. return C;
  58. }
  59. static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High,
  60. unsigned JTCasesIndex, BranchProbability Prob) {
  61. CaseCluster C;
  62. C.Kind = CC_JumpTable;
  63. C.Low = Low;
  64. C.High = High;
  65. C.JTCasesIndex = JTCasesIndex;
  66. C.Prob = Prob;
  67. return C;
  68. }
  69. static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High,
  70. unsigned BTCasesIndex, BranchProbability Prob) {
  71. CaseCluster C;
  72. C.Kind = CC_BitTests;
  73. C.Low = Low;
  74. C.High = High;
  75. C.BTCasesIndex = BTCasesIndex;
  76. C.Prob = Prob;
  77. return C;
  78. }
  79. };
  80. using CaseClusterVector = std::vector<CaseCluster>;
  81. using CaseClusterIt = CaseClusterVector::iterator;
  82. /// Sort Clusters and merge adjacent cases.
  83. void sortAndRangeify(CaseClusterVector &Clusters);
  84. struct CaseBits {
  85. uint64_t Mask = 0;
  86. MachineBasicBlock *BB = nullptr;
  87. unsigned Bits = 0;
  88. BranchProbability ExtraProb;
  89. CaseBits() = default;
  90. CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits,
  91. BranchProbability Prob)
  92. : Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {}
  93. };
  94. using CaseBitsVector = std::vector<CaseBits>;
  95. /// This structure is used to communicate between SelectionDAGBuilder and
  96. /// SDISel for the code generation of additional basic blocks needed by
  97. /// multi-case switch statements.
  98. struct CaseBlock {
  99. // For the GISel interface.
  100. struct PredInfoPair {
  101. CmpInst::Predicate Pred;
  102. // Set when no comparison should be emitted.
  103. bool NoCmp;
  104. };
  105. union {
  106. // The condition code to use for the case block's setcc node.
  107. // Besides the integer condition codes, this can also be SETTRUE, in which
  108. // case no comparison gets emitted.
  109. ISD::CondCode CC;
  110. struct PredInfoPair PredInfo;
  111. };
  112. // The LHS/MHS/RHS of the comparison to emit.
  113. // Emit by default LHS op RHS. MHS is used for range comparisons:
  114. // If MHS is not null: (LHS <= MHS) and (MHS <= RHS).
  115. const Value *CmpLHS, *CmpMHS, *CmpRHS;
  116. // The block to branch to if the setcc is true/false.
  117. MachineBasicBlock *TrueBB, *FalseBB;
  118. // The block into which to emit the code for the setcc and branches.
  119. MachineBasicBlock *ThisBB;
  120. /// The debug location of the instruction this CaseBlock was
  121. /// produced from.
  122. SDLoc DL;
  123. DebugLoc DbgLoc;
  124. // Branch weights.
  125. BranchProbability TrueProb, FalseProb;
  126. // Constructor for SelectionDAG.
  127. CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
  128. const Value *cmpmiddle, MachineBasicBlock *truebb,
  129. MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl,
  130. BranchProbability trueprob = BranchProbability::getUnknown(),
  131. BranchProbability falseprob = BranchProbability::getUnknown())
  132. : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
  133. TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
  134. TrueProb(trueprob), FalseProb(falseprob) {}
  135. // Constructor for GISel.
  136. CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs,
  137. const Value *cmprhs, const Value *cmpmiddle,
  138. MachineBasicBlock *truebb, MachineBasicBlock *falsebb,
  139. MachineBasicBlock *me, DebugLoc dl,
  140. BranchProbability trueprob = BranchProbability::getUnknown(),
  141. BranchProbability falseprob = BranchProbability::getUnknown())
  142. : PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle),
  143. CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me),
  144. DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob) {}
  145. };
  146. struct JumpTable {
  147. /// The virtual register containing the index of the jump table entry
  148. /// to jump to.
  149. unsigned Reg;
  150. /// The JumpTableIndex for this jump table in the function.
  151. unsigned JTI;
  152. /// The MBB into which to emit the code for the indirect jump.
  153. MachineBasicBlock *MBB;
  154. /// The MBB of the default bb, which is a successor of the range
  155. /// check MBB. This is when updating PHI nodes in successors.
  156. MachineBasicBlock *Default;
  157. JumpTable(unsigned R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D)
  158. : Reg(R), JTI(J), MBB(M), Default(D) {}
  159. };
  160. struct JumpTableHeader {
  161. APInt First;
  162. APInt Last;
  163. const Value *SValue;
  164. MachineBasicBlock *HeaderBB;
  165. bool Emitted;
  166. bool FallthroughUnreachable = false;
  167. JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H,
  168. bool E = false)
  169. : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H),
  170. Emitted(E) {}
  171. };
  172. using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>;
  173. struct BitTestCase {
  174. uint64_t Mask;
  175. MachineBasicBlock *ThisBB;
  176. MachineBasicBlock *TargetBB;
  177. BranchProbability ExtraProb;
  178. BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr,
  179. BranchProbability Prob)
  180. : Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {}
  181. };
  182. using BitTestInfo = SmallVector<BitTestCase, 3>;
  183. struct BitTestBlock {
  184. APInt First;
  185. APInt Range;
  186. const Value *SValue;
  187. unsigned Reg;
  188. MVT RegVT;
  189. bool Emitted;
  190. bool ContiguousRange;
  191. MachineBasicBlock *Parent;
  192. MachineBasicBlock *Default;
  193. BitTestInfo Cases;
  194. BranchProbability Prob;
  195. BranchProbability DefaultProb;
  196. bool FallthroughUnreachable = false;
  197. BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT, bool E,
  198. bool CR, MachineBasicBlock *P, MachineBasicBlock *D,
  199. BitTestInfo C, BranchProbability Pr)
  200. : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg),
  201. RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D),
  202. Cases(std::move(C)), Prob(Pr) {}
  203. };
  204. /// Return the range of values within a range.
  205. uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First,
  206. unsigned Last);
  207. /// Return the number of cases within a range.
  208. uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
  209. unsigned First, unsigned Last);
  210. struct SwitchWorkListItem {
  211. MachineBasicBlock *MBB;
  212. CaseClusterIt FirstCluster;
  213. CaseClusterIt LastCluster;
  214. const ConstantInt *GE;
  215. const ConstantInt *LT;
  216. BranchProbability DefaultProb;
  217. };
  218. using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>;
  219. class SwitchLowering {
  220. public:
  221. SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {}
  222. void init(const TargetLowering &tli, const TargetMachine &tm,
  223. const DataLayout &dl) {
  224. TLI = &tli;
  225. TM = &tm;
  226. DL = &dl;
  227. }
  228. /// Vector of CaseBlock structures used to communicate SwitchInst code
  229. /// generation information.
  230. std::vector<CaseBlock> SwitchCases;
  231. /// Vector of JumpTable structures used to communicate SwitchInst code
  232. /// generation information.
  233. std::vector<JumpTableBlock> JTCases;
  234. /// Vector of BitTestBlock structures used to communicate SwitchInst code
  235. /// generation information.
  236. std::vector<BitTestBlock> BitTestCases;
  237. void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI,
  238. MachineBasicBlock *DefaultMBB,
  239. ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI);
  240. bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First,
  241. unsigned Last, const SwitchInst *SI,
  242. MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster);
  243. void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI);
  244. /// Build a bit test cluster from Clusters[First..Last]. Returns false if it
  245. /// decides it's not a good idea.
  246. bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last,
  247. const SwitchInst *SI, CaseCluster &BTCluster);
  248. virtual void addSuccessorWithProb(
  249. MachineBasicBlock *Src, MachineBasicBlock *Dst,
  250. BranchProbability Prob = BranchProbability::getUnknown()) = 0;
  251. virtual ~SwitchLowering() = default;
  252. private:
  253. const TargetLowering *TLI;
  254. const TargetMachine *TM;
  255. const DataLayout *DL;
  256. FunctionLoweringInfo &FuncInfo;
  257. };
  258. } // namespace SwitchCG
  259. } // namespace llvm
  260. #endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
  261. #ifdef __GNUC__
  262. #pragma GCC diagnostic pop
  263. #endif