bin_saver.h 21 KB


  1. #pragma once
  2. #include "buffered_io.h"
  3. #include "class_factory.h"
  4. #include <library/cpp/containers/2d_array/2d_array.h>
  5. #include <util/generic/hash_set.h>
  6. #include <util/generic/buffer.h>
  7. #include <util/generic/list.h>
  8. #include <util/generic/maybe.h>
  9. #include <util/generic/bitmap.h>
  10. #include <util/generic/variant.h>
  11. #include <util/generic/ylimits.h>
  12. #include <util/memory/blob.h>
  13. #include <util/digest/murmur.h>
  14. #include <array>
  15. #include <bitset>
  16. #include <list>
  17. #include <string>
  18. #ifdef _MSC_VER
  19. #pragma warning(disable : 4127)
  20. #endif
  21. enum ESaverMode {
  22. SAVER_MODE_READ = 1,
  23. SAVER_MODE_WRITE = 2,
  24. SAVER_MODE_WRITE_COMPRESSED = 3,
  25. };
  26. namespace NBinSaverInternals {
  27. // This lets explicitly control the overload resolution priority
  28. // The higher P means higher priority in overload resolution order
  29. template <int P>
  30. struct TOverloadPriority : TOverloadPriority <P-1> {
  31. };
  32. template <>
  33. struct TOverloadPriority<0> {
  34. };
  35. }
  36. //////////////////////////////////////////////////////////////////////////
  37. struct IBinSaver {
  38. public:
  39. typedef unsigned char chunk_id;
  40. typedef ui32 TStoredSize; // changing this will break compatibility
  41. private:
  42. // This overload is required to avoid infinite recursion when overriding serialization in derived classes:
  43. // struct B {
  44. // virtual int operator &(IBinSaver& f) {
  45. // return 0;
  46. // }
  47. // };
  48. //
  49. // struct D : B {
  50. // int operator &(IBinSaver& f) override {
  51. // f.Add(0, static_cast<B*>(this));
  52. // return 0;
  53. // }
  54. // };
  55. template <class T, typename = decltype(std::declval<T*>()->T::operator&(std::declval<IBinSaver&>()))>
  56. void CallObjectSerialize(T* p, NBinSaverInternals::TOverloadPriority<2>) { // highest priority - will be resolved first if enabled
  57. // Note: p->operator &(*this) would lead to infinite recursion
  58. p->T::operator&(*this);
  59. }
  60. template <class T, typename = decltype(std::declval<T&>() & std::declval<IBinSaver&>())>
  61. void CallObjectSerialize(T* p, NBinSaverInternals::TOverloadPriority<1>) { // lower priority - will be resolved second if enabled
  62. (*p) & (*this);
  63. }
  64. template <class T>
  65. void CallObjectSerialize(T* p, NBinSaverInternals::TOverloadPriority<0>) { // lower priority - will be resolved last
  66. #if (!defined(_MSC_VER))
  67. // In MSVC __has_trivial_copy returns false to enums, primitive types and arrays.
  68. static_assert(__has_trivial_copy(T), "Class is nontrivial copyable, you must define operator&, see");
  69. #endif
  70. DataChunk(p, sizeof(T));
  71. }
  72. // vector
  73. template <class T, class TA>
  74. void DoVector(TVector<T, TA>& data) {
  75. TStoredSize nSize;
  76. if (IsReading()) {
  77. data.clear();
  78. Add(2, &nSize);
  79. data.resize(nSize);
  80. } else {
  81. nSize = data.size();
  82. CheckOverflow(nSize, data.size());
  83. Add(2, &nSize);
  84. }
  85. for (TStoredSize i = 0; i < nSize; i++)
  86. Add(1, &data[i]);
  87. }
  88. template <class T, int N>
  89. void DoArray(T (&data)[N]) {
  90. for (size_t i = 0; i < N; i++) {
  91. Add(1, &(data[i]));
  92. }
  93. }
  94. template <typename TLarge>
  95. void CheckOverflow(TStoredSize nSize, TLarge origSize) {
  96. if (nSize != origSize) {
  97. fprintf(stderr, "IBinSaver: object size is too large to be serialized (%" PRIu32 " != %" PRIu64 ")\n", nSize, (ui64)origSize);
  98. abort();
  99. }
  100. }
  101. template <class T, class TA>
  102. void DoDataVector(TVector<T, TA>& data) {
  103. TStoredSize nSize = data.size();
  104. CheckOverflow(nSize, data.size());
  105. Add(1, &nSize);
  106. if (IsReading()) {
  107. data.clear();
  108. data.resize(nSize);
  109. }
  110. if (nSize > 0)
  111. DataChunk(&data[0], sizeof(T) * nSize);
  112. }
  113. template <class AM>
  114. void DoAnyMap(AM& data) {
  115. if (IsReading()) {
  116. data.clear();
  117. TStoredSize nSize;
  118. Add(3, &nSize);
  119. TVector<typename AM::key_type, typename std::allocator_traits<typename AM::allocator_type>::template rebind_alloc<typename AM::key_type>> indices;
  120. indices.resize(nSize);
  121. for (TStoredSize i = 0; i < nSize; ++i)
  122. Add(1, &indices[i]);
  123. for (TStoredSize i = 0; i < nSize; ++i)
  124. Add(2, &data[indices[i]]);
  125. } else {
  126. TStoredSize nSize = data.size();
  127. CheckOverflow(nSize, data.size());
  128. Add(3, &nSize);
  129. TVector<typename AM::key_type, typename std::allocator_traits<typename AM::allocator_type>::template rebind_alloc<typename AM::key_type>> indices;
  130. indices.resize(nSize);
  131. TStoredSize i = 1;
  132. for (auto pos = data.begin(); pos != data.end(); ++pos, ++i)
  133. indices[nSize - i] = pos->first;
  134. for (TStoredSize j = 0; j < nSize; ++j)
  135. Add(1, &indices[j]);
  136. for (TStoredSize j = 0; j < nSize; ++j)
  137. Add(2, &data[indices[j]]);
  138. }
  139. }
  140. // hash_multimap
  141. template <class AMM>
  142. void DoAnyMultiMap(AMM& data) {
  143. if (IsReading()) {
  144. data.clear();
  145. TStoredSize nSize;
  146. Add(3, &nSize);
  147. TVector<typename AMM::key_type, typename std::allocator_traits<typename AMM::allocator_type>::template rebind_alloc<typename AMM::key_type>> indices;
  148. indices.resize(nSize);
  149. for (TStoredSize i = 0; i < nSize; ++i)
  150. Add(1, &indices[i]);
  151. for (TStoredSize i = 0; i < nSize; ++i) {
  152. std::pair<typename AMM::key_type, typename AMM::mapped_type> valToInsert;
  153. valToInsert.first = indices[i];
  154. Add(2, &valToInsert.second);
  155. data.insert(valToInsert);
  156. }
  157. } else {
  158. TStoredSize nSize = data.size();
  159. CheckOverflow(nSize, data.size());
  160. Add(3, &nSize);
  161. for (auto pos = data.begin(); pos != data.end(); ++pos)
  162. Add(1, (typename AMM::key_type*)(&pos->first));
  163. for (auto pos = data.begin(); pos != data.end(); ++pos)
  164. Add(2, &pos->second);
  165. }
  166. }
  167. template <class T>
  168. void DoAnySet(T& data) {
  169. if (IsReading()) {
  170. data.clear();
  171. TStoredSize nSize;
  172. Add(2, &nSize);
  173. for (TStoredSize i = 0; i < nSize; ++i) {
  174. typename T::value_type member;
  175. Add(1, &member);
  176. data.insert(member);
  177. }
  178. } else {
  179. TStoredSize nSize = data.size();
  180. CheckOverflow(nSize, data.size());
  181. Add(2, &nSize);
  182. for (const auto& elem : data) {
  183. auto member = elem;
  184. Add(1, &member);
  185. }
  186. }
  187. }
  188. // 2D array
  189. template <class T>
  190. void Do2DArray(TArray2D<T>& a) {
  191. int nXSize = a.GetXSize(), nYSize = a.GetYSize();
  192. Add(1, &nXSize);
  193. Add(2, &nYSize);
  194. if (IsReading())
  195. a.SetSizes(nXSize, nYSize);
  196. for (int i = 0; i < nXSize * nYSize; i++)
  197. Add(3, &a[i / nXSize][i % nXSize]);
  198. }
  199. template <class T>
  200. void Do2DArrayData(TArray2D<T>& a) {
  201. int nXSize = a.GetXSize(), nYSize = a.GetYSize();
  202. Add(1, &nXSize);
  203. Add(2, &nYSize);
  204. if (IsReading())
  205. a.SetSizes(nXSize, nYSize);
  206. if (nXSize * nYSize > 0)
  207. DataChunk(&a[0][0], sizeof(T) * nXSize * nYSize);
  208. }
  209. // strings
  210. template <class TStringType>
  211. void DataChunkStr(TStringType& data, i64 elemSize) {
  212. if (bRead) {
  213. TStoredSize nCount = 0;
  214. File.Read(&nCount, sizeof(TStoredSize));
  215. data.resize(nCount);
  216. if (nCount)
  217. File.Read(&*data.begin(), nCount * elemSize);
  218. } else {
  219. TStoredSize nCount = data.size();
  220. CheckOverflow(nCount, data.size());
  221. File.Write(&nCount, sizeof(TStoredSize));
  222. File.Write(data.c_str(), nCount * elemSize);
  223. }
  224. }
  225. void DataChunkString(std::string& data) {
  226. DataChunkStr(data, sizeof(char));
  227. }
  228. void DataChunkStroka(TString& data) {
  229. DataChunkStr(data, sizeof(TString::char_type));
  230. }
  231. void DataChunkWtroka(TUtf16String& data) {
  232. DataChunkStr(data, sizeof(wchar16));
  233. }
  234. void DataChunk(void* pData, i64 nSize) {
  235. i64 chunkSize = 1 << 30;
  236. for (i64 offset = 0; offset < nSize; offset += chunkSize) {
  237. void* ptr = (char*)pData + offset;
  238. i64 size = offset + chunkSize < nSize ? chunkSize : (nSize - offset);
  239. if (bRead)
  240. File.Read(ptr, size);
  241. else
  242. File.Write(ptr, size);
  243. }
  244. }
  245. // storing/loading pointers to objects
  246. void StoreObject(IObjectBase* pObject);
  247. IObjectBase* LoadObject();
  248. bool bRead;
  249. TBufferedStream<> File;
  250. // maps objects addresses during save(first) to addresses during load(second) - during loading
  251. // or serves as a sign that some object has been already stored - during storing
  252. bool StableOutput;
  253. typedef THashMap<void*, ui32> PtrIdHash;
  254. TAutoPtr<PtrIdHash> PtrIds;
  255. typedef THashMap<ui64, TPtr<IObjectBase>> CObjectsHash;
  256. TAutoPtr<CObjectsHash> Objects;
  257. TVector<IObjectBase*> ObjectQueue;
  258. public:
  259. bool IsReading() {
  260. return bRead;
  261. }
  262. void AddRawData(const chunk_id, void* pData, i64 nSize) {
  263. DataChunk(pData, nSize);
  264. }
  265. // return type of Add() is used to detect specialized serializer (see HasNonTrivialSerializer below)
  266. template <class T>
  267. char Add(const chunk_id, T* p) {
  268. CallObjectSerialize(p, NBinSaverInternals::TOverloadPriority<2>());
  269. return 0;
  270. }
  271. int Add(const chunk_id, std::string* pStr) {
  272. DataChunkString(*pStr);
  273. return 0;
  274. }
  275. int Add(const chunk_id, TString* pStr) {
  276. DataChunkStroka(*pStr);
  277. return 0;
  278. }
  279. int Add(const chunk_id, TUtf16String* pStr) {
  280. DataChunkWtroka(*pStr);
  281. return 0;
  282. }
  283. int Add(const chunk_id, TBlob* blob) {
  284. if (bRead) {
  285. ui64 size = 0;
  286. File.Read(&size, sizeof(size));
  287. TBuffer buffer;
  288. buffer.Advance(size);
  289. if (size > 0)
  290. File.Read(buffer.Data(), buffer.Size());
  291. (*blob) = TBlob::FromBuffer(buffer);
  292. } else {
  293. const ui64 size = blob->Size();
  294. File.Write(&size, sizeof(size));
  295. File.Write(blob->Data(), blob->Size());
  296. }
  297. return 0;
  298. }
  299. template <class T1, class TA>
  300. int Add(const chunk_id, TVector<T1, TA>* pVec) {
  301. if (HasNonTrivialSerializer<T1>(0u))
  302. DoVector(*pVec);
  303. else
  304. DoDataVector(*pVec);
  305. return 0;
  306. }
  307. template <class T, int N>
  308. int Add(const chunk_id, T (*pVec)[N]) {
  309. if (HasNonTrivialSerializer<T>(0u))
  310. DoArray(*pVec);
  311. else
  312. DataChunk(pVec, sizeof(*pVec));
  313. return 0;
  314. }
  315. template <class T1, class T2, class T3, class T4>
  316. int Add(const chunk_id, TMap<T1, T2, T3, T4>* pMap) {
  317. DoAnyMap(*pMap);
  318. return 0;
  319. }
  320. template <class T1, class T2, class T3, class T4, class T5>
  321. int Add(const chunk_id, THashMap<T1, T2, T3, T4, T5>* pHash) {
  322. DoAnyMap(*pHash);
  323. return 0;
  324. }
  325. template <class T1, class T2, class T3, class T4, class T5>
  326. int Add(const chunk_id, THashMultiMap<T1, T2, T3, T4, T5>* pHash) {
  327. DoAnyMultiMap(*pHash);
  328. return 0;
  329. }
  330. template <class K, class L, class A>
  331. int Add(const chunk_id, TSet<K, L, A>* pSet) {
  332. DoAnySet(*pSet);
  333. return 0;
  334. }
  335. template <class T1, class T2, class T3, class T4>
  336. int Add(const chunk_id, THashSet<T1, T2, T3, T4>* pHash) {
  337. DoAnySet(*pHash);
  338. return 0;
  339. }
  340. template <class T1>
  341. int Add(const chunk_id, TArray2D<T1>* pArr) {
  342. if (HasNonTrivialSerializer<T1>(0u))
  343. Do2DArray(*pArr);
  344. else
  345. Do2DArrayData(*pArr);
  346. return 0;
  347. }
  348. template <class T1>
  349. int Add(const chunk_id, TList<T1>* pList) {
  350. TList<T1>& data = *pList;
  351. if (IsReading()) {
  352. int nSize;
  353. Add(2, &nSize);
  354. data.clear();
  355. data.insert(data.begin(), nSize, T1());
  356. } else {
  357. int nSize = data.size();
  358. Add(2, &nSize);
  359. }
  360. int i = 1;
  361. for (typename TList<T1>::iterator k = data.begin(); k != data.end(); ++k, ++i)
  362. Add(i + 2, &(*k));
  363. return 0;
  364. }
  365. template <class T1, class T2>
  366. int Add(const chunk_id, std::pair<T1, T2>* pData) {
  367. Add(1, &(pData->first));
  368. Add(2, &(pData->second));
  369. return 0;
  370. }
  371. template <class T1, size_t N>
  372. int Add(const chunk_id, std::array<T1, N>* pData) {
  373. if (HasNonTrivialSerializer<T1>(0u)) {
  374. for (size_t i = 0; i < N; ++i)
  375. Add(1, &(*pData)[i]);
  376. } else {
  377. DataChunk((void*)pData->data(), pData->size() * sizeof(T1));
  378. }
  379. return 0;
  380. }
  381. template <size_t N>
  382. int Add(const chunk_id, std::bitset<N>* pData) {
  383. if (IsReading()) {
  384. std::string s;
  385. Add(1, &s);
  386. *pData = std::bitset<N>(s);
  387. } else {
  388. std::string s = pData->template to_string<char, std::char_traits<char>, std::allocator<char>>();
  389. Add(1, &s);
  390. }
  391. return 0;
  392. }
  393. int Add(const chunk_id, TDynBitMap* pData) {
  394. if (IsReading()) {
  395. ui64 count = 0;
  396. Add(1, &count);
  397. pData->Clear();
  398. pData->Reserve(count * sizeof(TDynBitMap::TChunk) * 8);
  399. for (ui64 i = 0; i < count; ++i) {
  400. TDynBitMap::TChunk chunk = 0;
  401. Add(i + 1, &chunk);
  402. if (i > 0) {
  403. pData->LShift(8 * sizeof(TDynBitMap::TChunk));
  404. }
  405. pData->Or(chunk);
  406. }
  407. } else {
  408. ui64 count = pData->GetChunkCount();
  409. Add(1, &count);
  410. for (ui64 i = 0; i < count; ++i) {
  411. // Write in reverse order
  412. TDynBitMap::TChunk chunk = pData->GetChunks()[count - i - 1];
  413. Add(i + 1, &chunk);
  414. }
  415. }
  416. return 0;
  417. }
  418. template <class TVariantClass>
  419. struct TLoadFromTypeFromListHelper {
  420. template <class T0, class... TTail>
  421. static void Do(IBinSaver& binSaver, ui32 typeIndex, TVariantClass* pData) {
  422. if constexpr (sizeof...(TTail) == 0) {
  423. Y_ASSERT(typeIndex == 0);
  424. T0 chunk;
  425. binSaver.Add(2, &chunk);
  426. *pData = std::move(chunk);
  427. } else {
  428. if (typeIndex == 0) {
  429. Do<T0>(binSaver, 0, pData);
  430. } else {
  431. Do<TTail...>(binSaver, typeIndex - 1, pData);
  432. }
  433. }
  434. }
  435. };
  436. template <class... TVariantTypes>
  437. int Add(const chunk_id, std::variant<TVariantTypes...>* pData) {
  438. static_assert(std::variant_size_v<std::variant<TVariantTypes...>> < Max<ui32>());
  439. ui32 index;
  440. if (IsReading()) {
  441. Add(1, &index);
  442. TLoadFromTypeFromListHelper<std::variant<TVariantTypes...>>::template Do<TVariantTypes...>(
  443. *this,
  444. index,
  445. pData
  446. );
  447. } else {
  448. index = pData->index(); // type cast is safe because of static_assert check above
  449. Add(1, &index);
  450. std::visit([&](auto& dst) -> void { Add(2, &dst); }, *pData);
  451. }
  452. return 0;
  453. }
  454. void AddPolymorphicBase(chunk_id, IObjectBase* pObject) {
  455. (*pObject) & (*this);
  456. }
  457. template <class T1, class T2>
  458. void DoPtr(TPtrBase<T1, T2>* pData) {
  459. if (pData && pData->Get()) {
  460. }
  461. if (IsReading())
  462. pData->Set(CastToUserObject(LoadObject(), (T1*)nullptr));
  463. else
  464. StoreObject(pData->GetBarePtr());
  465. }
  466. template <class T, class TPolicy>
  467. int Add(const chunk_id, TMaybe<T, TPolicy>* pData) {
  468. TMaybe<T, TPolicy>& data = *pData;
  469. if (IsReading()) {
  470. bool defined = false;
  471. Add(1, &defined);
  472. if (defined) {
  473. data = T();
  474. Add(2, data.Get());
  475. }
  476. } else {
  477. bool defined = data.Defined();
  478. Add(1, &defined);
  479. if (defined) {
  480. Add(2, data.Get());
  481. }
  482. }
  483. return 0;
  484. }
  485. template <typename TOne>
  486. void AddMulti(TOne& one) {
  487. Add(0, &one);
  488. }
  489. template <typename THead, typename... TTail>
  490. void AddMulti(THead& head, TTail&... tail) {
  491. Add(0, &head);
  492. AddMulti(tail...);
  493. }
  494. template <class T, typename = decltype(std::declval<T&>() & std::declval<IBinSaver&>())>
  495. static bool HasNonTrivialSerializer(ui32) {
  496. return true;
  497. }
  498. template <class T>
  499. static bool HasNonTrivialSerializer(...) {
  500. return sizeof(std::declval<IBinSaver*>()->Add(0, std::declval<T*>())) != 1;
  501. }
  502. public:
  503. IBinSaver(IBinaryStream& stream, bool _bRead, bool stableOutput = false)
  504. : bRead(_bRead)
  505. , File(_bRead, stream)
  506. , StableOutput(stableOutput)
  507. {
  508. }
  509. virtual ~IBinSaver();
  510. bool IsValid() const {
  511. return File.IsValid();
  512. }
  513. };
  514. // realisation of forward declared serialisation operator
  515. template <class TUserObj, class TRef>
  516. int TPtrBase<TUserObj, TRef>::operator&(IBinSaver& f) {
  517. f.DoPtr(this);
  518. return 0;
  519. }
  520. ////////////////////////////////////////////////////////////////////////////////////////////////////
  521. extern TClassFactory<IObjectBase>* pSaverClasses;
  522. void StartRegisterSaveload();
  523. template <class TReg>
  524. struct TRegisterSaveLoadType {
  525. TRegisterSaveLoadType(int num) {
  526. StartRegisterSaveload();
  527. pSaverClasses->RegisterType(num, TReg::NewSaveLoadNullItem, (TReg*)nullptr);
  528. }
  529. };
  530. #define Y_BINSAVER_REGISTER(name) \
  531. BASIC_REGISTER_CLASS(name) \
  532. static TRegisterSaveLoadType<name> init##name(MurmurHash<int>(#name, sizeof(#name)));
  533. #define REGISTER_SAVELOAD_CLASS(N, name) \
  534. BASIC_REGISTER_CLASS(name) \
  535. static TRegisterSaveLoadType<name> init##name##N(N);
  536. // using TObj/TRef on forward declared templ class will not work
  537. // but multiple registration with same id is allowed
  538. #define REGISTER_SAVELOAD_TEMPL1_CLASS(N, className, T) \
  539. static TRegisterSaveLoadType<className<T>> init##className##T##N(N);
  540. #define REGISTER_SAVELOAD_TEMPL2_CLASS(N, className, T1, T2) \
  541. typedef className<T1, T2> temp##className##T1##_##T2##temp; \
  542. static TRegisterSaveLoadType<className<T1, T2>> init##className##T1##_##T2##N(N);
  543. #define REGISTER_SAVELOAD_TEMPL3_CLASS(N, className, T1, T2, T3) \
  544. typedef className<T1, T2, T3> temp##className##T1##_##T2##_##T3##temp; \
  545. static TRegisterSaveLoadType<className<T1, T2, T3>> init##className##T1##_##T2##_##T3##N(N);
  546. #define REGISTER_SAVELOAD_NM_CLASS(N, nmspace, className) \
  547. BASIC_REGISTER_CLASS(nmspace::className) \
  548. static TRegisterSaveLoadType<nmspace::className> init_##nmspace##_##name##N(N);
  549. #define REGISTER_SAVELOAD_NM2_CLASS(N, nmspace1, nmspace2, className) \
  550. BASIC_REGISTER_CLASS(nmspace1::nmspace2::className) \
  551. static TRegisterSaveLoadType<nmspace1::nmspace2::className> init_##nmspace1##_##nmspace2##_##name##N(N);
  552. #define REGISTER_SAVELOAD_TEMPL1_NM_CLASS(N, nmspace, className, T) \
  553. typedef nmspace::className<T> temp_init##nmspace##className##T##temp; \
  554. BASIC_REGISTER_CLASS(nmspace::className<T>) \
  555. static TRegisterSaveLoadType<nmspace::className<T>> temp_init##nmspace##_##name##T##N(N);
  556. #define REGISTER_SAVELOAD_CLASS_NAME(N, cls, name) \
  557. BASIC_REGISTER_CLASS(cls) \
  558. static TRegisterSaveLoadType<cls> init##name##N(N);
  559. #define REGISTER_SAVELOAD_CLASS_NS_PREF(N, cls, ns, pref) \
  560. REGISTER_SAVELOAD_CLASS_NAME(N, ns ::cls, _##pref##_##cls)
  561. #define SAVELOAD(...) \
  562. int operator&(IBinSaver& f) { \
  563. f.AddMulti(__VA_ARGS__); \
  564. return 0; \
  565. }
  566. #define SAVELOAD_OVERRIDE_WITHOUT_BASE(...) \
  567. int operator&(IBinSaver& f) override { \
  568. f.AddMulti(__VA_ARGS__); \
  569. return 0; \
  570. }
  571. #define SAVELOAD_OVERRIDE(base, ...) \
  572. int operator&(IBinSaver& f) override { \
  573. base::operator&(f); \
  574. f.AddMulti(__VA_ARGS__); \
  575. return 0; \
  576. }
  577. #define SAVELOAD_BASE(...) \
  578. int operator&(IBinSaver& f) { \
  579. TBase::operator&(f); \
  580. f.AddMulti(__VA_ARGS__); \
  581. return 0; \
  582. }