mkql_match_recognize_save_load.h 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #pragma once
  2. #include <yql/essentials/minikql/computation/mkql_computation_node.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  4. #include <yql/essentials/minikql/comp_nodes/mkql_saveload.h>
  5. #include <yql/essentials/minikql/mkql_string_util.h>
  6. namespace NKikimr::NMiniKQL::NMatchRecognize {
  7. struct TSerializerContext {
  8. TSerializerContext(
  9. TComputationContext& ctx,
  10. TType* rowType,
  11. const TMutableObjectOverBoxedValue<TValuePackerBoxed>& rowPacker)
  12. : Ctx(ctx)
  13. , RowType(rowType)
  14. , RowPacker(rowPacker)
  15. {}
  16. TComputationContext& Ctx;
  17. TType* RowType;
  18. const TMutableObjectOverBoxedValue<TValuePackerBoxed>& RowPacker;
  19. };
  20. template<class>
  21. inline constexpr bool always_false_v = false;
  22. struct TMrOutputSerializer : TOutputSerializer {
  23. private:
  24. enum class TPtrStateMode {
  25. Saved = 0,
  26. FromCache = 1
  27. };
  28. public:
  29. TMrOutputSerializer(const TSerializerContext& context, EMkqlStateType stateType, ui32 stateVersion, TComputationContext& ctx)
  30. : TOutputSerializer(stateType, stateVersion, ctx)
  31. , Context(context)
  32. {}
  33. using TOutputSerializer::Write;
  34. template <typename... Ts>
  35. void operator()(Ts&&... args) {
  36. (Write(std::forward<Ts>(args)), ...);
  37. }
  38. void Write(const NUdf::TUnboxedValue& value) {
  39. WriteUnboxedValue(Context.RowPacker.RefMutableObject(Context.Ctx, false, Context.RowType), value);
  40. }
  41. template<class Type>
  42. void Write(const TIntrusivePtr<Type>& ptr) {
  43. bool isValid = static_cast<bool>(ptr);
  44. WriteBool(Buf, isValid);
  45. if (!isValid) {
  46. return;
  47. }
  48. auto addr = reinterpret_cast<std::uintptr_t>(ptr.Get());
  49. WriteUi64(Buf, addr);
  50. auto it = Cache.find(addr);
  51. if (it != Cache.end()) {
  52. WriteByte(Buf, static_cast<ui8>(TPtrStateMode::FromCache));
  53. return;
  54. }
  55. WriteByte(Buf, static_cast<ui8>(TPtrStateMode::Saved));
  56. ptr->Save(*this);
  57. Cache[addr] = addr;
  58. }
  59. private:
  60. const TSerializerContext& Context;
  61. mutable std::map<std::uintptr_t, std::uintptr_t> Cache;
  62. };
  63. struct TMrInputSerializer : TInputSerializer {
  64. private:
  65. enum class TPtrStateMode {
  66. Saved = 0,
  67. FromCache = 1
  68. };
  69. public:
  70. TMrInputSerializer(TSerializerContext& context, const NUdf::TUnboxedValue& state)
  71. : TInputSerializer(state, EMkqlStateType::SIMPLE_BLOB)
  72. , Context(context) {
  73. }
  74. using TInputSerializer::Read;
  75. template <typename... Ts>
  76. void operator()(Ts&... args) {
  77. (Read(args), ...);
  78. }
  79. void Read(NUdf::TUnboxedValue& value) {
  80. value = ReadUnboxedValue(Context.RowPacker.RefMutableObject(Context.Ctx, false, Context.RowType), Context.Ctx);
  81. }
  82. template<class Type>
  83. void Read(TIntrusivePtr<Type>& ptr) {
  84. bool isValid = Read<bool>();
  85. if (!isValid) {
  86. ptr.Reset();
  87. return;
  88. }
  89. ui64 addr = Read<ui64>();
  90. TPtrStateMode mode = static_cast<TPtrStateMode>(Read<ui8>());
  91. if (mode == TPtrStateMode::Saved) {
  92. ptr = MakeIntrusive<Type>();
  93. ptr->Load(*this);
  94. Cache[addr] = ptr.Get();
  95. return;
  96. }
  97. auto it = Cache.find(addr);
  98. MKQL_ENSURE(it != Cache.end(), "Internal error");
  99. auto* cachePtr = static_cast<Type*>(it->second);
  100. ptr = TIntrusivePtr<Type>(cachePtr);
  101. }
  102. private:
  103. TSerializerContext& Context;
  104. mutable std::map<std::uintptr_t, void *> Cache;
  105. };
  106. } //namespace NKikimr::NMiniKQL::NMatchRecognize