mkql_block_top.cpp 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735
  1. #include "mkql_block_top.h"
  2. #include <yql/essentials/minikql/computation/mkql_block_reader.h>
  3. #include <yql/essentials/minikql/computation/mkql_block_builder.h>
  4. #include <yql/essentials/minikql/computation/mkql_block_impl.h>
  5. #include <yql/essentials/minikql/computation/mkql_block_impl_codegen.h> // Y_IGNORE
  6. #include <yql/essentials/public/udf/arrow/block_item_comparator.h>
  7. #include <yql/essentials/minikql/arrow/arrow_defs.h>
  8. #include <yql/essentials/minikql/arrow/arrow_util.h>
  9. #include <yql/essentials/minikql/mkql_type_builder.h>
  10. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  11. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  12. #include <yql/essentials/minikql/mkql_node_builder.h>
  13. #include <yql/essentials/minikql/mkql_node_cast.h>
  14. #include <yql/essentials/utils/sort.h>
  15. namespace NKikimr {
  16. namespace NMiniKQL {
  17. namespace {
  18. template <bool Sort, bool HasCount>
  19. class TTopOrSortBlocksWrapper : public TStatefulWideFlowCodegeneratorNode<TTopOrSortBlocksWrapper<Sort, HasCount>> {
  20. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TTopOrSortBlocksWrapper<Sort, HasCount>>;
  21. using TChunkedArrayIndex = std::vector<IArrayBuilder::TArrayDataItem>;
  22. public:
  23. TTopOrSortBlocksWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TArrayRef<TType* const> wideComponents, IComputationNode* count,
  24. TComputationNodePtrVector&& directions, std::vector<ui32>&& keyIndicies)
  25. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  26. , Flow_(flow)
  27. , Count_(count)
  28. , Directions_(std::move(directions))
  29. , KeyIndicies_(std::move(keyIndicies))
  30. , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(wideComponents.size()))
  31. {
  32. for (ui32 i = 0; i < wideComponents.size() - 1; ++i) {
  33. Columns_.push_back(AS_TYPE(TBlockType, wideComponents[i]));
  34. }
  35. }
  36. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  37. Y_ABORT_UNLESS(output[Columns_.size()]);
  38. auto& s = GetState(state, ctx);
  39. if (!s.Count) {
  40. if (s.IsFinished_)
  41. return EFetchResult::Finish;
  42. if (!s.WritingOutput_) {
  43. for (const auto fields = ctx.WideFields.data() + WideFieldsIndex_;;) {
  44. switch (Flow_->FetchValues(ctx, fields)) {
  45. case EFetchResult::Yield:
  46. return EFetchResult::Yield;
  47. case EFetchResult::One:
  48. s.ProcessInput();
  49. continue;
  50. case EFetchResult::Finish:
  51. break;
  52. }
  53. break;
  54. }
  55. }
  56. if (!s.FillOutput(ctx.HolderFactory))
  57. return EFetchResult::Finish;
  58. }
  59. const auto sliceSize = s.Slice();
  60. for (size_t i = 0; i <= Columns_.size(); ++i) {
  61. if (const auto out = output[i]) {
  62. *out = s.Get(sliceSize, ctx.HolderFactory, i);
  63. }
  64. }
  65. return EFetchResult::One;
  66. }
  67. #ifndef MKQL_DISABLE_CODEGEN
  68. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  69. auto& context = ctx.Codegen.GetContext();
  70. const auto width = Columns_.size() + 1U;
  71. const auto valueType = Type::getInt128Ty(context);
  72. const auto statusType = Type::getInt32Ty(context);
  73. const auto indexType = Type::getInt64Ty(context);
  74. const auto flagType = Type::getInt1Ty(context);
  75. const auto arrayType = ArrayType::get(valueType, width);
  76. const auto ptrValuesType = PointerType::getUnqual(arrayType);
  77. TLLVMFieldsStructureState stateFields(context, width);
  78. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  79. const auto statePtrType = PointerType::getUnqual(stateType);
  80. const auto atTop = &ctx.Func->getEntryBlock().back();
  81. const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Get));
  82. const auto getType = FunctionType::get(valueType, {statePtrType, indexType, ctx.GetFactory()->getType(), indexType}, false);
  83. const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop);
  84. const auto heightPtr = new AllocaInst(indexType, 0U, "height_ptr", atTop);
  85. const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", atTop);
  86. new StoreInst(ConstantInt::get(indexType, 0), heightPtr, atTop);
  87. new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, atTop);
  88. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  89. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  90. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  91. const auto test = BasicBlock::Create(context, "test", ctx.Func);
  92. const auto read = BasicBlock::Create(context, "read", ctx.Func);
  93. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  94. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  95. const auto fill = BasicBlock::Create(context, "fill", ctx.Func);
  96. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  97. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  98. block = make;
  99. llvm::Value* trunc;
  100. if constexpr (HasCount) {
  101. const auto count = GetNodeValue(Count_, ctx, block);
  102. trunc = GetterFor<ui64>(count, context, block);
  103. } else {
  104. trunc = ConstantInt::get(Type::getInt64Ty(context), 0U);
  105. }
  106. const auto dirsType = ArrayType::get(flagType, Directions_.size());
  107. const auto dirs = new AllocaInst(dirsType, 0U, "dirs", block);
  108. for (auto i = 0U; i < Directions_.size(); ++i) {
  109. const auto dir = GetNodeValue(Directions_[i], ctx, block);
  110. const auto cut = GetterFor<bool>(dir, context, block);
  111. const auto ptr = GetElementPtrInst::CreateInBounds(dirsType, dirs, {ConstantInt::get(indexType, 0), ConstantInt::get(indexType, i)}, "ptr", block);
  112. new StoreInst(cut, ptr, block);
  113. }
  114. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  115. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  116. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TTopOrSortBlocksWrapper::MakeState));
  117. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType(), dirs->getType(), trunc->getType()}, false);
  118. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  119. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr, dirs, trunc}, "", block);
  120. BranchInst::Create(main, block);
  121. block = main;
  122. const auto state = new LoadInst(valueType, statePtr, "state", block);
  123. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  124. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  125. const auto countPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetCount() }, "count_ptr", block);
  126. const auto count = new LoadInst(indexType, countPtr, "count", block);
  127. const auto none = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, count, ConstantInt::get(indexType, 0), "none", block);
  128. BranchInst::Create(more, fill, none, block);
  129. block = more;
  130. const auto finishedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIsFinished() }, "is_finished_ptr", block);
  131. const auto finished = new LoadInst(flagType, finishedPtr, "finished", block);
  132. const auto result = PHINode::Create(statusType, 4U, "result", over);
  133. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
  134. BranchInst::Create(over, test, finished, block);
  135. block = test;
  136. const auto writingOutputPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetWritingOutput() }, "writing_output_ptr", block);
  137. const auto writingOutput = new LoadInst(flagType, writingOutputPtr, "writing_output", block);
  138. BranchInst::Create(work, read, writingOutput, block);
  139. block = read;
  140. const auto getres = GetNodeValues(Flow_, ctx, block);
  141. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
  142. const auto way = SwitchInst::Create(getres.first, good, 2U, block);
  143. way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Finish)), work);
  144. way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Yield)), over);
  145. block = good;
  146. const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block);
  147. const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block);
  148. Value* array = UndefValue::get(arrayType);
  149. for (auto idx = 0U; idx < getres.second.size(); ++idx) {
  150. const auto value = getres.second[idx](ctx, block);
  151. AddRefBoxed(value, ctx, block);
  152. array = InsertValueInst::Create(array, value, {idx}, (TString("value_") += ToString(idx)).c_str(), block);
  153. }
  154. new StoreInst(array, values, block);
  155. const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::ProcessInput));
  156. const auto processBlockType = FunctionType::get(Type::getVoidTy(context), {statePtrType}, false);
  157. const auto processBlockPtr = CastInst::Create(Instruction::IntToPtr, processBlockFunc, PointerType::getUnqual(processBlockType), "process_inputs_func", block);
  158. CallInst::Create(processBlockType, processBlockPtr, {stateArg}, "", block);
  159. BranchInst::Create(read, block);
  160. block = work;
  161. const auto fillBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::FillOutput));
  162. const auto fillBlockType = FunctionType::get(flagType, {statePtrType, ctx.GetFactory()->getType()}, false);
  163. const auto fillBlockPtr = CastInst::Create(Instruction::IntToPtr, fillBlockFunc, PointerType::getUnqual(fillBlockType), "fill_output_func", block);
  164. const auto hasData = CallInst::Create(fillBlockType, fillBlockPtr, {stateArg, ctx.GetFactory()}, "fill_output", block);
  165. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
  166. BranchInst::Create(fill, over, hasData, block);
  167. block = fill;
  168. const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Slice));
  169. const auto sliceType = FunctionType::get(indexType, {statePtrType}, false);
  170. const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice_func", block);
  171. const auto slice = CallInst::Create(sliceType, slicePtr, {stateArg}, "slice", block);
  172. new StoreInst(slice, heightPtr, block);
  173. new StoreInst(stateArg, stateOnStack, block);
  174. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  175. BranchInst::Create(over, block);
  176. block = over;
  177. ICodegeneratorInlineWideNode::TGettersList getters(width);
  178. for (size_t idx = 0U; idx < getters.size(); ++idx) {
  179. getters[idx] = [idx, getType, getPtr, heightPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) {
  180. const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block);
  181. const auto heightArg = new LoadInst(indexType, heightPtr, "height", block);
  182. return CallInst::Create(getType, getPtr, {stateArg, heightArg, ctx.GetFactory(), ConstantInt::get(indexType, idx)}, "get", block);
  183. };
  184. }
  185. return {result, std::move(getters)};
  186. }
  187. #endif
  188. private:
  189. void RegisterDependencies() const final {
  190. if (const auto flow = this->FlowDependsOn(Flow_)) {
  191. if constexpr (HasCount) {
  192. this->DependsOn(flow, Count_);
  193. }
  194. for (auto dir : Directions_) {
  195. this->DependsOn(flow, dir);
  196. }
  197. }
  198. }
  199. class TState : public TBlockState {
  200. public:
  201. bool WritingOutput_ = false;
  202. bool IsFinished_ = false;
  203. ui64 OutputLength_ = 0;
  204. ui64 Written_ = 0;
  205. const std::vector<bool> Directions_;
  206. const ui64 Count_;
  207. const std::vector<TBlockType*> Columns_;
  208. const std::vector<ui32> KeyIndicies_;
  209. std::vector<std::vector<arrow::Datum>> SortInput_;
  210. std::vector<ui64> SortPermutation_;
  211. std::vector<TChunkedArrayIndex> SortArrays_;
  212. bool ScalarsFilled_ = false;
  213. TUnboxedValueVector ScalarValues_;
  214. std::vector<std::unique_ptr<IBlockReader>> LeftReaders_;
  215. std::vector<std::unique_ptr<IBlockReader>> RightReaders_;
  216. std::vector<std::unique_ptr<IArrayBuilder>> Builders_;
  217. ui64 BuilderMaxLength_ = 0;
  218. ui64 BuilderLength_ = 0;
  219. std::vector<NUdf::IBlockItemComparator::TPtr> Comparators_; // by key columns only
  220. TState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, const std::vector<ui32>& keyIndicies, const std::vector<TBlockType*>& columns, const bool* directions, ui64 count)
  221. : TBlockState(memInfo, columns.size() + 1U)
  222. , IsFinished_(HasCount && !count)
  223. , Directions_(directions, directions + keyIndicies.size())
  224. , Count_(count)
  225. , Columns_(columns)
  226. , KeyIndicies_(keyIndicies)
  227. , SortInput_(Columns_.size())
  228. , SortArrays_(Columns_.size())
  229. , LeftReaders_(Columns_.size())
  230. , RightReaders_(Columns_.size())
  231. , Builders_(Columns_.size())
  232. , Comparators_(KeyIndicies_.size())
  233. {
  234. for (ui32 i = 0; i < Columns_.size(); ++i) {
  235. if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) {
  236. continue;
  237. }
  238. LeftReaders_[i] = MakeBlockReader(TTypeInfoHelper(), columns[i]->GetItemType());
  239. RightReaders_[i] = MakeBlockReader(TTypeInfoHelper(), columns[i]->GetItemType());
  240. }
  241. for (ui32 k = 0; k < KeyIndicies_.size(); ++k) {
  242. Comparators_[k] = TBlockTypeHelper().MakeComparator(Columns_[KeyIndicies_[k]]->GetItemType());
  243. }
  244. BuilderMaxLength_ = GetStorageLength();
  245. size_t maxBlockItemSize = 0;
  246. for (ui32 i = 0; i < Columns_.size(); ++i) {
  247. if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) {
  248. continue;
  249. }
  250. maxBlockItemSize = Max(maxBlockItemSize, CalcMaxBlockItemSize(Columns_[i]->GetItemType()));
  251. };
  252. BuilderMaxLength_ = Max(BuilderMaxLength_, CalcBlockLen(maxBlockItemSize));
  253. for (ui32 i = 0; i < Columns_.size(); ++i) {
  254. if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) {
  255. continue;
  256. }
  257. Builders_[i] = MakeArrayBuilder(TTypeInfoHelper(), Columns_[i]->GetItemType(), ctx.ArrowMemoryPool, BuilderMaxLength_, &ctx.Builder->GetPgBuilder());
  258. }
  259. }
  260. void Add(const NUdf::TUnboxedValuePod value, size_t idx) {
  261. Values[idx] = value;
  262. }
  263. void ProcessInput() {
  264. const ui64 blockLen = TArrowBlock::From(Values.back()).GetDatum().template scalar_as<arrow::UInt64Scalar>().value;
  265. if (!ScalarsFilled_) {
  266. for (ui32 i = 0; i < Columns_.size(); ++i) {
  267. if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) {
  268. ScalarValues_[i] = std::move(Values[i]);
  269. }
  270. }
  271. ScalarsFilled_ = true;
  272. }
  273. if constexpr (!HasCount) {
  274. for (ui32 i = 0; i < Columns_.size(); ++i) {
  275. auto datum = TArrowBlock::From(Values[i]).GetDatum();
  276. if (Columns_[i]->GetShape() != TBlockType::EShape::Scalar) {
  277. SortInput_[i].emplace_back(datum);
  278. }
  279. }
  280. OutputLength_ += blockLen;
  281. Values.assign(Values.size(), NUdf::TUnboxedValuePod());
  282. return;
  283. }
  284. // shrink input block
  285. std::optional<std::vector<ui64>> blockIndicies;
  286. if (blockLen > Count_) {
  287. blockIndicies.emplace();
  288. blockIndicies->reserve(blockLen);
  289. for (ui64 row = 0; row < blockLen; ++row) {
  290. blockIndicies->emplace_back(row);
  291. }
  292. std::vector<TChunkedArrayIndex> arrayIndicies(Columns_.size());
  293. for (ui32 i = 0; i < Columns_.size(); ++i) {
  294. if (Columns_[i]->GetShape() != TBlockType::EShape::Scalar) {
  295. auto datum = TArrowBlock::From(Values[i]).GetDatum();
  296. arrayIndicies[i] = MakeChunkedArrayIndex(datum);
  297. }
  298. }
  299. const TBlockLess cmp(KeyIndicies_, *this, arrayIndicies);
  300. NYql::FastNthElement(blockIndicies->begin(), blockIndicies->begin() + Count_, blockIndicies->end(), cmp);
  301. }
  302. // copy all to builders
  303. AddTop(Columns_, blockIndicies, blockLen);
  304. if (BuilderLength_ + Count_ > BuilderMaxLength_) {
  305. CompressBuilders(false);
  306. }
  307. Values.assign(Values.size(), NUdf::TUnboxedValuePod());
  308. }
  309. ui64 GetStorageLength() const {
  310. return 2 * Count_;
  311. }
  312. void CompressBuilders(bool sort) {
  313. Y_ABORT_UNLESS(ScalarsFilled_);
  314. std::vector<TChunkedArrayIndex> arrayIndicies(Columns_.size());
  315. std::vector<arrow::Datum> tmpDatums(Columns_.size());
  316. for (ui32 i = 0; i < Columns_.size(); ++i) {
  317. if (Columns_[i]->GetShape() != TBlockType::EShape::Scalar) {
  318. auto datum = Builders_[i]->Build(false);
  319. arrayIndicies[i] = MakeChunkedArrayIndex(datum);
  320. tmpDatums[i] = std::move(datum);
  321. }
  322. }
  323. std::vector<ui64> blockIndicies;
  324. blockIndicies.reserve(BuilderLength_);
  325. for (ui64 row = 0; row < BuilderLength_; ++row) {
  326. blockIndicies.push_back(row);
  327. }
  328. const ui64 blockLen = Min(BuilderLength_, Count_);
  329. const TBlockLess cmp(KeyIndicies_, *this, arrayIndicies);
  330. if (BuilderLength_ <= Count_) {
  331. if (sort) {
  332. std::sort(blockIndicies.begin(), blockIndicies.end(), cmp);
  333. }
  334. } else {
  335. if (sort) {
  336. NYql::FastPartialSort(blockIndicies.begin(), blockIndicies.begin() + blockLen, blockIndicies.end(), cmp);
  337. } else {
  338. NYql::FastNthElement(blockIndicies.begin(), blockIndicies.begin() + blockLen, blockIndicies.end(), cmp);
  339. }
  340. }
  341. for (ui32 i = 0; i < Columns_.size(); ++i) {
  342. if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) {
  343. continue;
  344. }
  345. auto& arrayIndex = arrayIndicies[i];
  346. Builders_[i]->AddMany(arrayIndex.data(), arrayIndex.size(), blockIndicies.data(), blockLen);
  347. }
  348. BuilderLength_ = blockLen;
  349. }
  350. void SortAll() {
  351. SortPermutation_.reserve(OutputLength_);
  352. for (ui64 i = 0; i < OutputLength_; ++i) {
  353. SortPermutation_.emplace_back(i);
  354. }
  355. for (ui32 i = 0; i < Columns_.size(); ++i) {
  356. ui64 offset = 0;
  357. for (const auto& datum : SortInput_[i]) {
  358. if (datum.is_scalar()) {
  359. continue;
  360. } else if (datum.is_array()) {
  361. auto arrayData = datum.array();
  362. SortArrays_[i].push_back({ arrayData.get(), offset });
  363. offset += arrayData->length;
  364. } else {
  365. auto chunks = datum.chunks();
  366. for (auto& chunk : chunks) {
  367. auto arrayData = chunk->data();
  368. SortArrays_[i].push_back({ arrayData.get(), offset });
  369. offset += arrayData->length;
  370. }
  371. }
  372. }
  373. }
  374. TBlockLess cmp(KeyIndicies_, *this, SortArrays_);
  375. std::sort(SortPermutation_.begin(), SortPermutation_.end(), cmp);
  376. }
  377. bool FillOutput(const THolderFactory& holderFactory) {
  378. if (WritingOutput_) {
  379. FillSortOutputPart(holderFactory);
  380. } else if constexpr (!HasCount) {
  381. if (!OutputLength_) {
  382. IsFinished_ = true;
  383. return false;
  384. }
  385. SortAll();
  386. WritingOutput_ = true;
  387. FillSortOutputPart(holderFactory);
  388. } else {
  389. IsFinished_ = true;
  390. if (!BuilderLength_) {
  391. return false;
  392. }
  393. if (BuilderLength_ > Count_ || Sort) {
  394. CompressBuilders(Sort);
  395. }
  396. for (ui32 i = 0; i < Columns_.size(); ++i) {
  397. if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) {
  398. Values[i] = ScalarValues_[i];
  399. } else {
  400. Values[i] = holderFactory.CreateArrowBlock(arrow::Datum(Builders_[i]->Build(true)));
  401. }
  402. }
  403. Values.back() = holderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(BuilderLength_)));
  404. }
  405. FillArrays();
  406. return true;
  407. }
  408. void FillSortOutputPart(const THolderFactory& holderFactory) {
  409. auto blockLen = Min(BuilderMaxLength_, OutputLength_ - Written_);
  410. const bool isLast = (Written_ + blockLen == OutputLength_);
  411. for (ui32 i = 0; i < Columns_.size(); ++i) {
  412. if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) {
  413. Values[i] = ScalarValues_[i];
  414. } else {
  415. Builders_[i]->AddMany(SortArrays_[i].data(), SortArrays_[i].size(), SortPermutation_.data() + Written_, blockLen);
  416. Values[i] = holderFactory.CreateArrowBlock(arrow::Datum(Builders_[i]->Build(isLast)));
  417. }
  418. }
  419. Values.back() = holderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(blockLen)));
  420. Written_ += blockLen;
  421. if (Written_ >= OutputLength_)
  422. IsFinished_ = true;
  423. }
  424. void AddTop(const std::vector<TBlockType*>& columns, const std::optional<std::vector<ui64>>& blockIndicies, ui64 blockLen) {
  425. for (ui32 i = 0; i < columns.size(); ++i) {
  426. if (columns[i]->GetShape() == TBlockType::EShape::Scalar) {
  427. continue;
  428. }
  429. const auto& datum = TArrowBlock::From(Values[i]).GetDatum();
  430. auto arrayIndex = MakeChunkedArrayIndex(datum);
  431. if (blockIndicies) {
  432. Builders_[i]->AddMany(arrayIndex.data(), arrayIndex.size(), blockIndicies->data(), Count_);
  433. } else {
  434. Builders_[i]->AddMany(arrayIndex.data(), arrayIndex.size(), ui64(0), blockLen);
  435. }
  436. }
  437. if (blockIndicies) {
  438. BuilderLength_ += Count_;
  439. } else {
  440. BuilderLength_ += blockLen;
  441. }
  442. }
  443. };
  444. #ifndef MKQL_DISABLE_CODEGEN
  445. class TLLVMFieldsStructureState: public TLLVMFieldsStructureBlockState {
  446. private:
  447. using TBase = TLLVMFieldsStructureBlockState;
  448. llvm::IntegerType*const WritingOutputType;
  449. llvm::IntegerType*const IsFinishedType;
  450. protected:
  451. using TBase::Context;
  452. public:
  453. std::vector<llvm::Type*> GetFieldsArray() {
  454. std::vector<llvm::Type*> result = TBase::GetFieldsArray();
  455. result.emplace_back(WritingOutputType);
  456. result.emplace_back(IsFinishedType);
  457. return result;
  458. }
  459. llvm::Constant* GetWritingOutput() {
  460. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields);
  461. }
  462. llvm::Constant* GetIsFinished() {
  463. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields + 1);
  464. }
  465. TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width)
  466. : TBase(context, width)
  467. , WritingOutputType(Type::getInt1Ty(Context))
  468. , IsFinishedType(Type::getInt1Ty(Context))
  469. {}
  470. };
  471. #endif
  472. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state, const bool* directions, ui64 count = 0ULL) const {
  473. state = ctx.HolderFactory.Create<TState>(ctx, KeyIndicies_, Columns_, directions, count);
  474. }
  475. TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  476. if (state.IsInvalid()) {
  477. std::vector<bool> dirs(Directions_.size());
  478. std::transform(Directions_.cbegin(), Directions_.cend(), dirs.begin(), [&ctx](IComputationNode* dir){ return dir->GetValue(ctx).Get<bool>(); });
  479. if constexpr (HasCount)
  480. MakeState(ctx, state, dirs.data(), Count_->GetValue(ctx).Get<ui64>());
  481. else
  482. MakeState(ctx, state, dirs.data());
  483. auto& s = *static_cast<TState*>(state.AsBoxed().Get());
  484. const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
  485. for (size_t i = 0; i < s.Values.size(); ++i) {
  486. fields[i] = &s.Values[i];
  487. }
  488. return s;
  489. }
  490. return *static_cast<TState*>(state.AsBoxed().Get());
  491. }
  492. static TChunkedArrayIndex MakeChunkedArrayIndex(const arrow::Datum& datum) {
  493. TChunkedArrayIndex result;
  494. if (datum.is_array()) {
  495. result.push_back({datum.array().get(), 0});
  496. } else {
  497. auto chunks = datum.chunks();
  498. ui64 offset = 0;
  499. for (auto& chunk : chunks) {
  500. auto arrayData = chunk->data();
  501. result.push_back({arrayData.get(), offset});
  502. offset += arrayData->length;
  503. }
  504. }
  505. return result;
  506. }
  507. class TBlockLess {
  508. public:
  509. TBlockLess(const std::vector<ui32>& keyIndicies, const TState& state, const std::vector<TChunkedArrayIndex>& arrayIndicies)
  510. : KeyIndicies_(keyIndicies)
  511. , ArrayIndicies_(arrayIndicies)
  512. , State_(state)
  513. {}
  514. bool operator()(ui64 lhs, ui64 rhs) const {
  515. if (KeyIndicies_.size() == 1) {
  516. auto i = KeyIndicies_[0];
  517. auto& arrayIndex = ArrayIndicies_[i];
  518. if (arrayIndex.empty()) {
  519. // scalar
  520. return false;
  521. }
  522. auto leftItem = GetBlockItem(*State_.LeftReaders_[i], arrayIndex, lhs);
  523. auto rightItem = GetBlockItem(*State_.RightReaders_[i], arrayIndex, rhs);
  524. if (State_.Directions_[0]) {
  525. return State_.Comparators_[0]->Less(leftItem, rightItem);
  526. } else {
  527. return State_.Comparators_[0]->Greater(leftItem, rightItem);
  528. }
  529. } else {
  530. for (ui32 k = 0; k < KeyIndicies_.size(); ++k) {
  531. auto i = KeyIndicies_[k];
  532. auto& arrayIndex = ArrayIndicies_[i];
  533. if (arrayIndex.empty()) {
  534. // scalar
  535. continue;
  536. }
  537. auto leftItem = GetBlockItem(*State_.LeftReaders_[i], arrayIndex, lhs);
  538. auto rightItem = GetBlockItem(*State_.RightReaders_[i], arrayIndex, rhs);
  539. auto cmp = State_.Comparators_[k]->Compare(leftItem, rightItem);
  540. if (cmp == 0) {
  541. continue;
  542. }
  543. if (State_.Directions_[k]) {
  544. return cmp < 0;
  545. } else {
  546. return cmp > 0;
  547. }
  548. }
  549. return false;
  550. }
  551. }
  552. private:
  553. static TBlockItem GetBlockItem(IBlockReader& reader, const TChunkedArrayIndex& arrayIndex, ui64 idx) {
  554. Y_DEBUG_ABORT_UNLESS(!arrayIndex.empty());
  555. if (arrayIndex.size() == 1) {
  556. return reader.GetItem(*arrayIndex.front().Data, idx);
  557. }
  558. auto it = LookupArrayDataItem(arrayIndex.data(), arrayIndex.size(), idx);
  559. return reader.GetItem(*it->Data, idx);
  560. }
  561. const std::vector<ui32>& KeyIndicies_;
  562. const std::vector<TChunkedArrayIndex> ArrayIndicies_;
  563. const TState& State_;
  564. };
  565. IComputationWideFlowNode *const Flow_;
  566. IComputationNode *const Count_;
  567. const TComputationNodePtrVector Directions_;
  568. const std::vector<ui32> KeyIndicies_;
  569. std::vector<TBlockType*> Columns_;
  570. const size_t WideFieldsIndex_;
  571. };
  572. template <bool Sort, bool HasCount>
  573. IComputationNode* WrapTopOrSort(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  574. constexpr ui32 offset = HasCount ? 0 : 1;
  575. const ui32 inputsWithCount = callable.GetInputsCount() + offset;
  576. MKQL_ENSURE(inputsWithCount > 2U && !(inputsWithCount % 2U), "Expected more arguments.");
  577. const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
  578. const auto wideComponents = GetWideComponents(flowType);
  579. MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column");
  580. const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0));
  581. MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
  582. IComputationNode* count = nullptr;
  583. if constexpr (HasCount) {
  584. const auto countType = AS_TYPE(TDataType, callable.GetInput(1).GetStaticType());
  585. MKQL_ENSURE(countType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  586. count = LocateNode(ctx.NodeLocator, callable, 1);
  587. }
  588. TComputationNodePtrVector directions;
  589. std::vector<ui32> keyIndicies;
  590. for (ui32 i = 2; i < inputsWithCount; i += 2) {
  591. ui32 keyIndex = AS_VALUE(TDataLiteral, callable.GetInput(i - offset))->AsValue().Get<ui32>();
  592. MKQL_ENSURE(keyIndex + 1 < wideComponents.size(), "Wrong key index");
  593. keyIndicies.push_back(keyIndex);
  594. directions.push_back(LocateNode(ctx.NodeLocator, callable, i + 1 - offset));
  595. }
  596. return new TTopOrSortBlocksWrapper<Sort, HasCount>(ctx.Mutables, wideFlow, wideComponents, count, std::move(directions), std::move(keyIndicies));
  597. }
  598. } //namespace
  599. IComputationNode* WrapWideTopBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  600. return WrapTopOrSort<false, true>(callable, ctx);
  601. }
  602. IComputationNode* WrapWideTopSortBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  603. return WrapTopOrSort<true, true>(callable, ctx);
  604. }
  605. IComputationNode* WrapWideSortBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  606. return WrapTopOrSort<true, false>(callable, ctx);
  607. }
  608. }
  609. }