pool.h 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. #pragma once
  2. #include <util/generic/vector.h>
  3. #include <util/generic/string.h>
  4. #include <util/random/mersenne.h>
  5. #include <util/random/shuffle.h>
  6. struct TInstance {
  7. TVector<double> Features;
  8. double Goal;
  9. double Weight;
  10. static TInstance FromFeaturesString(const TString& featuresString);
  11. };
  12. struct TPool: public TVector<TInstance> {
  13. enum EIteratorType {
  14. LearnIterator,
  15. TestIterator,
  16. };
  17. class TCVIterator {
  18. private:
  19. const TPool& ParentPool;
  20. size_t FoldsCount;
  21. EIteratorType IteratorType;
  22. size_t TestFoldNumber;
  23. TVector<size_t> InstanceFoldNumbers;
  24. const size_t* Current;
  25. TMersenne<ui64> RandomGenerator;
  26. public:
  27. TCVIterator(const TPool& parentPool,
  28. const size_t foldsCount,
  29. const EIteratorType iteratorType);
  30. void ResetShuffle();
  31. void SetTestFold(const size_t testFoldNumber);
  32. bool IsValid() const;
  33. const TInstance& operator*() const;
  34. const TInstance* operator->() const;
  35. TPool::TCVIterator& operator++();
  36. private:
  37. void Advance();
  38. bool TakeCurrent() const;
  39. };
  40. void ReadFromFeatures(const TString& featuresPath);
  41. TCVIterator CrossValidationIterator(const size_t foldsCount, const EIteratorType iteratorType) const;
  42. TPool InjurePool(const double injureFactir, const double injureOffset) const;
  43. };