#pragma once #ifdef __GNUC__ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" #endif //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the MatrixBuilder class, which is used as a convenient way // to lower matrix operations to LLVM IR. // //===----------------------------------------------------------------------===// #ifndef LLVM_IR_MATRIXBUILDER_H #define LLVM_IR_MATRIXBUILDER_H #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/Alignment.h" namespace llvm { class Function; class Twine; class Module; class MatrixBuilder { IRBuilderBase &B; Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } std::pair splatScalarOperandIfNeeded(Value *LHS, Value *RHS) { assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && "One of the operands must be a matrix (embedded in a vector)"); if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { assert(!isa(LHS->getType()) && "LHS Assumed to be fixed width"); RHS = B.CreateVectorSplat( cast(LHS->getType())->getElementCount(), RHS, "scalar.splat"); } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { assert(!isa(RHS->getType()) && "RHS Assumed to be fixed width"); LHS = B.CreateVectorSplat( cast(RHS->getType())->getElementCount(), LHS, "scalar.splat"); } return {LHS, RHS}; } public: MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {} /// Create a column major, strided matrix load. /// \p EltTy - Matrix element type /// \p DataPtr - Start address of the matrix read /// \p Rows - Number of rows in matrix (must be a constant) /// \p Columns - Number of columns in matrix (must be a constant) /// \p Stride - Space between columns CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name = "") { auto *RetType = FixedVectorType::get(EltTy, Rows * Columns); Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows), B.getInt32(Columns)}; Type *OverloadedTypes[] = {RetType, Stride->getType()}; Function *TheFn = Intrinsic::getDeclaration( getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes); CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); Attribute AlignAttr = Attribute::getWithAlignment(Call->getContext(), Alignment); Call->addParamAttr(0, AlignAttr); return Call; } /// Create a column major, strided matrix store. /// \p Matrix - Matrix to store /// \p Ptr - Pointer to write back to /// \p Stride - Space between columns CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name = "") { Value *Ops[] = {Matrix, Ptr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows), B.getInt32(Columns)}; Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()}; Function *TheFn = Intrinsic::getDeclaration( getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes); CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); Attribute AlignAttr = Attribute::getWithAlignment(Call->getContext(), Alignment); Call->addParamAttr(1, AlignAttr); return Call; } /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows /// rows and \p Columns columns. CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name = "") { auto *OpType = cast(Matrix->getType()); auto *ReturnType = FixedVectorType::get(OpType->getElementType(), Rows * Columns); Type *OverloadedTypes[] = {ReturnType}; Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)}; Function *TheFn = Intrinsic::getDeclaration( getModule(), Intrinsic::matrix_transpose, OverloadedTypes); return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); } /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p /// RHS. CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name = "") { auto *LHSType = cast(LHS->getType()); auto *RHSType = cast(RHS->getType()); auto *ReturnType = FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns); Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns), B.getInt32(RHSColumns)}; Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType}; Function *TheFn = Intrinsic::getDeclaration( getModule(), Intrinsic::matrix_multiply, OverloadedTypes); return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); } /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p /// ColumnIdx). Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, Value *ColumnIdx, unsigned NumRows) { return B.CreateInsertElement( Matrix, NewVal, B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get( ColumnIdx->getType(), NumRows)), RowIdx)); } /// Add matrixes \p LHS and \p RHS. Support both integer and floating point /// matrixes. Value *CreateAdd(Value *LHS, Value *RHS) { assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { assert(!isa(LHS->getType()) && "LHS Assumed to be fixed width"); RHS = B.CreateVectorSplat( cast(LHS->getType())->getElementCount(), RHS, "scalar.splat"); } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { assert(!isa(RHS->getType()) && "RHS Assumed to be fixed width"); LHS = B.CreateVectorSplat( cast(RHS->getType())->getElementCount(), LHS, "scalar.splat"); } return cast(LHS->getType()) ->getElementType() ->isFloatingPointTy() ? B.CreateFAdd(LHS, RHS) : B.CreateAdd(LHS, RHS); } /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating /// point matrixes. Value *CreateSub(Value *LHS, Value *RHS) { assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { assert(!isa(LHS->getType()) && "LHS Assumed to be fixed width"); RHS = B.CreateVectorSplat( cast(LHS->getType())->getElementCount(), RHS, "scalar.splat"); } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { assert(!isa(RHS->getType()) && "RHS Assumed to be fixed width"); LHS = B.CreateVectorSplat( cast(RHS->getType())->getElementCount(), LHS, "scalar.splat"); } return cast(LHS->getType()) ->getElementType() ->isFloatingPointTy() ? B.CreateFSub(LHS, RHS) : B.CreateSub(LHS, RHS); } /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p /// RHS. Value *CreateScalarMultiply(Value *LHS, Value *RHS) { std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); if (LHS->getType()->getScalarType()->isFloatingPointTy()) return B.CreateFMul(LHS, RHS); return B.CreateMul(LHS, RHS); } /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p /// IsUnsigned indicates whether UDiv or SDiv should be used. Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) { assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()); assert(!isa(LHS->getType()) && "LHS Assumed to be fixed width"); RHS = B.CreateVectorSplat(cast(LHS->getType())->getElementCount(), RHS, "scalar.splat"); return cast(LHS->getType()) ->getElementType() ->isFloatingPointTy() ? B.CreateFDiv(LHS, RHS) : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS)); } /// Create an assumption that \p Idx is less than \p NumElements. void CreateIndexAssumption(Value *Idx, unsigned NumElements, Twine const &Name = "") { Value *NumElts = B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements); auto *Cmp = B.CreateICmpULT(Idx, NumElts); if (isa(Cmp)) assert(cast(Cmp)->isOne() && "Index must be valid!"); else B.CreateAssumption(Cmp); } /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from /// a matrix with \p NumRows embedded in a vector. Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows, Twine const &Name = "") { unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(), ColumnIdx->getType()->getScalarSizeInBits()); Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth); RowIdx = B.CreateZExt(RowIdx, IntTy); ColumnIdx = B.CreateZExt(ColumnIdx, IntTy); Value *NumRowsV = B.getIntN(MaxWidth, NumRows); return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx); } }; } // end namespace llvm #endif // LLVM_IR_MATRIXBUILDER_H #ifdef __GNUC__ #pragma GCC diagnostic pop #endif