linear_regression.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. #pragma once
  2. #include "linear_model.h"
  3. #include "welford.h"
  4. #include <library/cpp/accurate_accumulate/accurate_accumulate.h>
  5. #include <util/generic/vector.h>
  6. #include <util/generic/hash.h>
  7. #include <util/generic/ylimits.h>
  8. #include <util/generic/ymath.h>
  9. class TFastLinearRegressionSolver {
  10. private:
  11. TKahanAccumulator<double> SumSquaredGoals;
  12. TVector<double> LinearizedOLSMatrix;
  13. TVector<double> OLSVector;
  14. public:
  15. bool Add(const TVector<double>& features, const double goal, const double weight = 1.);
  16. TLinearModel Solve() const;
  17. double SumSquaredErrors() const;
  18. };
  19. class TLinearRegressionSolver {
  20. private:
  21. double GoalsMean = 0.;
  22. double GoalsDeviation = 0.;
  23. TVector<double> FeatureMeans;
  24. TVector<double> LastMeans;
  25. TVector<double> NewMeans;
  26. TVector<double> LinearizedOLSMatrix;
  27. TVector<double> OLSVector;
  28. TKahanAccumulator<double> SumWeights;
  29. public:
  30. bool Add(const TVector<double>& features, const double goal, const double weight = 1.);
  31. TLinearModel Solve() const;
  32. double SumSquaredErrors() const;
  33. };
  34. template <typename TStoreType>
  35. class TTypedFastSLRSolver {
  36. private:
  37. TStoreType SumFeatures = TStoreType();
  38. TStoreType SumSquaredFeatures = TStoreType();
  39. TStoreType SumGoals = TStoreType();
  40. TStoreType SumSquaredGoals = TStoreType();
  41. TStoreType SumProducts = TStoreType();
  42. TStoreType SumWeights = TStoreType();
  43. public:
  44. bool Add(const double feature, const double goal, const double weight = 1.) {
  45. SumFeatures += feature * weight;
  46. SumSquaredFeatures += feature * feature * weight;
  47. SumGoals += goal * weight;
  48. SumSquaredGoals += goal * goal * weight;
  49. SumProducts += goal * feature * weight;
  50. SumWeights += weight;
  51. return true;
  52. }
  53. template <typename TFloatType>
  54. void Solve(TFloatType& factor, TFloatType& intercept, const double regularizationParameter = 0.1) const {
  55. if (!(double)SumGoals) {
  56. factor = intercept = TFloatType();
  57. return;
  58. }
  59. double productsDeviation, featuresDeviation;
  60. SetupSolutionFactors(productsDeviation, featuresDeviation);
  61. if (!featuresDeviation) {
  62. factor = TFloatType();
  63. intercept = (double)SumGoals / (double)SumWeights;
  64. return;
  65. }
  66. factor = productsDeviation / (featuresDeviation + regularizationParameter);
  67. intercept = (double)SumGoals / (double)SumWeights - factor * (double)SumFeatures / (double)SumWeights;
  68. }
  69. double SumSquaredErrors(const double regularizationParameter = 0.1) const {
  70. if (!(double)SumWeights) {
  71. return 0.;
  72. }
  73. const double sumGoalSquaredDeviations = (double)SumSquaredGoals - (double)SumGoals / (double)SumWeights * (double)SumGoals;
  74. double productsDeviation, featuresDeviation;
  75. SetupSolutionFactors(productsDeviation, featuresDeviation);
  76. if (!featuresDeviation) {
  77. return sumGoalSquaredDeviations;
  78. }
  79. const double factor = productsDeviation / (featuresDeviation + regularizationParameter);
  80. const double sumSquaredErrors = factor * factor * featuresDeviation - 2 * factor * productsDeviation + sumGoalSquaredDeviations;
  81. return Max(0., sumSquaredErrors);
  82. }
  83. private:
  84. void SetupSolutionFactors(double& productsDeviation, double& featuresDeviation) const {
  85. if (!(double)SumWeights) {
  86. productsDeviation = featuresDeviation = 0.;
  87. return;
  88. }
  89. featuresDeviation = (double)SumSquaredFeatures - (double)SumFeatures / (double)SumWeights * (double)SumFeatures;
  90. if (!featuresDeviation) {
  91. return;
  92. }
  93. productsDeviation = (double)SumProducts - (double)SumFeatures / (double)SumWeights * (double)SumGoals;
  94. }
  95. };
  96. using TFastSLRSolver = TTypedFastSLRSolver<double>;
  97. using TKahanSLRSolver = TTypedFastSLRSolver<TKahanAccumulator<double>>;
  98. class TSLRSolver {
  99. private:
  100. double FeaturesMean = 0.;
  101. double FeaturesDeviation = 0.;
  102. double GoalsMean = 0.;
  103. double GoalsDeviation = 0.;
  104. TKahanAccumulator<double> SumWeights;
  105. double Covariation = 0.;
  106. public:
  107. bool Add(const double feature, const double goal, const double weight = 1.);
  108. bool Add(const double* featuresBegin, const double* featuresEnd, const double* goalsBegin);
  109. bool Add(const double* featuresBegin, const double* featuresEnd, const double* goalsBegin, const double* weightsBegin);
  110. bool Add(const TVector<double>& features, const TVector<double>& goals) {
  111. Y_ASSERT(features.size() == goals.size());
  112. return Add(features.data(), features.data() + features.size(), goals.data());
  113. }
  114. bool Add(const TVector<double>& features, const TVector<double>& goals, const TVector<double>& weights) {
  115. Y_ASSERT(features.size() == goals.size() && features.size() == weights.size());
  116. return Add(features.data(), features.data() + features.size(), goals.data(), weights.data());
  117. }
  118. template <typename TFloatType>
  119. void Solve(TFloatType& factor, TFloatType& intercept, const double regularizationParameter = 0.1) const {
  120. if (!FeaturesDeviation) {
  121. factor = 0.;
  122. intercept = GoalsMean;
  123. return;
  124. }
  125. factor = Covariation / (FeaturesDeviation + regularizationParameter);
  126. intercept = GoalsMean - factor * FeaturesMean;
  127. }
  128. double SumSquaredErrors(const double regularizationParameter = 0.1) const;
  129. double GetSumWeights() const {
  130. return SumWeights.Get();
  131. }
  132. };
  133. template <typename TSLRSolverType>
  134. class TTypedBestSLRSolver {
  135. private:
  136. TVector<TSLRSolverType> SLRSolvers;
  137. public:
  138. bool Add(const TVector<double>& features, const double goal, const double weight = 1.) {
  139. if (SLRSolvers.empty()) {
  140. SLRSolvers.resize(features.size());
  141. }
  142. for (size_t featureNumber = 0; featureNumber < features.size(); ++featureNumber) {
  143. SLRSolvers[featureNumber].Add(features[featureNumber], goal, weight);
  144. }
  145. return true;
  146. }
  147. TLinearModel Solve(const double regularizationParameter = 0.1) const {
  148. const TSLRSolverType* bestSolver = nullptr;
  149. for (const TSLRSolverType& solver : SLRSolvers) {
  150. if (!bestSolver || solver.SumSquaredErrors(regularizationParameter) < bestSolver->SumSquaredErrors(regularizationParameter)) {
  151. bestSolver = &solver;
  152. }
  153. }
  154. TVector<double> coefficients(SLRSolvers.size());
  155. double intercept = 0.0;
  156. if (bestSolver) {
  157. bestSolver->Solve(coefficients[bestSolver - SLRSolvers.begin()], intercept, regularizationParameter);
  158. }
  159. TLinearModel model(std::move(coefficients), intercept);
  160. return model;
  161. }
  162. double SumSquaredErrors(const double regularizationParameter = 0.1) const {
  163. if (SLRSolvers.empty()) {
  164. return 0.;
  165. }
  166. double sse = SLRSolvers.begin()->SumSquaredErrors(regularizationParameter);
  167. for (const TSLRSolver& solver : SLRSolvers) {
  168. sse = Min(solver.SumSquaredErrors(regularizationParameter), sse);
  169. }
  170. return sse;
  171. }
  172. };
  173. using TFastBestSLRSolver = TTypedBestSLRSolver<TFastSLRSolver>;
  174. using TKahanBestSLRSolver = TTypedBestSLRSolver<TKahanSLRSolver>;
  175. using TBestSLRSolver = TTypedBestSLRSolver<TSLRSolver>;
  176. enum ETransformationType {
  177. TT_IDENTITY,
  178. TT_SIGMA,
  179. };
  180. struct TTransformationParameters {
  181. double RegressionFactor = 1.;
  182. double RegressionIntercept = 0.;
  183. double FeatureOffset = 0.;
  184. double FeatureNormalizer = 1.;
  185. Y_SAVELOAD_DEFINE(RegressionFactor,
  186. RegressionIntercept,
  187. FeatureOffset,
  188. FeatureNormalizer);
  189. };
  190. class TFeaturesTransformer {
  191. private:
  192. ETransformationType TransformationType;
  193. TTransformationParameters TransformationParameters;
  194. public:
  195. Y_SAVELOAD_DEFINE(TransformationType, TransformationParameters);
  196. TFeaturesTransformer() = default;
  197. TFeaturesTransformer(const ETransformationType transformationType,
  198. const TTransformationParameters transformationParameters)
  199. : TransformationType(transformationType)
  200. , TransformationParameters(transformationParameters)
  201. {
  202. }
  203. double Transformation(const double value) const {
  204. switch (TransformationType) {
  205. case ETransformationType::TT_IDENTITY: {
  206. return value;
  207. }
  208. case ETransformationType::TT_SIGMA: {
  209. const double valueWithoutOffset = value - TransformationParameters.FeatureOffset;
  210. const double transformedValue = valueWithoutOffset / (fabs(valueWithoutOffset) + TransformationParameters.FeatureNormalizer);
  211. return TransformationParameters.RegressionIntercept + TransformationParameters.RegressionFactor * transformedValue;
  212. }
  213. }
  214. Y_ASSERT(0);
  215. return 0.;
  216. }
  217. };
  218. class TFeaturesTransformerLearner {
  219. private:
  220. struct TPoint {
  221. float Argument;
  222. float Target;
  223. };
  224. float MinimalArgument = Max<float>();
  225. float MaximalArgument = Min<float>();
  226. ETransformationType TransformationType;
  227. TVector<TPoint> Points;
  228. public:
  229. TFeaturesTransformerLearner(const ETransformationType transformationType)
  230. : TransformationType(transformationType)
  231. {
  232. }
  233. void Add(const float argument, const float target) {
  234. Points.push_back(TPoint{argument, target});
  235. MinimalArgument = Min(MinimalArgument, argument);
  236. MaximalArgument = Max(MaximalArgument, argument);
  237. }
  238. TFeaturesTransformer Solve(const size_t iterationsCount = 100);
  239. };
  240. class TFastFeaturesTransformerLearner {
  241. private:
  242. ETransformationType TransformationType;
  243. struct TBucket {
  244. TMeanCalculator ArgumentsMean;
  245. TMeanCalculator TargetsMean;
  246. };
  247. THashMap<double, TBucket> Buckets;
  248. double Step;
  249. public:
  250. TFastFeaturesTransformerLearner(const ETransformationType transformationType, const double step = 0.1)
  251. : TransformationType(transformationType)
  252. , Step(step)
  253. {
  254. }
  255. void Add(const float argument, const float target) {
  256. TBucket& bucket = Buckets[round(argument / Step)];
  257. bucket.ArgumentsMean.Add(argument);
  258. bucket.TargetsMean.Add(target);
  259. }
  260. TFeaturesTransformer Solve(const size_t iterationsCount = 100) {
  261. TFeaturesTransformerLearner learner(TransformationType);
  262. for (auto&& argumentWithBucket : Buckets) {
  263. const TBucket& bucket = argumentWithBucket.second;
  264. learner.Add(bucket.ArgumentsMean.GetMean(), bucket.TargetsMean.GetMean());
  265. }
  266. return learner.Solve(iterationsCount);
  267. }
  268. };