linear_regression.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. #include "linear_model.h"
  2. #include "linear_regression.h"
  3. #include <util/generic/ymath.h>
  4. #ifdef _sse2_
  5. #include <emmintrin.h>
  6. #include <xmmintrin.h>
  7. #endif
  8. #include <algorithm>
  9. #include <functional>
  10. namespace {
  11. inline void AddFeaturesProduct(const double weight, const TVector<double>& features, TVector<double>& linearizedOLSTriangleMatrix);
  12. TVector<double> Solve(const TVector<double>& olsMatrix, const TVector<double>& olsVector);
  13. double SumSquaredErrors(const TVector<double>& olsMatrix,
  14. const TVector<double>& olsVector,
  15. const TVector<double>& solution,
  16. const double goalsDeviation);
  17. }
  18. bool TFastLinearRegressionSolver::Add(const TVector<double>& features, const double goal, const double weight) {
  19. const size_t featuresCount = features.size();
  20. if (LinearizedOLSMatrix.empty()) {
  21. LinearizedOLSMatrix.resize((featuresCount + 1) * (featuresCount + 2) / 2);
  22. OLSVector.resize(featuresCount + 1);
  23. }
  24. AddFeaturesProduct(weight, features, LinearizedOLSMatrix);
  25. const double weightedGoal = goal * weight;
  26. double* olsVectorElement = OLSVector.data();
  27. for (const double feature : features) {
  28. *olsVectorElement += feature * weightedGoal;
  29. ++olsVectorElement;
  30. }
  31. *olsVectorElement += weightedGoal;
  32. SumSquaredGoals += goal * goal * weight;
  33. return true;
  34. }
  35. bool TLinearRegressionSolver::Add(const TVector<double>& features, const double goal, const double weight) {
  36. const size_t featuresCount = features.size();
  37. if (FeatureMeans.empty()) {
  38. FeatureMeans.resize(featuresCount);
  39. LastMeans.resize(featuresCount);
  40. NewMeans.resize(featuresCount);
  41. LinearizedOLSMatrix.resize(featuresCount * (featuresCount + 1) / 2);
  42. OLSVector.resize(featuresCount);
  43. }
  44. SumWeights += weight;
  45. if (!SumWeights.Get()) {
  46. return false;
  47. }
  48. for (size_t featureNumber = 0; featureNumber < featuresCount; ++featureNumber) {
  49. const double feature = features[featureNumber];
  50. double& featureMean = FeatureMeans[featureNumber];
  51. LastMeans[featureNumber] = weight * (feature - featureMean);
  52. featureMean += weight * (feature - featureMean) / SumWeights.Get();
  53. NewMeans[featureNumber] = feature - featureMean;
  54. ;
  55. }
  56. double* olsMatrixElement = LinearizedOLSMatrix.data();
  57. const double* lastMean = LastMeans.data();
  58. const double* newMean = NewMeans.data();
  59. const double* lastMeansEnd = lastMean + LastMeans.size();
  60. const double* newMeansEnd = newMean + NewMeans.size();
  61. #ifdef _sse2_
  62. for (; lastMean != lastMeansEnd; ++lastMean, ++newMean) {
  63. __m128d factor = _mm_set_pd(*lastMean, *lastMean);
  64. const double* secondFeatureMean = newMean;
  65. for (; secondFeatureMean + 1 < newMeansEnd; secondFeatureMean += 2, olsMatrixElement += 2) {
  66. __m128d matrixElem = _mm_loadu_pd(olsMatrixElement);
  67. __m128d secondFeatureMeanElem = _mm_loadu_pd(secondFeatureMean);
  68. __m128d product = _mm_mul_pd(factor, secondFeatureMeanElem);
  69. __m128d addition = _mm_add_pd(matrixElem, product);
  70. _mm_storeu_pd(olsMatrixElement, addition);
  71. }
  72. for (; secondFeatureMean < newMeansEnd; ++secondFeatureMean) {
  73. *olsMatrixElement++ += *lastMean * *secondFeatureMean;
  74. }
  75. }
  76. #else
  77. for (; lastMean != lastMeansEnd; ++lastMean, ++newMean) {
  78. for (const double* secondFeatureMean = newMean; secondFeatureMean < newMeansEnd; ++secondFeatureMean) {
  79. *olsMatrixElement++ += *lastMean * *secondFeatureMean;
  80. }
  81. }
  82. #endif
  83. for (size_t firstFeatureNumber = 0; firstFeatureNumber < features.size(); ++firstFeatureNumber) {
  84. OLSVector[firstFeatureNumber] += weight * (features[firstFeatureNumber] - FeatureMeans[firstFeatureNumber]) * (goal - GoalsMean);
  85. }
  86. const double oldGoalsMean = GoalsMean;
  87. GoalsMean += weight * (goal - GoalsMean) / SumWeights.Get();
  88. GoalsDeviation += weight * (goal - oldGoalsMean) * (goal - GoalsMean);
  89. return true;
  90. }
  91. TLinearModel TFastLinearRegressionSolver::Solve() const {
  92. TVector<double> coefficients = ::Solve(LinearizedOLSMatrix, OLSVector);
  93. double intercept = 0.;
  94. if (!coefficients.empty()) {
  95. intercept = coefficients.back();
  96. coefficients.pop_back();
  97. }
  98. return TLinearModel(std::move(coefficients), intercept);
  99. }
  100. TLinearModel TLinearRegressionSolver::Solve() const {
  101. TVector<double> coefficients = ::Solve(LinearizedOLSMatrix, OLSVector);
  102. double intercept = GoalsMean;
  103. const size_t featuresCount = OLSVector.size();
  104. for (size_t featureNumber = 0; featureNumber < featuresCount; ++featureNumber) {
  105. intercept -= FeatureMeans[featureNumber] * coefficients[featureNumber];
  106. }
  107. return TLinearModel(std::move(coefficients), intercept);
  108. }
  109. double TFastLinearRegressionSolver::SumSquaredErrors() const {
  110. const TVector<double> coefficients = ::Solve(LinearizedOLSMatrix, OLSVector);
  111. return ::SumSquaredErrors(LinearizedOLSMatrix, OLSVector, coefficients, SumSquaredGoals.Get());
  112. }
  113. double TLinearRegressionSolver::SumSquaredErrors() const {
  114. const TVector<double> coefficients = ::Solve(LinearizedOLSMatrix, OLSVector);
  115. return ::SumSquaredErrors(LinearizedOLSMatrix, OLSVector, coefficients, GoalsDeviation);
  116. }
  117. bool TSLRSolver::Add(const double feature, const double goal, const double weight) {
  118. SumWeights += weight;
  119. if (!SumWeights.Get()) {
  120. return false;
  121. }
  122. const double weightedFeatureDiff = weight * (feature - FeaturesMean);
  123. const double weightedGoalDiff = weight * (goal - GoalsMean);
  124. FeaturesMean += weightedFeatureDiff / SumWeights.Get();
  125. FeaturesDeviation += weightedFeatureDiff * (feature - FeaturesMean);
  126. GoalsMean += weightedGoalDiff / SumWeights.Get();
  127. GoalsDeviation += weightedGoalDiff * (goal - GoalsMean);
  128. Covariation += weightedFeatureDiff * (goal - GoalsMean);
  129. return true;
  130. }
  131. bool TSLRSolver::Add(const double* featuresBegin,
  132. const double* featuresEnd,
  133. const double* goalsBegin) {
  134. for (; featuresBegin != featuresEnd; ++featuresBegin, ++goalsBegin) {
  135. Add(*featuresBegin, *goalsBegin);
  136. }
  137. return true;
  138. }
  139. bool TSLRSolver::Add(const double* featuresBegin,
  140. const double* featuresEnd,
  141. const double* goalsBegin,
  142. const double* weightsBegin) {
  143. for (; featuresBegin != featuresEnd; ++featuresBegin, ++goalsBegin, ++weightsBegin) {
  144. Add(*featuresBegin, *goalsBegin, *weightsBegin);
  145. }
  146. return true;
  147. }
  148. double TSLRSolver::SumSquaredErrors(const double regularizationParameter) const {
  149. double factor, offset;
  150. Solve(factor, offset, regularizationParameter);
  151. return factor * factor * FeaturesDeviation - 2 * factor * Covariation + GoalsDeviation;
  152. }
  153. namespace {
  154. // LDL matrix decomposition, see http://en.wikipedia.org/wiki/Cholesky_decomposition#LDL_decomposition_2
  155. bool LDLDecomposition(const TVector<double>& linearizedOLSMatrix,
  156. const double regularizationThreshold,
  157. const double regularizationParameter,
  158. TVector<double>& decompositionTrace,
  159. TVector<TVector<double>>& decompositionMatrix) {
  160. const size_t featuresCount = decompositionTrace.size();
  161. size_t olsMatrixElementIdx = 0;
  162. for (size_t rowNumber = 0; rowNumber < featuresCount; ++rowNumber) {
  163. double& decompositionTraceElement = decompositionTrace[rowNumber];
  164. decompositionTraceElement = linearizedOLSMatrix[olsMatrixElementIdx] + regularizationParameter;
  165. TVector<double>& decompositionRow = decompositionMatrix[rowNumber];
  166. for (size_t i = 0; i < rowNumber; ++i) {
  167. decompositionTraceElement -= decompositionRow[i] * decompositionRow[i] * decompositionTrace[i];
  168. }
  169. if (fabs(decompositionTraceElement) < regularizationThreshold) {
  170. return false;
  171. }
  172. ++olsMatrixElementIdx;
  173. decompositionRow[rowNumber] = 1.;
  174. for (size_t columnNumber = rowNumber + 1; columnNumber < featuresCount; ++columnNumber) {
  175. TVector<double>& secondDecompositionRow = decompositionMatrix[columnNumber];
  176. double& decompositionMatrixElement = secondDecompositionRow[rowNumber];
  177. decompositionMatrixElement = linearizedOLSMatrix[olsMatrixElementIdx];
  178. for (size_t j = 0; j < rowNumber; ++j) {
  179. decompositionMatrixElement -= decompositionRow[j] * secondDecompositionRow[j] * decompositionTrace[j];
  180. }
  181. decompositionMatrixElement /= decompositionTraceElement;
  182. decompositionRow[columnNumber] = decompositionMatrixElement;
  183. ++olsMatrixElementIdx;
  184. }
  185. }
  186. return true;
  187. }
  188. void LDLDecomposition(const TVector<double>& linearizedOLSMatrix,
  189. TVector<double>& decompositionTrace,
  190. TVector<TVector<double>>& decompositionMatrix) {
  191. const double regularizationThreshold = 1e-5;
  192. double regularizationParameter = 0.;
  193. while (!LDLDecomposition(linearizedOLSMatrix,
  194. regularizationThreshold,
  195. regularizationParameter,
  196. decompositionTrace,
  197. decompositionMatrix)) {
  198. regularizationParameter = regularizationParameter ? 2 * regularizationParameter : 1e-5;
  199. }
  200. }
  201. TVector<double> SolveLower(const TVector<TVector<double>>& decompositionMatrix,
  202. const TVector<double>& decompositionTrace,
  203. const TVector<double>& olsVector) {
  204. const size_t featuresCount = olsVector.size();
  205. TVector<double> solution(featuresCount);
  206. for (size_t featureNumber = 0; featureNumber < featuresCount; ++featureNumber) {
  207. double& solutionElement = solution[featureNumber];
  208. solutionElement = olsVector[featureNumber];
  209. const TVector<double>& decompositionRow = decompositionMatrix[featureNumber];
  210. for (size_t i = 0; i < featureNumber; ++i) {
  211. solutionElement -= solution[i] * decompositionRow[i];
  212. }
  213. }
  214. for (size_t featureNumber = 0; featureNumber < featuresCount; ++featureNumber) {
  215. solution[featureNumber] /= decompositionTrace[featureNumber];
  216. }
  217. return solution;
  218. }
  219. TVector<double> SolveUpper(const TVector<TVector<double>>& decompositionMatrix,
  220. const TVector<double>& lowerSolution) {
  221. const size_t featuresCount = lowerSolution.size();
  222. TVector<double> solution(featuresCount);
  223. for (size_t featureNumber = featuresCount; featureNumber > 0; --featureNumber) {
  224. double& solutionElement = solution[featureNumber - 1];
  225. solutionElement = lowerSolution[featureNumber - 1];
  226. const TVector<double>& decompositionRow = decompositionMatrix[featureNumber - 1];
  227. for (size_t i = featureNumber; i < featuresCount; ++i) {
  228. solutionElement -= solution[i] * decompositionRow[i];
  229. }
  230. }
  231. return solution;
  232. }
  233. TVector<double> Solve(const TVector<double>& olsMatrix, const TVector<double>& olsVector) {
  234. const size_t featuresCount = olsVector.size();
  235. TVector<double> decompositionTrace(featuresCount);
  236. TVector<TVector<double>> decompositionMatrix(featuresCount, TVector<double>(featuresCount));
  237. LDLDecomposition(olsMatrix, decompositionTrace, decompositionMatrix);
  238. return SolveUpper(decompositionMatrix, SolveLower(decompositionMatrix, decompositionTrace, olsVector));
  239. }
  240. double SumSquaredErrors(const TVector<double>& olsMatrix,
  241. const TVector<double>& olsVector,
  242. const TVector<double>& solution,
  243. const double goalsDeviation) {
  244. const size_t featuresCount = olsVector.size();
  245. double sumSquaredErrors = goalsDeviation;
  246. size_t olsMatrixElementIdx = 0;
  247. for (size_t i = 0; i < featuresCount; ++i) {
  248. sumSquaredErrors += olsMatrix[olsMatrixElementIdx] * solution[i] * solution[i];
  249. ++olsMatrixElementIdx;
  250. for (size_t j = i + 1; j < featuresCount; ++j) {
  251. sumSquaredErrors += 2 * olsMatrix[olsMatrixElementIdx] * solution[i] * solution[j];
  252. ++olsMatrixElementIdx;
  253. }
  254. sumSquaredErrors -= 2 * solution[i] * olsVector[i];
  255. }
  256. return sumSquaredErrors;
  257. }
  258. #ifdef _sse2_
  259. inline void AddFeaturesProduct(const double weight, const TVector<double>& features, TVector<double>& linearizedOLSTriangleMatrix) {
  260. const double* leftFeature = features.data();
  261. const double* featuresEnd = features.data() + features.size();
  262. double* matrixElement = linearizedOLSTriangleMatrix.data();
  263. size_t unaligned = features.size() & 0x1;
  264. for (; leftFeature != featuresEnd; ++leftFeature, ++matrixElement) {
  265. const double weightedFeature = weight * *leftFeature;
  266. const double* rightFeature = leftFeature;
  267. __m128d wf = {weightedFeature, weightedFeature};
  268. for (size_t i = 0; i < unaligned; ++i, ++rightFeature, ++matrixElement) {
  269. *matrixElement += weightedFeature * *rightFeature;
  270. }
  271. unaligned = (unaligned + 1) & 0x1;
  272. for (; rightFeature != featuresEnd; rightFeature += 2, matrixElement += 2) {
  273. __m128d rf = _mm_loadu_pd(rightFeature);
  274. __m128d matrixRow = _mm_loadu_pd(matrixElement);
  275. __m128d rowAdd = _mm_mul_pd(rf, wf);
  276. _mm_storeu_pd(matrixElement, _mm_add_pd(rowAdd, matrixRow));
  277. }
  278. *matrixElement += weightedFeature;
  279. }
  280. linearizedOLSTriangleMatrix.back() += weight;
  281. }
  282. #else
  283. inline void AddFeaturesProduct(const double weight, const TVector<double>& features, TVector<double>& linearizedTriangleMatrix) {
  284. const double* leftFeature = features.data();
  285. const double* featuresEnd = features.data() + features.size();
  286. double* matrixElement = linearizedTriangleMatrix.data();
  287. for (; leftFeature != featuresEnd; ++leftFeature, ++matrixElement) {
  288. const double weightedFeature = weight * *leftFeature;
  289. const double* rightFeature = leftFeature;
  290. for (; rightFeature != featuresEnd; ++rightFeature, ++matrixElement) {
  291. *matrixElement += weightedFeature * *rightFeature;
  292. }
  293. *matrixElement += weightedFeature;
  294. }
  295. linearizedTriangleMatrix.back() += weight;
  296. }
  297. #endif
  298. }
  299. namespace {
  300. inline double ArgMinPrecise(std::function<double(double)> func, double left, double right) {
  301. const size_t intervalsCount = 20;
  302. double points[intervalsCount + 1];
  303. double values[intervalsCount + 1];
  304. while (right > left + 1e-5) {
  305. for (size_t pointNumber = 0; pointNumber <= intervalsCount; ++pointNumber) {
  306. points[pointNumber] = left + pointNumber * (right - left) / intervalsCount;
  307. values[pointNumber] = func(points[pointNumber]);
  308. }
  309. size_t bestPointNumber = MinElement(values, values + intervalsCount + 1) - values;
  310. if (bestPointNumber == 0) {
  311. right = points[bestPointNumber + 1];
  312. continue;
  313. }
  314. if (bestPointNumber == intervalsCount) {
  315. left = points[bestPointNumber - 1];
  316. continue;
  317. }
  318. right = points[bestPointNumber + 1];
  319. left = points[bestPointNumber - 1];
  320. }
  321. return func(left) < func(right) ? left : right;
  322. }
  323. }
  324. TFeaturesTransformer TFeaturesTransformerLearner::Solve(const size_t iterationsCount /* = 100 */) {
  325. TTransformationParameters transformationParameters;
  326. auto updateParameter = [this, &transformationParameters](double TTransformationParameters::*parameter,
  327. const double left,
  328. const double right) {
  329. auto evalParameter = [this, &transformationParameters, parameter](double parameterValue) {
  330. transformationParameters.*parameter = parameterValue;
  331. TFeaturesTransformer transformer(TransformationType, transformationParameters);
  332. double sse = 0.;
  333. for (const TPoint& point : Points) {
  334. const double error = transformer.Transformation(point.Argument) - point.Target;
  335. sse += error * error;
  336. }
  337. return sse;
  338. };
  339. transformationParameters.*parameter = ArgMinPrecise(evalParameter, left, right);
  340. };
  341. auto updateRegressionParameters = [this, &transformationParameters]() {
  342. TFeaturesTransformer transformer(TransformationType, transformationParameters);
  343. TSLRSolver slrSolver;
  344. for (const TPoint& point : Points) {
  345. slrSolver.Add(transformer.Transformation(point.Argument), point.Target);
  346. }
  347. double factor, intercept;
  348. slrSolver.Solve(factor, intercept);
  349. transformationParameters.RegressionFactor *= factor;
  350. transformationParameters.RegressionIntercept *= factor;
  351. transformationParameters.RegressionIntercept += intercept;
  352. };
  353. for (size_t iterationNumber = 0; iterationNumber < iterationsCount; ++iterationNumber) {
  354. updateParameter(&TTransformationParameters::FeatureOffset, MinimalArgument, MaximalArgument);
  355. updateParameter(&TTransformationParameters::FeatureNormalizer, 0., MaximalArgument - MinimalArgument);
  356. updateRegressionParameters();
  357. }
  358. return TFeaturesTransformer(TransformationType, transformationParameters);
  359. }