welford.h 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. #pragma once
  2. #include <library/cpp/accurate_accumulate/accurate_accumulate.h>
  3. #include <util/ysaveload.h>
  4. // accurately computes (w_1 * x_1 + w_2 * x_2 + ... + w_n * x_n) / (w_1 + w_2 + ... + w_n)
  5. class TMeanCalculator {
  6. private:
  7. double Mean = 0.;
  8. TKahanAccumulator<double> SumWeights;
  9. public:
  10. Y_SAVELOAD_DEFINE(Mean, SumWeights);
  11. void Multiply(const double value);
  12. void Add(const double value, const double weight = 1.);
  13. void Remove(const double value, const double weight = 1.);
  14. double GetMean() const;
  15. double GetSumWeights() const;
  16. void Reset();
  17. bool operator<(const TMeanCalculator& other) const {
  18. return Mean < other.Mean;
  19. }
  20. bool operator>(const TMeanCalculator& other) const {
  21. return Mean > other.Mean;
  22. }
  23. };
  24. // accurately computes (w_1 * x_1 * y_1 + w_2 * x_2 * y_2 + ... + w_n * x_n * y_n) / (w_1 + w_2 + ... + w_n)
  25. class TCovariationCalculator {
  26. private:
  27. double Covariation = 0.;
  28. double FirstValueMean = 0.;
  29. double SecondValueMean = 0.;
  30. TKahanAccumulator<double> SumWeights;
  31. public:
  32. Y_SAVELOAD_DEFINE(Covariation, FirstValueMean, SecondValueMean, SumWeights);
  33. void Add(const double firstValue, const double secondValue, const double weight = 1.);
  34. void Remove(const double firstValue, const double secondValue, const double weight = 1.);
  35. double GetFirstValueMean() const;
  36. double GetSecondValueMean() const;
  37. double GetCovariation() const;
  38. double GetSumWeights() const;
  39. void Reset();
  40. };
  41. // accurately computes (w_1 * x_1 * x_1 + w_2 * x_2 * x_2 + ... + w_n * x_n * x_n) / (w_1 + w_2 + ... + w_n)
  42. class TDeviationCalculator {
  43. private:
  44. double Deviation = 0.;
  45. TMeanCalculator MeanCalculator;
  46. public:
  47. Y_SAVELOAD_DEFINE(Deviation, MeanCalculator);
  48. void Add(const double value, const double weight = 1.);
  49. void Remove(const double value, const double weight = 1.);
  50. double GetMean() const;
  51. double GetDeviation() const;
  52. double GetStdDev() const;
  53. double GetSumWeights() const;
  54. void Reset();
  55. };