mkql_wide_chopper.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. #include "mkql_chopper.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/mkql_node_cast.h>
  4. #include <yql/essentials/utils/cast.h>
  5. namespace NKikimr {
  6. namespace NMiniKQL {
  7. using NYql::EnsureDynamicCast;
  8. namespace {
  9. using namespace std::placeholders;
  10. class TWideChopperWrapper : public TStatefulWideFlowCodegeneratorNode<TWideChopperWrapper> {
  11. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideChopperWrapper>;
  12. public:
  13. enum class EState : ui64 {
  14. Work,
  15. Chop,
  16. Next,
  17. Skip
  18. };
  19. TWideChopperWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& itemArgs, TComputationNodePtrVector&& keys, TComputationExternalNodePtrVector&& keyArgs, IComputationNode* chop, IComputationWideFlowProxyNode* input, IComputationWideFlowNode* output)
  20. : TBaseComputation(mutables, flow, EValueRepresentation::Any)
  21. , Flow(flow)
  22. , ItemArgs(std::move(itemArgs))
  23. , Keys(std::move(keys))
  24. , KeyArgs(std::move(keyArgs))
  25. , Chop(chop)
  26. , Input(input)
  27. , Output(output)
  28. , ItemsOnKeys(GetPasstroughtMap(ItemArgs, Keys))
  29. , KeysOnItems(GetPasstroughtMap(Keys, ItemArgs))
  30. , SwitchItem(IsPasstrought(Chop, ItemArgs))
  31. , WideFieldsIndex(mutables.IncrementWideFieldsIndex(ItemArgs.size()))
  32. {
  33. Input->SetFetcher(std::bind(&TWideChopperWrapper::DoCalculateInput, this, std::bind(&TWideChopperWrapper::RefState, this, _1), _1, _2));
  34. }
  35. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  36. auto** fields = ctx.WideFields.data() + WideFieldsIndex;
  37. if (state.IsInvalid()) {
  38. for (auto i = 0U; i < ItemArgs.size(); ++i)
  39. fields[i] = &ItemArgs[i]->RefValue(ctx);
  40. if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result)
  41. return result;
  42. for (ui32 i = 0U; i < Keys.size(); ++i)
  43. if (KeyArgs[i]->GetDependencesCount() > 0U)
  44. KeyArgs[i]->SetValue(ctx, Keys[i]->GetValue(ctx));
  45. state = NUdf::TUnboxedValuePod(ui64(EState::Next));
  46. } else if (EState::Skip == EState(state.Get<ui64>())) {
  47. do {
  48. for (auto i = 0U; i < ItemArgs.size(); ++i)
  49. fields[i] = &ItemArgs[i]->RefValue(ctx);
  50. if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result)
  51. return result;
  52. } while (!Chop->GetValue(ctx).Get<bool>());
  53. for (ui32 i = 0U; i < Keys.size(); ++i)
  54. if (KeyArgs[i]->GetDependencesCount() > 0U)
  55. KeyArgs[i]->SetValue(ctx, Keys[i]->GetValue(ctx));
  56. state = NUdf::TUnboxedValuePod(ui64(EState::Next));
  57. }
  58. while (true) {
  59. if (const auto result = Output->FetchValues(ctx, output); EFetchResult::Finish == result) {
  60. Input->InvalidateValue(ctx);
  61. switch (EState(state.Get<ui64>())) {
  62. case EState::Work:
  63. case EState::Next:
  64. do {
  65. for (auto i = 0U; i < ItemArgs.size(); ++i)
  66. fields[i] = &ItemArgs[i]->RefValue(ctx);
  67. switch (const auto next = Flow->FetchValues(ctx, fields)) {
  68. case EFetchResult::Yield:
  69. state = NUdf::TUnboxedValuePod(ui64(EState::Skip));
  70. case EFetchResult::Finish:
  71. return next;
  72. case EFetchResult::One:
  73. break;
  74. }
  75. } while (!Chop->GetValue(ctx).Get<bool>());
  76. case EState::Chop:
  77. for (ui32 i = 0U; i < Keys.size(); ++i)
  78. if (KeyArgs[i]->GetDependencesCount() > 0U)
  79. KeyArgs[i]->SetValue(ctx, Keys[i]->GetValue(ctx));
  80. state = NUdf::TUnboxedValuePod(ui64(EState::Next));
  81. default:
  82. continue;
  83. }
  84. } else
  85. return result;
  86. }
  87. }
  88. private:
  89. EFetchResult DoCalculateInput(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  90. if (EState::Next == EState(state.Get<ui64>())) {
  91. state = NUdf::TUnboxedValuePod(ui64(EState::Work));
  92. for (auto i = 0U; i < ItemArgs.size(); ++i)
  93. if (const auto out = output[i])
  94. *out = ItemArgs[i]->GetValue(ctx);
  95. return EFetchResult::One;
  96. }
  97. auto** fields = ctx.WideFields.data() + WideFieldsIndex;
  98. for (auto i = 0U; i < ItemArgs.size(); ++i)
  99. fields[i] = &ItemArgs[i]->RefValue(ctx);
  100. if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result)
  101. return result;
  102. for (auto i = 0U; i < ItemArgs.size(); ++i)
  103. if (const auto out = output[i])
  104. *out = *fields[i];
  105. if (Chop->GetValue(ctx).Get<bool>()) {
  106. state = NUdf::TUnboxedValuePod(ui64(EState::Chop));
  107. return EFetchResult::Finish;
  108. }
  109. return EFetchResult::One;
  110. }
  111. #ifndef MKQL_DISABLE_CODEGEN
  112. TGenerateResult DoGenGetValuesInput(const TCodegenContext& ctx, BasicBlock*& block) const {
  113. auto& context = ctx.Codegen.GetContext();
  114. const auto load = BasicBlock::Create(context, "load", ctx.Func);
  115. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  116. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  117. const auto resultType = Type::getInt32Ty(context);
  118. const auto result = PHINode::Create(resultType, 4U, "result", done);
  119. const auto valueType = Type::getInt128Ty(context);
  120. const auto statePtr = GetElementPtrInst::CreateInBounds(valueType, ctx.GetMutables(), {ConstantInt::get(Type::getInt32Ty(context), static_cast<const IComputationNode*>(this)->GetIndex())}, "state_ptr", block);
  121. const auto entry = new LoadInst(valueType, statePtr, "entry", block);
  122. const auto next = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, entry, GetConstant(ui64(EState::Next), context), "next", block);
  123. BranchInst::Create(load, work, next, block);
  124. block = load;
  125. new StoreInst(GetConstant(ui64(EState::Work), context), statePtr, block);
  126. result->addIncoming(ConstantInt::get(resultType, i32(EFetchResult::One)), block);
  127. BranchInst::Create(done, block);
  128. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  129. const auto step = BasicBlock::Create(context, "step", ctx.Func);
  130. block = work;
  131. auto getres = GetNodeValues(Flow, ctx, block);
  132. const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), 0), "special", block);
  133. result->addIncoming(getres.first, block);
  134. BranchInst::Create(done, good, special, block);
  135. block = good;
  136. std::vector<Value*> items(ItemArgs.size(), nullptr);
  137. for (ui32 i = 0U; i < items.size(); ++i) {
  138. EnsureDynamicCast<ICodegeneratorExternalNode*>(ItemArgs[i])->CreateSetValue(ctx, block, items[i] = getres.second[i](ctx, block));
  139. }
  140. const auto chop = SwitchItem ? items[*SwitchItem] : GetNodeValue(Chop, ctx, block);
  141. const auto cast = CastInst::Create(Instruction::Trunc, chop, Type::getInt1Ty(context), "bool", block);
  142. result->addIncoming(ConstantInt::get(resultType, i32(EFetchResult::One)), block);
  143. BranchInst::Create(step, done, cast, block);
  144. block = step;
  145. new StoreInst(GetConstant(ui64(EState::Chop), context), statePtr, block);
  146. result->addIncoming(ConstantInt::get(resultType, i32(EFetchResult::Finish)), block);
  147. BranchInst::Create(done, block);
  148. block = done;
  149. ICodegeneratorInlineWideNode::TGettersList getters;
  150. getters.reserve(ItemArgs.size());
  151. std::transform(ItemArgs.cbegin(), ItemArgs.cend(), std::back_inserter(getters), [&](IComputationNode* node) {
  152. return [node](const TCodegenContext& ctx, BasicBlock*& block){ return GetNodeValue(node, ctx, block); };
  153. });
  154. return {result, std::move(getters)};
  155. }
  156. public:
  157. TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  158. EnsureDynamicCast<IWideFlowProxyCodegeneratorNode*>(Input)->SetGenerator(std::bind(&TWideChopperWrapper::DoGenGetValuesInput, this, _1, _2));
  159. auto& context = ctx.Codegen.GetContext();
  160. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  161. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  162. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  163. const auto pass = BasicBlock::Create(context, "pass", ctx.Func);
  164. const auto resultType = Type::getInt32Ty(context);
  165. const auto result = PHINode::Create(resultType, 5U, "result", exit);
  166. const auto valueType = Type::getInt128Ty(context);
  167. const auto first = new LoadInst(valueType, statePtr, "first", block);
  168. const auto enter = SwitchInst::Create(first, loop, 2U, block);
  169. enter->addCase(GetInvalid(context), init);
  170. enter->addCase(GetConstant(ui64(EState::Skip), context), pass);
  171. {
  172. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  173. block = init;
  174. const auto getfirst = GetNodeValues(Flow, ctx, block);
  175. const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getfirst.first, ConstantInt::get(getfirst.first->getType(), 0), "special", block);
  176. result->addIncoming(getfirst.first, block);
  177. BranchInst::Create(exit, next, special, block);
  178. block = next;
  179. new StoreInst(GetConstant(ui64(EState::Next), context), statePtr, block);
  180. std::vector<Value*> items(ItemArgs.size(), nullptr);
  181. for (ui32 i = 0U; i < items.size(); ++i) {
  182. EnsureDynamicCast<ICodegeneratorExternalNode*>(ItemArgs[i])->CreateSetValue(ctx, block, items[i] = getfirst.second[i](ctx, block));
  183. }
  184. for (ui32 i = 0U; i < Keys.size(); ++i) {
  185. if (KeyArgs[i]->GetDependencesCount() > 0U) {
  186. const auto map = KeysOnItems[i];
  187. const auto key = map ? items[*map] : GetNodeValue(Keys[i], ctx, block);
  188. EnsureDynamicCast<ICodegeneratorExternalNode*>(KeyArgs[i])->CreateSetValue(ctx, block, key);
  189. }
  190. }
  191. BranchInst::Create(loop, block);
  192. }
  193. const auto part = BasicBlock::Create(context, "part", ctx.Func);
  194. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  195. const auto step = BasicBlock::Create(context, "step", ctx.Func);
  196. const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
  197. block = loop;
  198. auto getres = GetNodeValues(Output, ctx, block);
  199. const auto state = new LoadInst(valueType, statePtr, "state", block);
  200. const auto finish = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLT, getres.first, ConstantInt::get(getres.first->getType(), 0), "finish", block);
  201. result->addIncoming(getres.first, block);
  202. BranchInst::Create(part, exit, finish, block);
  203. block = part;
  204. EnsureDynamicCast<IWideFlowProxyCodegeneratorNode*>(Input)->CreateInvalidate(ctx, block);
  205. result->addIncoming(ConstantInt::get(resultType, i32(EFetchResult::Finish)), block);
  206. const auto choise = SwitchInst::Create(state, exit, 3U, block);
  207. choise->addCase(GetConstant(ui64(EState::Next), context), pass);
  208. choise->addCase(GetConstant(ui64(EState::Work), context), pass);
  209. choise->addCase(GetConstant(ui64(EState::Chop), context), step);
  210. block = pass;
  211. const auto getnext = GetNodeValues(Flow, ctx, block);
  212. result->addIncoming(getnext.first, block);
  213. const auto way = SwitchInst::Create(getnext.first, good, 2U, block);
  214. way->addCase(ConstantInt::get(resultType, i32(EFetchResult::Finish)), exit);
  215. way->addCase(ConstantInt::get(resultType, i32(EFetchResult::Yield)), skip);
  216. block = good;
  217. std::vector<Value*> items(ItemArgs.size(), nullptr);
  218. for (ui32 i = 0U; i < items.size(); ++i) {
  219. EnsureDynamicCast<ICodegeneratorExternalNode*>(ItemArgs[i])->CreateSetValue(ctx, block, items[i] = getnext.second[i](ctx, block));
  220. }
  221. const auto chop = SwitchItem ? items[*SwitchItem] : GetNodeValue(Chop, ctx, block);
  222. const auto cast = CastInst::Create(Instruction::Trunc, chop, Type::getInt1Ty(context), "bool", block);
  223. BranchInst::Create(step, pass, cast, block);
  224. block = step;
  225. new StoreInst(GetConstant(ui64(EState::Next), context), statePtr, block);
  226. for (ui32 i = 0U; i < Keys.size(); ++i) {
  227. if (KeyArgs[i]->GetDependencesCount() > 0U) {
  228. const auto key = GetNodeValue(Keys[i], ctx, block);
  229. EnsureDynamicCast<ICodegeneratorExternalNode*>(KeyArgs[i])->CreateSetValue(ctx, block, key);
  230. }
  231. }
  232. BranchInst::Create(loop, block);
  233. block = skip;
  234. new StoreInst(GetConstant(ui64(EState::Skip), context), statePtr, block);
  235. result->addIncoming(ConstantInt::get(resultType, i32(EFetchResult::Yield)), block);
  236. BranchInst::Create(exit, block);
  237. block = exit;
  238. return {result, std::move(getres.second)};
  239. }
  240. #endif
  241. private:
  242. void RegisterDependencies() const final {
  243. if (const auto flow = FlowDependsOn(Flow)) {
  244. std::for_each(ItemArgs.cbegin(), ItemArgs.cend(), std::bind(&TWideChopperWrapper::Own, flow, std::placeholders::_1));
  245. std::for_each(Keys.cbegin(), Keys.cend(), std::bind(&TWideChopperWrapper::DependsOn, flow, std::placeholders::_1));
  246. std::for_each(KeyArgs.cbegin(), KeyArgs.cend(), std::bind(&TWideChopperWrapper::Own, flow, std::placeholders::_1));
  247. OwnProxy(flow, Input);
  248. DependsOn(flow, Output);
  249. }
  250. }
  251. IComputationWideFlowNode *const Flow;
  252. const TComputationExternalNodePtrVector ItemArgs;
  253. const TComputationNodePtrVector Keys;
  254. const TComputationExternalNodePtrVector KeyArgs;
  255. IComputationNode *const Chop;
  256. IComputationWideFlowProxyNode *const Input;
  257. IComputationWideFlowNode *const Output;
  258. const TPasstroughtMap ItemsOnKeys, KeysOnItems;
  259. const std::optional<size_t> SwitchItem;
  260. const ui32 WideFieldsIndex;
  261. };
  262. }
  263. IComputationNode* WrapWideChopper(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  264. MKQL_ENSURE(callable.GetInputsCount() >= 4U, "Expected at least four args.");
  265. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, callable.GetInput(0U).GetStaticType()));
  266. const ui32 width = wideComponents.size();
  267. const auto flow = LocateNode(ctx.NodeLocator, callable, 0U);
  268. const auto keysSize = (callable.GetInputsCount() - width - 4U) >> 1U;
  269. TComputationNodePtrVector keys;
  270. keys.reserve(keysSize);
  271. auto index = width;
  272. std::generate_n(std::back_inserter(keys), keysSize, [&](){ return LocateNode(ctx.NodeLocator, callable, ++index); } );
  273. index += keysSize;
  274. const auto switchResult = LocateNode(ctx.NodeLocator, callable, ++index);
  275. const auto input = LocateNode(ctx.NodeLocator, callable, ++index, true);
  276. const auto output = LocateNode(ctx.NodeLocator, callable, ++index, true);
  277. TComputationExternalNodePtrVector itemArgs, keyArgs;
  278. itemArgs.reserve(width);
  279. index = 0U;
  280. std::generate_n(std::back_inserter(itemArgs), width, [&](){ return LocateExternalNode(ctx.NodeLocator, callable, ++index); } );
  281. index += keysSize;
  282. keyArgs.reserve(keysSize);
  283. std::generate_n(std::back_inserter(keyArgs), keysSize, [&](){ return LocateExternalNode(ctx.NodeLocator, callable, ++index); } );
  284. if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) {
  285. return new TWideChopperWrapper(ctx.Mutables, wide, std::move(itemArgs), std::move(keys), std::move(keyArgs), switchResult,
  286. EnsureDynamicCast<IComputationWideFlowProxyNode*>(input),
  287. EnsureDynamicCast<IComputationWideFlowNode*>(output));
  288. }
  289. THROW yexception() << "Expected wide flow.";
  290. }
  291. }
  292. }