TileShapeInfo.h 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. /// \file Shape utility for AMX.
  15. /// AMX hardware requires to config the shape of tile data register before use.
  16. /// The 2D shape includes row and column. In AMX intrinsics interface the shape
  17. /// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd
  18. /// machine operand of AMX pseudo instructions. ShapeT class is to facilitate
  19. /// tile config and register allocator. The row and column are machine operand
  20. /// of AMX pseudo instructions.
  21. //
  22. //===----------------------------------------------------------------------===//
  23. #ifndef LLVM_CODEGEN_TILESHAPEINFO_H
  24. #define LLVM_CODEGEN_TILESHAPEINFO_H
  25. #include "llvm/ADT/DenseMapInfo.h"
  26. #include "llvm/CodeGen/MachineInstr.h"
  27. #include "llvm/CodeGen/MachineOperand.h"
  28. #include "llvm/CodeGen/MachineRegisterInfo.h"
  29. #include "llvm/CodeGen/Register.h"
  30. namespace llvm {
  31. class ShapeT {
  32. public:
  33. ShapeT(MachineOperand *Row, MachineOperand *Col,
  34. const MachineRegisterInfo *MRI = nullptr)
  35. : Row(Row), Col(Col) {
  36. if (MRI)
  37. deduceImm(MRI);
  38. }
  39. ShapeT()
  40. : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
  41. ColImm(InvalidImmShape) {}
  42. bool operator==(const ShapeT &Shape) {
  43. MachineOperand *R = Shape.Row;
  44. MachineOperand *C = Shape.Col;
  45. if (!R || !C)
  46. return false;
  47. if (!Row || !Col)
  48. return false;
  49. if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg())
  50. return true;
  51. if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape))
  52. return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm();
  53. return false;
  54. }
  55. bool operator!=(const ShapeT &Shape) { return !(*this == Shape); }
  56. MachineOperand *getRow() const { return Row; }
  57. MachineOperand *getCol() const { return Col; }
  58. int64_t getRowImm() const { return RowImm; }
  59. int64_t getColImm() const { return ColImm; }
  60. bool isValid() { return (Row != nullptr) && (Col != nullptr); }
  61. void deduceImm(const MachineRegisterInfo *MRI) {
  62. // All def must be the same value, otherwise it is invalid MIs.
  63. // Find the immediate.
  64. // TODO copy propagation.
  65. auto GetImm = [&](Register Reg) {
  66. int64_t Imm = InvalidImmShape;
  67. for (const MachineOperand &DefMO : MRI->def_operands(Reg)) {
  68. const auto *MI = DefMO.getParent();
  69. if (MI->isMoveImmediate()) {
  70. Imm = MI->getOperand(1).getImm();
  71. break;
  72. }
  73. }
  74. return Imm;
  75. };
  76. RowImm = GetImm(Row->getReg());
  77. ColImm = GetImm(Col->getReg());
  78. }
  79. private:
  80. static constexpr int64_t InvalidImmShape = -1;
  81. MachineOperand *Row;
  82. MachineOperand *Col;
  83. int64_t RowImm;
  84. int64_t ColImm;
  85. };
  86. } // namespace llvm
  87. #endif
  88. #ifdef __GNUC__
  89. #pragma GCC diagnostic pop
  90. #endif