SwitchLoweringUtils.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. //===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file contains switch inst lowering optimizations and utilities for
  10. // codegen, so that it can be used for both SelectionDAG and GlobalISel.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/CodeGen/SwitchLoweringUtils.h"
  14. #include "llvm/CodeGen/FunctionLoweringInfo.h"
  15. #include "llvm/CodeGen/MachineJumpTableInfo.h"
  16. #include "llvm/CodeGen/TargetLowering.h"
  17. #include "llvm/Target/TargetMachine.h"
  18. using namespace llvm;
  19. using namespace SwitchCG;
  20. uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters,
  21. unsigned First, unsigned Last) {
  22. assert(Last >= First);
  23. const APInt &LowCase = Clusters[First].Low->getValue();
  24. const APInt &HighCase = Clusters[Last].High->getValue();
  25. assert(LowCase.getBitWidth() == HighCase.getBitWidth());
  26. // FIXME: A range of consecutive cases has 100% density, but only requires one
  27. // comparison to lower. We should discriminate against such consecutive ranges
  28. // in jump tables.
  29. return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
  30. }
  31. uint64_t
  32. SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
  33. unsigned First, unsigned Last) {
  34. assert(Last >= First);
  35. assert(TotalCases[Last] >= TotalCases[First]);
  36. uint64_t NumCases =
  37. TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
  38. return NumCases;
  39. }
  40. void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
  41. const SwitchInst *SI,
  42. MachineBasicBlock *DefaultMBB,
  43. ProfileSummaryInfo *PSI,
  44. BlockFrequencyInfo *BFI) {
  45. #ifndef NDEBUG
  46. // Clusters must be non-empty, sorted, and only contain Range clusters.
  47. assert(!Clusters.empty());
  48. for (CaseCluster &C : Clusters)
  49. assert(C.Kind == CC_Range);
  50. for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
  51. assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
  52. #endif
  53. assert(TLI && "TLI not set!");
  54. if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
  55. return;
  56. const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
  57. const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
  58. // Bail if not enough cases.
  59. const int64_t N = Clusters.size();
  60. if (N < 2 || N < MinJumpTableEntries)
  61. return;
  62. // Accumulated number of cases in each cluster and those prior to it.
  63. SmallVector<unsigned, 8> TotalCases(N);
  64. for (unsigned i = 0; i < N; ++i) {
  65. const APInt &Hi = Clusters[i].High->getValue();
  66. const APInt &Lo = Clusters[i].Low->getValue();
  67. TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
  68. if (i != 0)
  69. TotalCases[i] += TotalCases[i - 1];
  70. }
  71. uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
  72. uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
  73. assert(NumCases < UINT64_MAX / 100);
  74. assert(Range >= NumCases);
  75. // Cheap case: the whole range may be suitable for jump table.
  76. if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
  77. CaseCluster JTCluster;
  78. if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
  79. Clusters[0] = JTCluster;
  80. Clusters.resize(1);
  81. return;
  82. }
  83. }
  84. // The algorithm below is not suitable for -O0.
  85. if (TM->getOptLevel() == CodeGenOpt::None)
  86. return;
  87. // Split Clusters into minimum number of dense partitions. The algorithm uses
  88. // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
  89. // for the Case Statement'" (1994), but builds the MinPartitions array in
  90. // reverse order to make it easier to reconstruct the partitions in ascending
  91. // order. In the choice between two optimal partitionings, it picks the one
  92. // which yields more jump tables.
  93. // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
  94. SmallVector<unsigned, 8> MinPartitions(N);
  95. // LastElement[i] is the last element of the partition starting at i.
  96. SmallVector<unsigned, 8> LastElement(N);
  97. // PartitionsScore[i] is used to break ties when choosing between two
  98. // partitionings resulting in the same number of partitions.
  99. SmallVector<unsigned, 8> PartitionsScore(N);
  100. // For PartitionsScore, a small number of comparisons is considered as good as
  101. // a jump table and a single comparison is considered better than a jump
  102. // table.
  103. enum PartitionScores : unsigned {
  104. NoTable = 0,
  105. Table = 1,
  106. FewCases = 1,
  107. SingleCase = 2
  108. };
  109. // Base case: There is only one way to partition Clusters[N-1].
  110. MinPartitions[N - 1] = 1;
  111. LastElement[N - 1] = N - 1;
  112. PartitionsScore[N - 1] = PartitionScores::SingleCase;
  113. // Note: loop indexes are signed to avoid underflow.
  114. for (int64_t i = N - 2; i >= 0; i--) {
  115. // Find optimal partitioning of Clusters[i..N-1].
  116. // Baseline: Put Clusters[i] into a partition on its own.
  117. MinPartitions[i] = MinPartitions[i + 1] + 1;
  118. LastElement[i] = i;
  119. PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
  120. // Search for a solution that results in fewer partitions.
  121. for (int64_t j = N - 1; j > i; j--) {
  122. // Try building a partition from Clusters[i..j].
  123. Range = getJumpTableRange(Clusters, i, j);
  124. NumCases = getJumpTableNumCases(TotalCases, i, j);
  125. assert(NumCases < UINT64_MAX / 100);
  126. assert(Range >= NumCases);
  127. if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
  128. unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
  129. unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
  130. int64_t NumEntries = j - i + 1;
  131. if (NumEntries == 1)
  132. Score += PartitionScores::SingleCase;
  133. else if (NumEntries <= SmallNumberOfEntries)
  134. Score += PartitionScores::FewCases;
  135. else if (NumEntries >= MinJumpTableEntries)
  136. Score += PartitionScores::Table;
  137. // If this leads to fewer partitions, or to the same number of
  138. // partitions with better score, it is a better partitioning.
  139. if (NumPartitions < MinPartitions[i] ||
  140. (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
  141. MinPartitions[i] = NumPartitions;
  142. LastElement[i] = j;
  143. PartitionsScore[i] = Score;
  144. }
  145. }
  146. }
  147. }
  148. // Iterate over the partitions, replacing some with jump tables in-place.
  149. unsigned DstIndex = 0;
  150. for (unsigned First = 0, Last; First < N; First = Last + 1) {
  151. Last = LastElement[First];
  152. assert(Last >= First);
  153. assert(DstIndex <= First);
  154. unsigned NumClusters = Last - First + 1;
  155. CaseCluster JTCluster;
  156. if (NumClusters >= MinJumpTableEntries &&
  157. buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
  158. Clusters[DstIndex++] = JTCluster;
  159. } else {
  160. for (unsigned I = First; I <= Last; ++I)
  161. std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
  162. }
  163. }
  164. Clusters.resize(DstIndex);
  165. }
  166. bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
  167. unsigned First, unsigned Last,
  168. const SwitchInst *SI,
  169. MachineBasicBlock *DefaultMBB,
  170. CaseCluster &JTCluster) {
  171. assert(First <= Last);
  172. auto Prob = BranchProbability::getZero();
  173. unsigned NumCmps = 0;
  174. std::vector<MachineBasicBlock*> Table;
  175. DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
  176. // Initialize probabilities in JTProbs.
  177. for (unsigned I = First; I <= Last; ++I)
  178. JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
  179. for (unsigned I = First; I <= Last; ++I) {
  180. assert(Clusters[I].Kind == CC_Range);
  181. Prob += Clusters[I].Prob;
  182. const APInt &Low = Clusters[I].Low->getValue();
  183. const APInt &High = Clusters[I].High->getValue();
  184. NumCmps += (Low == High) ? 1 : 2;
  185. if (I != First) {
  186. // Fill the gap between this and the previous cluster.
  187. const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
  188. assert(PreviousHigh.slt(Low));
  189. uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
  190. for (uint64_t J = 0; J < Gap; J++)
  191. Table.push_back(DefaultMBB);
  192. }
  193. uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
  194. for (uint64_t J = 0; J < ClusterSize; ++J)
  195. Table.push_back(Clusters[I].MBB);
  196. JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
  197. }
  198. unsigned NumDests = JTProbs.size();
  199. if (TLI->isSuitableForBitTests(NumDests, NumCmps,
  200. Clusters[First].Low->getValue(),
  201. Clusters[Last].High->getValue(), *DL)) {
  202. // Clusters[First..Last] should be lowered as bit tests instead.
  203. return false;
  204. }
  205. // Create the MBB that will load from and jump through the table.
  206. // Note: We create it here, but it's not inserted into the function yet.
  207. MachineFunction *CurMF = FuncInfo.MF;
  208. MachineBasicBlock *JumpTableMBB =
  209. CurMF->CreateMachineBasicBlock(SI->getParent());
  210. // Add successors. Note: use table order for determinism.
  211. SmallPtrSet<MachineBasicBlock *, 8> Done;
  212. for (MachineBasicBlock *Succ : Table) {
  213. if (Done.count(Succ))
  214. continue;
  215. addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
  216. Done.insert(Succ);
  217. }
  218. JumpTableMBB->normalizeSuccProbs();
  219. unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
  220. ->createJumpTableIndex(Table);
  221. // Set up the jump table info.
  222. JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
  223. JumpTableHeader JTH(Clusters[First].Low->getValue(),
  224. Clusters[Last].High->getValue(), SI->getCondition(),
  225. nullptr, false);
  226. JTCases.emplace_back(std::move(JTH), std::move(JT));
  227. JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
  228. JTCases.size() - 1, Prob);
  229. return true;
  230. }
  231. void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
  232. const SwitchInst *SI) {
  233. // Partition Clusters into as few subsets as possible, where each subset has a
  234. // range that fits in a machine word and has <= 3 unique destinations.
  235. #ifndef NDEBUG
  236. // Clusters must be sorted and contain Range or JumpTable clusters.
  237. assert(!Clusters.empty());
  238. assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
  239. for (const CaseCluster &C : Clusters)
  240. assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
  241. for (unsigned i = 1; i < Clusters.size(); ++i)
  242. assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
  243. #endif
  244. // The algorithm below is not suitable for -O0.
  245. if (TM->getOptLevel() == CodeGenOpt::None)
  246. return;
  247. // If target does not have legal shift left, do not emit bit tests at all.
  248. EVT PTy = TLI->getPointerTy(*DL);
  249. if (!TLI->isOperationLegal(ISD::SHL, PTy))
  250. return;
  251. int BitWidth = PTy.getSizeInBits();
  252. const int64_t N = Clusters.size();
  253. // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
  254. SmallVector<unsigned, 8> MinPartitions(N);
  255. // LastElement[i] is the last element of the partition starting at i.
  256. SmallVector<unsigned, 8> LastElement(N);
  257. // FIXME: This might not be the best algorithm for finding bit test clusters.
  258. // Base case: There is only one way to partition Clusters[N-1].
  259. MinPartitions[N - 1] = 1;
  260. LastElement[N - 1] = N - 1;
  261. // Note: loop indexes are signed to avoid underflow.
  262. for (int64_t i = N - 2; i >= 0; --i) {
  263. // Find optimal partitioning of Clusters[i..N-1].
  264. // Baseline: Put Clusters[i] into a partition on its own.
  265. MinPartitions[i] = MinPartitions[i + 1] + 1;
  266. LastElement[i] = i;
  267. // Search for a solution that results in fewer partitions.
  268. // Note: the search is limited by BitWidth, reducing time complexity.
  269. for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
  270. // Try building a partition from Clusters[i..j].
  271. // Check the range.
  272. if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
  273. Clusters[j].High->getValue(), *DL))
  274. continue;
  275. // Check nbr of destinations and cluster types.
  276. // FIXME: This works, but doesn't seem very efficient.
  277. bool RangesOnly = true;
  278. BitVector Dests(FuncInfo.MF->getNumBlockIDs());
  279. for (int64_t k = i; k <= j; k++) {
  280. if (Clusters[k].Kind != CC_Range) {
  281. RangesOnly = false;
  282. break;
  283. }
  284. Dests.set(Clusters[k].MBB->getNumber());
  285. }
  286. if (!RangesOnly || Dests.count() > 3)
  287. break;
  288. // Check if it's a better partition.
  289. unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
  290. if (NumPartitions < MinPartitions[i]) {
  291. // Found a better partition.
  292. MinPartitions[i] = NumPartitions;
  293. LastElement[i] = j;
  294. }
  295. }
  296. }
  297. // Iterate over the partitions, replacing with bit-test clusters in-place.
  298. unsigned DstIndex = 0;
  299. for (unsigned First = 0, Last; First < N; First = Last + 1) {
  300. Last = LastElement[First];
  301. assert(First <= Last);
  302. assert(DstIndex <= First);
  303. CaseCluster BitTestCluster;
  304. if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
  305. Clusters[DstIndex++] = BitTestCluster;
  306. } else {
  307. size_t NumClusters = Last - First + 1;
  308. std::memmove(&Clusters[DstIndex], &Clusters[First],
  309. sizeof(Clusters[0]) * NumClusters);
  310. DstIndex += NumClusters;
  311. }
  312. }
  313. Clusters.resize(DstIndex);
  314. }
  315. bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
  316. unsigned First, unsigned Last,
  317. const SwitchInst *SI,
  318. CaseCluster &BTCluster) {
  319. assert(First <= Last);
  320. if (First == Last)
  321. return false;
  322. BitVector Dests(FuncInfo.MF->getNumBlockIDs());
  323. unsigned NumCmps = 0;
  324. for (int64_t I = First; I <= Last; ++I) {
  325. assert(Clusters[I].Kind == CC_Range);
  326. Dests.set(Clusters[I].MBB->getNumber());
  327. NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
  328. }
  329. unsigned NumDests = Dests.count();
  330. APInt Low = Clusters[First].Low->getValue();
  331. APInt High = Clusters[Last].High->getValue();
  332. assert(Low.slt(High));
  333. if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
  334. return false;
  335. APInt LowBound;
  336. APInt CmpRange;
  337. const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
  338. assert(TLI->rangeFitsInWord(Low, High, *DL) &&
  339. "Case range must fit in bit mask!");
  340. // Check if the clusters cover a contiguous range such that no value in the
  341. // range will jump to the default statement.
  342. bool ContiguousRange = true;
  343. for (int64_t I = First + 1; I <= Last; ++I) {
  344. if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
  345. ContiguousRange = false;
  346. break;
  347. }
  348. }
  349. if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
  350. // Optimize the case where all the case values fit in a word without having
  351. // to subtract minValue. In this case, we can optimize away the subtraction.
  352. LowBound = APInt::getZero(Low.getBitWidth());
  353. CmpRange = High;
  354. ContiguousRange = false;
  355. } else {
  356. LowBound = Low;
  357. CmpRange = High - Low;
  358. }
  359. CaseBitsVector CBV;
  360. auto TotalProb = BranchProbability::getZero();
  361. for (unsigned i = First; i <= Last; ++i) {
  362. // Find the CaseBits for this destination.
  363. unsigned j;
  364. for (j = 0; j < CBV.size(); ++j)
  365. if (CBV[j].BB == Clusters[i].MBB)
  366. break;
  367. if (j == CBV.size())
  368. CBV.push_back(
  369. CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
  370. CaseBits *CB = &CBV[j];
  371. // Update Mask, Bits and ExtraProb.
  372. uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
  373. uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
  374. assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
  375. CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
  376. CB->Bits += Hi - Lo + 1;
  377. CB->ExtraProb += Clusters[i].Prob;
  378. TotalProb += Clusters[i].Prob;
  379. }
  380. BitTestInfo BTI;
  381. llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
  382. // Sort by probability first, number of bits second, bit mask third.
  383. if (a.ExtraProb != b.ExtraProb)
  384. return a.ExtraProb > b.ExtraProb;
  385. if (a.Bits != b.Bits)
  386. return a.Bits > b.Bits;
  387. return a.Mask < b.Mask;
  388. });
  389. for (auto &CB : CBV) {
  390. MachineBasicBlock *BitTestBB =
  391. FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
  392. BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
  393. }
  394. BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
  395. SI->getCondition(), -1U, MVT::Other, false,
  396. ContiguousRange, nullptr, nullptr, std::move(BTI),
  397. TotalProb);
  398. BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
  399. BitTestCases.size() - 1, TotalProb);
  400. return true;
  401. }
  402. void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
  403. #ifndef NDEBUG
  404. for (const CaseCluster &CC : Clusters)
  405. assert(CC.Low == CC.High && "Input clusters must be single-case");
  406. #endif
  407. llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
  408. return a.Low->getValue().slt(b.Low->getValue());
  409. });
  410. // Merge adjacent clusters with the same destination.
  411. const unsigned N = Clusters.size();
  412. unsigned DstIndex = 0;
  413. for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
  414. CaseCluster &CC = Clusters[SrcIndex];
  415. const ConstantInt *CaseVal = CC.Low;
  416. MachineBasicBlock *Succ = CC.MBB;
  417. if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
  418. (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
  419. // If this case has the same successor and is a neighbour, merge it into
  420. // the previous cluster.
  421. Clusters[DstIndex - 1].High = CaseVal;
  422. Clusters[DstIndex - 1].Prob += CC.Prob;
  423. } else {
  424. std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
  425. sizeof(Clusters[SrcIndex]));
  426. }
  427. }
  428. Clusters.resize(DstIndex);
  429. }