AArch64PBQPRegAlloc.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. // This file contains the AArch64 / Cortex-A57 specific register allocation
  9. // constraints for use by the PBQP register allocator.
  10. //
  11. // It is essentially a transcription of what is contained in
  12. // AArch64A57FPLoadBalancing, which tries to use a balanced
  13. // mix of odd and even D-registers when performing a critical sequence of
  14. // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
  15. //===----------------------------------------------------------------------===//
  16. #include "AArch64PBQPRegAlloc.h"
  17. #include "AArch64.h"
  18. #include "AArch64RegisterInfo.h"
  19. #include "llvm/CodeGen/LiveIntervals.h"
  20. #include "llvm/CodeGen/MachineBasicBlock.h"
  21. #include "llvm/CodeGen/MachineFunction.h"
  22. #include "llvm/CodeGen/MachineRegisterInfo.h"
  23. #include "llvm/CodeGen/RegAllocPBQP.h"
  24. #include "llvm/Support/Debug.h"
  25. #include "llvm/Support/ErrorHandling.h"
  26. #include "llvm/Support/raw_ostream.h"
  27. #define DEBUG_TYPE "aarch64-pbqp"
  28. using namespace llvm;
  29. namespace {
  30. #ifndef NDEBUG
  31. bool isFPReg(unsigned reg) {
  32. return AArch64::FPR32RegClass.contains(reg) ||
  33. AArch64::FPR64RegClass.contains(reg) ||
  34. AArch64::FPR128RegClass.contains(reg);
  35. }
  36. #endif
  37. bool isOdd(unsigned reg) {
  38. switch (reg) {
  39. default:
  40. llvm_unreachable("Register is not from the expected class !");
  41. case AArch64::S1:
  42. case AArch64::S3:
  43. case AArch64::S5:
  44. case AArch64::S7:
  45. case AArch64::S9:
  46. case AArch64::S11:
  47. case AArch64::S13:
  48. case AArch64::S15:
  49. case AArch64::S17:
  50. case AArch64::S19:
  51. case AArch64::S21:
  52. case AArch64::S23:
  53. case AArch64::S25:
  54. case AArch64::S27:
  55. case AArch64::S29:
  56. case AArch64::S31:
  57. case AArch64::D1:
  58. case AArch64::D3:
  59. case AArch64::D5:
  60. case AArch64::D7:
  61. case AArch64::D9:
  62. case AArch64::D11:
  63. case AArch64::D13:
  64. case AArch64::D15:
  65. case AArch64::D17:
  66. case AArch64::D19:
  67. case AArch64::D21:
  68. case AArch64::D23:
  69. case AArch64::D25:
  70. case AArch64::D27:
  71. case AArch64::D29:
  72. case AArch64::D31:
  73. case AArch64::Q1:
  74. case AArch64::Q3:
  75. case AArch64::Q5:
  76. case AArch64::Q7:
  77. case AArch64::Q9:
  78. case AArch64::Q11:
  79. case AArch64::Q13:
  80. case AArch64::Q15:
  81. case AArch64::Q17:
  82. case AArch64::Q19:
  83. case AArch64::Q21:
  84. case AArch64::Q23:
  85. case AArch64::Q25:
  86. case AArch64::Q27:
  87. case AArch64::Q29:
  88. case AArch64::Q31:
  89. return true;
  90. case AArch64::S0:
  91. case AArch64::S2:
  92. case AArch64::S4:
  93. case AArch64::S6:
  94. case AArch64::S8:
  95. case AArch64::S10:
  96. case AArch64::S12:
  97. case AArch64::S14:
  98. case AArch64::S16:
  99. case AArch64::S18:
  100. case AArch64::S20:
  101. case AArch64::S22:
  102. case AArch64::S24:
  103. case AArch64::S26:
  104. case AArch64::S28:
  105. case AArch64::S30:
  106. case AArch64::D0:
  107. case AArch64::D2:
  108. case AArch64::D4:
  109. case AArch64::D6:
  110. case AArch64::D8:
  111. case AArch64::D10:
  112. case AArch64::D12:
  113. case AArch64::D14:
  114. case AArch64::D16:
  115. case AArch64::D18:
  116. case AArch64::D20:
  117. case AArch64::D22:
  118. case AArch64::D24:
  119. case AArch64::D26:
  120. case AArch64::D28:
  121. case AArch64::D30:
  122. case AArch64::Q0:
  123. case AArch64::Q2:
  124. case AArch64::Q4:
  125. case AArch64::Q6:
  126. case AArch64::Q8:
  127. case AArch64::Q10:
  128. case AArch64::Q12:
  129. case AArch64::Q14:
  130. case AArch64::Q16:
  131. case AArch64::Q18:
  132. case AArch64::Q20:
  133. case AArch64::Q22:
  134. case AArch64::Q24:
  135. case AArch64::Q26:
  136. case AArch64::Q28:
  137. case AArch64::Q30:
  138. return false;
  139. }
  140. }
  141. bool haveSameParity(unsigned reg1, unsigned reg2) {
  142. assert(isFPReg(reg1) && "Expecting an FP register for reg1");
  143. assert(isFPReg(reg2) && "Expecting an FP register for reg2");
  144. return isOdd(reg1) == isOdd(reg2);
  145. }
  146. }
  147. bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
  148. unsigned Ra) {
  149. if (Rd == Ra)
  150. return false;
  151. LiveIntervals &LIs = G.getMetadata().LIS;
  152. if (Register::isPhysicalRegister(Rd) || Register::isPhysicalRegister(Ra)) {
  153. LLVM_DEBUG(dbgs() << "Rd is a physical reg:"
  154. << Register::isPhysicalRegister(Rd) << '\n');
  155. LLVM_DEBUG(dbgs() << "Ra is a physical reg:"
  156. << Register::isPhysicalRegister(Ra) << '\n');
  157. return false;
  158. }
  159. PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
  160. PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
  161. const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
  162. &G.getNodeMetadata(node1).getAllowedRegs();
  163. const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
  164. &G.getNodeMetadata(node2).getAllowedRegs();
  165. PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
  166. // The edge does not exist. Create one with the appropriate interference
  167. // costs.
  168. if (edge == G.invalidEdgeId()) {
  169. const LiveInterval &ld = LIs.getInterval(Rd);
  170. const LiveInterval &la = LIs.getInterval(Ra);
  171. bool livesOverlap = ld.overlaps(la);
  172. PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
  173. vRaAllowed->size() + 1, 0);
  174. for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
  175. unsigned pRd = (*vRdAllowed)[i];
  176. for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
  177. unsigned pRa = (*vRaAllowed)[j];
  178. if (livesOverlap && TRI->regsOverlap(pRd, pRa))
  179. costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
  180. else
  181. costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
  182. }
  183. }
  184. G.addEdge(node1, node2, std::move(costs));
  185. return true;
  186. }
  187. if (G.getEdgeNode1Id(edge) == node2) {
  188. std::swap(node1, node2);
  189. std::swap(vRdAllowed, vRaAllowed);
  190. }
  191. // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
  192. PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
  193. for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
  194. unsigned pRd = (*vRdAllowed)[i];
  195. // Get the maximum cost (excluding unallocatable reg) for same parity
  196. // registers
  197. PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
  198. for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
  199. unsigned pRa = (*vRaAllowed)[j];
  200. if (haveSameParity(pRd, pRa))
  201. if (costs[i + 1][j + 1] !=
  202. std::numeric_limits<PBQP::PBQPNum>::infinity() &&
  203. costs[i + 1][j + 1] > sameParityMax)
  204. sameParityMax = costs[i + 1][j + 1];
  205. }
  206. // Ensure all registers with a different parity have a higher cost
  207. // than sameParityMax
  208. for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
  209. unsigned pRa = (*vRaAllowed)[j];
  210. if (!haveSameParity(pRd, pRa))
  211. if (sameParityMax > costs[i + 1][j + 1])
  212. costs[i + 1][j + 1] = sameParityMax + 1.0;
  213. }
  214. }
  215. G.updateEdgeCosts(edge, std::move(costs));
  216. return true;
  217. }
  218. void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
  219. unsigned Ra) {
  220. LiveIntervals &LIs = G.getMetadata().LIS;
  221. // Do some Chain management
  222. if (Chains.count(Ra)) {
  223. if (Rd != Ra) {
  224. LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI)
  225. << " to " << printReg(Rd, TRI) << '\n';);
  226. Chains.remove(Ra);
  227. Chains.insert(Rd);
  228. }
  229. } else {
  230. LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI)
  231. << '\n';);
  232. Chains.insert(Rd);
  233. }
  234. PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
  235. const LiveInterval &ld = LIs.getInterval(Rd);
  236. for (auto r : Chains) {
  237. // Skip self
  238. if (r == Rd)
  239. continue;
  240. const LiveInterval &lr = LIs.getInterval(r);
  241. if (ld.overlaps(lr)) {
  242. const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
  243. &G.getNodeMetadata(node1).getAllowedRegs();
  244. PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
  245. const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
  246. &G.getNodeMetadata(node2).getAllowedRegs();
  247. PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
  248. assert(edge != G.invalidEdgeId() &&
  249. "PBQP error ! The edge should exist !");
  250. LLVM_DEBUG(dbgs() << "Refining constraint !\n";);
  251. if (G.getEdgeNode1Id(edge) == node2) {
  252. std::swap(node1, node2);
  253. std::swap(vRdAllowed, vRrAllowed);
  254. }
  255. // Enforce that cost is higher with all other Chains of the same parity
  256. PBQP::Matrix costs(G.getEdgeCosts(edge));
  257. for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
  258. unsigned pRd = (*vRdAllowed)[i];
  259. // Get the maximum cost (excluding unallocatable reg) for all other
  260. // parity registers
  261. PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
  262. for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
  263. unsigned pRa = (*vRrAllowed)[j];
  264. if (!haveSameParity(pRd, pRa))
  265. if (costs[i + 1][j + 1] !=
  266. std::numeric_limits<PBQP::PBQPNum>::infinity() &&
  267. costs[i + 1][j + 1] > sameParityMax)
  268. sameParityMax = costs[i + 1][j + 1];
  269. }
  270. // Ensure all registers with same parity have a higher cost
  271. // than sameParityMax
  272. for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
  273. unsigned pRa = (*vRrAllowed)[j];
  274. if (haveSameParity(pRd, pRa))
  275. if (sameParityMax > costs[i + 1][j + 1])
  276. costs[i + 1][j + 1] = sameParityMax + 1.0;
  277. }
  278. }
  279. G.updateEdgeCosts(edge, std::move(costs));
  280. }
  281. }
  282. }
  283. static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
  284. const MachineInstr &MI) {
  285. const LiveInterval &LI = LIs.getInterval(reg);
  286. SlotIndex SI = LIs.getInstructionIndex(MI);
  287. return LI.expiredAt(SI);
  288. }
  289. void A57ChainingConstraint::apply(PBQPRAGraph &G) {
  290. const MachineFunction &MF = G.getMetadata().MF;
  291. LiveIntervals &LIs = G.getMetadata().LIS;
  292. TRI = MF.getSubtarget().getRegisterInfo();
  293. LLVM_DEBUG(MF.dump());
  294. for (const auto &MBB: MF) {
  295. Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
  296. for (const auto &MI: MBB) {
  297. // Forget Chains which have expired
  298. for (auto r : Chains) {
  299. SmallVector<unsigned, 8> toDel;
  300. if(regJustKilledBefore(LIs, r, MI)) {
  301. LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at ";
  302. MI.print(dbgs()););
  303. toDel.push_back(r);
  304. }
  305. while (!toDel.empty()) {
  306. Chains.remove(toDel.back());
  307. toDel.pop_back();
  308. }
  309. }
  310. switch (MI.getOpcode()) {
  311. case AArch64::FMSUBSrrr:
  312. case AArch64::FMADDSrrr:
  313. case AArch64::FNMSUBSrrr:
  314. case AArch64::FNMADDSrrr:
  315. case AArch64::FMSUBDrrr:
  316. case AArch64::FMADDDrrr:
  317. case AArch64::FNMSUBDrrr:
  318. case AArch64::FNMADDDrrr: {
  319. Register Rd = MI.getOperand(0).getReg();
  320. Register Ra = MI.getOperand(3).getReg();
  321. if (addIntraChainConstraint(G, Rd, Ra))
  322. addInterChainConstraint(G, Rd, Ra);
  323. break;
  324. }
  325. case AArch64::FMLAv2f32:
  326. case AArch64::FMLSv2f32: {
  327. Register Rd = MI.getOperand(0).getReg();
  328. addInterChainConstraint(G, Rd, Rd);
  329. break;
  330. }
  331. default:
  332. break;
  333. }
  334. }
  335. }
  336. }