main.cpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. #include "pool.h"
  2. #include <library/cpp/linear_regression/linear_regression.h>
  3. #include <util/datetime/base.h>
  4. #include <util/datetime/cputimer.h>
  5. #include <util/system/type_name.h>
  6. #include <util/string/printf.h>
  7. template <typename TLRSolver>
  8. void QualityBenchmark(const TPool& originalPool) {
  9. auto measure = [&](const double injureFactor, const double injureOffset) {
  10. TPool injuredPool = originalPool.InjurePool(injureFactor, injureOffset);
  11. static const size_t runsCount = 10;
  12. static const size_t foldsCount = 10;
  13. TMeanCalculator determinationCoefficientCalculator;
  14. TPool::TCVIterator learnIterator = injuredPool.CrossValidationIterator(foldsCount, TPool::LearnIterator);
  15. TPool::TCVIterator testIterator = injuredPool.CrossValidationIterator(foldsCount, TPool::TestIterator);
  16. for (size_t runNumber = 0; runNumber < runsCount; ++runNumber) {
  17. for (size_t foldNumber = 0; foldNumber < foldsCount; ++foldNumber) {
  18. learnIterator.ResetShuffle();
  19. learnIterator.SetTestFold(foldNumber);
  20. testIterator.ResetShuffle();
  21. testIterator.SetTestFold(foldNumber);
  22. TLRSolver solver;
  23. for (; learnIterator.IsValid(); ++learnIterator) {
  24. solver.Add(learnIterator->Features, learnIterator->Goal, learnIterator->Weight);
  25. }
  26. TLinearModel model = solver.Solve();
  27. TDeviationCalculator goalsCalculator;
  28. TKahanAccumulator<double> errorsCalculator;
  29. for (; testIterator.IsValid(); ++testIterator) {
  30. const double prediction = model.Prediction(testIterator->Features);
  31. const double goal = testIterator->Goal;
  32. const double weight = testIterator->Weight;
  33. const double error = goal - prediction;
  34. goalsCalculator.Add(goal, weight);
  35. errorsCalculator += error * error * weight;
  36. }
  37. const double determinationCoefficient = 1 - errorsCalculator.Get() / goalsCalculator.GetDeviation();
  38. determinationCoefficientCalculator.Add(determinationCoefficient);
  39. }
  40. }
  41. return determinationCoefficientCalculator.GetMean();
  42. };
  43. Cout << TypeName<TLRSolver>() << ":\n";
  44. Cout << "\t" << Sprintf("base : %.10lf\n", measure(1., 0.));
  45. Cout << "\t" << Sprintf("injure1 : %.10lf\n", measure(1e-1, 1e+1));
  46. Cout << "\t" << Sprintf("injure2 : %.10lf\n", measure(1e-3, 1e+4));
  47. Cout << "\t" << Sprintf("injure3 : %.10lf\n", measure(1e-3, 1e+5));
  48. Cout << "\t" << Sprintf("injure4 : %.10lf\n", measure(1e-3, 1e+6));
  49. Cout << "\t" << Sprintf("injure5 : %.10lf\n", measure(1e-4, 1e+6));
  50. Cout << "\t" << Sprintf("injure6 : %.10lf\n", measure(1e-4, 1e+7));
  51. Cout << Endl;
  52. }
  53. template <typename TLRSolver>
  54. void SpeedBenchmark(const TPool& originalPool) {
  55. TDeviationCalculator speedTest;
  56. static const size_t runsCount = 1000;
  57. for (size_t runNumber = 0; runNumber < runsCount; ++runNumber) {
  58. TLRSolver solver;
  59. TLinearModel model;
  60. {
  61. TSimpleTimer timer;
  62. for (const TInstance& instance : originalPool) {
  63. solver.Add(instance.Features, instance.Goal, instance.Weight);
  64. }
  65. model = solver.Solve();
  66. speedTest.Add(timer.Get().MicroSeconds());
  67. }
  68. }
  69. const double multiplier = 1e-6;
  70. Cout << Sprintf("%.5lf +/- %.5lf: ", speedTest.GetMean() * multiplier, speedTest.GetStdDev() * multiplier) << TypeName<TLRSolver>() << Endl;
  71. }
  72. int main(int argc, const char** argv) {
  73. for (int taskNumber = 1; taskNumber < argc; ++taskNumber) {
  74. TPool pool;
  75. pool.ReadFromFeatures(argv[taskNumber]);
  76. Cout << argv[taskNumber] << ":" << Endl;
  77. QualityBenchmark<TFastBestSLRSolver>(pool);
  78. QualityBenchmark<TKahanBestSLRSolver>(pool);
  79. QualityBenchmark<TBestSLRSolver>(pool);
  80. QualityBenchmark<TLinearRegressionSolver>(pool);
  81. QualityBenchmark<TFastLinearRegressionSolver>(pool);
  82. SpeedBenchmark<TFastBestSLRSolver>(pool);
  83. SpeedBenchmark<TKahanBestSLRSolver>(pool);
  84. SpeedBenchmark<TBestSLRSolver>(pool);
  85. SpeedBenchmark<TLinearRegressionSolver>(pool);
  86. SpeedBenchmark<TFastLinearRegressionSolver>(pool);
  87. }
  88. return 0;
  89. }