mkql_heap.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. #include "mkql_heap.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. #include <yql/essentials/minikql/mkql_node_builder.h>
  6. #include <yql/essentials/utils/sort.h>
  7. namespace NKikimr {
  8. namespace NMiniKQL {
  9. namespace {
  10. using TComparator = std::function<bool(const NUdf::TUnboxedValuePod l, const NUdf::TUnboxedValuePod r)>;
  11. using TAlgorithm = void(*)(NUdf::TUnboxedValuePod*, NUdf::TUnboxedValuePod*, TComparator);
  12. using TArgsPlace = std::array<NUdf::TUnboxedValuePod, 2U>;
  13. using TComparePtr = bool (*)(TComputationContext& ctx, const NUdf::TUnboxedValuePod l, const NUdf::TUnboxedValuePod r);
  14. class THeapWrapper : public TMutableCodegeneratorNode<THeapWrapper>
  15. #ifndef MKQL_DISABLE_CODEGEN
  16. , public ICodegeneratorRootNode
  17. #endif
  18. {
  19. typedef TMutableCodegeneratorNode<THeapWrapper> TBaseComputation;
  20. public:
  21. THeapWrapper(TAlgorithm algorithm, TComputationMutables& mutables, IComputationNode* list, IComputationExternalNode* left, IComputationExternalNode* right, IComputationNode* compare)
  22. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  23. , Algorithm(algorithm)
  24. , List(list)
  25. , Left(left)
  26. , Right(right)
  27. , Compare(compare)
  28. {}
  29. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  30. auto list = List->GetValue(ctx);
  31. const auto size = list.GetListLength();
  32. if (size < 2U)
  33. return list.Release();
  34. NUdf::TUnboxedValue *items = nullptr;
  35. const auto next = ctx.HolderFactory.CloneArray(list.Release(), items);
  36. NUdf::TUnboxedValuePod *const begin = items, *const end = items + size;
  37. Do(ctx, begin, end);
  38. return next;
  39. }
  40. #ifndef MKQL_DISABLE_CODEGEN
  41. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  42. auto& context = ctx.Codegen.GetContext();
  43. const auto valueType = Type::getInt128Ty(context);
  44. const auto fact = ctx.GetFactory();
  45. const auto func = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THolderFactory::CloneArray));// TODO: Generate code instead of call CloneArray.
  46. const auto list = GetNodeValue(List, ctx, block);
  47. const auto size = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListLength>(Type::getInt64Ty(context), list, ctx.Codegen, block);
  48. const auto test = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, size, ConstantInt::get(size->getType(), 1ULL), "test", block);
  49. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  50. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  51. const auto result = PHINode::Create(valueType, 2U, "result", done);
  52. result->addIncoming(list, block);
  53. BranchInst::Create(work, done, test, block);
  54. block = work;
  55. const auto itemsType = PointerType::getUnqual(valueType);
  56. const auto itemsPtr = *Stateless || ctx.AlwaysInline ?
  57. new AllocaInst(itemsType, 0U, "items_ptr", &ctx.Func->getEntryBlock().back()):
  58. new AllocaInst(itemsType, 0U, "items_ptr", block);
  59. const auto idxType = Type::getInt32Ty(context);
  60. Value* array = nullptr;
  61. const auto funType = FunctionType::get(valueType, {fact->getType(), list->getType(), itemsPtr->getType()}, false);
  62. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
  63. array = CallInst::Create(funType, funcPtr, {fact, list, itemsPtr}, "array", block);
  64. result->addIncoming(array, block);
  65. const auto algo = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THeapWrapper::Do));
  66. const auto self = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(this));
  67. const auto items = new LoadInst(itemsType, itemsPtr, "items", block);
  68. const auto zero = ConstantInt::get(idxType, 0);
  69. const auto begin = GetElementPtrInst::CreateInBounds(valueType, items, {zero}, "begin", block);
  70. const auto end = GetElementPtrInst::CreateInBounds(valueType, items, {size}, "end", block);
  71. const auto selfPtr = CastInst::Create(Instruction::IntToPtr, self, PointerType::getUnqual(StructType::get(context)), "comp", block);
  72. const auto doType = FunctionType::get(Type::getVoidTy(context), {selfPtr->getType(), ctx.Ctx->getType(), begin->getType(), end->getType()}, false);
  73. const auto doPtr = CastInst::Create(Instruction::IntToPtr, algo, PointerType::getUnqual(doType), "do", block);
  74. CallInst::Create(doType, doPtr, {selfPtr, ctx.Ctx, begin, end}, "", block);
  75. BranchInst::Create(done, block);
  76. block = done;
  77. return result;
  78. }
  79. #endif
  80. private:
  81. void Do(TComputationContext& ctx, NUdf::TUnboxedValuePod* begin, NUdf::TUnboxedValuePod* end) const {
  82. if (ctx.ExecuteLLVM && Comparator) {
  83. return Algorithm(begin, end, std::bind(Comparator, std::ref(ctx), std::placeholders::_1, std::placeholders::_2));
  84. }
  85. TArgsPlace args;
  86. Left->SetGetter([&](TComputationContext&) { return args.front(); });
  87. Right->SetGetter([&](TComputationContext&) { return args.back(); });
  88. Algorithm(begin, end, std::bind(&THeapWrapper::Comp, this, std::ref(args), std::ref(ctx), std::placeholders::_1, std::placeholders::_2));
  89. }
  90. bool Comp(TArgsPlace& args, TComputationContext& ctx, const NUdf::TUnboxedValuePod l, const NUdf::TUnboxedValuePod r) const {
  91. args = {{l, r}};
  92. Left->InvalidateValue(ctx);
  93. Right->InvalidateValue(ctx);
  94. return Compare->GetValue(ctx).Get<bool>();
  95. }
  96. void RegisterDependencies() const final {
  97. this->DependsOn(List);
  98. this->Own(Left);
  99. this->Own(Right);
  100. this->DependsOn(Compare);
  101. }
  102. const TAlgorithm Algorithm;
  103. IComputationNode* const List;
  104. IComputationExternalNode* const Left;
  105. IComputationExternalNode* const Right;
  106. IComputationNode* const Compare;
  107. TComparePtr Comparator = nullptr;
  108. #ifndef MKQL_DISABLE_CODEGEN
  109. TString MakeName() const {
  110. TStringStream out;
  111. out << this->DebugString() << "::compare_(" << static_cast<const void*>(this) << ").";
  112. return out.Str();
  113. }
  114. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  115. if (CompareFunc) {
  116. Comparator = reinterpret_cast<TComparePtr>(codegen.GetPointerToFunction(CompareFunc));
  117. }
  118. }
  119. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  120. CompareFunc = GenerateCompareFunction(codegen, MakeName(), Left, Right, Compare);
  121. codegen.ExportSymbol(CompareFunc);
  122. }
  123. Function* CompareFunc = nullptr;
  124. #endif
  125. };
  126. IComputationNode* WrapHeap(TAlgorithm algorithm, TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  127. MKQL_ENSURE(callable.GetInputsCount() == 4, "Expected 4 args");
  128. const auto list = LocateNode(ctx.NodeLocator, callable, 0);
  129. const auto compare = LocateNode(ctx.NodeLocator, callable, 3);
  130. const auto left = LocateExternalNode(ctx.NodeLocator, callable, 1);
  131. const auto right = LocateExternalNode(ctx.NodeLocator, callable, 2);
  132. return new THeapWrapper(algorithm, ctx.Mutables, list, left, right, compare);
  133. }
  134. using TNthAlgorithm = void(*)(NUdf::TUnboxedValuePod*, NUdf::TUnboxedValuePod*, NUdf::TUnboxedValuePod*, TComparator);
  135. class TNthWrapper : public TMutableCodegeneratorNode<TNthWrapper>
  136. #ifndef MKQL_DISABLE_CODEGEN
  137. , public ICodegeneratorRootNode
  138. #endif
  139. {
  140. typedef TMutableCodegeneratorNode<TNthWrapper> TBaseComputation;
  141. public:
  142. TNthWrapper(TNthAlgorithm algorithm, TComputationMutables& mutables, IComputationNode* list, IComputationNode* middle, IComputationExternalNode* left, IComputationExternalNode* right, IComputationNode* compare)
  143. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  144. , Algorithm(algorithm)
  145. , List(list)
  146. , Middle(middle)
  147. , Left(left)
  148. , Right(right)
  149. , Compare(compare)
  150. {}
  151. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  152. auto list = List->GetValue(ctx);
  153. auto middle = Middle->GetValue(ctx).Get<ui64>();
  154. const auto size = list.GetListLength();
  155. middle = std::min(middle, size);
  156. if (middle == 0U || size < 2U)
  157. return list.Release();
  158. NUdf::TUnboxedValue *items = nullptr;
  159. const auto next = ctx.HolderFactory.CloneArray(list.Release(), items);
  160. NUdf::TUnboxedValuePod *const begin = items, *const mid = items + middle, *const end = items + size;
  161. Do(ctx, begin, mid, end);
  162. return next;
  163. }
  164. #ifndef MKQL_DISABLE_CODEGEN
  165. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  166. auto& context = ctx.Codegen.GetContext();
  167. const auto valueType = Type::getInt128Ty(context);
  168. const auto fact = ctx.GetFactory();
  169. const auto func = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THolderFactory::CloneArray));// TODO: Generate code instead of call CloneArray.
  170. const auto list = GetNodeValue(List, ctx, block);
  171. const auto midv = GetNodeValue(Middle, ctx, block);
  172. const auto middle = GetterFor<ui64>(midv, context, block);
  173. const auto size = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListLength>(Type::getInt64Ty(context), list, ctx.Codegen, block);
  174. const auto greater = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, middle, size, "greater", block);
  175. const auto min = SelectInst::Create(greater, size, middle, "min", block);
  176. const auto one = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, min, ConstantInt::get(size->getType(), 0ULL), "one", block);
  177. const auto two = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, size, ConstantInt::get(size->getType(), 1ULL), "two", block);
  178. const auto test = BinaryOperator::CreateAnd(one, two, "and", block);
  179. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  180. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  181. const auto result = PHINode::Create(valueType, 2U, "result", done);
  182. result->addIncoming(list, block);
  183. BranchInst::Create(work, done, test, block);
  184. block = work;
  185. const auto itemsType = PointerType::getUnqual(valueType);
  186. const auto itemsPtr = *Stateless || ctx.AlwaysInline ?
  187. new AllocaInst(itemsType, 0U, "items_ptr", &ctx.Func->getEntryBlock().back()):
  188. new AllocaInst(itemsType, 0U, "items_ptr", block);
  189. const auto idxType = Type::getInt32Ty(context);
  190. Value* array = nullptr;
  191. const auto funType = FunctionType::get(valueType, {fact->getType(), list->getType(), itemsPtr->getType()}, false);
  192. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
  193. array = CallInst::Create(funType, funcPtr, {fact, list, itemsPtr}, "array", block);
  194. result->addIncoming(array, block);
  195. const auto algo = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TNthWrapper::Do));
  196. const auto self = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(this));
  197. const auto items = new LoadInst(itemsType, itemsPtr, "items", block);
  198. const auto zero = ConstantInt::get(idxType, 0);
  199. const auto begin = GetElementPtrInst::CreateInBounds(valueType, items, {zero}, "begin", block);
  200. const auto mid = GetElementPtrInst::CreateInBounds(valueType, items, {min}, "middle", block);
  201. const auto end = GetElementPtrInst::CreateInBounds(valueType, items, {size}, "end", block);
  202. const auto selfPtr = CastInst::Create(Instruction::IntToPtr, self, PointerType::getUnqual(StructType::get(context)), "comp", block);
  203. const auto doType = FunctionType::get(Type::getVoidTy(context), {selfPtr->getType(), ctx.Ctx->getType(), begin->getType(), mid->getType(), end->getType()}, false);
  204. const auto doPtr = CastInst::Create(Instruction::IntToPtr, algo, PointerType::getUnqual(doType), "do", block);
  205. CallInst::Create(doType, doPtr, {selfPtr, ctx.Ctx, begin, mid, end}, "", block);
  206. BranchInst::Create(done, block);
  207. block = done;
  208. return result;
  209. }
  210. #endif
  211. private:
  212. void Do(TComputationContext& ctx, NUdf::TUnboxedValuePod* begin, NUdf::TUnboxedValuePod* nth, NUdf::TUnboxedValuePod* end) const {
  213. if (ctx.ExecuteLLVM && Comparator) {
  214. return Algorithm(begin, nth, end, std::bind(Comparator, std::ref(ctx), std::placeholders::_1, std::placeholders::_2));
  215. }
  216. TArgsPlace args;
  217. Left->SetGetter([&](TComputationContext&) { return args.front(); });
  218. Right->SetGetter([&](TComputationContext&) { return args.back(); });
  219. Algorithm(begin, nth, end, std::bind(&TNthWrapper::Comp, this, std::ref(args), std::ref(ctx), std::placeholders::_1, std::placeholders::_2));
  220. }
  221. bool Comp(TArgsPlace& args, TComputationContext& ctx, const NUdf::TUnboxedValuePod l, const NUdf::TUnboxedValuePod r) const {
  222. args = {{l, r}};
  223. Left->InvalidateValue(ctx);
  224. Right->InvalidateValue(ctx);
  225. return Compare->GetValue(ctx).Get<bool>();
  226. }
  227. void RegisterDependencies() const final {
  228. this->DependsOn(List);
  229. this->DependsOn(Middle);
  230. this->Own(Left);
  231. this->Own(Right);
  232. this->DependsOn(Compare);
  233. }
  234. const TNthAlgorithm Algorithm;
  235. IComputationNode* const List;
  236. IComputationNode* const Middle;
  237. IComputationExternalNode* const Left;
  238. IComputationExternalNode* const Right;
  239. IComputationNode* const Compare;
  240. TComparePtr Comparator = nullptr;
  241. #ifndef MKQL_DISABLE_CODEGEN
  242. TString MakeName() const {
  243. TStringStream out;
  244. out << this->DebugString() << "::compare_(" << static_cast<const void*>(this) << ").";
  245. return out.Str();
  246. }
  247. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  248. if (CompareFunc) {
  249. Comparator = reinterpret_cast<TComparePtr>(codegen.GetPointerToFunction(CompareFunc));
  250. }
  251. }
  252. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  253. CompareFunc = GenerateCompareFunction(codegen, MakeName(), Left, Right, Compare);
  254. codegen.ExportSymbol(CompareFunc);
  255. }
  256. Function* CompareFunc = nullptr;
  257. #endif
  258. };
  259. IComputationNode* WrapNth(TNthAlgorithm algorithm, TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  260. MKQL_ENSURE(callable.GetInputsCount() == 5, "Expected 5 args");
  261. const auto list = LocateNode(ctx.NodeLocator, callable, 0);
  262. const auto middle = LocateNode(ctx.NodeLocator, callable, 1);
  263. const auto compare = LocateNode(ctx.NodeLocator, callable, 4);
  264. const auto left = LocateExternalNode(ctx.NodeLocator, callable, 2);
  265. const auto right = LocateExternalNode(ctx.NodeLocator, callable, 3);
  266. return new TNthWrapper(algorithm, ctx.Mutables, list, middle, left, right, compare);
  267. }
  268. }
  269. IComputationNode* WrapMakeHeap(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  270. return WrapHeap(&std::make_heap<NUdf::TUnboxedValuePod*, TComparator>, callable, ctx);
  271. }
  272. IComputationNode* WrapPushHeap(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  273. return WrapHeap(&std::push_heap<NUdf::TUnboxedValuePod*, TComparator>, callable, ctx);
  274. }
  275. IComputationNode* WrapPopHeap(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  276. return WrapHeap(&std::pop_heap<NUdf::TUnboxedValuePod*, TComparator>, callable, ctx);
  277. }
  278. IComputationNode* WrapSortHeap(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  279. return WrapHeap(&std::sort_heap<NUdf::TUnboxedValuePod*, TComparator>, callable, ctx);
  280. }
  281. IComputationNode* WrapStableSort(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  282. return WrapHeap(&std::stable_sort<NUdf::TUnboxedValuePod*, TComparator>, callable, ctx);
  283. }
  284. IComputationNode* WrapNthElement(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  285. return WrapNth(&NYql::FastNthElement<NUdf::TUnboxedValuePod*, TComparator>, callable, ctx);
  286. }
  287. IComputationNode* WrapPartialSort(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  288. return WrapNth(&NYql::FastPartialSort<NUdf::TUnboxedValuePod*, TComparator>, callable, ctx);
  289. }
  290. }
  291. }