mkql_if.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. #include "mkql_if.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/mkql_node_cast.h>
  4. #include <yql/essentials/minikql/mkql_node_builder.h>
  5. namespace NKikimr {
  6. namespace NMiniKQL {
  7. namespace {
  8. template<bool IsOptional>
  9. class TIfWrapper : public TMutableCodegeneratorNode<TIfWrapper<IsOptional>> {
  10. using TBaseComputation = TMutableCodegeneratorNode<TIfWrapper<IsOptional>>;
  11. public:
  12. TIfWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* predicate, IComputationNode* thenBranch, IComputationNode* elseBranch)
  13. : TBaseComputation(mutables, kind)
  14. , Predicate(predicate)
  15. , ThenBranch(thenBranch)
  16. , ElseBranch(elseBranch)
  17. {}
  18. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  19. const auto& predicate = Predicate->GetValue(ctx);
  20. if (IsOptional && !predicate) {
  21. return NUdf::TUnboxedValuePod();
  22. }
  23. return (predicate.Get<bool>() ? ThenBranch : ElseBranch)->GetValue(ctx).Release();
  24. }
  25. #ifndef MKQL_DISABLE_CODEGEN
  26. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  27. auto& context = ctx.Codegen.GetContext();
  28. const auto then = BasicBlock::Create(context, "then", ctx.Func);
  29. const auto elsb = BasicBlock::Create(context, "else", ctx.Func);
  30. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  31. const auto value = GetNodeValue(Predicate, ctx, block);
  32. const auto result = PHINode::Create(value->getType(), IsOptional ? 3U : 2U, "result", done);
  33. if (IsOptional) {
  34. result->addIncoming(value, block);
  35. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  36. BranchInst::Create(done, good, IsEmpty(value, block, context), block);
  37. block = good;
  38. }
  39. const auto cast = CastInst::Create(Instruction::Trunc, value, Type::getInt1Ty(context), "bool", block);
  40. BranchInst::Create(then, elsb, cast, block);
  41. {
  42. block = then;
  43. const auto left = GetNodeValue(ThenBranch, ctx, block);
  44. result->addIncoming(left, block);
  45. BranchInst::Create(done, block);
  46. }
  47. {
  48. block = elsb;
  49. const auto right = GetNodeValue(ElseBranch, ctx, block);
  50. result->addIncoming(right, block);
  51. BranchInst::Create(done, block);
  52. }
  53. block = done;
  54. return result;
  55. }
  56. #endif
  57. private:
  58. void RegisterDependencies() const final {
  59. this->DependsOn(Predicate);
  60. this->DependsOn(ThenBranch);
  61. this->DependsOn(ElseBranch);
  62. }
  63. IComputationNode* const Predicate;
  64. IComputationNode* const ThenBranch;
  65. IComputationNode* const ElseBranch;
  66. };
  67. template<bool IsOptional>
  68. class TFlowIfWrapper : public TStatefulFlowCodegeneratorNode<TFlowIfWrapper<IsOptional>> {
  69. using TBaseComputation = TStatefulFlowCodegeneratorNode<TFlowIfWrapper<IsOptional>>;
  70. public:
  71. TFlowIfWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* predicate, IComputationNode* thenBranch, IComputationNode* elseBranch)
  72. : TBaseComputation(mutables, nullptr, kind)
  73. , Predicate(predicate)
  74. , ThenBranch(thenBranch)
  75. , ElseBranch(elseBranch)
  76. {}
  77. NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  78. if (state.IsInvalid()) {
  79. state = Predicate->GetValue(ctx);
  80. }
  81. if (IsOptional && !state) {
  82. return NUdf::TUnboxedValuePod::MakeFinish();
  83. }
  84. return (state.Get<bool>() ? ThenBranch : ElseBranch)->GetValue(ctx).Release();
  85. }
  86. #ifndef MKQL_DISABLE_CODEGEN
  87. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  88. auto& context = ctx.Codegen.GetContext();
  89. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  90. const auto test = BasicBlock::Create(context, "test", ctx.Func);
  91. const auto then = BasicBlock::Create(context, "then", ctx.Func);
  92. const auto elsb = BasicBlock::Create(context, "else", ctx.Func);
  93. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  94. BranchInst::Create(init, test, IsInvalid(statePtr, block, context), block);
  95. block = init;
  96. GetNodeValue(statePtr, Predicate, ctx, block);
  97. BranchInst::Create(test, block);
  98. block = test;
  99. const auto valueType = Type::getInt128Ty(context);
  100. const auto state = new LoadInst(valueType, statePtr, "state", block);
  101. const auto result = PHINode::Create(valueType, IsOptional ? 3U : 2U, "result", done);
  102. if (IsOptional) {
  103. result->addIncoming(state, block);
  104. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  105. BranchInst::Create(done, good, IsEmpty(state, block, context), block);
  106. block = good;
  107. }
  108. const auto cast = CastInst::Create(Instruction::Trunc, state, Type::getInt1Ty(context), "bool", block);
  109. BranchInst::Create(then, elsb, cast, block);
  110. {
  111. block = then;
  112. const auto left = GetNodeValue(ThenBranch, ctx, block);
  113. result->addIncoming(left, block);
  114. BranchInst::Create(done, block);
  115. }
  116. {
  117. block = elsb;
  118. const auto right = GetNodeValue(ElseBranch, ctx, block);
  119. result->addIncoming(right, block);
  120. BranchInst::Create(done, block);
  121. }
  122. block = done;
  123. return result;
  124. }
  125. #endif
  126. private:
  127. void RegisterDependencies() const final {
  128. if (const auto flow = this->FlowDependsOnBoth(ThenBranch, ElseBranch))
  129. this->DependsOn(flow, Predicate);
  130. }
  131. IComputationNode* const Predicate;
  132. IComputationNode* const ThenBranch;
  133. IComputationNode* const ElseBranch;
  134. };
  135. class TWideIfWrapper : public TStatefulWideFlowCodegeneratorNode<TWideIfWrapper> {
  136. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideIfWrapper>;
  137. public:
  138. TWideIfWrapper(TComputationMutables& mutables, IComputationNode* predicate, IComputationWideFlowNode* thenBranch, IComputationWideFlowNode* elseBranch)
  139. : TBaseComputation(mutables, nullptr, EValueRepresentation::Embedded)
  140. , Predicate(predicate)
  141. , ThenBranch(thenBranch)
  142. , ElseBranch(elseBranch)
  143. {}
  144. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  145. if (state.IsInvalid()) {
  146. state = Predicate->GetValue(ctx);
  147. }
  148. return (state.Get<bool>() ? ThenBranch : ElseBranch)->FetchValues(ctx, output);
  149. }
  150. #ifndef MKQL_DISABLE_CODEGEN
  151. TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  152. auto& context = ctx.Codegen.GetContext();
  153. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  154. const auto test = BasicBlock::Create(context, "test", ctx.Func);
  155. const auto then = BasicBlock::Create(context, "then", ctx.Func);
  156. const auto elsb = BasicBlock::Create(context, "else", ctx.Func);
  157. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  158. BranchInst::Create(init, test, IsInvalid(statePtr, block, context), block);
  159. block = init;
  160. GetNodeValue(statePtr, Predicate, ctx, block);
  161. BranchInst::Create(test, block);
  162. block = test;
  163. const auto valueType = Type::getInt128Ty(context);
  164. const auto state = new LoadInst(valueType, statePtr, "state", block);
  165. const auto result = PHINode::Create(Type::getInt32Ty(context), 2, "result", done);
  166. const auto cast = CastInst::Create(Instruction::Trunc, state, Type::getInt1Ty(context), "bool", block);
  167. BranchInst::Create(then, elsb, cast, block);
  168. block = then;
  169. const auto left = GetNodeValues(ThenBranch, ctx, block);
  170. result->addIncoming(left.first, block);
  171. BranchInst::Create(done, block);
  172. block = elsb;
  173. const auto right = GetNodeValues(ElseBranch, ctx, block);
  174. result->addIncoming(right.first, block);
  175. BranchInst::Create(done, block);
  176. block = done;
  177. MKQL_ENSURE(left.second.size() == right.second.size(), "Expected same width of flows.");
  178. TGettersList getters;
  179. getters.reserve(left.second.size());
  180. const auto index = static_cast<const IComputationNode*>(this)->GetIndex();
  181. size_t idx = 0U;
  182. std::generate_n(std::back_inserter(getters), right.second.size(), [&]() {
  183. const auto i = idx++;
  184. return [index, lget = left.second[i], rget = right.second[i]](const TCodegenContext& ctx, BasicBlock*& block) {
  185. auto& context = ctx.Codegen.GetContext();
  186. const auto then = BasicBlock::Create(context, "then", ctx.Func);
  187. const auto elsb = BasicBlock::Create(context, "elsb", ctx.Func);
  188. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  189. const auto valueType = Type::getInt128Ty(context);
  190. const auto result = PHINode::Create(valueType, 2, "result", done);
  191. const auto statePtr = GetElementPtrInst::CreateInBounds(valueType, ctx.GetMutables(), {ConstantInt::get(Type::getInt32Ty(context), index)}, "state_ptr", block);
  192. const auto state = new LoadInst(valueType, statePtr, "state", block);
  193. const auto trunc = CastInst::Create(Instruction::Trunc, state, Type::getInt1Ty(context), "trunc", block);
  194. BranchInst::Create(then, elsb, trunc, block);
  195. block = then;
  196. result->addIncoming(lget(ctx, block), block);
  197. BranchInst::Create(done, block);
  198. block = elsb;
  199. result->addIncoming(rget(ctx, block), block);
  200. BranchInst::Create(done, block);
  201. block = done;
  202. return result;
  203. };
  204. });
  205. return {result, std::move(getters)};
  206. }
  207. #endif
  208. private:
  209. void RegisterDependencies() const final {
  210. if (const auto flow = FlowDependsOnBoth(ThenBranch, ElseBranch))
  211. DependsOn(flow, Predicate);
  212. }
  213. IComputationNode* const Predicate;
  214. IComputationWideFlowNode* const ThenBranch;
  215. IComputationWideFlowNode* const ElseBranch;
  216. };
  217. }
  218. IComputationNode* WrapIf(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  219. MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args");
  220. bool isOptional;
  221. const auto predicateType = UnpackOptionalData(callable.GetInput(0), isOptional);
  222. MKQL_ENSURE(predicateType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool or optional of bool.");
  223. const auto predicate = LocateNode(ctx.NodeLocator, callable, 0);
  224. const auto thenBranch = LocateNode(ctx.NodeLocator, callable, 1);
  225. const auto elseBranch = LocateNode(ctx.NodeLocator, callable, 2);
  226. const auto type = callable.GetType()->GetReturnType();
  227. if (type->IsFlow()) {
  228. const auto thenWide = dynamic_cast<IComputationWideFlowNode*>(thenBranch);
  229. const auto elseWide = dynamic_cast<IComputationWideFlowNode*>(elseBranch);
  230. if (thenWide && elseWide && !isOptional)
  231. return new TWideIfWrapper(ctx.Mutables, predicate, thenWide, elseWide);
  232. else if (!thenWide && !elseWide) {
  233. if (isOptional)
  234. return new TFlowIfWrapper<true>(ctx.Mutables, GetValueRepresentation(type), predicate, thenBranch, elseBranch);
  235. else
  236. return new TFlowIfWrapper<false>(ctx.Mutables, GetValueRepresentation(type), predicate, thenBranch, elseBranch);
  237. }
  238. } else {
  239. if (isOptional) {
  240. return new TIfWrapper<true>(ctx.Mutables, GetValueRepresentation(type), predicate, thenBranch, elseBranch);
  241. } else {
  242. return new TIfWrapper<false>(ctx.Mutables, GetValueRepresentation(type), predicate, thenBranch, elseBranch);
  243. }
  244. }
  245. THROW yexception() << "Wrong signature.";
  246. }
  247. }
  248. }