mkql_wide_chain_map.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. #include "mkql_wide_chain_map.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/computation/mkql_custom_list.h>
  5. #include <yql/essentials/minikql/mkql_node_cast.h>
  6. #include <yql/essentials/utils/cast.h>
  7. namespace NKikimr {
  8. namespace NMiniKQL {
  9. using NYql::EnsureDynamicCast;
  10. namespace {
  11. class TWideChain1MapWrapper : public TStatefulWideFlowCodegeneratorNode<TWideChain1MapWrapper> {
  12. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideChain1MapWrapper>;
  13. public:
  14. TWideChain1MapWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow,
  15. TComputationExternalNodePtrVector&& inputs,
  16. TComputationNodePtrVector&& initItems,
  17. TComputationExternalNodePtrVector&& outputs,
  18. TComputationNodePtrVector&& updateItems)
  19. : TBaseComputation(mutables, flow, EValueRepresentation::Embedded)
  20. , Flow(flow)
  21. , Inputs(std::move(inputs))
  22. , InitItems(std::move(initItems))
  23. , Outputs(std::move(outputs))
  24. , UpdateItems(std::move(updateItems))
  25. , InputsOnInit(GetPasstroughtMapOneToOne(Inputs, InitItems))
  26. , InputsOnUpdate(GetPasstroughtMapOneToOne(Inputs, UpdateItems))
  27. , InitOnInputs(GetPasstroughtMapOneToOne(InitItems, Inputs))
  28. , UpdateOnInputs(GetPasstroughtMapOneToOne(UpdateItems, Inputs))
  29. , OutputsOnUpdate(GetPasstroughtMapOneToOne(Outputs, UpdateItems))
  30. , UpdateOnOutputs(GetPasstroughtMapOneToOne(UpdateItems, Outputs))
  31. , WideFieldsIndex(mutables.IncrementWideFieldsIndex(Inputs.size()))
  32. , TempStateIndex(std::exchange(mutables.CurValueIndex, mutables.CurValueIndex + Outputs.size()))
  33. {}
  34. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  35. if (state.IsInvalid()) {
  36. state = NUdf::TUnboxedValuePod();
  37. return CalculateFirst(ctx, output);
  38. }
  39. return CalculateOther(ctx, output);
  40. }
  41. #ifndef MKQL_DISABLE_CODEGEN
  42. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  43. auto& context = ctx.Codegen.GetContext();
  44. const auto flagType = Type::getInt1Ty(context);
  45. const auto flagPtr = new AllocaInst(flagType, 0U, "flag_ptr", &ctx.Func->getEntryBlock().back());
  46. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  47. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  48. const auto getres = GetNodeValues(Flow, ctx, block);
  49. const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), 0), "special", block);
  50. BranchInst::Create(done, good, special, block);
  51. block = good;
  52. for (auto i = 0U; i < Inputs.size(); ++i)
  53. if (Inputs[i]->GetDependencesCount() > 0U || !InputsOnInit[i] || !InputsOnUpdate[i])
  54. EnsureDynamicCast<ICodegeneratorExternalNode*>(Inputs[i])->CreateSetValue(ctx, block, getres.second[i](ctx, block));
  55. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  56. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  57. const auto flag = IsInvalid(statePtr, block, context);
  58. new StoreInst(flag, flagPtr, block);
  59. BranchInst::Create(init, next, flag, block);
  60. block = init;
  61. for (auto i = 0U; i < Outputs.size(); ++i) {
  62. if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i]) {
  63. const auto& map = InitOnInputs[i];
  64. const auto value = map ? getres.second[*map](ctx, block) : GetNodeValue(InitItems[i], ctx, block);
  65. EnsureDynamicCast<ICodegeneratorExternalNode*>(Outputs[i])->CreateSetValue(ctx, block, value);
  66. }
  67. }
  68. new StoreInst(GetEmpty(context), statePtr, block);
  69. BranchInst::Create(done, block);
  70. block = next;
  71. std::vector<Value*> outputs(Outputs.size(), nullptr);
  72. for (auto i = 0U; i < outputs.size(); ++i) {
  73. if (const auto& dep = OutputsOnUpdate[i]; Outputs[i]->GetDependencesCount() > 0U || (dep && *dep != i)) {
  74. const auto& map = UpdateOnInputs[i];
  75. outputs[i] = map ? getres.second[*map](ctx, block) : GetNodeValue(UpdateItems[i], ctx, block);
  76. }
  77. }
  78. for (auto i = 0U; i < outputs.size(); ++i)
  79. if (const auto out = outputs[i])
  80. EnsureDynamicCast<ICodegeneratorExternalNode*>(Outputs[i])->CreateSetValue(ctx, block, out);
  81. BranchInst::Create(done, block);
  82. block = done;
  83. ICodegeneratorInlineWideNode::TGettersList result;
  84. result.reserve(Outputs.size());
  85. for (auto i = 0U; i < Outputs.size(); ++i) {
  86. if (const auto& one = InitOnInputs[i], two = UpdateOnInputs[i]; one && two && *one == *two)
  87. result.emplace_back(getres.second[*two]);
  88. else if (Outputs[i]->GetDependencesCount() > 0 || OutputsOnUpdate[i])
  89. result.emplace_back([output = Outputs[i]] (const TCodegenContext& ctx, BasicBlock*& block) { return GetNodeValue(output, ctx, block); });
  90. else
  91. result.emplace_back([this, i, source = getres.second, flagPtr, flagType] (const TCodegenContext& ctx, BasicBlock*& block) {
  92. auto& context = ctx.Codegen.GetContext();
  93. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  94. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  95. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  96. const auto result = PHINode::Create(Type::getInt128Ty(context), 2U, "result", done);
  97. const auto flag = new LoadInst(flagType, flagPtr, "flag", block);
  98. BranchInst::Create(init, next, flag, block);
  99. block = init;
  100. if (const auto& map = InitOnInputs[i])
  101. result->addIncoming(source[*map](ctx, block), block);
  102. else
  103. result->addIncoming(GetNodeValue(InitItems[i], ctx, block), block);
  104. BranchInst::Create(done, block);
  105. block = next;
  106. if (const auto& map = UpdateOnInputs[i])
  107. result->addIncoming(source[*map](ctx, block), block);
  108. else
  109. result->addIncoming(GetNodeValue(UpdateItems[i], ctx, block), block);
  110. BranchInst::Create(done, block);
  111. block = done;
  112. return result;
  113. });
  114. };
  115. return {getres.first, std::move(result)};
  116. }
  117. #endif
  118. private:
  119. EFetchResult CalculateFirst(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  120. auto** fields = ctx.WideFields.data() + WideFieldsIndex;
  121. for (auto i = 0U; i < Inputs.size(); ++i) {
  122. if (const auto& map = InputsOnInit[i]; map && !Inputs[i]->GetDependencesCount()) {
  123. if (const auto& to = UpdateOnOutputs[*map]) {
  124. fields[i] = &Outputs[*to]->RefValue(ctx);
  125. continue;
  126. } else if (const auto out = output[*map]) {
  127. fields[i] = out;
  128. continue;
  129. }
  130. } else {
  131. fields[i] = &Inputs[i]->RefValue(ctx);
  132. continue;
  133. }
  134. fields[i] = nullptr;
  135. }
  136. if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result)
  137. return result;
  138. for (auto i = 0U; i < Outputs.size(); ++i) {
  139. if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i]) {
  140. if (const auto& map = InitOnInputs[i]; !map || Inputs[*map]->GetDependencesCount() > 0U) {
  141. Outputs[i]->SetValue(ctx, InitItems[i]->GetValue(ctx));
  142. }
  143. }
  144. }
  145. for (auto i = 0U; i < Outputs.size(); ++i) {
  146. if (const auto out = output[i]) {
  147. if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i])
  148. *out = Outputs[i]->GetValue(ctx);
  149. else {
  150. if (const auto& map = InitOnInputs[i]) {
  151. if (const auto from = *map; !Inputs[from]->GetDependencesCount()) {
  152. if (const auto first = *InputsOnInit[from]; first != i)
  153. *out = *output[first];
  154. continue;
  155. }
  156. }
  157. *out = InitItems[i]->GetValue(ctx);
  158. }
  159. }
  160. }
  161. return EFetchResult::One;
  162. }
  163. EFetchResult CalculateOther(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  164. auto** fields = ctx.WideFields.data() + WideFieldsIndex;
  165. for (auto i = 0U; i < Inputs.size(); ++i) {
  166. if (const auto& map = InputsOnUpdate[i]; map && !Inputs[i]->GetDependencesCount()) {
  167. if (const auto out = output[*map]) {
  168. fields[i] = out;
  169. continue;
  170. }
  171. } else {
  172. fields[i] = &Inputs[i]->RefValue(ctx);
  173. continue;
  174. }
  175. fields[i] = nullptr;
  176. }
  177. if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result)
  178. return result;
  179. for (auto i = 0U; i < Outputs.size(); ++i) {
  180. if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i]) {
  181. if (const auto& map = UpdateOnInputs[i]; !map || Inputs[*map]->GetDependencesCount() > 0U) {
  182. ctx.MutableValues[TempStateIndex + i] = UpdateItems[i]->GetValue(ctx);
  183. }
  184. }
  185. }
  186. for (auto i = 0U; i < Outputs.size(); ++i) {
  187. if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i]) {
  188. if (const auto& map = UpdateOnInputs[i]; !map || Inputs[*map]->GetDependencesCount() > 0U) {
  189. Outputs[i]->SetValue(ctx, std::move(ctx.MutableValues[TempStateIndex + i]));
  190. }
  191. }
  192. }
  193. for (auto i = 0U; i < Outputs.size(); ++i) {
  194. if (const auto out = output[i]) {
  195. if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i])
  196. *out = Outputs[i]->GetValue(ctx);
  197. else {
  198. if (const auto& map = UpdateOnInputs[i]) {
  199. if (const auto from = *map; !Inputs[from]->GetDependencesCount()) {
  200. if (const auto first = *InputsOnUpdate[from]; first != i)
  201. *out = *output[first];
  202. continue;
  203. }
  204. }
  205. *out = UpdateItems[i]->GetValue(ctx);
  206. }
  207. }
  208. }
  209. return EFetchResult::One;
  210. }
  211. void RegisterDependencies() const final {
  212. if (const auto flow = FlowDependsOn(Flow)) {
  213. std::for_each(Inputs.cbegin(), Inputs.cend(), std::bind(&TWideChain1MapWrapper::Own, flow, std::placeholders::_1));
  214. std::for_each(Outputs.cbegin(), Outputs.cend(), std::bind(&TWideChain1MapWrapper::Own, flow, std::placeholders::_1));
  215. std::for_each(InitItems.cbegin(), InitItems.cend(), std::bind(&TWideChain1MapWrapper::DependsOn, flow, std::placeholders::_1));
  216. std::for_each(UpdateItems.cbegin(), UpdateItems.cend(), std::bind(&TWideChain1MapWrapper::DependsOn, flow, std::placeholders::_1));
  217. }
  218. }
  219. IComputationWideFlowNode* const Flow;
  220. const TComputationExternalNodePtrVector Inputs;
  221. const TComputationNodePtrVector InitItems;
  222. const TComputationExternalNodePtrVector Outputs;
  223. const TComputationNodePtrVector UpdateItems;
  224. const TPasstroughtMap InputsOnInit, InputsOnUpdate, InitOnInputs, UpdateOnInputs, OutputsOnUpdate, UpdateOnOutputs;
  225. const ui32 WideFieldsIndex;
  226. const ui32 TempStateIndex;
  227. };
  228. }
  229. IComputationNode* WrapWideChain1Map(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  230. MKQL_ENSURE(callable.GetInputsCount() > 0U, "Expected argument.");
  231. const auto inputWidth = GetWideComponentsCount(AS_TYPE(TFlowType, callable.GetInput(0U).GetStaticType()));
  232. const auto outputWidth = GetWideComponentsCount(AS_TYPE(TFlowType, callable.GetType()->GetReturnType()));
  233. MKQL_ENSURE(callable.GetInputsCount() == inputWidth + outputWidth * 3U + 1U, "Wrong signature.");
  234. const auto flow = LocateNode(ctx.NodeLocator, callable, 0U);
  235. if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) {
  236. TComputationNodePtrVector initOutput(outputWidth, nullptr), updateOutput(outputWidth, nullptr);
  237. auto index = inputWidth;
  238. std::generate(initOutput.begin(), initOutput.end(), [&](){ return LocateNode(ctx.NodeLocator, callable, ++index); });
  239. index += outputWidth;
  240. std::generate(updateOutput.begin(), updateOutput.end(), [&](){ return LocateNode(ctx.NodeLocator, callable, ++index); });
  241. TComputationExternalNodePtrVector inputs(inputWidth, nullptr), outputs(outputWidth, nullptr);
  242. index = 0U;
  243. std::generate(inputs.begin(), inputs.end(), [&](){ return LocateExternalNode(ctx.NodeLocator, callable, ++index); });
  244. index += outputWidth;
  245. std::generate(outputs.begin(), outputs.end(), [&](){ return LocateExternalNode(ctx.NodeLocator, callable, ++index); });
  246. return new TWideChain1MapWrapper(ctx.Mutables, wide, std::move(inputs), std::move(initOutput), std::move(outputs), std::move(updateOutput));
  247. }
  248. THROW yexception() << "Expected wide flow.";
  249. }
  250. }
  251. }