mkql_block_agg_sum.cpp 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877
  1. #include "mkql_block_agg_sum.h"
  2. #include "mkql_block_agg_state_helper.h"
  3. #include <yql/essentials/minikql/mkql_node_builder.h>
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. #include <yql/essentials/minikql/mkql_node_printer.h>
  6. #include <yql/essentials/minikql/computation/mkql_block_builder.h>
  7. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  8. #include <yql/essentials/minikql/arrow/arrow_defs.h>
  9. #include <yql/essentials/minikql/arrow/arrow_util.h>
  10. #include <yql/essentials/minikql/arrow/mkql_bit_utils.h>
  11. #include <arrow/scalar.h>
  12. #include <arrow/array/builder_primitive.h>
  13. namespace NKikimr {
  14. namespace NMiniKQL {
  15. namespace {
  16. template<bool IsNullable, typename TSum>
  17. struct TSumState;
  18. template<typename TSum>
  19. struct TSumState<true, TSum> {
  20. typename TPrimitiveDataType<TSum>::TArithmetic Sum_ = 0;
  21. ui8 IsValid_ = 0;
  22. };
  23. template<typename TSum>
  24. struct TSumState<false, TSum> {
  25. typename TPrimitiveDataType<TSum>::TArithmetic Sum_ = 0;
  26. };
  27. template<typename TOut>
  28. struct TAvgState {
  29. typename TPrimitiveDataType<TOut>::TArithmetic Sum_ = 0;
  30. ui64 Count_ = 0;
  31. };
  32. template <bool IsNullable, typename TSum>
  33. class TSumColumnBuilder : public IAggColumnBuilder {
  34. public:
  35. using TStateType = TSumState<IsNullable, TSum>;
  36. TSumColumnBuilder(ui64 size, TType* dataType, TComputationContext& ctx)
  37. : Builder_(dataType, TTypeInfoHelper(), ctx.ArrowMemoryPool, size)
  38. , Ctx_(ctx)
  39. {
  40. }
  41. void Add(const void* state) final {
  42. auto typedState = MakeStateWrapper<TStateType>(state);
  43. if constexpr (IsNullable) {
  44. if (!typedState->IsValid_) {
  45. Builder_.Add(TBlockItem());
  46. return;
  47. }
  48. }
  49. Builder_.Add(TBlockItem(TSum(typedState->Sum_)));
  50. }
  51. NUdf::TUnboxedValue Build() final {
  52. return Ctx_.HolderFactory.CreateArrowBlock(Builder_.Build(true));
  53. }
  54. private:
  55. NYql::NUdf::TFixedSizeArrayBuilder<TSum, IsNullable> Builder_;
  56. TComputationContext& Ctx_;
  57. };
  58. template<typename TOut>
  59. class TAvgStateColumnBuilder : public IAggColumnBuilder {
  60. public:
  61. TAvgStateColumnBuilder(ui64 size, TType* outputType, TComputationContext& ctx)
  62. : Ctx_(ctx)
  63. , Builder_(MakeArrayBuilder(TTypeInfoHelper(), outputType, ctx.ArrowMemoryPool, size, &ctx.Builder->GetPgBuilder()))
  64. {
  65. }
  66. void Add(const void* state) final {
  67. auto typedState = MakeStateWrapper<TAvgState<TOut>>(state);
  68. auto tupleBuilder = static_cast<NUdf::TTupleArrayBuilder<true>*>(Builder_.get());
  69. if (typedState->Count_) {
  70. TBlockItem tupleItems[] = { TBlockItem(TOut(typedState->Sum_)), TBlockItem(typedState->Count_)} ;
  71. tupleBuilder->Add(TBlockItem(tupleItems));
  72. } else {
  73. tupleBuilder->Add(TBlockItem());
  74. }
  75. }
  76. NUdf::TUnboxedValue Build() final {
  77. return Ctx_.HolderFactory.CreateArrowBlock(Builder_->Build(true));
  78. }
  79. private:
  80. TComputationContext& Ctx_;
  81. const std::unique_ptr<IArrayBuilder> Builder_;
  82. };
  83. template<typename TOut>
  84. class TAvgResultColumnBuilder : public IAggColumnBuilder {
  85. public:
  86. TAvgResultColumnBuilder(ui64 size, TComputationContext& ctx)
  87. : Ctx_(ctx)
  88. , Builder_(TTypeInfoHelper(), arrow::TypeTraits<typename TPrimitiveDataType<TOut>::TResult>::type_singleton(), ctx.ArrowMemoryPool, size)
  89. {
  90. }
  91. void Add(const void* state) final {
  92. auto typedState = MakeStateWrapper<TAvgState<TOut>>(state);
  93. if (typedState->Count_) {
  94. Builder_.Add(TBlockItem(TOut(typedState->Sum_ / typedState->Count_)));
  95. } else {
  96. Builder_.Add(TBlockItem());
  97. }
  98. }
  99. NUdf::TUnboxedValue Build() final {
  100. return Ctx_.HolderFactory.CreateArrowBlock(Builder_.Build(true));
  101. }
  102. private:
  103. TComputationContext& Ctx_;
  104. NYql::NUdf::TFixedSizeArrayBuilder<TOut, /*Nullable=*/true> Builder_;
  105. };
  106. template <typename TTag, bool IsNullable, bool IsScalar, typename TIn, typename TSum>
  107. class TSumBlockAggregator;
  108. template <typename TTag, typename TIn, typename TOut>
  109. class TAvgBlockAggregator;
  110. template <bool IsNullable, bool IsScalar, typename TIn, typename TSum>
  111. class TSumBlockAggregator<TCombineAllTag, IsNullable, IsScalar, TIn, TSum> : public TCombineAllTag::TBase {
  112. public:
  113. using TBase = TCombineAllTag::TBase;
  114. using TStateType = TSumState<IsNullable, TSum>;
  115. using TInScalar = typename TPrimitiveDataType<TIn>::TScalarResult;
  116. TSumBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, TType* dataType, TComputationContext& ctx)
  117. : TBase(sizeof(TStateType), filterColumn, ctx)
  118. , ArgColumn_(argColumn)
  119. {
  120. Y_UNUSED(dataType);
  121. }
  122. void InitState(void* state) final {
  123. TStateType st;
  124. WriteUnaligned<TStateType>(state, st);
  125. }
  126. void DestroyState(void* state) noexcept final {
  127. static_assert(std::is_trivially_destructible<TStateType>::value);
  128. Y_UNUSED(state);
  129. }
  130. void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final {
  131. auto typedState = MakeStateWrapper<TStateType>(state);
  132. const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum();
  133. if constexpr (IsScalar) {
  134. Y_ENSURE(datum.is_scalar());
  135. if constexpr (IsNullable) {
  136. if (datum.scalar()->is_valid) {
  137. typedState->Sum_ += (filtered ? *filtered : batchLength) * Cast(datum.scalar_as<TInScalar>().value);
  138. typedState->IsValid_ = 1;
  139. }
  140. } else {
  141. typedState->Sum_ += (filtered ? *filtered : batchLength) * Cast(datum.scalar_as<TInScalar>().value);
  142. }
  143. } else {
  144. const auto& array = datum.array();
  145. auto ptr = array->GetValues<TIn>(1);
  146. auto len = array->length;
  147. auto nullCount = IsNullable ? array->GetNullCount() : 0;
  148. auto count = len - nullCount;
  149. if (!count) {
  150. return;
  151. }
  152. if (!filtered) {
  153. if constexpr (IsNullable) {
  154. typedState->IsValid_ = 1;
  155. }
  156. auto sum = typedState->Sum_;
  157. if (IsNullable && nullCount != 0) {
  158. auto nullBitmapPtr = array->GetValues<uint8_t>(0, 0);
  159. for (int64_t i = 0; i < len; ++i) {
  160. ui64 fullIndex = i + array->offset;
  161. ui8 notNull = (nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1;
  162. sum += SelectArg<TIn>(notNull, ptr[i], 0);
  163. }
  164. } else {
  165. for (int64_t i = 0; i < len; ++i) {
  166. sum += ptr[i];
  167. }
  168. }
  169. typedState->Sum_ = sum;
  170. } else {
  171. const auto& filterDatum = TArrowBlock::From(columns[*FilterColumn_]).GetDatum();
  172. const auto& filterArray = filterDatum.array();
  173. MKQL_ENSURE(filterArray->GetNullCount() == 0, "Expected non-nullable bool column");
  174. const ui8* filterBitmap = filterArray->template GetValues<uint8_t>(1);
  175. auto sum = typedState->Sum_;
  176. if (IsNullable && nullCount != 0) {
  177. ui64 count = 0;
  178. auto nullBitmapPtr = array->template GetValues<uint8_t>(0, 0);
  179. for (int64_t i = 0; i < len; ++i) {
  180. ui64 fullIndex = i + array->offset;
  181. ui8 notNullAndFiltered = ((nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1) & filterBitmap[i];
  182. sum += SelectArg<TIn>(notNullAndFiltered, ptr[i], 0);
  183. count += notNullAndFiltered;
  184. }
  185. if constexpr (IsNullable) {
  186. typedState->IsValid_ |= count ? 1 : 0;
  187. }
  188. } else {
  189. for (int64_t i = 0; i < len; ++i) {
  190. sum += SelectArg<TIn>(filterBitmap[i], ptr[i], 0);
  191. }
  192. if constexpr (IsNullable) {
  193. typedState->IsValid_ = 1;
  194. }
  195. }
  196. typedState->Sum_ = sum;
  197. }
  198. }
  199. }
  200. NUdf::TUnboxedValue FinishOne(const void* state) final {
  201. auto typedState = MakeStateWrapper<TStateType>(state);
  202. if constexpr (IsNullable) {
  203. if (!typedState->IsValid_) {
  204. return NUdf::TUnboxedValuePod();
  205. }
  206. }
  207. return NUdf::TUnboxedValuePod(TSum(typedState->Sum_));
  208. }
  209. private:
  210. const ui32 ArgColumn_;
  211. };
  212. template <bool IsNullable, bool IsScalar, typename TIn, typename TSum>
  213. void PushValueToState(TSumState<IsNullable, TSum>* typedState, const arrow::Datum& datum, ui64 row) {
  214. using TInScalar = typename TPrimitiveDataType<TIn>::TScalarResult;
  215. if constexpr (IsScalar) {
  216. Y_ENSURE(datum.is_scalar());
  217. if constexpr (IsNullable) {
  218. if (datum.scalar()->is_valid) {
  219. typedState->Sum_ += Cast(datum.scalar_as<TInScalar>().value);
  220. typedState->IsValid_ = 1;
  221. }
  222. } else {
  223. typedState->Sum_ += Cast(datum.scalar_as<TInScalar>().value);
  224. }
  225. } else {
  226. const auto& array = datum.array();
  227. auto ptr = array->GetValues<TIn>(1);
  228. if constexpr (IsNullable) {
  229. if (array->GetNullCount() == 0) {
  230. typedState->IsValid_ = 1;
  231. typedState->Sum_ += ptr[row];
  232. } else {
  233. auto nullBitmapPtr = array->GetValues<uint8_t>(0, 0);
  234. ui64 fullIndex = row + array->offset;
  235. ui8 notNull = (nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1;
  236. typedState->Sum_ += SelectArg<TIn>(notNull, ptr[row], 0);
  237. typedState->IsValid_ |= notNull;
  238. }
  239. } else {
  240. typedState->Sum_ += ptr[row];
  241. }
  242. }
  243. }
  244. template <bool IsNullable, bool IsScalar, typename TIn, typename TSum>
  245. class TSumBlockAggregator<TCombineKeysTag, IsNullable, IsScalar, TIn, TSum> : public TCombineKeysTag::TBase {
  246. public:
  247. using TBase = TCombineKeysTag::TBase;
  248. using TStateType = TSumState<IsNullable, TSum>;
  249. TSumBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, TType* dataType, TComputationContext& ctx)
  250. : TBase(sizeof(TStateType), filterColumn, ctx)
  251. , ArgColumn_(argColumn)
  252. , DataType_(dataType)
  253. {
  254. }
  255. void InitKey(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  256. TStateType st;
  257. WriteUnaligned<TStateType>(state, st);
  258. UpdateKey(state, batchNum, columns, row);
  259. }
  260. void DestroyState(void* state) noexcept final {
  261. static_assert(std::is_trivially_destructible<TStateType>::value);
  262. Y_UNUSED(state);
  263. }
  264. void UpdateKey(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  265. Y_UNUSED(batchNum);
  266. auto typedState = MakeStateWrapper<TStateType>(state);
  267. const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum();
  268. PushValueToState<IsNullable, IsScalar, TIn, TSum>(typedState.Get(), datum, row);
  269. }
  270. std::unique_ptr<IAggColumnBuilder> MakeStateBuilder(ui64 size) final {
  271. return std::make_unique<TSumColumnBuilder<IsNullable, TSum>>(size, DataType_, Ctx_);
  272. }
  273. private:
  274. const ui32 ArgColumn_;
  275. TType* const DataType_;
  276. };
  277. template <bool IsNullable, bool IsScalar, typename TIn, typename TSum>
  278. class TSumBlockAggregator<TFinalizeKeysTag, IsNullable, IsScalar, TIn, TSum> : public TFinalizeKeysTag::TBase {
  279. public:
  280. using TBase = TFinalizeKeysTag::TBase;
  281. using TStateType = TSumState<IsNullable, TSum>;
  282. TSumBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, TType* dataType, TComputationContext& ctx)
  283. : TBase(sizeof(TStateType), filterColumn, ctx)
  284. , ArgColumn_(argColumn)
  285. , DataType_(dataType)
  286. {
  287. }
  288. void LoadState(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  289. TStateType st;
  290. WriteUnaligned<TStateType>(state, st);
  291. UpdateState(state, batchNum, columns, row);
  292. }
  293. void DestroyState(void* state) noexcept final {
  294. static_assert(std::is_trivially_destructible<TStateType>::value);
  295. Y_UNUSED(state);
  296. }
  297. void UpdateState(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  298. Y_UNUSED(batchNum);
  299. auto typedState = MakeStateWrapper<TStateType>(state);
  300. const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum();
  301. PushValueToState<IsNullable, IsScalar, TIn, TSum>(typedState.Get(), datum, row);
  302. }
  303. std::unique_ptr<IAggColumnBuilder> MakeResultBuilder(ui64 size) final {
  304. return std::make_unique<TSumColumnBuilder<IsNullable, TSum>>(size, DataType_, Ctx_);
  305. }
  306. private:
  307. const ui32 ArgColumn_;
  308. TType* const DataType_;
  309. };
  310. template<typename TIn, typename TOut>
  311. class TAvgBlockAggregator<TCombineAllTag, TIn, TOut> : public TCombineAllTag::TBase {
  312. public:
  313. using TBase = TCombineAllTag::TBase;
  314. using TInScalar = typename TPrimitiveDataType<TIn>::TScalarResult;
  315. TAvgBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, TType* outputType, TComputationContext& ctx)
  316. : TBase(sizeof(TAvgState<TOut>), filterColumn, ctx)
  317. , ArgColumn_(argColumn)
  318. {
  319. Y_UNUSED(outputType);
  320. }
  321. void InitState(void* state) final {
  322. TAvgState<TOut> st;
  323. WriteUnaligned<TAvgState<TOut>>(state, st);
  324. }
  325. void DestroyState(void* state) noexcept final {
  326. static_assert(std::is_trivially_destructible<TAvgState<TOut>>::value);
  327. Y_UNUSED(state);
  328. }
  329. void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final {
  330. auto typedState = MakeStateWrapper<TAvgState<TOut>>(state);
  331. const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum();
  332. if (datum.is_scalar()) {
  333. if (datum.scalar()->is_valid) {
  334. typedState->Sum_ += (filtered ? *filtered : batchLength) * Cast(datum.scalar_as<TInScalar>().value);
  335. typedState->Count_ += batchLength;
  336. }
  337. } else {
  338. const auto& array = datum.array();
  339. auto ptr = array->GetValues<TIn>(1);
  340. auto len = array->length;
  341. auto count = len - array->GetNullCount();
  342. if (!count) {
  343. return;
  344. }
  345. if (!filtered) {
  346. typedState->Count_ += count;
  347. auto sum = typedState->Sum_;
  348. if (array->GetNullCount() == 0) {
  349. for (int64_t i = 0; i < len; ++i) {
  350. sum += ptr[i];
  351. }
  352. } else {
  353. auto nullBitmapPtr = array->GetValues<uint8_t>(0, 0);
  354. for (int64_t i = 0; i < len; ++i) {
  355. ui64 fullIndex = i + array->offset;
  356. // bit 1 -> mask 0xFF..FF, bit 0 -> mask 0x00..00
  357. ui8 notNull = (nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1;
  358. sum += SelectArg<TIn>(notNull, ptr[i], 0);
  359. }
  360. }
  361. typedState->Sum_ = sum;
  362. } else {
  363. const auto& filterDatum = TArrowBlock::From(columns[*FilterColumn_]).GetDatum();
  364. const auto& filterArray = filterDatum.array();
  365. MKQL_ENSURE(filterArray->GetNullCount() == 0, "Expected non-nullable bool column");
  366. const ui8* filterBitmap = filterArray->template GetValues<uint8_t>(1);
  367. auto sum = typedState->Sum_;
  368. ui64 count = typedState->Count_;
  369. if (array->GetNullCount() == 0) {
  370. for (int64_t i = 0; i < len; ++i) {
  371. ui8 filtered = filterBitmap[i];
  372. sum += SelectArg<TIn>(filterBitmap[i], ptr[i], 0);
  373. count += filtered;
  374. }
  375. } else {
  376. auto nullBitmapPtr = array->GetValues<uint8_t>(0, 0);
  377. for (int64_t i = 0; i < len; ++i) {
  378. ui64 fullIndex = i + array->offset;
  379. ui8 notNullAndFiltered = ((nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1) & filterBitmap[i];
  380. sum += SelectArg<TIn>(notNullAndFiltered, ptr[i], 0);
  381. count += notNullAndFiltered;
  382. }
  383. }
  384. typedState->Sum_ = sum;
  385. typedState->Count_ = count;
  386. }
  387. }
  388. }
  389. NUdf::TUnboxedValue FinishOne(const void* state) final {
  390. auto typedState = MakeStateWrapper<TAvgState<TOut>>(state);
  391. if (!typedState->Count_) {
  392. return NUdf::TUnboxedValuePod();
  393. }
  394. NUdf::TUnboxedValue* items;
  395. auto arr = Ctx_.HolderFactory.CreateDirectArrayHolder(2, items);
  396. items[0] = NUdf::TUnboxedValuePod(TOut(typedState->Sum_));
  397. items[1] = NUdf::TUnboxedValuePod(typedState->Count_);
  398. return arr;
  399. }
  400. private:
  401. ui32 ArgColumn_;
  402. };
  403. template <typename TIn, typename TOut>
  404. class TAvgBlockAggregator<TCombineKeysTag, TIn, TOut> : public TCombineKeysTag::TBase {
  405. public:
  406. using TBase = TCombineKeysTag::TBase;
  407. using TInScalar = typename TPrimitiveDataType<TIn>::TScalarResult;
  408. TAvgBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, TType* outputType, TComputationContext& ctx)
  409. : TBase(sizeof(TAvgState<TOut>), filterColumn, ctx)
  410. , ArgColumn_(argColumn)
  411. , OutputType_(outputType)
  412. {
  413. }
  414. void InitKey(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  415. TAvgState<TOut> st;
  416. WriteUnaligned<TAvgState<TOut>>(state, st);
  417. UpdateKey(state, batchNum, columns, row);
  418. }
  419. void DestroyState(void* state) noexcept final {
  420. static_assert(std::is_trivially_destructible<TAvgState<TOut>>::value);
  421. Y_UNUSED(state);
  422. }
  423. void UpdateKey(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  424. Y_UNUSED(batchNum);
  425. auto typedState = MakeStateWrapper<TAvgState<TOut>>(state);
  426. const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum();
  427. if (datum.is_scalar()) {
  428. if (datum.scalar()->is_valid) {
  429. typedState->Sum_ += Cast(datum.scalar_as<TInScalar>().value);
  430. typedState->Count_ += 1;
  431. }
  432. } else {
  433. const auto& array = datum.array();
  434. auto ptr = array->GetValues<TIn>(1);
  435. if (array->GetNullCount() == 0) {
  436. typedState->Sum_ += ptr[row];
  437. typedState->Count_ += 1;
  438. } else {
  439. auto nullBitmapPtr = array->GetValues<uint8_t>(0, 0);
  440. ui64 fullIndex = row + array->offset;
  441. ui8 notNull = (nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1;
  442. typedState->Sum_ += SelectArg<TIn>(notNull, ptr[row], 0);
  443. typedState->Count_ += notNull;
  444. }
  445. }
  446. }
  447. std::unique_ptr<IAggColumnBuilder> MakeStateBuilder(ui64 size) final {
  448. return std::make_unique<TAvgStateColumnBuilder<TOut>>(size, OutputType_, Ctx_);
  449. }
  450. private:
  451. const ui32 ArgColumn_;
  452. TType* const OutputType_;
  453. };
  454. template<typename TOut>
  455. class TAvgBlockAggregatorOverState : public TFinalizeKeysTag::TBase {
  456. public:
  457. using TBase = TFinalizeKeysTag::TBase;
  458. using TInScalar = typename TPrimitiveDataType<TOut>::TScalarResult;
  459. TAvgBlockAggregatorOverState(ui32 argColumn, TComputationContext& ctx)
  460. : TBase(sizeof(TAvgState<TOut>), {}, ctx)
  461. , ArgColumn_(argColumn)
  462. {
  463. }
  464. void LoadState(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  465. TAvgState<TOut> st;
  466. WriteUnaligned<TAvgState<TOut>>(state, st);
  467. UpdateState(state, batchNum, columns, row);
  468. }
  469. void DestroyState(void* state) noexcept final {
  470. static_assert(std::is_trivially_destructible<TAvgState<TOut>>::value);
  471. Y_UNUSED(state);
  472. }
  473. void UpdateState(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  474. Y_UNUSED(batchNum);
  475. auto typedState = MakeStateWrapper<TAvgState<TOut>>(state);
  476. const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum();
  477. if (datum.is_scalar()) {
  478. if (datum.scalar()->is_valid) {
  479. const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*datum.scalar());
  480. typedState->Sum_ += Cast(arrow::internal::checked_cast<const TInScalar&>(*structScalar.value[0]).value);
  481. typedState->Count_ += arrow::internal::checked_cast<const arrow::UInt64Scalar&>(*structScalar.value[1]).value;
  482. }
  483. } else {
  484. const auto& array = datum.array();
  485. auto sumPtr = array->child_data[0]->GetValues<TOut>(1);
  486. auto countPtr = array->child_data[1]->GetValues<ui64>(1);
  487. if (array->GetNullCount() == 0) {
  488. typedState->Sum_ += sumPtr[row];
  489. typedState->Count_ += countPtr[row];
  490. } else {
  491. auto nullBitmapPtr = array->GetValues<uint8_t>(0, 0);
  492. ui64 fullIndex = row + array->offset;
  493. // bit 1 -> mask 0xFF..FF, bit 0 -> mask 0x00..00
  494. auto bit = (nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1;
  495. ui64 mask = -ui64(bit);
  496. typedState->Sum_ += sumPtr[row] * bit;
  497. typedState->Count_ += mask & countPtr[row];
  498. }
  499. }
  500. }
  501. std::unique_ptr<IAggColumnBuilder> MakeResultBuilder(ui64 size) final {
  502. return std::make_unique<TAvgResultColumnBuilder<TOut>>(size, Ctx_);
  503. }
  504. private:
  505. const ui32 ArgColumn_;
  506. };
  507. template <typename TTag, bool IsNullable, bool IsScalar, typename TIn, typename TSum>
  508. class TPreparedSumBlockAggregator : public TTag::TPreparedAggregator {
  509. public:
  510. using TBase = typename TTag::TPreparedAggregator;
  511. using TStateType = TSumState<IsNullable, TSum>;
  512. TPreparedSumBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, TType* dataType)
  513. : TBase(sizeof(TStateType))
  514. , FilterColumn_(filterColumn)
  515. , ArgColumn_(argColumn)
  516. , DataType_(dataType)
  517. {}
  518. std::unique_ptr<typename TTag::TAggregator> Make(TComputationContext& ctx) const final {
  519. return std::make_unique<TSumBlockAggregator<TTag, IsNullable, IsScalar, TIn, TSum>>(FilterColumn_, ArgColumn_, DataType_, ctx);
  520. }
  521. private:
  522. const std::optional<ui32> FilterColumn_;
  523. const ui32 ArgColumn_;
  524. TType* const DataType_;
  525. };
  526. template<typename TTag, typename TIn, typename TSum>
  527. std::unique_ptr<typename TTag::TPreparedAggregator> PrepareSumFixed(TType* type, bool isOptional, bool isScalar, std::optional<ui32> filterColumn, ui32 argColumn) {
  528. if (isScalar) {
  529. if (isOptional) {
  530. return std::make_unique<TPreparedSumBlockAggregator<TTag, true, true, TIn, TSum>>(filterColumn, argColumn, type);
  531. }
  532. return std::make_unique<TPreparedSumBlockAggregator<TTag, false, true, TIn, TSum>>(filterColumn, argColumn, type);
  533. }
  534. if (isOptional) {
  535. return std::make_unique<TPreparedSumBlockAggregator<TTag, true, false, TIn, TSum>>(filterColumn, argColumn, type);
  536. }
  537. return std::make_unique<TPreparedSumBlockAggregator<TTag, false, false, TIn, TSum>>(filterColumn, argColumn, type);
  538. }
  539. template <typename TTag>
  540. std::unique_ptr<typename TTag::TPreparedAggregator> PrepareSum(TTupleType* tupleType, std::optional<ui32> filterColumn, ui32 argColumn, const TTypeEnvironment& env) {
  541. auto blockType = AS_TYPE(TBlockType, tupleType->GetElementType(argColumn));
  542. auto argType = blockType->GetItemType();
  543. bool isOptional;
  544. auto dataType = UnpackOptionalData(argType, isOptional);
  545. bool isScalar = blockType->GetShape() == TBlockType::EShape::Scalar;
  546. TType* sumRetType = nullptr;
  547. const auto& typeInfo = NYql::NUdf::GetDataTypeInfo(*dataType->GetDataSlot());
  548. if (typeInfo.Features & NYql::NUdf::EDataTypeFeatures::SignedIntegralType) {
  549. sumRetType = TDataType::Create(NUdf::TDataType<i64>::Id, env);
  550. } else if (typeInfo.Features & NYql::NUdf::EDataTypeFeatures::UnsignedIntegralType) {
  551. sumRetType = TDataType::Create(NUdf::TDataType<ui64>::Id, env);
  552. } else if (*dataType->GetDataSlot() == NUdf::EDataSlot::Decimal) {
  553. auto decimalType = static_cast<TDataDecimalType*>(dataType);
  554. auto [_, scale] = decimalType->GetParams();
  555. sumRetType = TDataDecimalType::Create(NYql::NDecimal::MaxPrecision, scale, env);
  556. } else if (*dataType->GetDataSlot() == NUdf::EDataSlot::Interval) {
  557. sumRetType = TDataDecimalType::Create(NYql::NDecimal::MaxPrecision, 0, env);
  558. } else {
  559. Y_ENSURE(typeInfo.Features & NYql::NUdf::EDataTypeFeatures::FloatType);
  560. sumRetType = dataType;
  561. }
  562. sumRetType = TOptionalType::Create(sumRetType, env);
  563. switch (*dataType->GetDataSlot()) {
  564. case NUdf::EDataSlot::Int8:
  565. return PrepareSumFixed<TTag, i8, i64>(sumRetType, isOptional, isScalar, filterColumn, argColumn);
  566. case NUdf::EDataSlot::Uint8:
  567. return PrepareSumFixed<TTag, ui8, ui64>(sumRetType, isOptional, isScalar, filterColumn, argColumn);
  568. case NUdf::EDataSlot::Int16:
  569. return PrepareSumFixed<TTag, i16, i64>(sumRetType, isOptional, isScalar, filterColumn, argColumn);
  570. case NUdf::EDataSlot::Uint16:
  571. return PrepareSumFixed<TTag, ui16, ui64>(sumRetType, isOptional, isScalar, filterColumn, argColumn);
  572. case NUdf::EDataSlot::Int32:
  573. return PrepareSumFixed<TTag, i32, i64>(sumRetType, isOptional, isScalar, filterColumn, argColumn);
  574. case NUdf::EDataSlot::Uint32:
  575. return PrepareSumFixed<TTag, ui32, ui64>(sumRetType, isOptional, isScalar, filterColumn, argColumn);
  576. case NUdf::EDataSlot::Int64:
  577. return PrepareSumFixed<TTag, i64, i64>(sumRetType, isOptional, isScalar, filterColumn, argColumn);
  578. case NUdf::EDataSlot::Uint64:
  579. return PrepareSumFixed<TTag, ui64, ui64>(sumRetType, isOptional, isScalar, filterColumn, argColumn);
  580. case NUdf::EDataSlot::Float:
  581. return PrepareSumFixed<TTag, float, float>(sumRetType, isOptional, isScalar, filterColumn, argColumn);
  582. case NUdf::EDataSlot::Double:
  583. return PrepareSumFixed<TTag, double, double>(sumRetType, isOptional, isScalar, filterColumn, argColumn);
  584. case NUdf::EDataSlot::Interval:
  585. return PrepareSumFixed<TTag, i64, NYql::NDecimal::TInt128>(sumRetType, isOptional, isScalar, filterColumn, argColumn);
  586. case NUdf::EDataSlot::Decimal:
  587. return PrepareSumFixed<TTag, NYql::NDecimal::TInt128, NYql::NDecimal::TInt128>(sumRetType, isOptional, isScalar, filterColumn, argColumn);
  588. default:
  589. throw yexception() << "Unsupported SUM input type";
  590. }
  591. }
  592. class TBlockSumFactory : public IBlockAggregatorFactory {
  593. public:
  594. std::unique_ptr<TCombineAllTag::TPreparedAggregator> PrepareCombineAll(
  595. TTupleType* tupleType,
  596. std::optional<ui32> filterColumn,
  597. const std::vector<ui32>& argsColumns,
  598. const TTypeEnvironment& env) const final
  599. {
  600. return PrepareSum<TCombineAllTag>(tupleType, filterColumn, argsColumns[0], env);
  601. }
  602. std::unique_ptr<TCombineKeysTag::TPreparedAggregator> PrepareCombineKeys(
  603. TTupleType* tupleType,
  604. const std::vector<ui32>& argsColumns,
  605. const TTypeEnvironment& env) const final
  606. {
  607. return PrepareSum<TCombineKeysTag>(tupleType, std::optional<ui32>(), argsColumns[0], env);
  608. }
  609. std::unique_ptr<TFinalizeKeysTag::TPreparedAggregator> PrepareFinalizeKeys(
  610. TTupleType* tupleType,
  611. const std::vector<ui32>& argsColumns,
  612. const TTypeEnvironment& env,
  613. TType* returnType,
  614. ui32 hint) const final
  615. {
  616. Y_UNUSED(returnType);
  617. Y_UNUSED(hint);
  618. return PrepareSum<TFinalizeKeysTag>(tupleType, std::optional<ui32>(), argsColumns[0], env);
  619. }
  620. };
  621. template <typename TTag, typename TIn, typename TOut>
  622. class TPreparedAvgBlockAggregator : public TTag::TPreparedAggregator {
  623. public:
  624. using TBase = typename TTag::TPreparedAggregator;
  625. TPreparedAvgBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, TType* outputType)
  626. : TBase(sizeof(TAvgState<TOut>))
  627. , FilterColumn_(filterColumn)
  628. , ArgColumn_(argColumn)
  629. , OutputType_(outputType)
  630. {}
  631. std::unique_ptr<typename TTag::TAggregator> Make(TComputationContext& ctx) const final {
  632. return std::make_unique<TAvgBlockAggregator<TTag, TIn, TOut>>(FilterColumn_, ArgColumn_, OutputType_, ctx);
  633. }
  634. private:
  635. const std::optional<ui32> FilterColumn_;
  636. const ui32 ArgColumn_;
  637. TType* const OutputType_;
  638. };
  639. template<typename TOut>
  640. class TPreparedAvgBlockAggregatorOverState : public TFinalizeKeysTag::TPreparedAggregator {
  641. public:
  642. using TBase = TFinalizeKeysTag::TPreparedAggregator;
  643. TPreparedAvgBlockAggregatorOverState(ui32 argColumn)
  644. : TBase(sizeof(TAvgState<TOut>))
  645. , ArgColumn_(argColumn)
  646. {}
  647. std::unique_ptr<typename TFinalizeKeysTag::TAggregator> Make(TComputationContext& ctx) const final {
  648. return std::make_unique<TAvgBlockAggregatorOverState<TOut>>(ArgColumn_, ctx);
  649. }
  650. private:
  651. const ui32 ArgColumn_;
  652. };
  653. template <typename TTag>
  654. std::unique_ptr<typename TTag::TPreparedAggregator> PrepareAvg(TTupleType* tupleType, std::optional<ui32> filterColumn, ui32 argColumn, const TTypeEnvironment& env);
  655. template <typename TTag>
  656. std::unique_ptr<typename TTag::TPreparedAggregator> PrepareAvgOverInput(TTupleType* tupleType, std::optional<ui32> filterColumn, ui32 argColumn, const TTypeEnvironment& env) {
  657. auto doubleType = TDataType::Create(NUdf::TDataType<double>::Id, env);
  658. auto ui64Type = TDataType::Create(NUdf::TDataType<ui64>::Id, env);
  659. TVector<TType*> tupleElements = { doubleType, ui64Type };
  660. auto avgRetType = TOptionalType::Create(TTupleType::Create(2, tupleElements.data(), env), env);
  661. auto argType = AS_TYPE(TBlockType, tupleType->GetElementType(argColumn))->GetItemType();
  662. bool isOptional;
  663. auto dataType = UnpackOptionalData(argType, isOptional);
  664. switch (*dataType->GetDataSlot()) {
  665. case NUdf::EDataSlot::Int8:
  666. return std::make_unique<TPreparedAvgBlockAggregator<TTag, i8, double>>(filterColumn, argColumn, avgRetType);
  667. case NUdf::EDataSlot::Uint8:
  668. case NUdf::EDataSlot::Bool:
  669. return std::make_unique<TPreparedAvgBlockAggregator<TTag, ui8, double>>(filterColumn, argColumn, avgRetType);
  670. case NUdf::EDataSlot::Int16:
  671. return std::make_unique<TPreparedAvgBlockAggregator<TTag, i16, double>>(filterColumn, argColumn, avgRetType);
  672. case NUdf::EDataSlot::Uint16:
  673. return std::make_unique<TPreparedAvgBlockAggregator<TTag, ui16, double>>(filterColumn, argColumn, avgRetType);
  674. case NUdf::EDataSlot::Int32:
  675. return std::make_unique<TPreparedAvgBlockAggregator<TTag, i32, double>>(filterColumn, argColumn, avgRetType);
  676. case NUdf::EDataSlot::Uint32:
  677. return std::make_unique<TPreparedAvgBlockAggregator<TTag, ui32, double>>(filterColumn, argColumn, avgRetType);
  678. case NUdf::EDataSlot::Int64:
  679. return std::make_unique<TPreparedAvgBlockAggregator<TTag, i64, double>>(filterColumn, argColumn, avgRetType);
  680. case NUdf::EDataSlot::Uint64:
  681. return std::make_unique<TPreparedAvgBlockAggregator<TTag, ui64, double>>(filterColumn, argColumn, avgRetType);
  682. case NUdf::EDataSlot::Float:
  683. return std::make_unique<TPreparedAvgBlockAggregator<TTag, float, double>>(filterColumn, argColumn, avgRetType);
  684. case NUdf::EDataSlot::Double:
  685. return std::make_unique<TPreparedAvgBlockAggregator<TTag, double, double>>(filterColumn, argColumn, avgRetType);
  686. case NUdf::EDataSlot::Interval: {
  687. auto decimalType = TDataDecimalType::Create(NYql::NDecimal::MaxPrecision, 0, env);
  688. TVector<TType*> tupleDecimalElements = { decimalType, ui64Type };
  689. auto avgRetDecimalType = TOptionalType::Create(TTupleType::Create(2, tupleDecimalElements.data(), env), env);
  690. return std::make_unique<TPreparedAvgBlockAggregator<TTag, i64, NYql::NDecimal::TInt128>>(filterColumn, argColumn, avgRetDecimalType);
  691. }
  692. case NUdf::EDataSlot::Decimal: {
  693. auto [precision, scale] = static_cast<TDataDecimalType*>(dataType)->GetParams();
  694. auto decimalType = TDataDecimalType::Create(precision, scale, env);
  695. TVector<TType*> tupleDecimalElements = { decimalType, ui64Type };
  696. auto avgRetDecimalType = TOptionalType::Create(TTupleType::Create(2, tupleDecimalElements.data(), env), env);
  697. return std::make_unique<TPreparedAvgBlockAggregator<TTag, NYql::NDecimal::TInt128, NYql::NDecimal::TInt128>>(filterColumn, argColumn, avgRetDecimalType);
  698. }
  699. default:
  700. throw yexception() << "Unsupported AVG input type";
  701. }
  702. }
  703. template <>
  704. std::unique_ptr<typename TCombineAllTag::TPreparedAggregator> PrepareAvg<TCombineAllTag>(TTupleType* tupleType, std::optional<ui32> filterColumn, ui32 argColumn, const TTypeEnvironment& env) {
  705. return PrepareAvgOverInput<TCombineAllTag>(tupleType, filterColumn, argColumn, env);
  706. }
  707. template <>
  708. std::unique_ptr<typename TCombineKeysTag::TPreparedAggregator> PrepareAvg<TCombineKeysTag>(TTupleType* tupleType, std::optional<ui32> filterColumn, ui32 argColumn, const TTypeEnvironment& env) {
  709. return PrepareAvgOverInput<TCombineKeysTag>(tupleType, filterColumn, argColumn, env);
  710. }
  711. template <>
  712. std::unique_ptr<typename TFinalizeKeysTag::TPreparedAggregator> PrepareAvg<TFinalizeKeysTag>(TTupleType* tupleType, std::optional<ui32> filterColumn, ui32 argColumn, const TTypeEnvironment& env) {
  713. Y_UNUSED(filterColumn);
  714. Y_UNUSED(env);
  715. auto argType = AS_TYPE(TBlockType, tupleType->GetElementType(argColumn))->GetItemType();
  716. bool isOptional;
  717. auto aggTupleType = UnpackOptional(argType, isOptional);
  718. MKQL_ENSURE(aggTupleType->IsTuple(),
  719. "Expected tuple or optional of tuple, actual: " << PrintNode(argType, true));
  720. auto dataType = UnpackOptionalData(AS_TYPE(TTupleType, aggTupleType)->GetElementType(0), isOptional);
  721. switch (*dataType->GetDataSlot()) {
  722. case NUdf::EDataSlot::Decimal:
  723. return std::make_unique<TPreparedAvgBlockAggregatorOverState<NYql::NDecimal::TInt128>>(argColumn);
  724. case NUdf::EDataSlot::Double:
  725. return std::make_unique<TPreparedAvgBlockAggregatorOverState<double>>(argColumn);
  726. default:
  727. throw yexception() << "Unsupported Finalize input type";
  728. }
  729. }
  730. class TBlockAvgFactory : public IBlockAggregatorFactory {
  731. public:
  732. std::unique_ptr<TCombineAllTag::TPreparedAggregator> PrepareCombineAll(
  733. TTupleType* tupleType,
  734. std::optional<ui32> filterColumn,
  735. const std::vector<ui32>& argsColumns,
  736. const TTypeEnvironment& env) const final {
  737. return PrepareAvg<TCombineAllTag>(tupleType, filterColumn, argsColumns[0], env);
  738. }
  739. std::unique_ptr<TCombineKeysTag::TPreparedAggregator> PrepareCombineKeys(
  740. TTupleType* tupleType,
  741. const std::vector<ui32>& argsColumns,
  742. const TTypeEnvironment& env) const final {
  743. return PrepareAvg<TCombineKeysTag>(tupleType, std::optional<ui32>(), argsColumns[0], env);
  744. }
  745. std::unique_ptr<TFinalizeKeysTag::TPreparedAggregator> PrepareFinalizeKeys(
  746. TTupleType* tupleType,
  747. const std::vector<ui32>& argsColumns,
  748. const TTypeEnvironment& env,
  749. TType* returnType,
  750. ui32 hint) const final {
  751. Y_UNUSED(returnType);
  752. Y_UNUSED(hint);
  753. return PrepareAvg<TFinalizeKeysTag>(tupleType, std::optional<ui32>(), argsColumns[0], env);
  754. }
  755. };
  756. }
  757. std::unique_ptr<IBlockAggregatorFactory> MakeBlockSumFactory() {
  758. return std::make_unique<TBlockSumFactory>();
  759. }
  760. std::unique_ptr<IBlockAggregatorFactory> MakeBlockAvgFactory() {
  761. return std::make_unique<TBlockAvgFactory>();
  762. }
  763. }
  764. }