Clustering.h 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. //===-- Clustering.h --------------------------------------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. ///
  9. /// \file
  10. /// Utilities to compute benchmark result clusters.
  11. ///
  12. //===----------------------------------------------------------------------===//
  13. #ifndef LLVM_TOOLS_LLVM_EXEGESIS_CLUSTERING_H
  14. #define LLVM_TOOLS_LLVM_EXEGESIS_CLUSTERING_H
  15. #include "BenchmarkResult.h"
  16. #include "llvm/ADT/Optional.h"
  17. #include "llvm/Support/Error.h"
  18. #include <limits>
  19. #include <vector>
  20. namespace llvm {
  21. namespace exegesis {
  22. class InstructionBenchmarkClustering {
  23. public:
  24. enum ModeE { Dbscan, Naive };
  25. // Clusters `Points` using DBSCAN with the given parameters. See the cc file
  26. // for more explanations on the algorithm.
  27. static Expected<InstructionBenchmarkClustering>
  28. create(const std::vector<InstructionBenchmark> &Points, ModeE Mode,
  29. size_t DbscanMinPts, double AnalysisClusteringEpsilon,
  30. const MCSubtargetInfo *SubtargetInfo = nullptr,
  31. const MCInstrInfo *InstrInfo = nullptr);
  32. class ClusterId {
  33. public:
  34. static ClusterId noise() { return ClusterId(kNoise); }
  35. static ClusterId error() { return ClusterId(kError); }
  36. static ClusterId makeValid(size_t Id, bool IsUnstable = false) {
  37. return ClusterId(Id, IsUnstable);
  38. }
  39. static ClusterId makeValidUnstable(size_t Id) {
  40. return makeValid(Id, /*IsUnstable=*/true);
  41. }
  42. ClusterId() : Id_(kUndef), IsUnstable_(false) {}
  43. // Compare id's, ignoring the 'unstability' bit.
  44. bool operator==(const ClusterId &O) const { return Id_ == O.Id_; }
  45. bool operator<(const ClusterId &O) const { return Id_ < O.Id_; }
  46. bool isValid() const { return Id_ <= kMaxValid; }
  47. bool isUnstable() const { return IsUnstable_; }
  48. bool isNoise() const { return Id_ == kNoise; }
  49. bool isError() const { return Id_ == kError; }
  50. bool isUndef() const { return Id_ == kUndef; }
  51. // Precondition: isValid().
  52. size_t getId() const {
  53. assert(isValid());
  54. return Id_;
  55. }
  56. private:
  57. ClusterId(size_t Id, bool IsUnstable = false)
  58. : Id_(Id), IsUnstable_(IsUnstable) {}
  59. static constexpr const size_t kMaxValid =
  60. (std::numeric_limits<size_t>::max() >> 1) - 4;
  61. static constexpr const size_t kNoise = kMaxValid + 1;
  62. static constexpr const size_t kError = kMaxValid + 2;
  63. static constexpr const size_t kUndef = kMaxValid + 3;
  64. size_t Id_ : (std::numeric_limits<size_t>::digits - 1);
  65. size_t IsUnstable_ : 1;
  66. };
  67. static_assert(sizeof(ClusterId) == sizeof(size_t), "should be a bit field.");
  68. struct Cluster {
  69. Cluster() = delete;
  70. explicit Cluster(const ClusterId &Id) : Id(Id) {}
  71. const ClusterId Id;
  72. // Indices of benchmarks within the cluster.
  73. std::vector<int> PointIndices;
  74. };
  75. ClusterId getClusterIdForPoint(size_t P) const {
  76. return ClusterIdForPoint_[P];
  77. }
  78. const std::vector<InstructionBenchmark> &getPoints() const { return Points_; }
  79. const Cluster &getCluster(ClusterId Id) const {
  80. assert(!Id.isUndef() && "unlabeled cluster");
  81. if (Id.isNoise()) {
  82. return NoiseCluster_;
  83. }
  84. if (Id.isError()) {
  85. return ErrorCluster_;
  86. }
  87. return Clusters_[Id.getId()];
  88. }
  89. const std::vector<Cluster> &getValidClusters() const { return Clusters_; }
  90. // Returns true if the given point is within a distance Epsilon of each other.
  91. bool isNeighbour(const std::vector<BenchmarkMeasure> &P,
  92. const std::vector<BenchmarkMeasure> &Q,
  93. const double EpsilonSquared_) const {
  94. double DistanceSquared = 0.0;
  95. for (size_t I = 0, E = P.size(); I < E; ++I) {
  96. const auto Diff = P[I].PerInstructionValue - Q[I].PerInstructionValue;
  97. DistanceSquared += Diff * Diff;
  98. }
  99. return DistanceSquared <= EpsilonSquared_;
  100. }
  101. private:
  102. InstructionBenchmarkClustering(
  103. const std::vector<InstructionBenchmark> &Points,
  104. double AnalysisClusteringEpsilonSquared);
  105. Error validateAndSetup();
  106. void clusterizeDbScan(size_t MinPts);
  107. void clusterizeNaive(const MCSubtargetInfo &SubtargetInfo,
  108. const MCInstrInfo &InstrInfo);
  109. // Stabilization is only needed if dbscan was used to clusterize.
  110. void stabilize(unsigned NumOpcodes);
  111. void rangeQuery(size_t Q, std::vector<size_t> &Scratchpad) const;
  112. bool areAllNeighbours(ArrayRef<size_t> Pts) const;
  113. const std::vector<InstructionBenchmark> &Points_;
  114. const double AnalysisClusteringEpsilonSquared_;
  115. int NumDimensions_ = 0;
  116. // ClusterForPoint_[P] is the cluster id for Points[P].
  117. std::vector<ClusterId> ClusterIdForPoint_;
  118. std::vector<Cluster> Clusters_;
  119. Cluster NoiseCluster_;
  120. Cluster ErrorCluster_;
  121. };
  122. class SchedClassClusterCentroid {
  123. public:
  124. const std::vector<PerInstructionStats> &getStats() const {
  125. return Representative;
  126. }
  127. std::vector<BenchmarkMeasure> getAsPoint() const;
  128. void addPoint(ArrayRef<BenchmarkMeasure> Point);
  129. bool validate(InstructionBenchmark::ModeE Mode) const;
  130. private:
  131. // Measurement stats for the points in the SchedClassCluster.
  132. std::vector<PerInstructionStats> Representative;
  133. };
  134. } // namespace exegesis
  135. } // namespace llvm
  136. #endif // LLVM_TOOLS_LLVM_EXEGESIS_CLUSTERING_H