123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- #pragma once
- #ifdef __GNUC__
- #pragma GCC diagnostic push
- #pragma GCC diagnostic ignored "-Wunused-parameter"
- #endif
- //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
- //
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
- // See https://llvm.org/LICENSE.txt for license information.
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- //
- //===----------------------------------------------------------------------===//
- //
- /// \file Shape utility for AMX.
- /// AMX hardware requires to config the shape of tile data register before use.
- /// The 2D shape includes row and column. In AMX intrinsics interface the shape
- /// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd
- /// machine operand of AMX pseudo instructions. ShapeT class is to facilitate
- /// tile config and register allocator. The row and column are machine operand
- /// of AMX pseudo instructions.
- //
- //===----------------------------------------------------------------------===//
- #ifndef LLVM_CODEGEN_TILESHAPEINFO_H
- #define LLVM_CODEGEN_TILESHAPEINFO_H
- #include "llvm/ADT/DenseMapInfo.h"
- #include "llvm/CodeGen/MachineInstr.h"
- #include "llvm/CodeGen/MachineOperand.h"
- #include "llvm/CodeGen/MachineRegisterInfo.h"
- #include "llvm/CodeGen/Register.h"
- namespace llvm {
- class ShapeT {
- public:
- ShapeT(MachineOperand *Row, MachineOperand *Col,
- const MachineRegisterInfo *MRI = nullptr)
- : Row(Row), Col(Col) {
- if (MRI)
- deduceImm(MRI);
- }
- ShapeT()
- : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
- ColImm(InvalidImmShape) {}
- bool operator==(const ShapeT &Shape) {
- MachineOperand *R = Shape.Row;
- MachineOperand *C = Shape.Col;
- if (!R || !C)
- return false;
- if (!Row || !Col)
- return false;
- if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg())
- return true;
- if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape))
- return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm();
- return false;
- }
- bool operator!=(const ShapeT &Shape) { return !(*this == Shape); }
- MachineOperand *getRow() const { return Row; }
- MachineOperand *getCol() const { return Col; }
- int64_t getRowImm() const { return RowImm; }
- int64_t getColImm() const { return ColImm; }
- bool isValid() { return (Row != nullptr) && (Col != nullptr); }
- void deduceImm(const MachineRegisterInfo *MRI) {
- // All def must be the same value, otherwise it is invalid MIs.
- // Find the immediate.
- // TODO copy propagation.
- auto GetImm = [&](Register Reg) {
- int64_t Imm = InvalidImmShape;
- for (const MachineOperand &DefMO : MRI->def_operands(Reg)) {
- const auto *MI = DefMO.getParent();
- if (MI->isMoveImmediate()) {
- Imm = MI->getOperand(1).getImm();
- break;
- }
- }
- return Imm;
- };
- RowImm = GetImm(Row->getReg());
- ColImm = GetImm(Col->getReg());
- }
- private:
- static constexpr int64_t InvalidImmShape = -1;
- MachineOperand *Row;
- MachineOperand *Col;
- int64_t RowImm;
- int64_t ColImm;
- };
- } // namespace llvm
- #endif
- #ifdef __GNUC__
- #pragma GCC diagnostic pop
- #endif
|