linear_regression_ut.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. #include <library/cpp/linear_regression/linear_regression.h>
  2. #include <library/cpp/testing/unittest/registar.h>
  3. #include <util/generic/vector.h>
  4. #include <util/generic/ymath.h>
  5. #include <util/random/random.h>
  6. #include <util/system/defaults.h>
  7. namespace {
  8. void ValueIsCorrect(const double value, const double expectedValue, double possibleRelativeError) {
  9. UNIT_ASSERT_DOUBLES_EQUAL(value, expectedValue, possibleRelativeError * expectedValue);
  10. }
  11. }
  12. Y_UNIT_TEST_SUITE(TLinearRegressionTest) {
  13. Y_UNIT_TEST(MeanAndDeviationTest) {
  14. TVector<double> arguments;
  15. TVector<double> weights;
  16. const size_t argumentsCount = 100;
  17. for (size_t i = 0; i < argumentsCount; ++i) {
  18. arguments.push_back(i);
  19. weights.push_back(i);
  20. }
  21. TDeviationCalculator deviationCalculator;
  22. TMeanCalculator meanCalculator;
  23. for (size_t i = 0; i < arguments.size(); ++i) {
  24. meanCalculator.Add(arguments[i], weights[i]);
  25. deviationCalculator.Add(arguments[i], weights[i]);
  26. }
  27. double actualMean = InnerProduct(arguments, weights) / Accumulate(weights, 0.0);
  28. double actualDeviation = 0.;
  29. for (size_t i = 0; i < arguments.size(); ++i) {
  30. double deviation = arguments[i] - actualMean;
  31. actualDeviation += deviation * deviation * weights[i];
  32. }
  33. UNIT_ASSERT(IsValidFloat(meanCalculator.GetMean()));
  34. UNIT_ASSERT_DOUBLES_EQUAL(meanCalculator.GetMean(), actualMean, 1e-10);
  35. UNIT_ASSERT(IsValidFloat(deviationCalculator.GetDeviation()));
  36. UNIT_ASSERT_DOUBLES_EQUAL(meanCalculator.GetMean(), deviationCalculator.GetMean(), 0);
  37. UNIT_ASSERT(IsValidFloat(meanCalculator.GetSumWeights()));
  38. UNIT_ASSERT(IsValidFloat(deviationCalculator.GetSumWeights()));
  39. UNIT_ASSERT_DOUBLES_EQUAL(meanCalculator.GetSumWeights(), deviationCalculator.GetSumWeights(), 0);
  40. UNIT_ASSERT_DOUBLES_EQUAL(meanCalculator.GetSumWeights(), Accumulate(weights, 0.0), 0);
  41. ValueIsCorrect(deviationCalculator.GetDeviation(), actualDeviation, 1e-5);
  42. TMeanCalculator checkRemovingMeanCalculator;
  43. TDeviationCalculator checkRemovingDeviationCalculator;
  44. const size_t argumentsToRemoveCount = argumentsCount / 3;
  45. for (size_t i = 0; i < argumentsCount; ++i) {
  46. if (i < argumentsToRemoveCount) {
  47. meanCalculator.Remove(arguments[i], weights[i]);
  48. deviationCalculator.Remove(arguments[i], weights[i]);
  49. } else {
  50. checkRemovingMeanCalculator.Add(arguments[i], weights[i]);
  51. checkRemovingDeviationCalculator.Add(arguments[i], weights[i]);
  52. }
  53. }
  54. UNIT_ASSERT(IsValidFloat(meanCalculator.GetMean()));
  55. UNIT_ASSERT(IsValidFloat(checkRemovingMeanCalculator.GetMean()));
  56. UNIT_ASSERT(IsValidFloat(deviationCalculator.GetDeviation()));
  57. UNIT_ASSERT(IsValidFloat(checkRemovingDeviationCalculator.GetDeviation()));
  58. UNIT_ASSERT_DOUBLES_EQUAL(meanCalculator.GetMean(), deviationCalculator.GetMean(), 0);
  59. UNIT_ASSERT_DOUBLES_EQUAL(meanCalculator.GetMean(), checkRemovingMeanCalculator.GetMean(), 1e-10);
  60. ValueIsCorrect(deviationCalculator.GetDeviation(), checkRemovingDeviationCalculator.GetDeviation(), 1e-5);
  61. }
  62. Y_UNIT_TEST(CovariationTest) {
  63. TVector<double> firstValues;
  64. TVector<double> secondValues;
  65. TVector<double> weights;
  66. const size_t argumentsCount = 100;
  67. for (size_t i = 0; i < argumentsCount; ++i) {
  68. firstValues.push_back(i);
  69. secondValues.push_back(i * i);
  70. weights.push_back(i);
  71. }
  72. TCovariationCalculator covariationCalculator;
  73. for (size_t i = 0; i < argumentsCount; ++i) {
  74. covariationCalculator.Add(firstValues[i], secondValues[i], weights[i]);
  75. }
  76. const double firstValuesMean = InnerProduct(firstValues, weights) / Accumulate(weights, 0.0);
  77. const double secondValuesMean = InnerProduct(secondValues, weights) / Accumulate(weights, 0.0);
  78. double actualCovariation = 0.;
  79. for (size_t i = 0; i < argumentsCount; ++i) {
  80. actualCovariation += (firstValues[i] - firstValuesMean) * (secondValues[i] - secondValuesMean) * weights[i];
  81. }
  82. UNIT_ASSERT(IsValidFloat(covariationCalculator.GetCovariation()));
  83. UNIT_ASSERT(IsValidFloat(covariationCalculator.GetFirstValueMean()));
  84. UNIT_ASSERT(IsValidFloat(covariationCalculator.GetSecondValueMean()));
  85. UNIT_ASSERT_DOUBLES_EQUAL(covariationCalculator.GetFirstValueMean(), firstValuesMean, 1e-10);
  86. UNIT_ASSERT_DOUBLES_EQUAL(covariationCalculator.GetSecondValueMean(), secondValuesMean, 1e-10);
  87. UNIT_ASSERT(IsValidFloat(covariationCalculator.GetSumWeights()));
  88. UNIT_ASSERT_DOUBLES_EQUAL(covariationCalculator.GetSumWeights(), Accumulate(weights, 0.0), 0);
  89. ValueIsCorrect(covariationCalculator.GetCovariation(), actualCovariation, 1e-5);
  90. TCovariationCalculator checkRemovingCovariationCalculator;
  91. const size_t argumentsToRemoveCount = argumentsCount / 3;
  92. for (size_t i = 0; i < argumentsCount; ++i) {
  93. if (i < argumentsToRemoveCount) {
  94. covariationCalculator.Remove(firstValues[i], secondValues[i], weights[i]);
  95. } else {
  96. checkRemovingCovariationCalculator.Add(firstValues[i], secondValues[i], weights[i]);
  97. }
  98. }
  99. ValueIsCorrect(covariationCalculator.GetCovariation(), checkRemovingCovariationCalculator.GetCovariation(), 1e-5);
  100. }
  101. template <typename TSLRSolverType>
  102. void SLRTest() {
  103. TVector<double> arguments;
  104. TVector<double> weights;
  105. TVector<double> goals;
  106. const double factor = 2.;
  107. const double intercept = 105.;
  108. const double randomError = 0.01;
  109. const size_t argumentsCount = 10;
  110. for (size_t i = 0; i < argumentsCount; ++i) {
  111. arguments.push_back(i);
  112. weights.push_back(i);
  113. goals.push_back(arguments.back() * factor + intercept + 2 * (i % 2 - 0.5) * randomError);
  114. }
  115. TSLRSolverType slrSolver;
  116. for (size_t i = 0; i < argumentsCount; ++i) {
  117. slrSolver.Add(arguments[i], goals[i], weights[i]);
  118. }
  119. for (double regularizationThreshold = 0.; regularizationThreshold < 0.05; regularizationThreshold += 0.01) {
  120. double solutionFactor, solutionIntercept;
  121. slrSolver.Solve(solutionFactor, solutionIntercept, regularizationThreshold);
  122. double predictedSumSquaredErrors = slrSolver.SumSquaredErrors(regularizationThreshold);
  123. UNIT_ASSERT(IsValidFloat(solutionFactor));
  124. UNIT_ASSERT(IsValidFloat(solutionIntercept));
  125. UNIT_ASSERT(IsValidFloat(predictedSumSquaredErrors));
  126. UNIT_ASSERT_DOUBLES_EQUAL(solutionFactor, factor, 1e-2);
  127. UNIT_ASSERT_DOUBLES_EQUAL(solutionIntercept, intercept, 1e-2);
  128. double sumSquaredErrors = 0.;
  129. for (size_t i = 0; i < argumentsCount; ++i) {
  130. double error = goals[i] - arguments[i] * solutionFactor - solutionIntercept;
  131. sumSquaredErrors += error * error * weights[i];
  132. }
  133. if (!regularizationThreshold) {
  134. UNIT_ASSERT(predictedSumSquaredErrors < Accumulate(weights, 0.0) * randomError * randomError);
  135. }
  136. UNIT_ASSERT_DOUBLES_EQUAL(predictedSumSquaredErrors, sumSquaredErrors, 1e-8);
  137. }
  138. }
  139. Y_UNIT_TEST(FastSLRTest) {
  140. SLRTest<TFastSLRSolver>();
  141. }
  142. Y_UNIT_TEST(KahanSLRTest) {
  143. SLRTest<TKahanSLRSolver>();
  144. }
  145. Y_UNIT_TEST(SLRTest) {
  146. SLRTest<TSLRSolver>();
  147. }
  148. template <typename TLinearRegressionSolverType>
  149. void LinearRegressionTest() {
  150. const size_t featuresCount = 10;
  151. const size_t instancesCount = 10000;
  152. const double randomError = 0.01;
  153. TVector<double> coefficients;
  154. for (size_t featureNumber = 0; featureNumber < featuresCount; ++featureNumber) {
  155. coefficients.push_back(featureNumber);
  156. }
  157. const double intercept = 10;
  158. TVector<TVector<double>> featuresMatrix;
  159. TVector<double> goals;
  160. TVector<double> weights;
  161. for (size_t instanceNumber = 0; instanceNumber < instancesCount; ++instanceNumber) {
  162. TVector<double> features;
  163. for (size_t featureNumber = 0; featureNumber < featuresCount; ++featureNumber) {
  164. features.push_back(RandomNumber<double>());
  165. }
  166. featuresMatrix.push_back(features);
  167. const double goal = InnerProduct(coefficients, features) + intercept + 2 * (instanceNumber % 2 - 0.5) * randomError;
  168. goals.push_back(goal);
  169. weights.push_back(instanceNumber);
  170. }
  171. TLinearRegressionSolverType lrSolver;
  172. for (size_t instanceNumber = 0; instanceNumber < instancesCount; ++instanceNumber) {
  173. lrSolver.Add(featuresMatrix[instanceNumber], goals[instanceNumber], weights[instanceNumber]);
  174. }
  175. const TLinearModel model = lrSolver.Solve();
  176. for (size_t featureNumber = 0; featureNumber < featuresCount; ++featureNumber) {
  177. UNIT_ASSERT_DOUBLES_EQUAL(model.GetCoefficients()[featureNumber], coefficients[featureNumber], 1e-2);
  178. }
  179. UNIT_ASSERT_DOUBLES_EQUAL(model.GetIntercept(), intercept, 1e-2);
  180. const double expectedSumSquaredErrors = randomError * randomError * Accumulate(weights, 0.0);
  181. UNIT_ASSERT_DOUBLES_EQUAL(lrSolver.SumSquaredErrors(), expectedSumSquaredErrors, expectedSumSquaredErrors * 0.01);
  182. }
  183. Y_UNIT_TEST(FastLRTest) {
  184. LinearRegressionTest<TFastLinearRegressionSolver>();
  185. }
  186. Y_UNIT_TEST(LRTest) {
  187. LinearRegressionTest<TLinearRegressionSolver>();
  188. }
  189. void TransformationTest(const ETransformationType transformationType, const size_t pointsCount) {
  190. TVector<float> arguments;
  191. TVector<float> goals;
  192. const double regressionFactor = 10.;
  193. const double regressionIntercept = 100;
  194. const double featureOffset = -1.5;
  195. const double featureNormalizer = 15;
  196. const double left = -100.;
  197. const double right = +100.;
  198. const double step = (right - left) / pointsCount;
  199. for (double argument = left; argument <= right; argument += step) {
  200. const double goal = regressionIntercept + regressionFactor * (argument - featureOffset) / (fabs(argument - featureOffset) + featureNormalizer);
  201. arguments.push_back(argument);
  202. goals.push_back(goal);
  203. }
  204. TFastFeaturesTransformerLearner learner(transformationType);
  205. for (size_t instanceNumber = 0; instanceNumber < arguments.size(); ++instanceNumber) {
  206. learner.Add(arguments[instanceNumber], goals[instanceNumber]);
  207. }
  208. TFeaturesTransformer transformer = learner.Solve();
  209. double sse = 0.;
  210. for (size_t instanceNumber = 0; instanceNumber < arguments.size(); ++instanceNumber) {
  211. const double error = transformer.Transformation(arguments[instanceNumber]) - goals[instanceNumber];
  212. sse += error * error;
  213. }
  214. const double rmse = sqrt(sse / arguments.size());
  215. UNIT_ASSERT_DOUBLES_EQUAL(rmse, 0., 1e-3);
  216. }
  217. Y_UNIT_TEST(SigmaTest100) {
  218. TransformationTest(ETransformationType::TT_SIGMA, 100);
  219. }
  220. Y_UNIT_TEST(SigmaTest1000) {
  221. TransformationTest(ETransformationType::TT_SIGMA, 1000);
  222. }
  223. Y_UNIT_TEST(SigmaTest10000) {
  224. TransformationTest(ETransformationType::TT_SIGMA, 10000);
  225. }
  226. Y_UNIT_TEST(SigmaTest100000) {
  227. TransformationTest(ETransformationType::TT_SIGMA, 100000);
  228. }
  229. Y_UNIT_TEST(SigmaTest1000000) {
  230. TransformationTest(ETransformationType::TT_SIGMA, 1000000);
  231. }
  232. Y_UNIT_TEST(SigmaTest10000000) {
  233. TransformationTest(ETransformationType::TT_SIGMA, 10000000);
  234. }
  235. Y_UNIT_TEST(ResetCalculatorTest) {
  236. TVector<double> arguments;
  237. TVector<double> weights;
  238. const double eps = 1e-10;
  239. const size_t argumentsCount = 100;
  240. for (size_t i = 0; i < argumentsCount; ++i) {
  241. arguments.push_back(i);
  242. weights.push_back(i);
  243. }
  244. TDeviationCalculator deviationCalculator1, deviationCalculator2;
  245. TMeanCalculator meanCalculator1, meanCalculator2;
  246. TCovariationCalculator covariationCalculator1, covariationCalculator2;
  247. for (size_t i = 0; i < arguments.size(); ++i) {
  248. meanCalculator1.Add(arguments[i], weights[i]);
  249. meanCalculator2.Add(arguments[i], weights[i]);
  250. deviationCalculator1.Add(arguments[i], weights[i]);
  251. deviationCalculator2.Add(arguments[i], weights[i]);
  252. covariationCalculator1.Add(arguments[i], arguments[arguments.size() - i - 1], weights[i]);
  253. covariationCalculator2.Add(arguments[i], arguments[arguments.size() - i - 1], weights[i]);
  254. }
  255. UNIT_ASSERT_DOUBLES_EQUAL(meanCalculator1.GetMean(), meanCalculator2.GetMean(), eps);
  256. UNIT_ASSERT_DOUBLES_EQUAL(meanCalculator1.GetSumWeights(), meanCalculator2.GetSumWeights(), eps);
  257. UNIT_ASSERT_DOUBLES_EQUAL(deviationCalculator1.GetMean(), deviationCalculator2.GetMean(), eps);
  258. UNIT_ASSERT_DOUBLES_EQUAL(deviationCalculator1.GetDeviation(), deviationCalculator2.GetDeviation(), eps);
  259. UNIT_ASSERT_DOUBLES_EQUAL(deviationCalculator1.GetStdDev(), deviationCalculator2.GetStdDev(), eps);
  260. UNIT_ASSERT_DOUBLES_EQUAL(deviationCalculator1.GetSumWeights(), deviationCalculator2.GetSumWeights(), eps);
  261. UNIT_ASSERT_DOUBLES_EQUAL(covariationCalculator1.GetFirstValueMean(), covariationCalculator2.GetFirstValueMean(), eps);
  262. UNIT_ASSERT_DOUBLES_EQUAL(covariationCalculator1.GetSecondValueMean(), covariationCalculator2.GetSecondValueMean(), eps);
  263. UNIT_ASSERT_DOUBLES_EQUAL(covariationCalculator1.GetCovariation(), covariationCalculator2.GetCovariation(), eps);
  264. UNIT_ASSERT_DOUBLES_EQUAL(covariationCalculator1.GetSumWeights(), covariationCalculator2.GetSumWeights(), eps);
  265. meanCalculator2.Reset();
  266. deviationCalculator2.Reset();
  267. covariationCalculator2.Reset();
  268. UNIT_ASSERT_DOUBLES_EQUAL(0.0, meanCalculator2.GetMean(), eps);
  269. UNIT_ASSERT_DOUBLES_EQUAL(0.0, meanCalculator2.GetSumWeights(), eps);
  270. UNIT_ASSERT_DOUBLES_EQUAL(0.0, deviationCalculator2.GetMean(), eps);
  271. UNIT_ASSERT_DOUBLES_EQUAL(0.0, deviationCalculator2.GetDeviation(), eps);
  272. UNIT_ASSERT_DOUBLES_EQUAL(0.0, deviationCalculator2.GetStdDev(), eps);
  273. UNIT_ASSERT_DOUBLES_EQUAL(0.0, deviationCalculator2.GetSumWeights(), eps);
  274. UNIT_ASSERT_DOUBLES_EQUAL(0.0, covariationCalculator2.GetFirstValueMean(), eps);
  275. UNIT_ASSERT_DOUBLES_EQUAL(0.0, covariationCalculator2.GetSecondValueMean(), eps);
  276. UNIT_ASSERT_DOUBLES_EQUAL(0.0, covariationCalculator2.GetCovariation(), eps);
  277. UNIT_ASSERT_DOUBLES_EQUAL(0.0, covariationCalculator2.GetSumWeights(), eps);
  278. for (size_t i = 0; i < arguments.size(); ++i) {
  279. meanCalculator2.Add(arguments[i], weights[i]);
  280. deviationCalculator2.Add(arguments[i], weights[i]);
  281. covariationCalculator2.Add(arguments[i], arguments[arguments.size() - i - 1], weights[i]);
  282. }
  283. UNIT_ASSERT_DOUBLES_EQUAL(meanCalculator1.GetMean(), meanCalculator2.GetMean(), 1e-10);
  284. UNIT_ASSERT_DOUBLES_EQUAL(meanCalculator1.GetSumWeights(), meanCalculator2.GetSumWeights(), 1e-10);
  285. UNIT_ASSERT_DOUBLES_EQUAL(deviationCalculator1.GetMean(), deviationCalculator2.GetMean(), eps);
  286. UNIT_ASSERT_DOUBLES_EQUAL(deviationCalculator1.GetDeviation(), deviationCalculator2.GetDeviation(), eps);
  287. UNIT_ASSERT_DOUBLES_EQUAL(deviationCalculator1.GetStdDev(), deviationCalculator2.GetStdDev(), eps);
  288. UNIT_ASSERT_DOUBLES_EQUAL(deviationCalculator1.GetSumWeights(), deviationCalculator2.GetSumWeights(), eps);
  289. UNIT_ASSERT_DOUBLES_EQUAL(covariationCalculator1.GetFirstValueMean(), covariationCalculator2.GetFirstValueMean(), eps);
  290. UNIT_ASSERT_DOUBLES_EQUAL(covariationCalculator1.GetSecondValueMean(), covariationCalculator2.GetSecondValueMean(), eps);
  291. UNIT_ASSERT_DOUBLES_EQUAL(covariationCalculator1.GetCovariation(), covariationCalculator2.GetCovariation(), eps);
  292. UNIT_ASSERT_DOUBLES_EQUAL(covariationCalculator1.GetSumWeights(), covariationCalculator2.GetSumWeights(), eps);
  293. }
  294. }