mkql_ifpresent.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. #include "mkql_ifpresent.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/mkql_node_cast.h>
  4. namespace NKikimr {
  5. namespace NMiniKQL {
  6. namespace {
  7. template<bool IsMultiOptional>
  8. class TIfPresentWrapper : public TMutableCodegeneratorNode<TIfPresentWrapper<IsMultiOptional>> {
  9. using TBaseComputation = TMutableCodegeneratorNode<TIfPresentWrapper<IsMultiOptional>>;
  10. public:
  11. TIfPresentWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* optional, IComputationExternalNode* item, IComputationNode* presentBranch,
  12. IComputationNode* missingBranch)
  13. : TBaseComputation(mutables, kind)
  14. , Optional(optional)
  15. , Item(item)
  16. , PresentBranch(presentBranch)
  17. , MissingBranch(missingBranch)
  18. {}
  19. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  20. if (const auto& previous = Item->GetValue(ctx); previous.IsInvalid()) {
  21. const auto optional = Optional->GetValue(ctx);
  22. if (optional)
  23. Item->SetValue(ctx, optional.GetOptionalValueIf<IsMultiOptional>());
  24. return (optional ? PresentBranch : MissingBranch)->GetValue(ctx).Release();
  25. } else {
  26. return (previous ? PresentBranch : MissingBranch)->GetValue(ctx).Release();
  27. }
  28. }
  29. #ifndef MKQL_DISABLE_CODEGEN
  30. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  31. auto& context = ctx.Codegen.GetContext();
  32. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(Item);
  33. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  34. const auto previous = codegenItem->CreateGetValue(ctx, block);
  35. const auto fast = BasicBlock::Create(context, "fast", ctx.Func);
  36. const auto slow = BasicBlock::Create(context, "slow", ctx.Func);
  37. const auto pres = BasicBlock::Create(context, "pres", ctx.Func);
  38. const auto miss = BasicBlock::Create(context, "miss", ctx.Func);
  39. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  40. const auto result = PHINode::Create(previous->getType(), 2, "result", done);
  41. const auto choise = SwitchInst::Create(previous, fast, 2U, block);
  42. choise->addCase(GetEmpty(context), miss);
  43. choise->addCase(GetInvalid(context), slow);
  44. block = slow;
  45. const auto value = GetNodeValue(Optional, ctx, block);
  46. BranchInst::Create(pres, miss, IsExists(value, block, context), block);
  47. block = pres;
  48. codegenItem->CreateSetValue(ctx, block, IsMultiOptional ? GetOptionalValue(context, value, block) : value);
  49. BranchInst::Create(fast, block);
  50. block = fast;
  51. const auto left = GetNodeValue(PresentBranch, ctx, block);
  52. result->addIncoming(left, block);
  53. BranchInst::Create(done, block);
  54. block = miss;
  55. const auto right = GetNodeValue(MissingBranch, ctx, block);
  56. result->addIncoming(right, block);
  57. BranchInst::Create(done, block);
  58. block = done;
  59. return result;
  60. }
  61. #endif
  62. private:
  63. void RegisterDependencies() const final {
  64. this->DependsOn(Optional);
  65. this->DependsOn(MissingBranch);
  66. Optional->AddDependence(Item);
  67. this->Own(Item);
  68. this->DependsOn(PresentBranch);
  69. }
  70. IComputationNode* const Optional;
  71. IComputationExternalNode* const Item;
  72. IComputationNode* const PresentBranch;
  73. IComputationNode* const MissingBranch;
  74. };
  75. template<bool IsMultiOptional>
  76. class TFlowIfPresentWrapper : public TStatelessFlowCodegeneratorNode<TFlowIfPresentWrapper<IsMultiOptional>> {
  77. using TBaseComputation = TStatelessFlowCodegeneratorNode<TFlowIfPresentWrapper<IsMultiOptional>>;
  78. public:
  79. TFlowIfPresentWrapper(EValueRepresentation kind, IComputationNode* optional, IComputationExternalNode* item, IComputationNode* presentBranch,
  80. IComputationNode* missingBranch)
  81. : TBaseComputation(nullptr, kind)
  82. , Optional(optional)
  83. , Item(item)
  84. , PresentBranch(presentBranch)
  85. , MissingBranch(missingBranch)
  86. {}
  87. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  88. if (const auto& previous = Item->GetValue(ctx); previous.IsInvalid()) {
  89. const auto optional = Optional->GetValue(ctx);
  90. if (optional)
  91. Item->SetValue(ctx, optional.GetOptionalValueIf<IsMultiOptional>());
  92. return (optional ? PresentBranch : MissingBranch)->GetValue(ctx).Release();
  93. } else {
  94. return (previous ? PresentBranch : MissingBranch)->GetValue(ctx).Release();
  95. }
  96. }
  97. #ifndef MKQL_DISABLE_CODEGEN
  98. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  99. auto& context = ctx.Codegen.GetContext();
  100. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(Item);
  101. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  102. const auto previous = codegenItem->CreateGetValue(ctx, block);
  103. const auto fast = BasicBlock::Create(context, "fast", ctx.Func);
  104. const auto slow = BasicBlock::Create(context, "slow", ctx.Func);
  105. const auto pres = BasicBlock::Create(context, "pres", ctx.Func);
  106. const auto miss = BasicBlock::Create(context, "miss", ctx.Func);
  107. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  108. const auto result = PHINode::Create(previous->getType(), 2, "result", done);
  109. const auto choise = SwitchInst::Create(previous, fast, 2U, block);
  110. choise->addCase(GetEmpty(context), miss);
  111. choise->addCase(GetInvalid(context), slow);
  112. block = slow;
  113. const auto value = GetNodeValue(Optional, ctx, block);
  114. BranchInst::Create(pres, miss, IsExists(value, block, context), block);
  115. block = pres;
  116. codegenItem->CreateSetValue(ctx, block, IsMultiOptional ? GetOptionalValue(context, value, block) : value);
  117. BranchInst::Create(fast, block);
  118. block = fast;
  119. const auto left = GetNodeValue(PresentBranch, ctx, block);
  120. result->addIncoming(left, block);
  121. BranchInst::Create(done, block);
  122. block = miss;
  123. const auto right = GetNodeValue(MissingBranch, ctx, block);
  124. result->addIncoming(right, block);
  125. BranchInst::Create(done, block);
  126. block = done;
  127. return result;
  128. }
  129. #endif
  130. private:
  131. void RegisterDependencies() const final {
  132. if (const auto flow = this->FlowDependsOnBoth(PresentBranch, MissingBranch)) {
  133. this->DependsOn(flow, Optional);
  134. this->Own(flow, Item);
  135. }
  136. Optional->AddDependence(Item);
  137. }
  138. IComputationNode* const Optional;
  139. IComputationExternalNode* const Item;
  140. IComputationNode* const PresentBranch;
  141. IComputationNode* const MissingBranch;
  142. };
  143. template<bool IsMultiOptional>
  144. class TWideIfPresentWrapper : public TStatelessWideFlowCodegeneratorNode<TWideIfPresentWrapper<IsMultiOptional>> {
  145. using TBaseComputation = TStatelessWideFlowCodegeneratorNode<TWideIfPresentWrapper<IsMultiOptional>>;
  146. public:
  147. TWideIfPresentWrapper(IComputationNode* optional, IComputationExternalNode* item, IComputationWideFlowNode* presentBranch,
  148. IComputationWideFlowNode* missingBranch)
  149. : TBaseComputation(nullptr)
  150. , Optional(optional)
  151. , Item(item)
  152. , PresentBranch(presentBranch)
  153. , MissingBranch(missingBranch)
  154. {}
  155. EFetchResult DoCalculate(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  156. if (const auto& previous = Item->GetValue(ctx); previous.IsInvalid()) {
  157. const auto optional = Optional->GetValue(ctx);
  158. if (optional)
  159. Item->SetValue(ctx, optional.GetOptionalValueIf<IsMultiOptional>());
  160. return (optional ? PresentBranch : MissingBranch)->FetchValues(ctx, output);
  161. } else {
  162. return (previous ? PresentBranch : MissingBranch)->FetchValues(ctx, output);
  163. }
  164. }
  165. #ifndef MKQL_DISABLE_CODEGEN
  166. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, BasicBlock*& block) const {
  167. auto& context = ctx.Codegen.GetContext();
  168. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(Item);
  169. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  170. const auto previous = codegenItem->CreateGetValue(ctx, block);
  171. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  172. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  173. const auto pres = BasicBlock::Create(context, "pres", ctx.Func);
  174. const auto miss = BasicBlock::Create(context, "miss", ctx.Func);
  175. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  176. const auto result = PHINode::Create(Type::getInt32Ty(context), 2, "result", done);
  177. const auto choise = SwitchInst::Create(previous, pres, 2U, block);
  178. choise->addCase(GetEmpty(context), miss);
  179. choise->addCase(GetInvalid(context), init);
  180. block = init;
  181. const auto value = GetNodeValue(Optional, ctx, block);
  182. BranchInst::Create(good, miss, IsExists(value, block, context), block);
  183. block = good;
  184. codegenItem->CreateSetValue(ctx, block, IsMultiOptional ? GetOptionalValue(context, value, block) : value);
  185. BranchInst::Create(pres, block);
  186. block = pres;
  187. const auto left = GetNodeValues(PresentBranch, ctx, block);
  188. result->addIncoming(left.first, block);
  189. BranchInst::Create(done, block);
  190. block = miss;
  191. const auto right = GetNodeValues(MissingBranch, ctx, block);
  192. result->addIncoming(right.first, block);
  193. BranchInst::Create(done, block);
  194. block = done;
  195. MKQL_ENSURE(left.second.size() == right.second.size(), "Expected same width of flows.");
  196. ICodegeneratorInlineWideNode::TGettersList getters;
  197. getters.reserve(left.second.size());
  198. size_t idx = 0U;
  199. std::generate_n(std::back_inserter(getters), right.second.size(), [&]() {
  200. const auto i = idx++;
  201. return [codegenItem, lget = left.second[i], rget = right.second[i]](const TCodegenContext& ctx, BasicBlock*& block) {
  202. auto& context = ctx.Codegen.GetContext();
  203. const auto pres = BasicBlock::Create(context, "pres", ctx.Func);
  204. const auto miss = BasicBlock::Create(context, "miss", ctx.Func);
  205. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  206. const auto current = codegenItem->CreateGetValue(ctx, block);
  207. const auto result = PHINode::Create(current->getType(), 2, "result", done);
  208. const auto choise = SwitchInst::Create(current, pres, 2U, block);
  209. choise->addCase(GetEmpty(context), miss);
  210. choise->addCase(GetInvalid(context), miss);
  211. block = pres;
  212. result->addIncoming(lget(ctx, block), block);
  213. BranchInst::Create(done, block);
  214. block = miss;
  215. result->addIncoming(rget(ctx, block), block);
  216. BranchInst::Create(done, block);
  217. block = done;
  218. return result;
  219. };
  220. });
  221. return {result, std::move(getters)};
  222. }
  223. #endif
  224. private:
  225. void RegisterDependencies() const final {
  226. if (const auto flow = this->FlowDependsOnBoth(PresentBranch, MissingBranch)) {
  227. this->DependsOn(flow, Optional);
  228. this->Own(flow, Item);
  229. }
  230. Optional->AddDependence(Item);
  231. }
  232. IComputationNode* const Optional;
  233. IComputationExternalNode* const Item;
  234. IComputationWideFlowNode* const PresentBranch;
  235. IComputationWideFlowNode* const MissingBranch;
  236. };
  237. }
  238. IComputationNode* WrapIfPresent(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  239. MKQL_ENSURE(callable.GetInputsCount() == 4, "Expected 4 args");
  240. const auto optional = LocateNode(ctx.NodeLocator, callable, 0);
  241. const auto presentBranch = LocateNode(ctx.NodeLocator, callable, 2);
  242. const auto missingBranch = LocateNode(ctx.NodeLocator, callable, 3);
  243. const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 1);
  244. const auto innerType = AS_TYPE(TOptionalType, callable.GetInput(0U).GetStaticType())->GetItemType();
  245. const bool multiOptional = innerType->IsOptional() || innerType->IsPg();
  246. if (const auto type = callable.GetType()->GetReturnType(); type->IsFlow()) {
  247. const auto presWide = dynamic_cast<IComputationWideFlowNode*>(presentBranch);
  248. const auto missWide = dynamic_cast<IComputationWideFlowNode*>(missingBranch);
  249. if (presWide && missWide) {
  250. if (multiOptional)
  251. return new TWideIfPresentWrapper<true>(optional, itemArg, presWide, missWide);
  252. else
  253. return new TWideIfPresentWrapper<false>(optional, itemArg, presWide, missWide);
  254. } else if (!presWide && !missWide) {
  255. if (multiOptional)
  256. return new TFlowIfPresentWrapper<true>(GetValueRepresentation(type), optional, itemArg, presentBranch, missingBranch);
  257. else
  258. return new TFlowIfPresentWrapper<false>(GetValueRepresentation(type), optional, itemArg, presentBranch, missingBranch);
  259. }
  260. } else if (multiOptional) {
  261. return new TIfPresentWrapper<true>(ctx.Mutables, GetValueRepresentation(type), optional, itemArg, presentBranch, missingBranch);
  262. } else {
  263. return new TIfPresentWrapper<false>(ctx.Mutables, GetValueRepresentation(type), optional, itemArg, presentBranch, missingBranch);
  264. }
  265. THROW yexception() << "Wrong signature.";
  266. }
  267. }
  268. }