mkql_block_agg_count.cpp 14 KB


  1. #include "mkql_block_agg_count.h"
  2. #include <yql/essentials/minikql/arrow/arrow_defs.h>
  3. #include <yql/essentials/minikql/computation/mkql_block_builder.h>
  4. namespace NKikimr {
  5. namespace NMiniKQL {
  6. namespace {
  7. struct TState {
  8. ui64 Count_ = 0;
  9. };
  10. class TColumnBuilder : public IAggColumnBuilder {
  11. public:
  12. TColumnBuilder(ui64 size, TComputationContext& ctx)
  13. : Builder_(TTypeInfoHelper(), arrow::uint64(), ctx.ArrowMemoryPool, size)
  14. , Ctx_(ctx)
  15. {
  16. }
  17. void Add(const void* state) final {
  18. auto typedState = static_cast<const TState*>(state);
  19. Builder_.Add(TBlockItem(typedState->Count_));
  20. }
  21. NUdf::TUnboxedValue Build() final {
  22. return Ctx_.HolderFactory.CreateArrowBlock(Builder_.Build(true));
  23. }
  24. private:
  25. NYql::NUdf::TFixedSizeArrayBuilder<ui64, false> Builder_;
  26. TComputationContext& Ctx_;
  27. };
  28. template <typename TTag>
  29. class TCountAllAggregator;
  30. template <typename TTag>
  31. class TCountAggregator;
  32. template <>
  33. class TCountAllAggregator<TCombineAllTag> : public TCombineAllTag::TBase {
  34. public:
  35. using TBase = TCombineAllTag::TBase;
  36. TCountAllAggregator(std::optional<ui32> filterColumn, ui32 argColumn, TComputationContext& ctx)
  37. : TBase(sizeof(TState), filterColumn, ctx)
  38. {
  39. Y_UNUSED(argColumn);
  40. }
  41. void InitState(void* state) final {
  42. new(state) TState();
  43. }
  44. void DestroyState(void* state) noexcept final {
  45. static_assert(std::is_trivially_destructible<TState>::value);
  46. Y_UNUSED(state);
  47. }
  48. void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final {
  49. auto typedState = static_cast<TState*>(state);
  50. Y_UNUSED(columns);
  51. if (filtered) {
  52. typedState->Count_ += *filtered;
  53. }
  54. else {
  55. typedState->Count_ += batchLength;
  56. }
  57. }
  58. NUdf::TUnboxedValue FinishOne(const void* state) final {
  59. auto typedState = static_cast<const TState*>(state);
  60. return NUdf::TUnboxedValuePod(typedState->Count_);
  61. }
  62. };
  63. template <>
  64. class TCountAllAggregator<TCombineKeysTag> : public TCombineKeysTag::TBase {
  65. public:
  66. using TBase = TCombineKeysTag::TBase;
  67. TCountAllAggregator(std::optional<ui32> filterColumn, ui32 argColumn, TComputationContext& ctx)
  68. : TBase(sizeof(TState), filterColumn, ctx)
  69. {
  70. Y_UNUSED(argColumn);
  71. }
  72. void InitKey(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  73. new(state) TState();
  74. UpdateKey(state, batchNum, columns, row);
  75. }
  76. void DestroyState(void* state) noexcept final {
  77. static_assert(std::is_trivially_destructible<TState>::value);
  78. Y_UNUSED(state);
  79. }
  80. void UpdateKey(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  81. Y_UNUSED(batchNum);
  82. Y_UNUSED(columns);
  83. Y_UNUSED(row);
  84. auto typedState = static_cast<TState*>(state);
  85. typedState->Count_ += 1;
  86. }
  87. std::unique_ptr<IAggColumnBuilder> MakeStateBuilder(ui64 size) final {
  88. return std::make_unique<TColumnBuilder>(size, Ctx_);
  89. }
  90. };
  91. template <>
  92. class TCountAllAggregator<TFinalizeKeysTag> : public TFinalizeKeysTag::TBase {
  93. public:
  94. using TBase = TFinalizeKeysTag::TBase;
  95. TCountAllAggregator(std::optional<ui32> filterColumn, ui32 argColumn, TComputationContext& ctx)
  96. : TBase(sizeof(TState), filterColumn, ctx)
  97. , ArgColumn_(argColumn)
  98. {
  99. }
  100. void LoadState(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  101. new(state) TState();
  102. UpdateState(state, batchNum, columns, row);
  103. }
  104. void DestroyState(void* state) noexcept final {
  105. static_assert(std::is_trivially_destructible<TState>::value);
  106. Y_UNUSED(state);
  107. }
  108. void UpdateState(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  109. Y_UNUSED(batchNum);
  110. auto typedState = static_cast<TState*>(state);
  111. const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum();
  112. if (datum.is_scalar()) {
  113. MKQL_ENSURE(datum.scalar()->is_valid, "Expected not null");
  114. typedState->Count_ += datum.scalar_as<arrow::UInt64Scalar>().value;
  115. } else {
  116. const auto& array = datum.array();
  117. auto ptr = array->GetValues<ui64>(1);
  118. MKQL_ENSURE(array->GetNullCount() == 0, "Expected not null");
  119. typedState->Count_ += ptr[row];
  120. }
  121. }
  122. std::unique_ptr<IAggColumnBuilder> MakeResultBuilder(ui64 size) final {
  123. return std::make_unique<TColumnBuilder>(size, Ctx_);
  124. }
  125. private:
  126. const ui32 ArgColumn_;
  127. };
  128. template <>
  129. class TCountAggregator<TCombineAllTag> : public TCombineAllTag::TBase {
  130. public:
  131. using TBase = TCombineAllTag::TBase;
  132. TCountAggregator(std::optional<ui32> filterColumn, ui32 argColumn, TComputationContext& ctx)
  133. : TBase(sizeof(TState), filterColumn, ctx)
  134. , ArgColumn_(argColumn)
  135. {
  136. }
  137. void InitState(void* state) final {
  138. new(state) TState();
  139. }
  140. void DestroyState(void* state) noexcept final {
  141. static_assert(std::is_trivially_destructible<TState>::value);
  142. Y_UNUSED(state);
  143. }
  144. void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final {
  145. auto typedState = static_cast<TState*>(state);
  146. const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum();
  147. if (datum.is_scalar()) {
  148. if (datum.scalar()->is_valid) {
  149. typedState->Count_ += filtered ? *filtered : batchLength;
  150. }
  151. } else {
  152. const auto& array = datum.array();
  153. if (!filtered) {
  154. typedState->Count_ += array->length - array->GetNullCount();
  155. } else if (array->GetNullCount() == array->length) {
  156. // all nulls
  157. return;
  158. } else if (array->GetNullCount() == 0) {
  159. // no nulls
  160. typedState->Count_ += *filtered;
  161. } else {
  162. const auto& filterDatum = TArrowBlock::From(columns[*FilterColumn_]).GetDatum();
  163. // intersect masks from nulls and filter column
  164. const auto& filterArray = filterDatum.array();
  165. MKQL_ENSURE(filterArray->GetNullCount() == 0, "Expected non-nullable bool column");
  166. auto nullBitmapPtr = array->GetValues<uint8_t>(0, 0);
  167. const ui8* filterBitmap = filterArray->GetValues<uint8_t>(1);
  168. auto state = typedState->Count_;
  169. for (ui32 i = 0; i < array->length; ++i) {
  170. ui64 fullIndex = i + array->offset;
  171. auto bit1 = ((nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1);
  172. auto bit2 = filterBitmap[i];
  173. state += bit1 & bit2;
  174. }
  175. typedState->Count_ = state;
  176. }
  177. }
  178. }
  179. NUdf::TUnboxedValue FinishOne(const void* state) final {
  180. auto typedState = static_cast<const TState*>(state);
  181. return NUdf::TUnboxedValuePod(typedState->Count_);
  182. }
  183. private:
  184. const ui32 ArgColumn_;
  185. };
  186. template <>
  187. class TCountAggregator<TCombineKeysTag> : public TCombineKeysTag::TBase {
  188. public:
  189. using TBase = TCombineKeysTag::TBase;
  190. TCountAggregator(std::optional<ui32> filterColumn, ui32 argColumn, TComputationContext& ctx)
  191. : TBase(sizeof(TState), filterColumn, ctx)
  192. , ArgColumn_(argColumn)
  193. {
  194. }
  195. void InitKey(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  196. new(state) TState();
  197. UpdateKey(state, batchNum, columns, row);
  198. }
  199. void DestroyState(void* state) noexcept final {
  200. static_assert(std::is_trivially_destructible<TState>::value);
  201. Y_UNUSED(state);
  202. }
  203. void UpdateKey(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
  204. Y_UNUSED(batchNum);
  205. auto typedState = static_cast<TState*>(state);
  206. const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum();
  207. if (datum.is_scalar()) {
  208. if (datum.scalar()->is_valid) {
  209. typedState->Count_ += 1;
  210. }
  211. } else {
  212. const auto& array = datum.array();
  213. if (array->GetNullCount() == 0) {
  214. typedState->Count_ += 1;
  215. } else {
  216. auto nullBitmapPtr = array->GetValues<uint8_t>(0, 0);
  217. auto fullIndex = row + array->offset;
  218. auto bit = ((nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1);
  219. typedState->Count_ += bit;
  220. }
  221. }
  222. }
  223. std::unique_ptr<IAggColumnBuilder> MakeStateBuilder(ui64 size) final {
  224. return std::make_unique<TColumnBuilder>(size, Ctx_);
  225. }
  226. private:
  227. const ui32 ArgColumn_;
  228. };
  229. template <>
  230. class TCountAggregator<TFinalizeKeysTag> : public TCountAllAggregator<TFinalizeKeysTag>
  231. {
  232. public:
  233. using TBase = TCountAllAggregator<TFinalizeKeysTag>;
  234. TCountAggregator(std::optional<ui32> filterColumn, ui32 argColumn, TComputationContext& ctx)
  235. : TBase(filterColumn, argColumn, ctx)
  236. {}
  237. };
  238. template <typename TTag>
  239. class TPreparedCountAll : public TTag::TPreparedAggregator {
  240. public:
  241. using TBase = typename TTag::TPreparedAggregator;
  242. TPreparedCountAll(std::optional<ui32> filterColumn, ui32 argColumn)
  243. : TBase(sizeof(TState))
  244. , FilterColumn_(filterColumn)
  245. , ArgColumn_(argColumn)
  246. {}
  247. std::unique_ptr<typename TTag::TAggregator> Make(TComputationContext& ctx) const final {
  248. return std::make_unique<TCountAllAggregator<TTag>>(FilterColumn_, ArgColumn_, ctx);
  249. }
  250. private:
  251. const std::optional<ui32> FilterColumn_;
  252. const ui32 ArgColumn_;
  253. };
  254. template <typename TTag>
  255. class TPreparedCount : public TTag::TPreparedAggregator {
  256. public:
  257. using TBase = typename TTag::TPreparedAggregator;
  258. TPreparedCount(std::optional<ui32> filterColumn, ui32 argColumn)
  259. : TBase(sizeof(TState))
  260. , FilterColumn_(filterColumn)
  261. , ArgColumn_(argColumn)
  262. {}
  263. std::unique_ptr<typename TTag::TAggregator> Make(TComputationContext& ctx) const final {
  264. return std::make_unique<TCountAggregator<TTag>>(FilterColumn_, ArgColumn_, ctx);
  265. }
  266. private:
  267. const std::optional<ui32> FilterColumn_;
  268. const ui32 ArgColumn_;
  269. };
  270. template <typename TTag>
  271. std::unique_ptr<typename TTag::TPreparedAggregator> PrepareCountAll(std::optional<ui32> filterColumn, ui32 argColumn) {
  272. return std::make_unique<TPreparedCountAll<TTag>>(filterColumn, argColumn);
  273. }
  274. template <typename TTag>
  275. std::unique_ptr<typename TTag::TPreparedAggregator> PrepareCount(std::optional<ui32> filterColumn, ui32 argColumn) {
  276. return std::make_unique<TPreparedCount<TTag>>(filterColumn, argColumn);
  277. }
  278. class TBlockCountAllFactory : public IBlockAggregatorFactory {
  279. public:
  280. std::unique_ptr<TCombineAllTag::TPreparedAggregator> PrepareCombineAll(
  281. TTupleType* tupleType,
  282. std::optional<ui32> filterColumn,
  283. const std::vector<ui32>& argsColumns,
  284. const TTypeEnvironment& env) const final {
  285. Y_UNUSED(tupleType);
  286. Y_UNUSED(argsColumns);
  287. Y_UNUSED(env);
  288. return PrepareCountAll<TCombineAllTag>(filterColumn, 0);
  289. }
  290. std::unique_ptr<TCombineKeysTag::TPreparedAggregator> PrepareCombineKeys(
  291. TTupleType* tupleType,
  292. const std::vector<ui32>& argsColumns,
  293. const TTypeEnvironment& env) const final {
  294. Y_UNUSED(tupleType);
  295. Y_UNUSED(argsColumns);
  296. Y_UNUSED(env);
  297. return PrepareCountAll<TCombineKeysTag>(std::optional<ui32>(), 0);
  298. }
  299. std::unique_ptr<TFinalizeKeysTag::TPreparedAggregator> PrepareFinalizeKeys(
  300. TTupleType* tupleType,
  301. const std::vector<ui32>& argsColumns,
  302. const TTypeEnvironment& env,
  303. TType* returnType,
  304. ui32 hint) const final {
  305. Y_UNUSED(tupleType);
  306. Y_UNUSED(argsColumns);
  307. Y_UNUSED(env);
  308. Y_UNUSED(returnType);
  309. Y_UNUSED(hint);
  310. return PrepareCountAll<TFinalizeKeysTag>(std::optional<ui32>(), argsColumns[0]);
  311. }
  312. };
  313. class TBlockCountFactory : public IBlockAggregatorFactory {
  314. public:
  315. std::unique_ptr<TCombineAllTag::TPreparedAggregator> PrepareCombineAll(
  316. TTupleType* tupleType,
  317. std::optional<ui32> filterColumn,
  318. const std::vector<ui32>& argsColumns,
  319. const TTypeEnvironment& env) const final {
  320. Y_UNUSED(tupleType);
  321. Y_UNUSED(env);
  322. return PrepareCount<TCombineAllTag>(filterColumn, argsColumns[0]);
  323. }
  324. std::unique_ptr<TCombineKeysTag::TPreparedAggregator> PrepareCombineKeys(
  325. TTupleType* tupleType,
  326. const std::vector<ui32>& argsColumns,
  327. const TTypeEnvironment& env) const final {
  328. Y_UNUSED(tupleType);
  329. Y_UNUSED(argsColumns);
  330. Y_UNUSED(env);
  331. return PrepareCount<TCombineKeysTag>(std::optional<ui32>(), argsColumns[0]);
  332. }
  333. std::unique_ptr<TFinalizeKeysTag::TPreparedAggregator> PrepareFinalizeKeys(
  334. TTupleType* tupleType,
  335. const std::vector<ui32>& argsColumns,
  336. const TTypeEnvironment& env,
  337. TType* returnType,
  338. ui32 hint) const final {
  339. Y_UNUSED(tupleType);
  340. Y_UNUSED(argsColumns);
  341. Y_UNUSED(env);
  342. Y_UNUSED(returnType);
  343. Y_UNUSED(hint);
  344. return PrepareCount<TFinalizeKeysTag>(std::optional<ui32>(), argsColumns[0]);
  345. }
  346. };
  347. }
  348. std::unique_ptr<IBlockAggregatorFactory> MakeBlockCountAllFactory() {
  349. return std::make_unique<TBlockCountAllFactory>();
  350. }
  351. std::unique_ptr<IBlockAggregatorFactory> MakeBlockCountFactory() {
  352. return std::make_unique<TBlockCountFactory>();
  353. }
  354. }
  355. }