arrow.h 54 KB


  1. #pragma once
  2. #include <yql/essentials/public/udf/arrow/block_reader.h>
  3. #include <yql/essentials/public/udf/arrow/block_builder.cpp>
  4. #include <arrow/compute/kernel.h>
  5. #include <yql/essentials/minikql/mkql_type_builder.h>
  6. #include <yql/essentials/minikql/arrow/arrow_util.h>
  7. #include <yql/essentials/parser/pg_catalog/catalog.h>
  8. #include <yql/essentials/minikql/comp_nodes/mkql_block_agg_factory.h>
  9. #include "arena_ctx.h"
  10. #include <functional>
  11. extern "C" {
  12. #include "postgres.h"
  13. #include "fmgr.h"
  14. #include "varatt.h"
  15. #include "catalog/pg_type_d.h"
  16. #include "catalog/pg_collation_d.h"
  17. }
  18. #include "utils.h"
  19. namespace NYql {
  20. struct TPgKernelState : arrow::compute::KernelState {
  21. FmgrInfo flinfo; /* lookup info used for this call */
  22. fmNodePtr context; /* pass info about context of call */
  23. fmNodePtr resultinfo; /* pass or return extra info about result */
  24. Oid fncollation; /* collation for function to use */
  25. TString Name;
  26. std::vector<bool> IsFixedArg;
  27. bool IsFixedResult;
  28. i32 TypeLen;
  29. std::shared_ptr<void> FmgrDataHolder;
  30. const NPg::TProcDesc* ProcDesc;
  31. };
  32. template <PGFunction PgFunc>
  33. struct TPgDirectFunc {
  34. Datum operator()(FunctionCallInfo info) const {
  35. return PgFunc(info);
  36. }
  37. };
  38. struct TPgIndirectFunc {
  39. TPgIndirectFunc(PGFunction pgFunc)
  40. : PgFunc(pgFunc)
  41. {}
  42. Datum operator()(FunctionCallInfo info) const {
  43. return PgFunc(info);
  44. }
  45. PGFunction PgFunc;
  46. };
  47. template <bool IsFixed>
  48. Datum CloneDatumToAggContext(Datum src, i32 typeLen) {
  49. if constexpr (IsFixed) {
  50. return src;
  51. } else {
  52. Y_ENSURE(NKikimr::NMiniKQL::TlsAllocState->CurrentContext);
  53. ui32 len;
  54. if (typeLen == -1) {
  55. len = GetFullVarSize((const text*)src);
  56. } else if (typeLen == -2) {
  57. len = 1 + strlen((const char*)src);
  58. } else {
  59. len = typeLen;
  60. }
  61. auto ret = (Datum)palloc(len);
  62. memcpy((void*)ret, (const void*)src, len);
  63. return ret;
  64. }
  65. }
  66. template <bool IsFixed>
  67. void CopyState(NullableDatum src, NullableDatum& dst) {
  68. if constexpr (IsFixed) {
  69. dst = src;
  70. } else {
  71. if (src.isnull == dst.isnull && src.value == dst.value) {
  72. return;
  73. }
  74. if (!dst.isnull) {
  75. pfree((void*)dst.value);
  76. }
  77. dst = src;
  78. }
  79. }
  80. template <bool IsFixed>
  81. void SaveToAggContext(NullableDatum& d, i32 typeLen) {
  82. if constexpr (IsFixed) {
  83. return;
  84. }
  85. if (d.isnull) {
  86. return;
  87. }
  88. // arrow Scalars/Arrays have null memory context
  89. if (NUdf::GetMemoryContext((void*)d.value)) {
  90. return;
  91. }
  92. d.value = CloneDatumToAggContext<false>(d.value, typeLen);
  93. }
  94. template <typename TArgsPolicy>
  95. struct TInputArgsAccessor {
  96. std::array<NullableDatum, TArgsPolicy::IsFixedArg.size()> Scalars;
  97. std::array<bool, TArgsPolicy::IsFixedArg.size()> IsScalar;
  98. std::array<ui64, TArgsPolicy::IsFixedArg.size()> Offsets;
  99. std::array<const ui8*, TArgsPolicy::IsFixedArg.size()> ValidMasks;
  100. std::array<ui64, TArgsPolicy::IsFixedArg.size()> ValidOffsetMask;
  101. ui8 fakeValidByte = 0xFF;
  102. std::array<const ui64*, TArgsPolicy::IsFixedArg.size()> FixedArrays;
  103. std::array<const ui32*, TArgsPolicy::IsFixedArg.size()> StringOffsetsArrays;
  104. std::array<const ui8*, TArgsPolicy::IsFixedArg.size()> StringDataArrays;
  105. void Bind(const std::vector<arrow::Datum>& values, size_t skipArgs = 0, TMaybe<size_t> realArgsCount = {}) {
  106. if constexpr (!TArgsPolicy::VarArgs) {
  107. const size_t argCount = realArgsCount.GetOrElse(TArgsPolicy::IsFixedArg.size());
  108. Y_ENSURE(argCount == values.size() + skipArgs);
  109. for (size_t j = skipArgs; j < argCount; ++j) {
  110. IsScalar[j] = values[j - skipArgs].is_scalar();
  111. if (IsScalar[j]) {
  112. const auto& scalar = *values[j - skipArgs].scalar();
  113. if (!scalar.is_valid) {
  114. Scalars[j].isnull = true;
  115. } else {
  116. Scalars[j].isnull = false;
  117. if (TArgsPolicy::IsFixedArg[j]) {
  118. Scalars[j].value = (Datum)*static_cast<const ui64*>(arrow::internal::checked_cast<const arrow::internal::PrimitiveScalarBase&>(scalar).data());
  119. } else {
  120. auto buffer = arrow::internal::checked_cast<const arrow::BaseBinaryScalar&>(scalar).value;
  121. Scalars[j].value = (Datum)(buffer->data() + sizeof(void*));
  122. }
  123. }
  124. } else {
  125. const auto& array = *values[j - skipArgs].array();
  126. Offsets[j] = array.offset;
  127. ValidMasks[j] = array.GetValues<ui8>(0, 0);
  128. if (ValidMasks[j]) {
  129. ValidOffsetMask[j] = ~0ull;
  130. } else {
  131. ValidOffsetMask[j] = 0ull;
  132. ValidMasks[j] = &fakeValidByte;
  133. }
  134. if (TArgsPolicy::IsFixedArg[j]) {
  135. FixedArrays[j] = array.GetValues<ui64>(1);
  136. } else {
  137. StringOffsetsArrays[j] = array.GetValues<ui32>(1);
  138. StringDataArrays[j] = array.GetValues<ui8>(2, 0);
  139. }
  140. }
  141. }
  142. }
  143. }
  144. };
  145. template <bool HasNulls, bool IsFixed>
  146. void FillScalarItem(const arrow::Scalar& scalar, NullableDatum& d) {
  147. if constexpr (IsFixed) {
  148. NUdf::TFixedSizeBlockReader<ui64, HasNulls> reader;
  149. auto item = reader.GetScalarItem(scalar);
  150. if (HasNulls && !item) {
  151. d.isnull = true;
  152. } else {
  153. d.isnull = false;
  154. d.value = (Datum)item.template As<ui64>();
  155. }
  156. } else {
  157. NUdf::TStringBlockReader<arrow::BinaryType, HasNulls> reader;
  158. auto item = reader.GetScalarItem(scalar);
  159. if (HasNulls && !item) {
  160. d.isnull = true;
  161. } else {
  162. d.isnull = false;
  163. d.value = (Datum)(item.AsStringRef().Data() + sizeof(void*));
  164. }
  165. }
  166. }
  167. template <bool HasNulls, bool IsFixed>
  168. void FillArrayItem(const arrow::ArrayData& array, size_t i, NullableDatum& d) {
  169. if constexpr (IsFixed) {
  170. NUdf::TFixedSizeBlockReader<ui64, HasNulls> reader;
  171. auto item = reader.GetItem(array, i);
  172. if (HasNulls && !item) {
  173. d.isnull = true;
  174. } else {
  175. d.isnull = false;
  176. d.value = (Datum)item.template As<ui64>();
  177. }
  178. } else {
  179. NUdf::TStringBlockReader<arrow::BinaryType, HasNulls> reader;
  180. auto item = reader.GetItem(array, i);
  181. if (HasNulls && !item) {
  182. d.isnull = true;
  183. } else {
  184. d.isnull = false;
  185. d.value = (Datum)(item.AsStringRef().Data() + sizeof(void*));
  186. }
  187. }
  188. }
  189. template <auto Start, auto End, auto Inc, class F>
  190. constexpr bool constexpr_for(F&& f) {
  191. if constexpr (Start < End) {
  192. if (!f(std::integral_constant<decltype(Start), Start>())) {
  193. return false;
  194. }
  195. return constexpr_for<Start + Inc, End, Inc>(f);
  196. }
  197. return true;
  198. }
  199. template <class F, class Tuple>
  200. constexpr bool constexpr_for_tuple(F&& f, Tuple&& tuple) {
  201. constexpr size_t cnt = std::tuple_size_v<std::decay_t<Tuple>>;
  202. return constexpr_for<size_t(0), cnt, size_t(1)>([&](auto i) {
  203. return f(i.value, std::get<i.value>(tuple));
  204. });
  205. }
  206. enum class EScalarArgBinary {
  207. Unknown,
  208. First,
  209. Second
  210. };
  211. struct TDefaultArgsPolicy {
  212. static constexpr bool VarArgs = true;
  213. static constexpr std::array<bool, 0> IsFixedArg = {};
  214. };
  215. Y_PRAGMA_DIAGNOSTIC_PUSH
  216. Y_PRAGMA("GCC diagnostic ignored \"-Wreturn-type-c-linkage\"")
  217. extern "C" TPgKernelState& GetPGKernelState(arrow::compute::KernelContext* ctx);
  218. Y_PRAGMA_DIAGNOSTIC_POP
  219. template <typename TFunc, bool IsStrict, bool IsFixedResult, typename TArgsPolicy = TDefaultArgsPolicy>
  220. struct TGenericExec {
  221. TGenericExec(TFunc func)
  222. : Func(func)
  223. {}
  224. Y_NO_INLINE arrow::Status operator()(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
  225. auto& state = GetPGKernelState(ctx);
  226. if constexpr (!TArgsPolicy::VarArgs) {
  227. Y_ENSURE(batch.values.size() == TArgsPolicy::IsFixedArg.size());
  228. Y_ENSURE(state.IsFixedArg.size() == TArgsPolicy::IsFixedArg.size());
  229. for (ui32 i = 0; i < TArgsPolicy::IsFixedArg.size(); ++i) {
  230. Y_ENSURE(state.IsFixedArg[i] == TArgsPolicy::IsFixedArg[i]);
  231. }
  232. }
  233. size_t length = 1;
  234. bool hasNulls = false;
  235. bool hasArrays = false;
  236. bool hasScalars = false;
  237. for (const auto& v : batch.values) {
  238. if (v.is_array()) {
  239. length = v.array()->length;
  240. if (v.array()->GetNullCount() > 0) {
  241. hasNulls = true;
  242. }
  243. hasArrays = true;
  244. } else {
  245. hasScalars = true;
  246. if (!v.scalar()->is_valid) {
  247. hasNulls = true;
  248. }
  249. }
  250. }
  251. Y_ENSURE(hasArrays);
  252. Y_ENSURE(state.flinfo.fn_strict == IsStrict);
  253. Y_ENSURE(state.IsFixedResult == IsFixedResult);
  254. TArenaMemoryContext arena;
  255. Dispatch1(hasScalars, hasNulls, ctx, batch, length, state, res);
  256. return arrow::Status::OK();
  257. }
  258. Y_NO_INLINE void Dispatch1(bool hasScalars, bool hasNulls, arrow::compute::KernelContext* ctx,
  259. const arrow::compute::ExecBatch& batch, size_t length, TPgKernelState& state, arrow::Datum* res) const {
  260. if (hasScalars) {
  261. if (hasNulls) {
  262. if constexpr (IsFixedResult) {
  263. NUdf::TFixedSizeArrayBuilder<ui64, true> builder(NKikimr::NMiniKQL::TTypeInfoHelper(), arrow::uint64(), *ctx->memory_pool(), length);
  264. *res = Dispatch2<true, true>(batch, length, state, builder);
  265. } else {
  266. NUdf::TStringArrayBuilder<arrow::BinaryType, true> builder(NKikimr::NMiniKQL::TTypeInfoHelper(), arrow::binary(), *ctx->memory_pool(), length);
  267. *res = Dispatch2<true, true>(batch, length, state, builder);
  268. }
  269. } else {
  270. if constexpr (IsFixedResult) {
  271. NUdf::TFixedSizeArrayBuilder<ui64, true> builder(NKikimr::NMiniKQL::TTypeInfoHelper(), arrow::uint64(), *ctx->memory_pool(), length);
  272. *res = Dispatch2<true, false>(batch, length, state, builder);
  273. } else {
  274. NUdf::TStringArrayBuilder<arrow::BinaryType, true> builder(NKikimr::NMiniKQL::TTypeInfoHelper(), arrow::binary(), *ctx->memory_pool(), length);
  275. *res = Dispatch2<true, false>(batch, length, state, builder);
  276. }
  277. }
  278. } else {
  279. if (hasNulls) {
  280. if constexpr (IsFixedResult) {
  281. NUdf::TFixedSizeArrayBuilder<ui64, true> builder(NKikimr::NMiniKQL::TTypeInfoHelper(), arrow::uint64(), *ctx->memory_pool(), length);
  282. *res = Dispatch2<false, true>(batch, length, state, builder);
  283. } else {
  284. NUdf::TStringArrayBuilder<arrow::BinaryType, true> builder(NKikimr::NMiniKQL::TTypeInfoHelper(), arrow::binary(), *ctx->memory_pool(), length);
  285. *res = Dispatch2<false, true>(batch, length, state, builder);
  286. }
  287. } else {
  288. if constexpr (IsFixedResult) {
  289. NUdf::TFixedSizeArrayBuilder<ui64, true> builder(NKikimr::NMiniKQL::TTypeInfoHelper(), arrow::uint64(), *ctx->memory_pool(), length);
  290. *res = Dispatch2<false, false>(batch, length, state, builder);
  291. } else {
  292. NUdf::TStringArrayBuilder<arrow::BinaryType, true> builder(NKikimr::NMiniKQL::TTypeInfoHelper(), arrow::binary(), *ctx->memory_pool(), length);
  293. *res = Dispatch2<false, false>(batch, length, state, builder);
  294. }
  295. }
  296. }
  297. }
  298. template <bool HasScalars, bool HasNulls, typename TBuilder>
  299. Y_NO_INLINE arrow::Datum Dispatch2(const arrow::compute::ExecBatch& batch, size_t length, TPgKernelState& state, TBuilder& builder) const {
  300. if constexpr (!TArgsPolicy::VarArgs) {
  301. if (TArgsPolicy::IsFixedArg.size() == 2) {
  302. if (batch.values[0].is_scalar()) {
  303. return Dispatch3<HasScalars, HasNulls, EScalarArgBinary::First>(batch, length, state, builder);
  304. }
  305. if (batch.values[1].is_scalar()) {
  306. return Dispatch3<HasScalars, HasNulls, EScalarArgBinary::Second>(batch, length, state, builder);
  307. }
  308. }
  309. }
  310. return Dispatch3<HasScalars, HasNulls, EScalarArgBinary::Unknown>(batch, length, state, builder);
  311. }
  312. template <bool HasScalars, bool HasNulls, EScalarArgBinary ScalarArgBinary, typename TBuilder>
  313. Y_NO_INLINE arrow::Datum Dispatch3(const arrow::compute::ExecBatch& batch, size_t length, TPgKernelState& state, TBuilder& builder) const {
  314. LOCAL_FCINFO(fcinfo, FUNC_MAX_ARGS);
  315. fcinfo->flinfo = &state.flinfo;
  316. fcinfo->context = state.context;
  317. fcinfo->resultinfo = state.resultinfo;
  318. fcinfo->fncollation = state.fncollation;
  319. fcinfo->nargs = batch.values.size();
  320. TInputArgsAccessor<TArgsPolicy> inputArgsAccessor;
  321. inputArgsAccessor.Bind(batch.values);
  322. ui64* fixedResultData = nullptr;
  323. ui8* fixedResultValidMask = nullptr;
  324. if constexpr (IsFixedResult) {
  325. builder.UnsafeReserve(length);
  326. fixedResultData = builder.MutableData();
  327. fixedResultValidMask = builder.MutableValidMask();
  328. }
  329. for (size_t i = 0; i < length; ++i) {
  330. Datum ret;
  331. bool needToFree = false;
  332. if constexpr (!TArgsPolicy::VarArgs) {
  333. if (!constexpr_for_tuple([&](auto const& j, auto const& v) {
  334. NullableDatum d;
  335. if (HasScalars && (
  336. (ScalarArgBinary == EScalarArgBinary::First && j == 0) ||
  337. (ScalarArgBinary == EScalarArgBinary::Second && j == 1) ||
  338. inputArgsAccessor.IsScalar[j])) {
  339. d = inputArgsAccessor.Scalars[j];
  340. } else {
  341. d.isnull = false;
  342. if constexpr (HasNulls) {
  343. ui64 fullIndex = (i + inputArgsAccessor.Offsets[j]) & inputArgsAccessor.ValidOffsetMask[j];
  344. d.isnull = ((inputArgsAccessor.ValidMasks[j][fullIndex >> 3] >> (fullIndex & 0x07)) & 1) == 0;
  345. }
  346. if (v) {
  347. d.value = (Datum)inputArgsAccessor.FixedArrays[j][i];
  348. } else {
  349. d.value = (Datum)(sizeof(void*) + inputArgsAccessor.StringOffsetsArrays[j][i] + inputArgsAccessor.StringDataArrays[j]);
  350. }
  351. }
  352. if (HasNulls && IsStrict && d.isnull) {
  353. return false;
  354. }
  355. fcinfo->args[j] = d;
  356. return true;
  357. }, TArgsPolicy::IsFixedArg)) {
  358. if constexpr (IsFixedResult) {
  359. fixedResultValidMask[i] = 0;
  360. } else {
  361. builder.Add(NUdf::TBlockItem{});
  362. }
  363. goto SkipCall;
  364. }
  365. } else {
  366. for (size_t j = 0; j < batch.values.size(); ++j) {
  367. NullableDatum d;
  368. if (HasScalars && batch.values[j].is_scalar()) {
  369. if (state.IsFixedArg[j]) {
  370. FillScalarItem<HasNulls, true>(*batch.values[j].scalar(), d);
  371. } else {
  372. FillScalarItem<HasNulls, false>(*batch.values[j].scalar(), d);
  373. }
  374. } else {
  375. if (state.IsFixedArg[j]) {
  376. FillArrayItem<HasNulls, true>(*batch.values[j].array(), i, d);
  377. } else {
  378. FillArrayItem<HasNulls, false>(*batch.values[j].array(), i, d);
  379. }
  380. }
  381. if (HasNulls && IsStrict && d.isnull) {
  382. if constexpr (IsFixedResult) {
  383. fixedResultValidMask[i] = 0;
  384. } else {
  385. builder.Add(NUdf::TBlockItem{});
  386. }
  387. goto SkipCall;
  388. }
  389. fcinfo->args[j] = d;
  390. }
  391. }
  392. fcinfo->isnull = false;
  393. if constexpr (TArgsPolicy::VarArgs) {
  394. needToFree = PrepareVariadicArray(*fcinfo, *state.ProcDesc);
  395. }
  396. ret = Func(fcinfo);
  397. if constexpr (TArgsPolicy::VarArgs) {
  398. if (needToFree) {
  399. FreeVariadicArray(*fcinfo, batch.values.size());
  400. }
  401. }
  402. if constexpr (IsFixedResult) {
  403. fixedResultData[i] = ui64(ret);
  404. fixedResultValidMask[i] = !fcinfo->isnull;
  405. } else {
  406. if (fcinfo->isnull) {
  407. builder.Add(NUdf::TBlockItem{});
  408. } else {
  409. auto ptr = (char*)ret;
  410. ui32 len;
  411. if (state.TypeLen == -1) {
  412. len = GetFullVarSize((const text*)ptr);
  413. } else if (state.TypeLen == -2) {
  414. len = 1 + strlen(ptr);
  415. } else {
  416. len = state.TypeLen;
  417. }
  418. builder.AddPgItem(NUdf::TStringRef(ptr, len));
  419. }
  420. }
  421. SkipCall:;
  422. }
  423. return builder.Build(true);
  424. }
  425. TFunc Func;
  426. };
  427. using TExecFunc = std::function<arrow::Status(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res)>;
  428. template <bool IsStrict, bool IsFixedResult>
  429. TExecFunc MakeIndirectExec(PGFunction pgFunc) {
  430. return TGenericExec<TPgIndirectFunc, IsStrict, IsFixedResult>(TPgIndirectFunc(pgFunc));
  431. }
  432. template <bool IsFixed, typename TArgsPolicy>
  433. NullableDatum GetInputValue(const TInputArgsAccessor<TArgsPolicy>& accessor, ui32 j, ui64 row) {
  434. static_assert(!TArgsPolicy::VarArgs);
  435. if (accessor.IsScalar[j]) {
  436. return accessor.Scalars[j];
  437. } else {
  438. NullableDatum d;
  439. ui64 fullIndex = (row + accessor.Offsets[j]) & accessor.ValidOffsetMask[j];
  440. d.isnull = ((accessor.ValidMasks[j][fullIndex >> 3] >> (fullIndex & 0x07)) & 1) == 0;
  441. if constexpr (IsFixed) {
  442. d.value = (Datum)accessor.FixedArrays[j][row];
  443. } else {
  444. d.value = (Datum)(sizeof(void*) + accessor.StringOffsetsArrays[j][row] + accessor.StringDataArrays[j]);
  445. }
  446. return d;
  447. }
  448. }
  449. template <bool IsFixed>
  450. NullableDatum GetInputValueSlow(const std::vector<arrow::Datum>& values, ui32 j, ui64 row) {
  451. NullableDatum d;
  452. if (values[j].is_scalar()) {
  453. if constexpr (IsFixed) {
  454. FillScalarItem<true, true>(*values[j].scalar(), d);
  455. } else {
  456. FillScalarItem<true, false>(*values[j].scalar(), d);
  457. }
  458. } else {
  459. if constexpr (IsFixed) {
  460. FillArrayItem<true, true>(*values[j].array(), row, d);
  461. } else {
  462. FillArrayItem<true, false>(*values[j].array(), row, d);
  463. }
  464. }
  465. return d;
  466. }
  467. template <bool IsFixed, bool HasFunc, typename TFunc, bool IsStrict, typename TBuilder>
  468. class TAggColumnBuilder : public NKikimr::NMiniKQL::IAggColumnBuilder {
  469. public:
  470. TAggColumnBuilder(const TString& name, TFunc func, ui64 size, FmgrInfo* funcInfo, const std::shared_ptr<arrow::DataType>& dataType,
  471. NKikimr::NMiniKQL::TComputationContext& ctx, i32 typeLen)
  472. : Name_(name)
  473. , Func_(func)
  474. , FuncInfo_(funcInfo)
  475. , Builder_(NKikimr::NMiniKQL::TTypeInfoHelper(), dataType, ctx.ArrowMemoryPool, size)
  476. , Ctx_(ctx)
  477. , TypeLen_(typeLen)
  478. {
  479. }
  480. void Add(const void* state) final {
  481. auto typedState = (NullableDatum*)state;
  482. auto ret = *typedState;
  483. if constexpr (HasFunc) {
  484. if (!IsStrict || !typedState->isnull) {
  485. LOCAL_FCINFO(callInfo, 1);
  486. callInfo->flinfo = FuncInfo_;
  487. callInfo->nargs = 1;
  488. callInfo->fncollation = DEFAULT_COLLATION_OID;
  489. callInfo->context = (Node*)NKikimr::NMiniKQL::TlsAllocState->CurrentContext;
  490. callInfo->isnull = false;
  491. callInfo->args[0].isnull = typedState->isnull;
  492. callInfo->args[0].value = typedState->value;
  493. ret.value = Func_(callInfo);
  494. ret.isnull = callInfo->isnull;
  495. }
  496. }
  497. if (ret.isnull) {
  498. Builder_.Add(NYql::NUdf::TBlockItem());
  499. } else {
  500. if constexpr (IsFixed) {
  501. Builder_.Add(NYql::NUdf::TBlockItem(ui64(ret.value)));
  502. } else if (TypeLen_ == -1) {
  503. auto ptr = (char*)ret.value;
  504. ui32 len = GetFullVarSize((const text*)ptr);
  505. Builder_.AddPgItem(NYql::NUdf::TStringRef(ptr, len));
  506. } else if (TypeLen_ == -2) {
  507. auto ptr = (char*)ret.value;
  508. ui32 len = 1 + strlen(ptr);
  509. Builder_.AddPgItem(NYql::NUdf::TStringRef(ptr, len));
  510. } else {
  511. auto ptr = (char*)ret.value;
  512. ui32 len = TypeLen_;
  513. Builder_.AddPgItem(NYql::NUdf::TStringRef(ptr, len));
  514. }
  515. }
  516. }
  517. NUdf::TUnboxedValue Build() final {
  518. return Ctx_.HolderFactory.CreateArrowBlock(Builder_.Build(true));
  519. }
  520. private:
  521. const TString Name_;
  522. const TFunc Func_;
  523. FmgrInfo* FuncInfo_;
  524. TBuilder Builder_;
  525. NKikimr::NMiniKQL::TComputationContext& Ctx_;
  526. const i32 TypeLen_;
  527. };
  528. template <typename TTransFunc, bool IsTransStrict, typename TTransArgsPolicy,
  529. typename TCombineFunc, bool IsCombineStrict, typename TCombineArgsPolicy,
  530. bool HasSerialize, typename TSerializeFunc, typename TSerializeArgsPolicy,
  531. bool HasDeserialize, typename TDeserializeFunc, typename TDeserializeArgsPolicy,
  532. bool HasFinal, typename TFinalFunc, bool IsFinalStrict, typename TFinalArgsPolicy,
  533. bool IsTransTypeFixed, bool IsSerializedTypeFixed, bool IsFinalTypeFixed, bool HasInitValue>
  534. class TGenericAgg {
  535. public:
  536. TGenericAgg(TTransFunc transFunc, TCombineFunc combineFunc, TSerializeFunc serializeFunc,
  537. TDeserializeFunc deserializeFunc, TFinalFunc finalFunc)
  538. : TransFunc(transFunc)
  539. , CombineFunc(combineFunc)
  540. , SerializeFunc(serializeFunc)
  541. , DeserializeFunc(deserializeFunc)
  542. , FinalFunc(finalFunc)
  543. {}
  544. private:
  545. template <typename TAggregatorBase>
  546. class TCombineAggregatorBase: public TAggregatorBase {
  547. protected:
  548. TCombineAggregatorBase(TTransFunc transFunc, TSerializeFunc serializeFunc, const std::vector<ui32>& argsColumns,
  549. std::optional<ui32> filterColumn, const NPg::TAggregateDesc& aggDesc, NKikimr::NMiniKQL::TComputationContext& ctx)
  550. : TAggregatorBase(sizeof(NullableDatum), filterColumn, ctx)
  551. , TransFunc_(transFunc)
  552. , SerializeFunc_(serializeFunc)
  553. , ArgsColumns_(argsColumns)
  554. , AggDesc_(aggDesc)
  555. , TransTypeLen_(NPg::LookupType(this->AggDesc_.TransTypeId).TypeLen)
  556. {
  557. if (!HasInitValue && IsTransStrict) {
  558. Y_ENSURE(AggDesc_.ArgTypes.size() == 1);
  559. }
  560. const auto& transDesc = NPg::LookupProc(AggDesc_.TransFuncId);
  561. for (ui32 i = 1; i < transDesc.ArgTypes.size(); ++i) {
  562. IsFixedArg_.push_back(NPg::LookupType(transDesc.ArgTypes[i]).PassByValue);
  563. }
  564. Zero(TransFuncInfo_);
  565. GetPgFuncAddr(AggDesc_.TransFuncId, TransFuncInfo_);
  566. Y_ENSURE(TransFuncInfo_.fn_addr);
  567. auto nargs = NPg::LookupProc(AggDesc_.TransFuncId).ArgTypes.size();
  568. if constexpr (HasSerialize) {
  569. Zero(SerializeFuncInfo_);
  570. GetPgFuncAddr(AggDesc_.SerializeFuncId, SerializeFuncInfo_);
  571. Y_ENSURE(SerializeFuncInfo_.fn_addr);
  572. }
  573. if constexpr (HasInitValue) {
  574. Zero(InFuncInfo_);
  575. const auto& transTypeDesc = NPg::LookupType(AggDesc_.TransTypeId);
  576. auto inFuncId = transTypeDesc.InFuncId;
  577. if (transTypeDesc.TypeId == transTypeDesc.ArrayTypeId) {
  578. inFuncId = NPg::LookupProc("array_in", { 0,0,0 }).ProcId;
  579. }
  580. TypeIOParam_ = MakeTypeIOParam(transTypeDesc);
  581. GetPgFuncAddr(inFuncId, InFuncInfo_);
  582. Y_ENSURE(InFuncInfo_.fn_addr);
  583. LOCAL_FCINFO(inCallInfo, 3);
  584. inCallInfo->flinfo = &this->InFuncInfo_;
  585. inCallInfo->nargs = 3;
  586. inCallInfo->fncollation = DEFAULT_COLLATION_OID;
  587. inCallInfo->context = (Node*)NKikimr::NMiniKQL::TlsAllocState->CurrentContext;
  588. inCallInfo->isnull = false;
  589. inCallInfo->args[0] = { (Datum)this->AggDesc_.InitValue.c_str(), false };
  590. inCallInfo->args[1] = { ObjectIdGetDatum(this->TypeIOParam_), false };
  591. inCallInfo->args[2] = { Int32GetDatum(-1), false };
  592. auto state = this->InFuncInfo_.fn_addr(inCallInfo);
  593. Y_ENSURE(!inCallInfo->isnull);
  594. PreparedInitValue_ = AnyDatumToPod(state, IsTransTypeFixed);
  595. }
  596. }
  597. const TTransFunc TransFunc_;
  598. const TSerializeFunc SerializeFunc_;
  599. const std::vector<ui32> ArgsColumns_;
  600. const NPg::TAggregateDesc& AggDesc_;
  601. const i32 TransTypeLen_;
  602. std::vector<bool> IsFixedArg_;
  603. FmgrInfo TransFuncInfo_;
  604. FmgrInfo SerializeFuncInfo_;
  605. FmgrInfo InFuncInfo_;
  606. ui32 TypeIOParam_ = 0;
  607. NKikimr::NUdf::TUnboxedValue PreparedInitValue_;
  608. };
  609. template <bool HasFilter>
  610. class TCombineAllAggregator : public TCombineAggregatorBase<NKikimr::NMiniKQL::TCombineAllTag::TBase> {
  611. public:
  612. using TBase = TCombineAggregatorBase<NKikimr::NMiniKQL::TCombineAllTag::TBase>;
  613. TCombineAllAggregator(TTransFunc transFunc, TSerializeFunc serializeFunc, const std::vector<ui32>& argsColumns,
  614. std::optional<ui32> filterColumn, const NPg::TAggregateDesc& aggDesc, NKikimr::NMiniKQL::TComputationContext& ctx)
  615. : TBase(transFunc, serializeFunc, argsColumns, filterColumn, aggDesc, ctx)
  616. {
  617. Y_ENSURE(HasFilter == filterColumn.has_value());
  618. }
  619. private:
  620. void DestroyState(void* state) noexcept final {
  621. Y_UNUSED(state);
  622. }
  623. void InitState(void* state) final {
  624. new(state) NullableDatum();
  625. auto typedState = (NullableDatum*)state;
  626. typedState->isnull = true;
  627. typedState->value = 0;
  628. if constexpr (HasInitValue) {
  629. auto datum = IsTransTypeFixed ? ScalarDatumFromPod(this->PreparedInitValue_) : PointerDatumFromPod(this->PreparedInitValue_);
  630. typedState->isnull = false;
  631. typedState->value = CloneDatumToAggContext<IsTransTypeFixed>(datum, this->TransTypeLen_);
  632. }
  633. }
  634. void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final {
  635. auto typedState = (NullableDatum*)state;
  636. std::vector<arrow::Datum> values;
  637. values.reserve(this->ArgsColumns_.size());
  638. for (auto col : this->ArgsColumns_) {
  639. values.push_back(NKikimr::NMiniKQL::TArrowBlock::From(columns[col]).GetDatum());
  640. }
  641. bool hasNulls = false;
  642. bool hasScalars = false;
  643. for (const auto& v : values) {
  644. if (v.is_array()) {
  645. if (v.array()->GetNullCount() > 0) {
  646. hasNulls = true;
  647. }
  648. } else {
  649. hasScalars = true;
  650. if (!v.scalar()->is_valid) {
  651. hasNulls = true;
  652. }
  653. }
  654. }
  655. const ui8* filterBitmap = nullptr;
  656. if constexpr(HasFilter) {
  657. const auto& filterDatum = NKikimr::NMiniKQL::TArrowBlock::From(columns[*this->FilterColumn_]).GetDatum();
  658. const auto& filterArray = filterDatum.array();
  659. Y_ENSURE(filterArray->GetNullCount() == 0);
  660. filterBitmap = filterArray->template GetValues<uint8_t>(1);
  661. }
  662. if (hasNulls) {
  663. if (hasScalars) {
  664. AddManyImpl<true, true>(typedState, values, batchLength, filterBitmap);
  665. } else {
  666. AddManyImpl<true, false>(typedState, values, batchLength, filterBitmap);
  667. }
  668. } else {
  669. if (hasScalars) {
  670. AddManyImpl<false, true>(typedState, values, batchLength, filterBitmap);
  671. } else {
  672. AddManyImpl<false, false>(typedState, values, batchLength, filterBitmap);
  673. }
  674. }
  675. }
  676. template <bool HasNulls, bool HasScalars>
  677. void AddManyImpl(NullableDatum* typedState, const std::vector<arrow::Datum>& values, ui64 batchLength, const ui8* filterBitmap) {
  678. LOCAL_FCINFO(transCallInfo, FUNC_MAX_ARGS);
  679. transCallInfo->flinfo = &this->TransFuncInfo_;
  680. if constexpr (!TTransArgsPolicy::VarArgs) {
  681. transCallInfo->nargs = TTransArgsPolicy::IsFixedArg.size();
  682. } else {
  683. transCallInfo->nargs = 1 + values.size();
  684. }
  685. transCallInfo->fncollation = DEFAULT_COLLATION_OID;
  686. transCallInfo->context = (Node*)NKikimr::NMiniKQL::TlsAllocState->CurrentContext;
  687. transCallInfo->args[0] = *typedState;
  688. TInputArgsAccessor<TTransArgsPolicy> inputArgsAccessor;
  689. inputArgsAccessor.Bind(values, 1);
  690. for (ui64 i = 0; i < batchLength; ++i) {
  691. if constexpr (HasFilter) {
  692. if (!filterBitmap[i]) {
  693. continue;
  694. }
  695. }
  696. Datum ret;
  697. if constexpr (!TTransArgsPolicy::VarArgs) {
  698. if (!constexpr_for_tuple([&](auto const& j, auto const& v) {
  699. if (j == 0) {
  700. return true;
  701. }
  702. NullableDatum d;
  703. if (HasScalars && inputArgsAccessor.IsScalar[j]) {
  704. d = inputArgsAccessor.Scalars[j];
  705. } else {
  706. d.isnull = false;
  707. if constexpr (HasNulls) {
  708. ui64 fullIndex = (i + inputArgsAccessor.Offsets[j]) & inputArgsAccessor.ValidOffsetMask[j];
  709. d.isnull = ((inputArgsAccessor.ValidMasks[j][fullIndex >> 3] >> (fullIndex & 0x07)) & 1) == 0;
  710. }
  711. if (v) {
  712. d.value = (Datum)inputArgsAccessor.FixedArrays[j][i];
  713. } else {
  714. d.value = (Datum)(sizeof(void*) + inputArgsAccessor.StringOffsetsArrays[j][i] + inputArgsAccessor.StringDataArrays[j]);
  715. }
  716. }
  717. if (HasNulls && IsTransStrict && d.isnull) {
  718. return false;
  719. }
  720. transCallInfo->args[j] = d;
  721. return true;
  722. }, TTransArgsPolicy::IsFixedArg)) {
  723. goto SkipCall;
  724. }
  725. } else {
  726. for (size_t j = 0; j < values.size(); ++j) {
  727. NullableDatum d;
  728. if (HasScalars && values[j].is_scalar()) {
  729. if (this->IsFixedArg_[j]) {
  730. FillScalarItem<HasNulls, true>(*values[j].scalar(), d);
  731. } else {
  732. FillScalarItem<HasNulls, false>(*values[j].scalar(), d);
  733. }
  734. } else {
  735. if (this->IsFixedArg_[j]) {
  736. FillArrayItem<HasNulls, true>(*values[j].array(), i, d);
  737. } else {
  738. FillArrayItem<HasNulls, false>(*values[j].array(), i, d);
  739. }
  740. }
  741. if (HasNulls && IsTransStrict && d.isnull) {
  742. goto SkipCall;
  743. }
  744. transCallInfo->args[1 + j] = d;
  745. }
  746. }
  747. if (!HasInitValue && IsTransStrict) {
  748. if (transCallInfo->args[0].isnull) {
  749. transCallInfo->args[0] = transCallInfo->args[1];
  750. continue;
  751. }
  752. }
  753. transCallInfo->isnull = false;
  754. ret = this->TransFunc_(transCallInfo);
  755. transCallInfo->args[0].value = ret;
  756. transCallInfo->args[0].isnull = transCallInfo->isnull;
  757. SkipCall:;
  758. }
  759. CopyState<IsTransTypeFixed>(transCallInfo->args[0], *typedState);
  760. SaveToAggContext<IsTransTypeFixed>(*typedState, this->TransTypeLen_);
  761. }
  762. NUdf::TUnboxedValue FinishOne(const void* state) final {
  763. auto typedState = (NullableDatum*)state;
  764. if (typedState->isnull) {
  765. return {};
  766. }
  767. if constexpr (HasSerialize) {
  768. NUdf::TUnboxedValue ret;
  769. LOCAL_FCINFO(serializeCallInfo, 1);
  770. serializeCallInfo->flinfo = &this->SerializeFuncInfo_;
  771. serializeCallInfo->nargs = 1;
  772. serializeCallInfo->fncollation = DEFAULT_COLLATION_OID;
  773. serializeCallInfo->context = (Node*)NKikimr::NMiniKQL::TlsAllocState->CurrentContext;
  774. serializeCallInfo->isnull = false;
  775. serializeCallInfo->args[0].isnull = false;
  776. serializeCallInfo->args[0].value = typedState->value;
  777. auto ser = this->SerializeFunc_(serializeCallInfo);
  778. Y_ENSURE(!serializeCallInfo->isnull);
  779. if constexpr (IsSerializedTypeFixed) {
  780. ret = ScalarDatumToPod(ser);
  781. } else {
  782. ret = PointerDatumToPod(ser);
  783. if (ser == typedState->value) {
  784. typedState->value = 0;
  785. typedState->isnull = true;
  786. }
  787. }
  788. return ret;
  789. } else {
  790. if constexpr (IsTransTypeFixed) {
  791. return ScalarDatumToPod(typedState->value);
  792. } else {
  793. auto ret = PointerDatumToPod(typedState->value);
  794. typedState->value = 0;
  795. typedState->isnull = true;
  796. return ret;
  797. }
  798. }
  799. }
  800. };
  801. class TCombineKeysAggregator : public TCombineAggregatorBase<NKikimr::NMiniKQL::TCombineKeysTag::TBase> {
  802. public:
  803. using TBase = TCombineAggregatorBase<NKikimr::NMiniKQL::TCombineKeysTag::TBase>;
  804. TCombineKeysAggregator(TTransFunc transFunc, TSerializeFunc serializeFunc, const std::vector<ui32>& argsColumns,
  805. const NPg::TAggregateDesc& aggDesc, NKikimr::NMiniKQL::TComputationContext& ctx)
  806. : TBase(transFunc, serializeFunc, argsColumns, std::optional<ui32>(), aggDesc, ctx)
  807. , SerializedType_(HasSerialize ? NPg::LookupProc(this->AggDesc_.SerializeFuncId).ResultType : this->AggDesc_.TransTypeId)
  808. {
  809. Values_.reserve(this->IsFixedArg_.size());
  810. }
  811. void DestroyState(void* state) noexcept final {
  812. Y_UNUSED(state);
  813. }
  814. void PrepareBatch(ui64 batchNum, const NKikimr::NUdf::TUnboxedValue* columns) {
  815. Values_.clear();
  816. for (auto col : this->ArgsColumns_) {
  817. Values_.push_back(NKikimr::NMiniKQL::TArrowBlock::From(columns[col]).GetDatum());
  818. }
  819. InputArgsAccessor_.Bind(Values_, 1);
  820. BatchNum_ = batchNum;
  821. }
  822. void InitKey(void* state, ui64 batchNum, const NKikimr::NUdf::TUnboxedValue* columns, ui64 row) final {
  823. new(state) NullableDatum();
  824. auto typedState = (NullableDatum*)state;
  825. typedState->isnull = true;
  826. typedState->value = 0;
  827. if constexpr (HasInitValue) {
  828. auto datum = IsTransTypeFixed ? ScalarDatumFromPod(this->PreparedInitValue_) : PointerDatumFromPod(this->PreparedInitValue_);
  829. typedState->isnull = false;
  830. typedState->value = CloneDatumToAggContext<IsTransTypeFixed>(datum, this->TransTypeLen_);
  831. }
  832. UpdateKey(state, batchNum, columns, row);
  833. }
  834. void UpdateKey(void* state, ui64 batchNum, const NKikimr::NUdf::TUnboxedValue* columns, ui64 row) final {
  835. auto typedState = (NullableDatum*)state;
  836. if (batchNum != BatchNum_) {
  837. PrepareBatch(batchNum, columns);
  838. }
  839. LOCAL_FCINFO(transCallInfo, FUNC_MAX_ARGS);
  840. transCallInfo->flinfo = &this->TransFuncInfo_;
  841. if constexpr (!TTransArgsPolicy::VarArgs) {
  842. transCallInfo->nargs = TTransArgsPolicy::IsFixedArg.size();
  843. } else {
  844. transCallInfo->nargs = 1 + Values_.size();
  845. }
  846. transCallInfo->fncollation = DEFAULT_COLLATION_OID;
  847. transCallInfo->context = (Node*)NKikimr::NMiniKQL::TlsAllocState->CurrentContext;
  848. transCallInfo->args[0] = *typedState;
  849. Datum ret;
  850. if constexpr (!TTransArgsPolicy::VarArgs) {
  851. if (!constexpr_for_tuple([&](auto const& j, auto const& v) {
  852. if (j == 0) {
  853. return true;
  854. }
  855. NullableDatum d;
  856. if (v) {
  857. d = GetInputValue<true>(InputArgsAccessor_, j, row);
  858. } else {
  859. d = GetInputValue<false>(InputArgsAccessor_, j, row);
  860. }
  861. if (IsTransStrict && d.isnull) {
  862. return false;
  863. }
  864. transCallInfo->args[j] = d;
  865. return true;
  866. }, TTransArgsPolicy::IsFixedArg)) {
  867. return;
  868. }
  869. } else {
  870. for (size_t j = 0; j < Values_.size(); ++j) {
  871. NullableDatum d;
  872. if (this->IsFixedArg_[j]) {
  873. d = GetInputValueSlow<true>(Values_, j, row);
  874. } else {
  875. d = GetInputValueSlow<false>(Values_, j, row);
  876. }
  877. if (IsTransStrict && d.isnull) {
  878. return;
  879. }
  880. transCallInfo->args[1 + j] = d;
  881. }
  882. }
  883. if (!HasInitValue && IsTransStrict) {
  884. if (transCallInfo->args[0].isnull) {
  885. typedState->isnull = false;
  886. typedState->value = CloneDatumToAggContext<IsTransTypeFixed>(transCallInfo->args[1].value, this->TransTypeLen_);
  887. return;
  888. }
  889. }
  890. transCallInfo->isnull = false;
  891. ret = this->TransFunc_(transCallInfo);
  892. CopyState<IsTransTypeFixed>({ret, transCallInfo->isnull}, *typedState);
  893. SaveToAggContext<IsTransTypeFixed>(*typedState, this->TransTypeLen_);
  894. }
  895. std::unique_ptr<NKikimr::NMiniKQL::IAggColumnBuilder> MakeStateBuilder(ui64 size) final {
  896. auto typeLen = NPg::LookupType(SerializedType_).TypeLen;
  897. if constexpr (IsSerializedTypeFixed) {
  898. return std::make_unique<TAggColumnBuilder<true, HasSerialize, TSerializeFunc, true, NYql::NUdf::TFixedSizeArrayBuilder<ui64, true>>>(
  899. this->AggDesc_.Name, this->SerializeFunc_, size, &this->SerializeFuncInfo_, arrow::uint64(), this->Ctx_, typeLen);
  900. } else {
  901. return std::make_unique<TAggColumnBuilder<false, HasSerialize, TSerializeFunc, true, NYql::NUdf::TStringArrayBuilder<arrow::BinaryType, true, NYql::NUdf::EPgStringType::Text>>>(
  902. this->AggDesc_.Name, this->SerializeFunc_, size, &this->SerializeFuncInfo_, arrow::binary(), this->Ctx_, typeLen);
  903. }
  904. }
  905. const ui32 SerializedType_;
  906. ui64 BatchNum_ = Max<ui64>();
  907. std::vector<arrow::Datum> Values_;
  908. TInputArgsAccessor<TTransArgsPolicy> InputArgsAccessor_;
  909. };
  910. class TFinalizeKeysAggregator : public NKikimr::NMiniKQL::TFinalizeKeysTag::TBase {
  911. public:
  912. using TBase = NKikimr::NMiniKQL::TFinalizeKeysTag::TBase;
  913. TFinalizeKeysAggregator(TDeserializeFunc deserializeFunc, TCombineFunc combineFunc, TFinalFunc finalFunc,
  914. ui32 stateColumn, const NPg::TAggregateDesc& aggDesc, NKikimr::NMiniKQL::TComputationContext& ctx)
  915. : TBase(sizeof(NullableDatum), std::optional<ui32>(), ctx)
  916. , DeserializeFunc_(deserializeFunc)
  917. , CombineFunc_(combineFunc)
  918. , FinalFunc_(finalFunc)
  919. , StateColumn_(stateColumn)
  920. , AggDesc_(aggDesc)
  921. , SerializedType_(HasSerialize ? NPg::LookupProc(this->AggDesc_.SerializeFuncId).ResultType : this->AggDesc_.TransTypeId)
  922. , FinalType_(HasFinal ? NPg::LookupProc(this->AggDesc_.FinalFuncId).ResultType : this->AggDesc_.TransTypeId)
  923. , TransTypeLen_(NPg::LookupType(this->AggDesc_.TransTypeId).TypeLen)
  924. {
  925. Values_.reserve(1);
  926. }
  927. void DestroyState(void* state) noexcept final {
  928. Y_UNUSED(state);
  929. }
  930. void PrepareBatch(ui64 batchNum, const NUdf::TUnboxedValue* columns) {
  931. Values_.clear();
  932. Values_.push_back(NKikimr::NMiniKQL::TArrowBlock::From(columns[StateColumn_]).GetDatum());
  933. if constexpr (HasDeserialize) {
  934. DeserializeAccessor_.Bind(Values_, 0, 1);
  935. } else {
  936. CombineAccessor_.Bind(Values_, 1);
  937. }
  938. BatchNum_ = batchNum;
  939. }
  940. void Deserialize(Datum ser, NullableDatum& result) {
  941. LOCAL_FCINFO(deserializeCallInfo, 1);
  942. deserializeCallInfo->flinfo = &this->DeserializeFuncInfo_;
  943. deserializeCallInfo->nargs = 1;
  944. deserializeCallInfo->fncollation = DEFAULT_COLLATION_OID;
  945. deserializeCallInfo->context = (Node*)NKikimr::NMiniKQL::TlsAllocState->CurrentContext;
  946. deserializeCallInfo->isnull = false;
  947. deserializeCallInfo->args[0].isnull = false;
  948. deserializeCallInfo->args[0].value = ser;
  949. result.value = this->DeserializeFunc_(deserializeCallInfo);
  950. result.isnull = deserializeCallInfo->isnull;
  951. }
  952. void LoadState(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  953. new(state) NullableDatum();
  954. auto typedState = (NullableDatum*)state;
  955. typedState->isnull = true;
  956. typedState->value = 0;
  957. if (BatchNum_ != batchNum) {
  958. PrepareBatch(batchNum, columns);
  959. }
  960. NullableDatum d;
  961. if constexpr (HasDeserialize && !TDeserializeArgsPolicy::VarArgs) {
  962. d = GetInputValue<IsSerializedTypeFixed>(DeserializeAccessor_, 0, row);
  963. } else if constexpr (!HasDeserialize && !TCombineArgsPolicy::VarArgs) {
  964. d = GetInputValue<IsSerializedTypeFixed>(CombineAccessor_, 1, row);
  965. } else {
  966. d = GetInputValueSlow<IsSerializedTypeFixed>(Values_, 0, row);
  967. }
  968. if (d.isnull) {
  969. return;
  970. }
  971. if constexpr (!HasDeserialize) {
  972. typedState->isnull = false;
  973. typedState->value = CloneDatumToAggContext<IsTransTypeFixed>(d.value, this->TransTypeLen_);
  974. } else {
  975. Deserialize(d.value, *typedState);
  976. }
  977. SaveToAggContext<IsTransTypeFixed>(*typedState, this->TransTypeLen_);
  978. }
  979. void UpdateState(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  980. auto typedState = (NullableDatum*)state;
  981. if (BatchNum_ != batchNum) {
  982. PrepareBatch(batchNum, columns);
  983. }
  984. NullableDatum d;
  985. if constexpr (HasDeserialize && !TDeserializeArgsPolicy::VarArgs) {
  986. d = GetInputValue<IsSerializedTypeFixed>(DeserializeAccessor_, 0, row);
  987. } else if constexpr (!HasDeserialize && !TCombineArgsPolicy::VarArgs) {
  988. d = GetInputValue<IsSerializedTypeFixed>(CombineAccessor_, 1, row);
  989. } else {
  990. d = GetInputValueSlow<IsSerializedTypeFixed>(Values_, 0, row);
  991. }
  992. if (IsCombineStrict && d.isnull) {
  993. return;
  994. }
  995. NullableDatum deser;
  996. if (d.isnull) {
  997. deser.isnull = true;
  998. deser.value = 0;
  999. } else {
  1000. if constexpr (!HasDeserialize) {
  1001. if (IsCombineStrict && typedState->isnull) {
  1002. typedState->isnull = false;
  1003. typedState->value = CloneDatumToAggContext<IsTransTypeFixed>(d.value, this->TransTypeLen_);
  1004. return;
  1005. }
  1006. deser = d;
  1007. } else {
  1008. Deserialize(d.value, deser);
  1009. if (IsCombineStrict && typedState->isnull) {
  1010. *typedState = deser;
  1011. return;
  1012. }
  1013. }
  1014. }
  1015. LOCAL_FCINFO(combineCallInfo, 2);
  1016. combineCallInfo->flinfo = &this->CombineFuncInfo_;
  1017. combineCallInfo->nargs = 2;
  1018. combineCallInfo->fncollation = DEFAULT_COLLATION_OID;
  1019. combineCallInfo->context = (Node*)NKikimr::NMiniKQL::TlsAllocState->CurrentContext;
  1020. combineCallInfo->isnull = false;
  1021. combineCallInfo->args[0] = *typedState;
  1022. combineCallInfo->args[1] = deser;
  1023. auto ret = this->CombineFunc_(combineCallInfo);
  1024. if constexpr (!HasDeserialize) {
  1025. if (!combineCallInfo->isnull && ret == d.value) {
  1026. typedState->isnull = false;
  1027. typedState->value = CloneDatumToAggContext<IsTransTypeFixed>(d.value, this->TransTypeLen_);
  1028. return;
  1029. }
  1030. }
  1031. CopyState<IsTransTypeFixed>({ret, combineCallInfo->isnull}, *typedState);
  1032. SaveToAggContext<IsTransTypeFixed>(*typedState, this->TransTypeLen_);
  1033. }
  1034. std::unique_ptr<NKikimr::NMiniKQL::IAggColumnBuilder> MakeResultBuilder(ui64 size) final {
  1035. auto typeLen = NPg::LookupType(FinalType_).TypeLen;
  1036. if constexpr (IsFinalTypeFixed) {
  1037. return std::make_unique<TAggColumnBuilder<true, HasFinal, TFinalFunc, IsFinalStrict, NYql::NUdf::TFixedSizeArrayBuilder<ui64, true>>>(
  1038. this->AggDesc_.Name, this->FinalFunc_, size, &this->FinalFuncInfo_, arrow::uint64(), this->Ctx_, typeLen);
  1039. } else {
  1040. return std::make_unique<TAggColumnBuilder<false, HasFinal, TFinalFunc, IsFinalStrict, NYql::NUdf::TStringArrayBuilder<arrow::BinaryType, true>>>(
  1041. this->AggDesc_.Name, this->FinalFunc_, size, &this->FinalFuncInfo_, arrow::binary(), this->Ctx_, typeLen);
  1042. }
  1043. }
  1044. const TDeserializeFunc DeserializeFunc_;
  1045. const TCombineFunc CombineFunc_;
  1046. const TFinalFunc FinalFunc_;
  1047. const ui32 StateColumn_;
  1048. const NPg::TAggregateDesc& AggDesc_;
  1049. const ui32 SerializedType_;
  1050. const ui32 FinalType_;
  1051. const i32 TransTypeLen_;
  1052. ui64 BatchNum_ = Max<ui64>();
  1053. std::vector<arrow::Datum> Values_;
  1054. TInputArgsAccessor<TDeserializeArgsPolicy> DeserializeAccessor_;
  1055. TInputArgsAccessor<TCombineArgsPolicy> CombineAccessor_;
  1056. FmgrInfo DeserializeFuncInfo_;
  1057. FmgrInfo CombineFuncInfo_;
  1058. FmgrInfo FinalFuncInfo_;
  1059. };
  1060. class TPreparedCombineAllAggregator : public NKikimr::NMiniKQL::IPreparedBlockAggregator<NKikimr::NMiniKQL::IBlockAggregatorCombineAll>{
  1061. public:
  1062. TPreparedCombineAllAggregator(TTransFunc transFunc, TSerializeFunc serializeFunc, const std::vector<ui32>& argsColumns,
  1063. std::optional<ui32> filterColumn, const NPg::TAggregateDesc& aggDesc)
  1064. : IPreparedBlockAggregator(sizeof(NullableDatum))
  1065. , TransFunc(transFunc)
  1066. , SerializeFunc(serializeFunc)
  1067. , ArgsColumns(argsColumns)
  1068. , FilterColumn(filterColumn)
  1069. , AggDesc(aggDesc)
  1070. {}
  1071. private:
  1072. std::unique_ptr<NKikimr::NMiniKQL::IBlockAggregatorCombineAll> Make(NKikimr::NMiniKQL::TComputationContext& ctx) const {
  1073. if (FilterColumn.has_value()) {
  1074. return std::make_unique<TCombineAllAggregator<true>>(TransFunc, SerializeFunc, ArgsColumns, FilterColumn, AggDesc, ctx);
  1075. } else {
  1076. return std::make_unique<TCombineAllAggregator<false>>(TransFunc, SerializeFunc, ArgsColumns, FilterColumn, AggDesc, ctx);
  1077. }
  1078. }
  1079. const TTransFunc TransFunc;
  1080. const TSerializeFunc SerializeFunc;
  1081. const std::vector<ui32> ArgsColumns;
  1082. const std::optional<ui32> FilterColumn;
  1083. const NPg::TAggregateDesc& AggDesc;
  1084. };
  1085. class TPreparedCombineKeysAggregator : public NKikimr::NMiniKQL::IPreparedBlockAggregator<NKikimr::NMiniKQL::IBlockAggregatorCombineKeys>{
  1086. public:
  1087. TPreparedCombineKeysAggregator(TTransFunc transFunc, TSerializeFunc serializeFunc, const std::vector<ui32>& argsColumns,
  1088. const NPg::TAggregateDesc& aggDesc)
  1089. : IPreparedBlockAggregator(sizeof(NullableDatum))
  1090. , TransFunc(transFunc)
  1091. , SerializeFunc(serializeFunc)
  1092. , ArgsColumns(argsColumns)
  1093. , AggDesc(aggDesc)
  1094. {}
  1095. private:
  1096. std::unique_ptr<NKikimr::NMiniKQL::IBlockAggregatorCombineKeys> Make(NKikimr::NMiniKQL::TComputationContext& ctx) const {
  1097. return std::make_unique<TCombineKeysAggregator>(TransFunc, SerializeFunc, ArgsColumns, AggDesc, ctx);
  1098. }
  1099. const TTransFunc TransFunc;
  1100. const TSerializeFunc SerializeFunc;
  1101. const std::vector<ui32> ArgsColumns;
  1102. const NPg::TAggregateDesc& AggDesc;
  1103. };
  1104. class TPreparedFinalizeKeysAggregator : public NKikimr::NMiniKQL::IPreparedBlockAggregator<NKikimr::NMiniKQL::IBlockAggregatorFinalizeKeys>{
  1105. public:
  1106. TPreparedFinalizeKeysAggregator(TDeserializeFunc deserializeFunc, TCombineFunc combineFunc, TFinalFunc finalFunc, ui32 stateColumn,
  1107. const NPg::TAggregateDesc& aggDesc)
  1108. : IPreparedBlockAggregator(sizeof(NullableDatum))
  1109. , DeserializeFunc(deserializeFunc)
  1110. , CombineFunc(combineFunc)
  1111. , FinalFunc(finalFunc)
  1112. , StateColumn(stateColumn)
  1113. , AggDesc(aggDesc)
  1114. {}
  1115. private:
  1116. std::unique_ptr<NKikimr::NMiniKQL::IBlockAggregatorFinalizeKeys> Make(NKikimr::NMiniKQL::TComputationContext& ctx) const {
  1117. return std::make_unique<TFinalizeKeysAggregator>(DeserializeFunc, CombineFunc, FinalFunc, StateColumn, AggDesc, ctx);
  1118. }
  1119. const TDeserializeFunc DeserializeFunc;
  1120. const TCombineFunc CombineFunc;
  1121. const TFinalFunc FinalFunc;
  1122. const ui32 StateColumn;
  1123. const NPg::TAggregateDesc& AggDesc;
  1124. };
  1125. public:
  1126. std::unique_ptr<NKikimr::NMiniKQL::IPreparedBlockAggregator<NKikimr::NMiniKQL::IBlockAggregatorCombineAll>> PrepareCombineAll(
  1127. std::optional<ui32> filterColumn,
  1128. const std::vector<ui32>& argsColumns,
  1129. const NPg::TAggregateDesc& aggDesc) const {
  1130. return std::make_unique<TPreparedCombineAllAggregator>(TransFunc, SerializeFunc, argsColumns, filterColumn, aggDesc);
  1131. }
  1132. std::unique_ptr<NKikimr::NMiniKQL::IPreparedBlockAggregator<NKikimr::NMiniKQL::IBlockAggregatorCombineKeys>> PrepareCombineKeys(
  1133. const std::vector<ui32>& argsColumns,
  1134. const NPg::TAggregateDesc& aggDesc) {
  1135. return std::make_unique<TPreparedCombineKeysAggregator>(TransFunc, SerializeFunc, argsColumns, aggDesc);
  1136. }
  1137. std::unique_ptr<NKikimr::NMiniKQL::IPreparedBlockAggregator<NKikimr::NMiniKQL::IBlockAggregatorFinalizeKeys>> PrepareFinalizeKeys(
  1138. ui32 stateColumn,
  1139. const NPg::TAggregateDesc& aggDesc) {
  1140. return std::make_unique<TPreparedFinalizeKeysAggregator>(DeserializeFunc, CombineFunc, FinalFunc, stateColumn, aggDesc);
  1141. }
  1142. private:
  1143. const TTransFunc TransFunc;
  1144. const TCombineFunc CombineFunc;
  1145. const TSerializeFunc SerializeFunc;
  1146. const TDeserializeFunc DeserializeFunc;
  1147. const TFinalFunc FinalFunc;
  1148. };
  1149. #if defined(_tsan_enabled_) || defined(_msan_enabled_) || defined(_asan_enabled_) || !defined(NDEBUG)
  1150. #ifndef USE_SLOW_PG_KERNELS
  1151. #define USE_SLOW_PG_KERNELS
  1152. #endif
  1153. #endif
  1154. TExecFunc FindExec(Oid oid);
  1155. const NPg::TAggregateDesc& ResolveAggregation(const TString& name, NKikimr::NMiniKQL::TTupleType* tupleType,
  1156. const std::vector<ui32>& argsColumns, NKikimr::NMiniKQL::TType* returnType, ui32 hint = 0);
  1157. }