MatrixUtils.cpp 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. //===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Utilities for generating tiled loops for matrix operations.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "llvm/Transforms/Utils/MatrixUtils.h"
  13. #include "llvm/Analysis/DomTreeUpdater.h"
  14. #include "llvm/Analysis/LoopInfo.h"
  15. #include "llvm/IR/BasicBlock.h"
  16. #include "llvm/IR/Dominators.h"
  17. #include "llvm/IR/IRBuilder.h"
  18. #include "llvm/IR/Type.h"
  19. using namespace llvm;
  20. BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
  21. Value *Bound, Value *Step, StringRef Name,
  22. IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
  23. LoopInfo &LI) {
  24. LLVMContext &Ctx = Preheader->getContext();
  25. BasicBlock *Header = BasicBlock::Create(
  26. Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit);
  27. BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body",
  28. Header->getParent(), Exit);
  29. BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch",
  30. Header->getParent(), Exit);
  31. Type *I32Ty = Type::getInt64Ty(Ctx);
  32. BranchInst::Create(Body, Header);
  33. BranchInst::Create(Latch, Body);
  34. PHINode *IV =
  35. PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator());
  36. IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader);
  37. B.SetInsertPoint(Latch);
  38. Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
  39. Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
  40. BranchInst::Create(Header, Exit, Cond, Latch);
  41. IV->addIncoming(Inc, Latch);
  42. BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
  43. BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
  44. PreheaderBr->setSuccessor(0, Header);
  45. DTU.applyUpdatesPermissive({
  46. {DominatorTree::Delete, Preheader, Tmp},
  47. {DominatorTree::Insert, Header, Body},
  48. {DominatorTree::Insert, Body, Latch},
  49. {DominatorTree::Insert, Latch, Header},
  50. {DominatorTree::Insert, Latch, Exit},
  51. {DominatorTree::Insert, Preheader, Header},
  52. });
  53. L->addBasicBlockToLoop(Header, LI);
  54. L->addBasicBlockToLoop(Body, LI);
  55. L->addBasicBlockToLoop(Latch, LI);
  56. return Body;
  57. }
  58. // Creates the following loop nest skeleton:
  59. // for C = 0; C < NumColumns; C += TileSize
  60. // for R = 0; R < NumRows; R += TileSize
  61. // for K = 0; K < Inner ; K += TileSize
  62. BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
  63. IRBuilderBase &B, DomTreeUpdater &DTU,
  64. LoopInfo &LI) {
  65. Loop *ColumnLoopInfo = LI.AllocateLoop();
  66. Loop *RowLoopInfo = LI.AllocateLoop();
  67. Loop *KLoopInfo = LI.AllocateLoop();
  68. RowLoopInfo->addChildLoop(KLoopInfo);
  69. ColumnLoopInfo->addChildLoop(RowLoopInfo);
  70. if (Loop *ParentL = LI.getLoopFor(Start))
  71. ParentL->addChildLoop(ColumnLoopInfo);
  72. else
  73. LI.addTopLevelLoop(ColumnLoopInfo);
  74. BasicBlock *ColBody =
  75. CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
  76. "cols", B, DTU, ColumnLoopInfo, LI);
  77. ColumnLoop.Latch = ColBody->getSingleSuccessor();
  78. BasicBlock *RowBody =
  79. CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows),
  80. B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI);
  81. RowLoop.Latch = RowBody->getSingleSuccessor();
  82. BasicBlock *InnerBody =
  83. CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner),
  84. B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI);
  85. KLoop.Latch = InnerBody->getSingleSuccessor();
  86. ColumnLoop.Header = ColBody->getSinglePredecessor();
  87. RowLoop.Header = RowBody->getSinglePredecessor();
  88. KLoop.Header = InnerBody->getSinglePredecessor();
  89. RowLoop.Index = &*RowLoop.Header->begin();
  90. ColumnLoop.Index = &*ColumnLoop.Header->begin();
  91. KLoop.Index = &*KLoop.Header->begin();
  92. return InnerBody;
  93. }