linear_model.h 882 B

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #pragma once
  2. #include <util/generic/algorithm.h>
  3. #include <util/generic/vector.h>
  4. #include <util/ysaveload.h>
  5. #include <utility>
  6. class TLinearModel {
  7. private:
  8. TVector<double> Coefficients;
  9. double Intercept;
  10. public:
  11. Y_SAVELOAD_DEFINE(Coefficients, Intercept);
  12. TLinearModel(TVector<double>&& coefficients, const double intercept)
  13. : Coefficients(std::move(coefficients))
  14. , Intercept(intercept)
  15. {
  16. }
  17. explicit TLinearModel(size_t featuresCount = 0)
  18. : Coefficients(featuresCount)
  19. , Intercept(0.)
  20. {
  21. }
  22. const TVector<double>& GetCoefficients() const {
  23. return Coefficients;
  24. }
  25. double GetIntercept() const {
  26. return Intercept;
  27. }
  28. template <typename T>
  29. double Prediction(const TVector<T>& features) const {
  30. return InnerProduct(Coefficients, features, Intercept);
  31. }
  32. };