tdigest.cpp 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. #include "tdigest.h"
  2. #include <library/cpp/tdigest/tdigest.pb.h>
  3. #include <cmath>
  4. // TODO: rewrite to https://github.com/tdunning/t-digest/blob/master/src/main/java/com/tdunning/math/stats/MergingDigest.java
  5. TDigest::TDigest(double delta, double k)
  6. : N(0)
  7. , Delta(delta)
  8. , K(k)
  9. {
  10. }
  11. TDigest::TDigest(double delta, double k, double firstValue)
  12. : TDigest(delta, k)
  13. {
  14. AddValue(firstValue);
  15. }
  16. TDigest::TDigest(TStringBuf serializedDigest)
  17. : N(0)
  18. {
  19. NTDigest::TDigest digest;
  20. Y_ABORT_UNLESS(digest.ParseFromArray(serializedDigest.data(), serializedDigest.size()));
  21. Delta = digest.GetDelta();
  22. K = digest.GetK();
  23. for (int i = 0; i < digest.centroids_size(); ++i) {
  24. const NTDigest::TDigest::TCentroid& centroid = digest.centroids(i);
  25. Update(centroid.GetMean(), centroid.GetWeight());
  26. }
  27. }
  28. TDigest::TDigest(const TDigest* digest1, const TDigest* digest2)
  29. : N(0)
  30. , Delta(std::min(digest1->Delta, digest2->Delta))
  31. , K(std::max(digest1->K, digest2->K))
  32. {
  33. Add(*digest1);
  34. Add(*digest2);
  35. }
  36. void TDigest::Add(const TDigest& otherDigest) {
  37. for (auto& it : otherDigest.Centroids)
  38. Update(it.Mean, it.Count);
  39. for (auto& it : otherDigest.Unmerged)
  40. Update(it.Mean, it.Count);
  41. }
  42. TDigest TDigest::operator+(const TDigest& other) {
  43. TDigest T(Delta, K);
  44. T.Add(*this);
  45. T.Add(other);
  46. return T;
  47. }
  48. TDigest& TDigest::operator+=(const TDigest& other) {
  49. Add(other);
  50. return *this;
  51. }
  52. void TDigest::AddCentroid(const TCentroid& centroid) {
  53. Unmerged.push_back(centroid);
  54. N += centroid.Count;
  55. }
  56. double TDigest::GetThreshold(double q) {
  57. return 4 * N * Delta * q * (1 - q);
  58. }
  59. void TDigest::MergeCentroid(TVector<TCentroid>& merged, double& sum, const TCentroid& centroid) {
  60. if (merged.empty()) {
  61. merged.push_back(centroid);
  62. sum += centroid.Count;
  63. return;
  64. }
  65. // Use quantile that has the tightest k
  66. double q1 = (sum - merged.back().Count * 0.5) / N;
  67. double q2 = (sum + centroid.Count * 0.5) / N;
  68. double k = GetThreshold(q1);
  69. double k2 = GetThreshold(q2);
  70. if (k > k2) {
  71. k = k2;
  72. }
  73. if (merged.back().Count + centroid.Count <= k) {
  74. merged.back().Update(centroid.Mean, centroid.Count);
  75. } else {
  76. merged.push_back(centroid);
  77. }
  78. sum += centroid.Count;
  79. }
  80. void TDigest::Update(double x, double w) {
  81. AddCentroid(TCentroid(x, w));
  82. if (Unmerged.size() >= K / Delta) {
  83. Compress();
  84. }
  85. }
  86. void TDigest::Compress() {
  87. if (Unmerged.empty())
  88. return;
  89. // Merge Centroids and Unmerged into Merged
  90. std::stable_sort(Unmerged.begin(), Unmerged.end());
  91. Merged.clear();
  92. double sum = 0;
  93. iter_t i = Centroids.begin();
  94. iter_t j = Unmerged.begin();
  95. while (i != Centroids.end() && j != Unmerged.end()) {
  96. if (i->Mean <= j->Mean) {
  97. MergeCentroid(Merged, sum, *i++);
  98. } else {
  99. MergeCentroid(Merged, sum, *j++);
  100. }
  101. }
  102. while (i != Centroids.end()) {
  103. MergeCentroid(Merged, sum, *i++);
  104. }
  105. while (j != Unmerged.end()) {
  106. MergeCentroid(Merged, sum, *j++);
  107. }
  108. swap(Centroids, Merged);
  109. Unmerged.clear();
  110. }
  111. void TDigest::Clear() {
  112. Centroids.clear();
  113. Unmerged.clear();
  114. N = 0;
  115. }
  116. void TDigest::AddValue(double value) {
  117. Update(value, 1);
  118. }
  119. double TDigest::GetPercentile(double percentile) {
  120. Compress();
  121. if (Centroids.empty())
  122. return 0.0;
  123. // This algorithm uses C=1/2 with 0.5 optimized away
  124. // See https://en.wikipedia.org/wiki/Percentile#First_Variant.2C
  125. double x = percentile * N;
  126. double sum = 0.0;
  127. double prev_x = 0;
  128. double prev_mean = Centroids.front().Mean;
  129. for (const auto& C : Centroids) {
  130. double current_x = sum + C.Count * 0.5;
  131. if (x <= current_x) {
  132. double k = (x - prev_x) / (current_x - prev_x);
  133. return prev_mean + k * (C.Mean - prev_mean);
  134. }
  135. sum += C.Count;
  136. prev_x = current_x;
  137. prev_mean = C.Mean;
  138. }
  139. return Centroids.back().Mean;
  140. }
  141. double TDigest::GetRank(double value) {
  142. Compress();
  143. if (Centroids.empty()) {
  144. return 0.0;
  145. }
  146. if (value < Centroids.front().Mean) {
  147. return 0.0;
  148. }
  149. if (value == Centroids.front().Mean) {
  150. return Centroids.front().Count * 0.5 / N;
  151. }
  152. double sum = 0.0;
  153. double prev_x = 0.0;
  154. double prev_mean = Centroids.front().Mean;
  155. for (const auto& C : Centroids) {
  156. double current_x = sum + C.Count * 0.5;
  157. if (value <= C.Mean) {
  158. double k = (value - prev_mean) / (C.Mean - prev_mean);
  159. return (prev_x + k * (current_x - prev_x)) / N;
  160. }
  161. sum += C.Count;
  162. prev_mean = C.Mean;
  163. prev_x = current_x;
  164. }
  165. return 1.0;
  166. }
  167. TString TDigest::Serialize() {
  168. Compress();
  169. NTDigest::TDigest digest;
  170. digest.SetDelta(Delta);
  171. digest.SetK(K);
  172. for (const auto& it : Centroids) {
  173. NTDigest::TDigest::TCentroid* centroid = digest.AddCentroids();
  174. centroid->SetMean(it.Mean);
  175. centroid->SetWeight(it.Count);
  176. }
  177. return digest.SerializeAsString();
  178. }
  179. i64 TDigest::GetCount() const {
  180. return std::llround(N);
  181. }