sample.h 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #pragma once
  2. #include <library/cpp/deprecated/accessors/accessors.h>
  3. #include <util/generic/buffer.h>
  4. #include <util/generic/vector.h>
  5. #include <util/random/fast.h>
  6. #include <util/random/shuffle.h>
  7. #include <functional>
  8. #include <type_traits>
  9. namespace NCodecs {
  10. class ISequenceReader {
  11. public:
  12. virtual bool NextRegion(TStringBuf& s) = 0;
  13. virtual ~ISequenceReader() = default;
  14. };
  15. template <class TValue>
  16. TStringBuf ValueToStringBuf(TValue&& t) {
  17. return TStringBuf{NAccessors::Begin(t), NAccessors::End(t)};
  18. }
  19. template <class TIter>
  20. TStringBuf IterToStringBuf(TIter iter) {
  21. return ValueToStringBuf(*iter);
  22. }
  23. template <class TItem>
  24. class TSimpleSequenceReader: public ISequenceReader {
  25. const TVector<TItem>& Items;
  26. size_t Idx = 0;
  27. public:
  28. TSimpleSequenceReader(const TVector<TItem>& items)
  29. : Items(items)
  30. {
  31. }
  32. bool NextRegion(TStringBuf& s) override {
  33. if (Idx >= Items.size()) {
  34. return false;
  35. }
  36. s = ValueToStringBuf(Items[Idx++]);
  37. return true;
  38. }
  39. };
  40. template <class TIter, class TGetter>
  41. size_t GetInputSize(TIter begin, TIter end, TGetter getter) {
  42. size_t totalBytes = 0;
  43. for (TIter iter = begin; iter != end; ++iter) {
  44. totalBytes += getter(iter).size();
  45. }
  46. return totalBytes;
  47. }
  48. template <class TIter>
  49. size_t GetInputSize(TIter begin, TIter end) {
  50. return GetInputSize(begin, end, IterToStringBuf<TIter>);
  51. }
  52. template <class TIter, class TGetter>
  53. TVector<TBuffer> GetSample(TIter begin, TIter end, size_t sampleSizeBytes, TGetter getter) {
  54. TFastRng64 rng{0x1ce1f2e507541a05, 0x07d45659, 0x7b8771030dd9917e, 0x2d6636ce};
  55. size_t totalBytes = GetInputSize(begin, end, getter);
  56. double sampleProb = (double)sampleSizeBytes / Max<size_t>(1, totalBytes);
  57. TVector<TBuffer> result;
  58. for (TIter iter = begin; iter != end; ++iter) {
  59. if (sampleProb >= 1 || rng.GenRandReal1() < sampleProb) {
  60. TStringBuf reg = getter(iter);
  61. result.emplace_back(reg.data(), reg.size());
  62. }
  63. }
  64. Shuffle(result.begin(), result.end(), rng);
  65. return result;
  66. }
  67. template <class TIter>
  68. TVector<TBuffer> GetSample(TIter begin, TIter end, size_t sampleSizeBytes) {
  69. return GetSample(begin, end, sampleSizeBytes, IterToStringBuf<TIter>);
  70. }
  71. }