linear_regression.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. #pragma once
  2. #include "linear_model.h"
  3. #include "welford.h"
  4. #include <library/cpp/accurate_accumulate/accurate_accumulate.h>
  5. #include <util/generic/vector.h>
  6. #include <util/generic/hash.h>
  7. #include <util/generic/ymath.h>
  8. class TFastLinearRegressionSolver {
  9. private:
  10. TKahanAccumulator<double> SumSquaredGoals;
  11. TVector<double> LinearizedOLSMatrix;
  12. TVector<double> OLSVector;
  13. public:
  14. bool Add(const TVector<double>& features, const double goal, const double weight = 1.);
  15. TLinearModel Solve() const;
  16. double SumSquaredErrors() const;
  17. };
  18. class TLinearRegressionSolver {
  19. private:
  20. double GoalsMean = 0.;
  21. double GoalsDeviation = 0.;
  22. TVector<double> FeatureMeans;
  23. TVector<double> LastMeans;
  24. TVector<double> NewMeans;
  25. TVector<double> LinearizedOLSMatrix;
  26. TVector<double> OLSVector;
  27. TKahanAccumulator<double> SumWeights;
  28. public:
  29. bool Add(const TVector<double>& features, const double goal, const double weight = 1.);
  30. TLinearModel Solve() const;
  31. double SumSquaredErrors() const;
  32. };
  33. template <typename TStoreType>
  34. class TTypedFastSLRSolver {
  35. private:
  36. TStoreType SumFeatures = TStoreType();
  37. TStoreType SumSquaredFeatures = TStoreType();
  38. TStoreType SumGoals = TStoreType();
  39. TStoreType SumSquaredGoals = TStoreType();
  40. TStoreType SumProducts = TStoreType();
  41. TStoreType SumWeights = TStoreType();
  42. public:
  43. bool Add(const double feature, const double goal, const double weight = 1.) {
  44. SumFeatures += feature * weight;
  45. SumSquaredFeatures += feature * feature * weight;
  46. SumGoals += goal * weight;
  47. SumSquaredGoals += goal * goal * weight;
  48. SumProducts += goal * feature * weight;
  49. SumWeights += weight;
  50. return true;
  51. }
  52. template <typename TFloatType>
  53. void Solve(TFloatType& factor, TFloatType& intercept, const double regularizationParameter = 0.1) const {
  54. if (!(double)SumGoals) {
  55. factor = intercept = TFloatType();
  56. return;
  57. }
  58. double productsDeviation, featuresDeviation;
  59. SetupSolutionFactors(productsDeviation, featuresDeviation);
  60. if (!featuresDeviation) {
  61. factor = TFloatType();
  62. intercept = (double)SumGoals / (double)SumWeights;
  63. return;
  64. }
  65. factor = productsDeviation / (featuresDeviation + regularizationParameter);
  66. intercept = (double)SumGoals / (double)SumWeights - factor * (double)SumFeatures / (double)SumWeights;
  67. }
  68. double SumSquaredErrors(const double regularizationParameter = 0.1) const {
  69. if (!(double)SumWeights) {
  70. return 0.;
  71. }
  72. const double sumGoalSquaredDeviations = (double)SumSquaredGoals - (double)SumGoals / (double)SumWeights * (double)SumGoals;
  73. double productsDeviation, featuresDeviation;
  74. SetupSolutionFactors(productsDeviation, featuresDeviation);
  75. if (!featuresDeviation) {
  76. return sumGoalSquaredDeviations;
  77. }
  78. const double factor = productsDeviation / (featuresDeviation + regularizationParameter);
  79. const double sumSquaredErrors = factor * factor * featuresDeviation - 2 * factor * productsDeviation + sumGoalSquaredDeviations;
  80. return Max(0., sumSquaredErrors);
  81. }
  82. private:
  83. void SetupSolutionFactors(double& productsDeviation, double& featuresDeviation) const {
  84. if (!(double)SumWeights) {
  85. productsDeviation = featuresDeviation = 0.;
  86. return;
  87. }
  88. featuresDeviation = (double)SumSquaredFeatures - (double)SumFeatures / (double)SumWeights * (double)SumFeatures;
  89. if (!featuresDeviation) {
  90. return;
  91. }
  92. productsDeviation = (double)SumProducts - (double)SumFeatures / (double)SumWeights * (double)SumGoals;
  93. }
  94. };
  95. using TFastSLRSolver = TTypedFastSLRSolver<double>;
  96. using TKahanSLRSolver = TTypedFastSLRSolver<TKahanAccumulator<double>>;
  97. class TSLRSolver {
  98. private:
  99. double FeaturesMean = 0.;
  100. double FeaturesDeviation = 0.;
  101. double GoalsMean = 0.;
  102. double GoalsDeviation = 0.;
  103. TKahanAccumulator<double> SumWeights;
  104. double Covariation = 0.;
  105. public:
  106. bool Add(const double feature, const double goal, const double weight = 1.);
  107. bool Add(const double* featuresBegin, const double* featuresEnd, const double* goalsBegin);
  108. bool Add(const double* featuresBegin, const double* featuresEnd, const double* goalsBegin, const double* weightsBegin);
  109. bool Add(const TVector<double>& features, const TVector<double>& goals) {
  110. Y_ASSERT(features.size() == goals.size());
  111. return Add(features.data(), features.data() + features.size(), goals.data());
  112. }
  113. bool Add(const TVector<double>& features, const TVector<double>& goals, const TVector<double>& weights) {
  114. Y_ASSERT(features.size() == goals.size() && features.size() == weights.size());
  115. return Add(features.data(), features.data() + features.size(), goals.data(), weights.data());
  116. }
  117. template <typename TFloatType>
  118. void Solve(TFloatType& factor, TFloatType& intercept, const double regularizationParameter = 0.1) const {
  119. if (!FeaturesDeviation) {
  120. factor = 0.;
  121. intercept = GoalsMean;
  122. return;
  123. }
  124. factor = Covariation / (FeaturesDeviation + regularizationParameter);
  125. intercept = GoalsMean - factor * FeaturesMean;
  126. }
  127. double SumSquaredErrors(const double regularizationParameter = 0.1) const;
  128. double GetSumWeights() const {
  129. return SumWeights.Get();
  130. }
  131. };
  132. template <typename TSLRSolverType>
  133. class TTypedBestSLRSolver {
  134. private:
  135. TVector<TSLRSolverType> SLRSolvers;
  136. public:
  137. bool Add(const TVector<double>& features, const double goal, const double weight = 1.) {
  138. if (SLRSolvers.empty()) {
  139. SLRSolvers.resize(features.size());
  140. }
  141. for (size_t featureNumber = 0; featureNumber < features.size(); ++featureNumber) {
  142. SLRSolvers[featureNumber].Add(features[featureNumber], goal, weight);
  143. }
  144. return true;
  145. }
  146. TLinearModel Solve(const double regularizationParameter = 0.1) const {
  147. const TSLRSolverType* bestSolver = nullptr;
  148. for (const TSLRSolverType& solver : SLRSolvers) {
  149. if (!bestSolver || solver.SumSquaredErrors(regularizationParameter) < bestSolver->SumSquaredErrors(regularizationParameter)) {
  150. bestSolver = &solver;
  151. }
  152. }
  153. TVector<double> coefficients(SLRSolvers.size());
  154. double intercept = 0.0;
  155. if (bestSolver) {
  156. bestSolver->Solve(coefficients[bestSolver - SLRSolvers.begin()], intercept, regularizationParameter);
  157. }
  158. TLinearModel model(std::move(coefficients), intercept);
  159. return model;
  160. }
  161. double SumSquaredErrors(const double regularizationParameter = 0.1) const {
  162. if (SLRSolvers.empty()) {
  163. return 0.;
  164. }
  165. double sse = SLRSolvers.begin()->SumSquaredErrors(regularizationParameter);
  166. for (const TSLRSolver& solver : SLRSolvers) {
  167. sse = Min(solver.SumSquaredErrors(regularizationParameter), sse);
  168. }
  169. return sse;
  170. }
  171. };
  172. using TFastBestSLRSolver = TTypedBestSLRSolver<TFastSLRSolver>;
  173. using TKahanBestSLRSolver = TTypedBestSLRSolver<TKahanSLRSolver>;
  174. using TBestSLRSolver = TTypedBestSLRSolver<TSLRSolver>;
  175. enum ETransformationType {
  176. TT_IDENTITY,
  177. TT_SIGMA,
  178. };
  179. struct TTransformationParameters {
  180. double RegressionFactor = 1.;
  181. double RegressionIntercept = 0.;
  182. double FeatureOffset = 0.;
  183. double FeatureNormalizer = 1.;
  184. Y_SAVELOAD_DEFINE(RegressionFactor,
  185. RegressionIntercept,
  186. FeatureOffset,
  187. FeatureNormalizer);
  188. };
  189. class TFeaturesTransformer {
  190. private:
  191. ETransformationType TransformationType;
  192. TTransformationParameters TransformationParameters;
  193. public:
  194. Y_SAVELOAD_DEFINE(TransformationType, TransformationParameters);
  195. TFeaturesTransformer() = default;
  196. TFeaturesTransformer(const ETransformationType transformationType,
  197. const TTransformationParameters transformationParameters)
  198. : TransformationType(transformationType)
  199. , TransformationParameters(transformationParameters)
  200. {
  201. }
  202. double Transformation(const double value) const {
  203. switch (TransformationType) {
  204. case ETransformationType::TT_IDENTITY: {
  205. return value;
  206. }
  207. case ETransformationType::TT_SIGMA: {
  208. const double valueWithoutOffset = value - TransformationParameters.FeatureOffset;
  209. const double transformedValue = valueWithoutOffset / (fabs(valueWithoutOffset) + TransformationParameters.FeatureNormalizer);
  210. return TransformationParameters.RegressionIntercept + TransformationParameters.RegressionFactor * transformedValue;
  211. }
  212. }
  213. Y_ASSERT(0);
  214. return 0.;
  215. }
  216. };
  217. class TFeaturesTransformerLearner {
  218. private:
  219. struct TPoint {
  220. float Argument;
  221. float Target;
  222. };
  223. float MinimalArgument = Max<float>();
  224. float MaximalArgument = Min<float>();
  225. ETransformationType TransformationType;
  226. TVector<TPoint> Points;
  227. public:
  228. TFeaturesTransformerLearner(const ETransformationType transformationType)
  229. : TransformationType(transformationType)
  230. {
  231. }
  232. void Add(const float argument, const float target) {
  233. Points.push_back(TPoint{argument, target});
  234. MinimalArgument = Min(MinimalArgument, argument);
  235. MaximalArgument = Max(MaximalArgument, argument);
  236. }
  237. TFeaturesTransformer Solve(const size_t iterationsCount = 100);
  238. };
  239. class TFastFeaturesTransformerLearner {
  240. private:
  241. ETransformationType TransformationType;
  242. struct TBucket {
  243. TMeanCalculator ArgumentsMean;
  244. TMeanCalculator TargetsMean;
  245. };
  246. THashMap<double, TBucket> Buckets;
  247. double Step;
  248. public:
  249. TFastFeaturesTransformerLearner(const ETransformationType transformationType, const double step = 0.1)
  250. : TransformationType(transformationType)
  251. , Step(step)
  252. {
  253. }
  254. void Add(const float argument, const float target) {
  255. TBucket& bucket = Buckets[round(argument / Step)];
  256. bucket.ArgumentsMean.Add(argument);
  257. bucket.TargetsMean.Add(target);
  258. }
  259. TFeaturesTransformer Solve(const size_t iterationsCount = 100) {
  260. TFeaturesTransformerLearner learner(TransformationType);
  261. for (auto&& argumentWithBucket : Buckets) {
  262. const TBucket& bucket = argumentWithBucket.second;
  263. learner.Add(bucket.ArgumentsMean.GetMean(), bucket.TargetsMean.GetMean());
  264. }
  265. return learner.Solve(iterationsCount);
  266. }
  267. };