mkql_map.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. #include "mkql_map.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. namespace NKikimr {
  6. namespace NMiniKQL {
  7. namespace {
  8. class TFlowMapWrapper : public TStatelessFlowCodegeneratorNode<TFlowMapWrapper> {
  9. typedef TStatelessFlowCodegeneratorNode<TFlowMapWrapper> TBaseComputation;
  10. public:
  11. TFlowMapWrapper(EValueRepresentation kind, IComputationNode* flow, IComputationExternalNode* item, IComputationNode* newItem)
  12. : TBaseComputation(flow, kind)
  13. , Flow(flow)
  14. , Item(item)
  15. , NewItem(newItem)
  16. {}
  17. NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const {
  18. if (auto item = Flow->GetValue(ctx); item.IsSpecial()) {
  19. return item;
  20. } else {
  21. Item->SetValue(ctx, std::move(item));
  22. }
  23. return NewItem->GetValue(ctx);
  24. }
  25. #ifndef MKQL_DISABLE_CODEGEN
  26. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  27. auto& context = ctx.Codegen.GetContext();
  28. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(Item);
  29. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  30. const auto item = GetNodeValue(Flow, ctx, block);
  31. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  32. const auto pass = BasicBlock::Create(context, "pass", ctx.Func);
  33. const auto result = PHINode::Create(item->getType(), 2, "result", pass);
  34. result->addIncoming(item, block);
  35. BranchInst::Create(pass, work, IsSpecial(item, block, context), block);
  36. block = work;
  37. codegenItem->CreateSetValue(ctx, block, item);
  38. const auto out = GetNodeValue(NewItem, ctx, block);
  39. result->addIncoming(out, block);
  40. BranchInst::Create(pass, block);
  41. block = pass;
  42. return result;
  43. }
  44. #endif
  45. private:
  46. void RegisterDependencies() const final {
  47. if (const auto flow = FlowDependsOn(Flow)) {
  48. Own(flow, Item);
  49. DependsOn(flow, NewItem);
  50. }
  51. }
  52. IComputationNode* const Flow;
  53. IComputationExternalNode* const Item;
  54. IComputationNode* const NewItem;
  55. };
  56. template <bool IsStream>
  57. class TBaseMapWrapper {
  58. protected:
  59. class TListValue : public TCustomListValue {
  60. public:
  61. class TIterator : public TComputationValue<TIterator> {
  62. public:
  63. TIterator(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& iter, IComputationExternalNode* item, IComputationNode* newItem)
  64. : TComputationValue<TIterator>(memInfo)
  65. , CompCtx(compCtx)
  66. , Iter(std::move(iter))
  67. , Item(item)
  68. , NewItem(newItem)
  69. {}
  70. private:
  71. bool Next(NUdf::TUnboxedValue& value) override {
  72. if (!Iter.Next(Item->RefValue(CompCtx))) {
  73. return false;
  74. }
  75. value = NewItem->GetValue(CompCtx);
  76. return true;
  77. }
  78. TComputationContext& CompCtx;
  79. const NUdf::TUnboxedValue Iter;
  80. IComputationExternalNode* const Item;
  81. IComputationNode* const NewItem;
  82. };
  83. TListValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& list, IComputationExternalNode* item, IComputationNode* newItem)
  84. : TCustomListValue(memInfo)
  85. , CompCtx(compCtx)
  86. , List(std::move(list))
  87. , Item(item)
  88. , NewItem(newItem)
  89. {}
  90. private:
  91. NUdf::TUnboxedValue GetListIterator() const final {
  92. return CompCtx.HolderFactory.Create<TIterator>(CompCtx, List.GetListIterator(), Item, NewItem);
  93. }
  94. ui64 GetListLength() const final {
  95. if (!Length) {
  96. Length = List.GetListLength();
  97. }
  98. return *Length;
  99. }
  100. bool HasListItems() const final {
  101. if (!HasItems) {
  102. HasItems = List.HasListItems();
  103. }
  104. return *HasItems;
  105. }
  106. bool HasFastListLength() const final {
  107. return List.HasFastListLength();
  108. }
  109. TComputationContext& CompCtx;
  110. const NUdf::TUnboxedValue List;
  111. IComputationExternalNode* const Item;
  112. IComputationNode* const NewItem;
  113. };
  114. class TStreamValue : public TComputationValue<TStreamValue> {
  115. public:
  116. using TBase = TComputationValue<TStreamValue>;
  117. TStreamValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& stream, IComputationExternalNode* item, IComputationNode* newItem)
  118. : TBase(memInfo)
  119. , CompCtx(compCtx)
  120. , Stream(std::move(stream))
  121. , Item(item)
  122. , NewItem(newItem)
  123. {
  124. }
  125. private:
  126. ui32 GetTraverseCount() const final {
  127. return 1U;
  128. }
  129. NUdf::TUnboxedValue GetTraverseItem(ui32) const final {
  130. return Stream;
  131. }
  132. NUdf::TUnboxedValue Save() const final {
  133. return NUdf::TUnboxedValuePod::Zero();
  134. }
  135. void Load(const NUdf::TStringRef&) final {}
  136. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) final {
  137. const auto status = Stream.Fetch(Item->RefValue(CompCtx));
  138. if (status != NUdf::EFetchStatus::Ok) {
  139. return status;
  140. }
  141. result = NewItem->GetValue(CompCtx);
  142. return NUdf::EFetchStatus::Ok;
  143. }
  144. TComputationContext& CompCtx;
  145. const NUdf::TUnboxedValue Stream;
  146. IComputationExternalNode* const Item;
  147. IComputationNode* const NewItem;
  148. };
  149. TBaseMapWrapper(IComputationNode* list, IComputationExternalNode* item, IComputationNode* newItem)
  150. : List(list), Item(item), NewItem(newItem)
  151. {}
  152. #ifndef MKQL_DISABLE_CODEGEN
  153. Function* GenerateMapper(NYql::NCodegen::ICodegen& codegen, const TString& name) const {
  154. auto& module = codegen.GetModule();
  155. auto& context = codegen.GetContext();
  156. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(Item);
  157. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  158. if (const auto f = module.getFunction(name.c_str()))
  159. return f;
  160. const auto valueType = Type::getInt128Ty(context);
  161. const auto containerType = static_cast<Type*>(valueType);
  162. const auto contextType = GetCompContextType(context);
  163. const auto statusType = IsStream ? Type::getInt32Ty(context) : Type::getInt1Ty(context);
  164. const auto funcType = FunctionType::get(statusType, {PointerType::getUnqual(contextType), containerType, PointerType::getUnqual(valueType)}, false);
  165. TCodegenContext ctx(codegen);
  166. ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee());
  167. DISubprogramAnnotator annotator(ctx, ctx.Func);
  168. auto args = ctx.Func->arg_begin();
  169. ctx.Ctx = &*args;
  170. const auto containerArg = &*++args;
  171. const auto valuePtr = &*++args;
  172. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  173. auto block = main;
  174. const auto container = static_cast<Value*>(containerArg);
  175. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  176. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  177. const auto itemPtr = codegenItem->CreateRefValue(ctx, block);
  178. const auto status = IsStream ?
  179. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(statusType, container, codegen, block, itemPtr):
  180. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(statusType, container, codegen, block, itemPtr);
  181. const auto icmp = IsStream ?
  182. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, status, ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Ok)), "cond", block):
  183. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, status, ConstantInt::getFalse(context), "cond", block);
  184. BranchInst::Create(done, good, icmp, block);
  185. block = good;
  186. SafeUnRefUnboxedOne(valuePtr, ctx, block);
  187. GetNodeValue(valuePtr, NewItem, ctx, block);
  188. BranchInst::Create(done, block);
  189. block = done;
  190. ReturnInst::Create(context, status, block);
  191. return ctx.Func;
  192. }
  193. using TMapPtr = std::conditional_t<IsStream, TStreamCodegenValueStateless::TFetchPtr, TListCodegenValue::TNextPtr>;
  194. Function* MapFunc = nullptr;
  195. TMapPtr Map = nullptr;
  196. #endif
  197. IComputationNode* const List;
  198. IComputationExternalNode* const Item;
  199. IComputationNode* const NewItem;
  200. };
  201. class TStreamMapWrapper : public TCustomValueCodegeneratorNode<TStreamMapWrapper>, private TBaseMapWrapper<true> {
  202. typedef TCustomValueCodegeneratorNode<TStreamMapWrapper> TBaseComputation;
  203. typedef TBaseMapWrapper<true> TBaseWrapper;
  204. public:
  205. TStreamMapWrapper(TComputationMutables& mutables, IComputationNode* list, IComputationExternalNode* item, IComputationNode* newItem)
  206. : TBaseComputation(mutables), TBaseWrapper(list, item, newItem)
  207. {}
  208. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  209. #ifndef MKQL_DISABLE_CODEGEN
  210. if (ctx.ExecuteLLVM && Map)
  211. return ctx.HolderFactory.Create<TStreamCodegenValueStateless>(Map, &ctx, List->GetValue(ctx));
  212. #endif
  213. return ctx.HolderFactory.Create<TStreamValue>(ctx, List->GetValue(ctx), Item, NewItem);
  214. }
  215. private:
  216. void RegisterDependencies() const final {
  217. DependsOn(List);
  218. Own(Item);
  219. DependsOn(NewItem);
  220. }
  221. #ifndef MKQL_DISABLE_CODEGEN
  222. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  223. MapFunc = GenerateMapper(codegen, TBaseComputation::MakeName("Fetch"));
  224. codegen.ExportSymbol(MapFunc);
  225. }
  226. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  227. if (MapFunc)
  228. Map = reinterpret_cast<TMapPtr>(codegen.GetPointerToFunction(MapFunc));
  229. }
  230. #endif
  231. };
  232. class TListMapWrapper : public TBothWaysCodegeneratorNode<TListMapWrapper>, private TBaseMapWrapper<false> {
  233. typedef TBothWaysCodegeneratorNode<TListMapWrapper> TBaseComputation;
  234. typedef TBaseMapWrapper<false> TBaseWrapper;
  235. public:
  236. TListMapWrapper(TComputationMutables& mutables, IComputationNode* list, IComputationExternalNode* item, IComputationNode* newItem)
  237. : TBaseComputation(mutables), TBaseWrapper(list, item, newItem)
  238. {}
  239. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  240. auto list = List->GetValue(ctx);
  241. if (auto elements = list.GetElements()) {
  242. auto size = list.GetListLength();
  243. NUdf::TUnboxedValue* items = nullptr;
  244. NUdf::TUnboxedValue result = ctx.HolderFactory.CreateDirectArrayHolder(size, items);
  245. while (size--) {
  246. Item->SetValue(ctx, NUdf::TUnboxedValue(*elements++));
  247. *items++ = NewItem->GetValue(ctx);
  248. }
  249. return result.Release();
  250. }
  251. return ctx.HolderFactory.Create<TListValue>(ctx, std::move(list), Item, NewItem);
  252. }
  253. #ifndef MKQL_DISABLE_CODEGEN
  254. NUdf::TUnboxedValuePod MakeLazyList(TComputationContext& ctx, const NUdf::TUnboxedValuePod value) const {
  255. return ctx.HolderFactory.Create<TListCodegenValue>(Map, &ctx, value);
  256. }
  257. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  258. auto& context = ctx.Codegen.GetContext();
  259. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(Item);
  260. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  261. const auto list = GetNodeValue(List, ctx, block);
  262. const auto lazy = BasicBlock::Create(context, "lazy", ctx.Func);
  263. const auto hard = BasicBlock::Create(context, "hard", ctx.Func);
  264. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  265. const auto map = PHINode::Create(list->getType(), 2U, "map", done);
  266. const auto elementsType = PointerType::getUnqual(list->getType());
  267. const auto elements = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElements>(elementsType, list, ctx.Codegen, block);
  268. const auto fill = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, elements, ConstantPointerNull::get(elementsType), "fill", block);
  269. BranchInst::Create(hard, lazy, fill, block);
  270. {
  271. block = hard;
  272. const auto size = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListLength>(Type::getInt64Ty(context), list, ctx.Codegen, block);
  273. const auto itemsPtr = *Stateless || ctx.AlwaysInline ?
  274. new AllocaInst(elementsType, 0U, "items_ptr", &ctx.Func->getEntryBlock().back()):
  275. new AllocaInst(elementsType, 0U, "items_ptr", block);
  276. const auto array = GenNewArray(ctx, size, itemsPtr, block);
  277. const auto items = new LoadInst(elementsType, itemsPtr, "items", block);
  278. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  279. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  280. const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
  281. const auto index = PHINode::Create(size->getType(), 2U, "index", loop);
  282. index->addIncoming(ConstantInt::get(size->getType(), 0), block);
  283. BranchInst::Create(loop, block);
  284. block = loop;
  285. const auto more = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, size, index, "more", block);
  286. BranchInst::Create(next, stop, more, block);
  287. block = next;
  288. const auto src = GetElementPtrInst::CreateInBounds(list->getType(), elements, {index}, "src", block);
  289. const auto item = new LoadInst(list->getType(), src, "item", block);
  290. codegenItem->CreateSetValue(ctx, block, item);
  291. const auto dst = GetElementPtrInst::CreateInBounds(list->getType(), items, {index}, "dst", block);
  292. GetNodeValue(dst, NewItem, ctx, block);
  293. const auto plus = BinaryOperator::CreateAdd(index, ConstantInt::get(size->getType(), 1), "plus", block);
  294. index->addIncoming(plus, block);
  295. BranchInst::Create(loop, block);
  296. block = stop;
  297. if (List->IsTemporaryValue()) {
  298. CleanupBoxed(list, ctx, block);
  299. }
  300. map->addIncoming(array, block);
  301. BranchInst::Create(done, block);
  302. }
  303. {
  304. block = lazy;
  305. const auto doFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TListMapWrapper::MakeLazyList));
  306. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  307. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  308. const auto funType = FunctionType::get(list->getType() , {self->getType(), ctx.Ctx->getType(), list->getType()}, false);
  309. const auto doFuncPtr = CastInst::Create(Instruction::IntToPtr, doFunc, PointerType::getUnqual(funType), "function", block);
  310. const auto value = CallInst::Create(funType, doFuncPtr, {self, ctx.Ctx, list}, "value", block);
  311. map->addIncoming(value, block);
  312. BranchInst::Create(done, block);
  313. }
  314. block = done;
  315. return map;
  316. }
  317. #endif
  318. private:
  319. void RegisterDependencies() const final {
  320. DependsOn(List);
  321. Own(Item);
  322. DependsOn(NewItem);
  323. }
  324. #ifndef MKQL_DISABLE_CODEGEN
  325. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  326. TMutableCodegeneratorRootNode<TListMapWrapper>::GenerateFunctions(codegen);
  327. MapFunc = GenerateMapper(codegen, TBaseComputation::MakeName("Next"));
  328. codegen.ExportSymbol(MapFunc);
  329. }
  330. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  331. TMutableCodegeneratorRootNode<TListMapWrapper>::FinalizeFunctions(codegen);
  332. if (MapFunc)
  333. Map = reinterpret_cast<TMapPtr>(codegen.GetPointerToFunction(MapFunc));
  334. }
  335. #endif
  336. };
  337. }
  338. IComputationNode* WrapMap(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  339. MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args, got " << callable.GetInputsCount());
  340. const auto type = callable.GetType()->GetReturnType();
  341. const auto flow = LocateNode(ctx.NodeLocator, callable, 0);
  342. const auto newItem = LocateNode(ctx.NodeLocator, callable, 2);
  343. const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 1);
  344. if (type->IsFlow()) {
  345. return new TFlowMapWrapper(GetValueRepresentation(type), flow, itemArg, newItem);
  346. } else if (type->IsStream()) {
  347. return new TStreamMapWrapper(ctx.Mutables, flow, itemArg, newItem);
  348. } else if (type->IsList()) {
  349. return new TListMapWrapper(ctx.Mutables, flow, itemArg, newItem);
  350. }
  351. THROW yexception() << "Expected flow, list or stream.";
  352. }
  353. }
  354. }