123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- //===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
- //
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
- // See https://llvm.org/LICENSE.txt for license information.
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- //
- //===----------------------------------------------------------------------===//
- //
- // Utilities for generating tiled loops for matrix operations.
- //
- //===----------------------------------------------------------------------===//
- #include "llvm/Transforms/Utils/MatrixUtils.h"
- #include "llvm/Analysis/DomTreeUpdater.h"
- #include "llvm/Analysis/LoopInfo.h"
- #include "llvm/IR/BasicBlock.h"
- #include "llvm/IR/Dominators.h"
- #include "llvm/IR/IRBuilder.h"
- #include "llvm/IR/Type.h"
- using namespace llvm;
- BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
- Value *Bound, Value *Step, StringRef Name,
- IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
- LoopInfo &LI) {
- LLVMContext &Ctx = Preheader->getContext();
- BasicBlock *Header = BasicBlock::Create(
- Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit);
- BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body",
- Header->getParent(), Exit);
- BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch",
- Header->getParent(), Exit);
- Type *I32Ty = Type::getInt64Ty(Ctx);
- BranchInst::Create(Body, Header);
- BranchInst::Create(Latch, Body);
- PHINode *IV =
- PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator());
- IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader);
- B.SetInsertPoint(Latch);
- Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
- Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
- BranchInst::Create(Header, Exit, Cond, Latch);
- IV->addIncoming(Inc, Latch);
- BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
- BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
- PreheaderBr->setSuccessor(0, Header);
- DTU.applyUpdatesPermissive({
- {DominatorTree::Delete, Preheader, Tmp},
- {DominatorTree::Insert, Header, Body},
- {DominatorTree::Insert, Body, Latch},
- {DominatorTree::Insert, Latch, Header},
- {DominatorTree::Insert, Latch, Exit},
- {DominatorTree::Insert, Preheader, Header},
- });
- L->addBasicBlockToLoop(Header, LI);
- L->addBasicBlockToLoop(Body, LI);
- L->addBasicBlockToLoop(Latch, LI);
- return Body;
- }
- // Creates the following loop nest skeleton:
- // for C = 0; C < NumColumns; C += TileSize
- // for R = 0; R < NumRows; R += TileSize
- // for K = 0; K < Inner ; K += TileSize
- BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
- IRBuilderBase &B, DomTreeUpdater &DTU,
- LoopInfo &LI) {
- Loop *ColumnLoopInfo = LI.AllocateLoop();
- Loop *RowLoopInfo = LI.AllocateLoop();
- Loop *KLoopInfo = LI.AllocateLoop();
- RowLoopInfo->addChildLoop(KLoopInfo);
- ColumnLoopInfo->addChildLoop(RowLoopInfo);
- if (Loop *ParentL = LI.getLoopFor(Start))
- ParentL->addChildLoop(ColumnLoopInfo);
- else
- LI.addTopLevelLoop(ColumnLoopInfo);
- BasicBlock *ColBody =
- CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
- "cols", B, DTU, ColumnLoopInfo, LI);
- ColumnLoop.Latch = ColBody->getSingleSuccessor();
- BasicBlock *RowBody =
- CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows),
- B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI);
- RowLoop.Latch = RowBody->getSingleSuccessor();
- BasicBlock *InnerBody =
- CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner),
- B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI);
- KLoop.Latch = InnerBody->getSingleSuccessor();
- ColumnLoop.Header = ColBody->getSinglePredecessor();
- RowLoop.Header = RowBody->getSinglePredecessor();
- KLoop.Header = InnerBody->getSinglePredecessor();
- RowLoop.Index = &*RowLoop.Header->begin();
- ColumnLoop.Index = &*ColumnLoop.Header->begin();
- KLoop.Index = &*KLoop.Header->begin();
- return InnerBody;
- }
|