MatrixUtils.h 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- MatrixUtils.h - Utilities to lower matrix intrinsics -----*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. // Utilities for generating tiled loops for matrix operations.
  15. //
  16. //===----------------------------------------------------------------------===//
  17. #ifndef LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H
  18. #define LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H
  19. #include "llvm/ADT/StringRef.h"
  20. namespace llvm {
  21. class DomTreeUpdater;
  22. class BasicBlock;
  23. class Value;
  24. class Loop;
  25. class LoopInfo;
  26. class IRBuilderBase;
  27. /// A helper struct to create IR loop nests for tiling in IR of the following
  28. /// form:
  29. /// for CurrentColumn = 0..NumColumns
  30. /// for CurrentRow = 0..NumRows
  31. /// for CurrentInner = 0..NumInner
  32. struct TileInfo {
  33. /// Number of rows of the matrix.
  34. unsigned NumRows;
  35. /// Number of columns of the matrix.
  36. unsigned NumColumns;
  37. /// Number of columns of the first matrix of a multiply /
  38. /// number of rows of the second matrix of a multiply.
  39. unsigned NumInner;
  40. /// Number of rows/columns in a tile.
  41. unsigned TileSize = -1;
  42. /// Start row of the current tile to compute.
  43. Value *CurrentRow;
  44. /// Start column of the current tile to compute.
  45. Value *CurrentCol;
  46. /// Current tile offset during the tile computation.
  47. Value *CurrentK;
  48. /// Header of the outermost loop iterating from 0..NumColumns.
  49. BasicBlock *ColumnLoopHeader = nullptr;
  50. /// Header of the second loop iterating from 0..NumRows.
  51. BasicBlock *RowLoopHeader = nullptr;
  52. /// Latch of the second loop iterating from 0..NumRows.
  53. BasicBlock *RowLoopLatch = nullptr;
  54. /// Header of the innermost loop iterating from 0..NumInner.
  55. BasicBlock *InnerLoopHeader = nullptr;
  56. /// Latch of the innermost loop iterating from 0..NumInner.
  57. BasicBlock *InnerLoopLatch = nullptr;
  58. TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner,
  59. unsigned TileSize)
  60. : NumRows(NumRows), NumColumns(NumColumns), NumInner(NumInner),
  61. TileSize(TileSize) {}
  62. /// Creates an IR loop nests for tiling of the form below. Returns the block
  63. /// for the inner loop body and sets {Column,Row,Inner}LoopHeader/Latch
  64. /// fields.
  65. ///
  66. /// for CurrentColumn = 0..NumColumns
  67. /// for CurrentRow = 0..NumRows
  68. /// for CurrentInner = 0..NumInner
  69. BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
  70. IRBuilderBase &B, DomTreeUpdater &DTU,
  71. LoopInfo &LI);
  72. private:
  73. /// Creates a new loop with header, body and latch blocks that iterates from
  74. /// [0, Bound). Updates \p Preheader to branch to the new header and uses \p
  75. /// Exit as exit block. Adds the new loop blocks to \L and applies dominator
  76. /// tree updates to \p DTU.
  77. static BasicBlock *CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
  78. Value *Bound, Value *Step, StringRef Name,
  79. IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
  80. LoopInfo &LI);
  81. };
  82. } // namespace llvm
  83. #endif
  84. #ifdef __GNUC__
  85. #pragma GCC diagnostic pop
  86. #endif