MatrixUtils.h 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- MatrixUtils.h - Utilities to lower matrix intrinsics -----*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. // Utilities for generating tiled loops for matrix operations.
  15. //
  16. //===----------------------------------------------------------------------===//
  17. #ifndef LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H
  18. #define LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H
  19. #include "llvm/ADT/StringRef.h"
  20. namespace llvm {
  21. class DomTreeUpdater;
  22. class BasicBlock;
  23. class Value;
  24. class Loop;
  25. class LoopInfo;
  26. class IRBuilderBase;
  27. /// A helper struct to create IR loop nests for tiling in IR of the following
  28. /// form:
  29. /// for ColumnLoop.Index = 0..NumColumns
  30. /// for RowLoop.Index = 0..NumRows
  31. /// for KLoop.Index = 0..NumInner
  32. struct TileInfo {
  33. /// Number of rows of the matrix.
  34. unsigned NumRows;
  35. /// Number of columns of the matrix.
  36. unsigned NumColumns;
  37. /// Number of columns of the first matrix of a multiply /
  38. /// number of rows of the second matrix of a multiply.
  39. unsigned NumInner;
  40. /// Number of rows/columns in a tile.
  41. unsigned TileSize = -1;
  42. /// Properties of a single loop used when generating the tiled loop nest.
  43. struct MatrixLoop {
  44. /// The index updated on every iteration.
  45. Value *Index = nullptr;
  46. /// The header and latch of the loop.
  47. BasicBlock *Header = nullptr;
  48. BasicBlock *Latch = nullptr;
  49. };
  50. /// The loop iterating on the rows.
  51. MatrixLoop RowLoop;
  52. /// The loop iterating on the columns.
  53. MatrixLoop ColumnLoop;
  54. /// The loop iterating on k (inner dimension).
  55. MatrixLoop KLoop;
  56. TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner,
  57. unsigned TileSize)
  58. : NumRows(NumRows), NumColumns(NumColumns), NumInner(NumInner),
  59. TileSize(TileSize) {}
  60. /// Creates an IR loop nests for tiling of the form below. Returns the block
  61. /// for the inner loop body and sets {Column,Row,Inner}LoopHeader/Latch
  62. /// fields.
  63. ///
  64. /// for ColumnLoop.Index = 0..NumColumns
  65. /// for RowLoop.Index = 0..NumRows
  66. /// for InnerLoop.Index = 0..NumInner
  67. BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
  68. IRBuilderBase &B, DomTreeUpdater &DTU,
  69. LoopInfo &LI);
  70. private:
  71. /// Creates a new loop with header, body and latch blocks that iterates from
  72. /// [0, Bound). Updates \p Preheader to branch to the new header and uses \p
  73. /// Exit as exit block. Adds the new loop blocks to \L and applies dominator
  74. /// tree updates to \p DTU.
  75. static BasicBlock *CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
  76. Value *Bound, Value *Step, StringRef Name,
  77. IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
  78. LoopInfo &LI);
  79. };
  80. } // namespace llvm
  81. #endif
  82. #ifdef __GNUC__
  83. #pragma GCC diagnostic pop
  84. #endif