mkql_withcontext.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. #include "mkql_withcontext.h"
  2. #include <yql/essentials/minikql/mkql_node_cast.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  5. #include <yql/essentials/parser/pg_wrapper/interface/context.h>
  6. #include <util/generic/scope.h>
  7. namespace NKikimr {
  8. namespace NMiniKQL {
  9. namespace {
  10. class TWithContextWrapper : public TMutableComputationNode<TWithContextWrapper> {
  11. typedef TMutableComputationNode<TWithContextWrapper> TBaseComputation;
  12. public:
  13. TWithContextWrapper(TComputationMutables& mutables, const std::string_view& contextType, IComputationNode* arg)
  14. : TBaseComputation(mutables)
  15. , Arg(arg)
  16. , ContextType(contextType)
  17. {}
  18. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
  19. auto prev = TlsAllocState->CurrentContext;
  20. TlsAllocState->CurrentContext = PgInitializeContext(ContextType);
  21. Y_DEFER {
  22. PgDestroyContext(ContextType, TlsAllocState->CurrentContext);
  23. TlsAllocState->CurrentContext = prev;
  24. };
  25. TPAllocScope scope;
  26. return Arg->GetValue(compCtx).Release();
  27. }
  28. private:
  29. void RegisterDependencies() const final {
  30. this->DependsOn(Arg);
  31. }
  32. IComputationNode* const Arg;
  33. const std::string_view ContextType;
  34. };
  35. struct TState : public TComputationValue<TState> {
  36. TState(TMemoryUsageInfo* memInfo, const std::string_view& contextType)
  37. : TComputationValue(memInfo)
  38. , ContextType(contextType)
  39. , Ctx(PgInitializeContext(ContextType))
  40. {
  41. Scope.Detach();
  42. }
  43. void Attach() {
  44. Scope.Attach();
  45. PrevContext = TlsAllocState->CurrentContext;
  46. TlsAllocState->CurrentContext = Ctx;
  47. }
  48. void Detach(const bool cleanup) {
  49. if (cleanup)
  50. Cleanup();
  51. Scope.Detach();
  52. TlsAllocState->CurrentContext = PrevContext;
  53. }
  54. ~TState() {
  55. Cleanup();
  56. }
  57. private:
  58. void Cleanup() {
  59. if (Ctx) {
  60. PgDestroyContext(ContextType, Ctx);
  61. Ctx = nullptr;
  62. Scope.Cleanup();
  63. }
  64. }
  65. const std::string_view ContextType;
  66. void* Ctx;
  67. TPAllocScope Scope;
  68. void* PrevContext = nullptr;
  69. };
  70. class TWithContextFlowWrapper : public TStatefulFlowCodegeneratorNode<TWithContextFlowWrapper> {
  71. using TBaseComputation = TStatefulFlowCodegeneratorNode<TWithContextFlowWrapper>;
  72. public:
  73. TWithContextFlowWrapper(TComputationMutables& mutables, const std::string_view& contextType,
  74. EValueRepresentation kind, IComputationNode* flow)
  75. : TBaseComputation(mutables, flow, kind, EValueRepresentation::Any)
  76. , Flow(flow)
  77. , ContextType(contextType)
  78. {}
  79. NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& stateValue, TComputationContext& ctx) const {
  80. if (!stateValue.HasValue()) {
  81. MakeState(ctx, stateValue);
  82. }
  83. auto& state = *static_cast<TState*>(stateValue.AsBoxed().Get());
  84. state.Attach();
  85. auto item = Flow->GetValue(ctx);
  86. state.Detach(item.IsFinish());
  87. return item.Release();
  88. }
  89. #ifndef MKQL_DISABLE_CODEGEN
  90. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  91. auto& context = ctx.Codegen.GetContext();
  92. const auto valueType = Type::getInt128Ty(context);
  93. const auto structPtrType = PointerType::getUnqual(StructType::get(context));
  94. const auto statePtrType = PointerType::getUnqual(structPtrType);
  95. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  96. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  97. BranchInst::Create(main, make, HasValue(statePtr, block, context), block);
  98. block = make;
  99. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  100. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  101. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWithContextFlowWrapper::MakeState));
  102. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  103. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  104. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  105. BranchInst::Create(main, block);
  106. block = main;
  107. const auto state = new LoadInst(valueType, statePtr, "state", block);
  108. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  109. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  110. const auto attachFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Attach));
  111. const auto attachFuncType = FunctionType::get(Type::getVoidTy(context), { statePtrType }, false);
  112. const auto attachFuncPtr = CastInst::Create(Instruction::IntToPtr, attachFunc, PointerType::getUnqual(attachFuncType), "attach", block);
  113. CallInst::Create(attachFuncType, attachFuncPtr, { stateArg }, "", block);
  114. const auto value = GetNodeValue(Flow, ctx, block);
  115. const auto finish = IsFinish(value, block, context);
  116. const auto detachFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Detach));
  117. const auto detachFuncType = FunctionType::get(Type::getVoidTy(context), { statePtrType, finish->getType() }, false);
  118. const auto detachFuncPtr = CastInst::Create(Instruction::IntToPtr, detachFunc, PointerType::getUnqual(detachFuncType), "detach", block);
  119. CallInst::Create(detachFuncType, detachFuncPtr, { stateArg, finish }, "", block);
  120. return value;
  121. }
  122. #endif
  123. private:
  124. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  125. state = ctx.HolderFactory.Create<TState>(ContextType);
  126. }
  127. void RegisterDependencies() const final {
  128. this->FlowDependsOn(Flow);
  129. }
  130. IComputationNode* const Flow;
  131. const std::string_view ContextType;
  132. };
  133. class TWithContextWideFlowWrapper : public TStatefulWideFlowCodegeneratorNode<TWithContextWideFlowWrapper> {
  134. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWithContextWideFlowWrapper>;
  135. public:
  136. TWithContextWideFlowWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow,
  137. const std::string_view& contextType)
  138. : TBaseComputation(mutables, flow, EValueRepresentation::Any)
  139. , Flow(flow)
  140. , ContextType(contextType)
  141. {}
  142. EFetchResult DoCalculate(NUdf::TUnboxedValue& stateValue, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  143. if (!stateValue.HasValue()) {
  144. MakeState(ctx, stateValue);
  145. }
  146. auto& state = *static_cast<TState*>(stateValue.AsBoxed().Get());
  147. state.Attach();
  148. const auto status = Flow->FetchValues(ctx, output);
  149. state.Detach(status == EFetchResult::Finish);
  150. return status;
  151. }
  152. #ifndef MKQL_DISABLE_CODEGEN
  153. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  154. auto& context = ctx.Codegen.GetContext();
  155. const auto valueType = Type::getInt128Ty(context);
  156. const auto indexType = Type::getInt32Ty(context);
  157. const auto structPtrType = PointerType::getUnqual(StructType::get(context));
  158. const auto statePtrType = PointerType::getUnqual(structPtrType);
  159. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  160. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  161. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  162. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  163. BranchInst::Create(main, make, HasValue(statePtr, block, context), block);
  164. block = make;
  165. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  166. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  167. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWithContextWideFlowWrapper::MakeState));
  168. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  169. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  170. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  171. BranchInst::Create(main, block);
  172. block = main;
  173. const auto state = new LoadInst(statePtrType, statePtr, "state", block);
  174. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  175. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  176. const auto attachFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Attach));
  177. const auto attachFuncType = FunctionType::get(Type::getVoidTy(context), { statePtrType }, false);
  178. const auto attachFuncPtr = CastInst::Create(Instruction::IntToPtr, attachFunc, PointerType::getUnqual(attachFuncType), "attach", block);
  179. CallInst::Create(attachFuncType, attachFuncPtr, { stateArg }, "", block);
  180. auto getres = GetNodeValues(Flow, ctx, block);
  181. const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Yield)), "special", block);
  182. BranchInst::Create(exit, good, special, block);
  183. block = good;
  184. const auto arrayType = ArrayType::get(valueType, getres.second.size());
  185. const auto arrayPtr = new AllocaInst(arrayType, 0U, "array_ptr", &ctx.Func->getEntryBlock().back());
  186. Value* array = UndefValue::get(arrayType);
  187. for (auto idx = 0U; idx < getres.second.size(); ++idx) {
  188. const auto item = getres.second[idx](ctx, block);
  189. array = InsertValueInst::Create(array, item, {idx}, (TString("value_") += ToString(idx)).c_str(), block);
  190. }
  191. new StoreInst(array, arrayPtr, block);
  192. BranchInst::Create(exit, block);
  193. block = exit;
  194. const auto finish = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Finish)), "finish", block);
  195. const auto detachFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Detach));
  196. const auto detachFuncType = FunctionType::get(Type::getVoidTy(context), { statePtrType, finish->getType() }, false);
  197. const auto detachFuncPtr = CastInst::Create(Instruction::IntToPtr, detachFunc, PointerType::getUnqual(detachFuncType), "detach", block);
  198. CallInst::Create(detachFuncType, detachFuncPtr, { stateArg, finish }, "", block);
  199. for (auto idx = 0U; idx < getres.second.size(); ++idx) {
  200. getres.second[idx] = [idx, arrayPtr, arrayType, indexType, valueType] (const TCodegenContext& ctx, BasicBlock*& block) {
  201. Y_UNUSED(ctx);
  202. const auto itemPtr = GetElementPtrInst::CreateInBounds(arrayType, arrayPtr, {ConstantInt::get(indexType, 0), ConstantInt::get(indexType, idx)}, (TString("ptr_") += ToString(idx)).c_str(), block);
  203. return new LoadInst(valueType, itemPtr, (TString("item_") += ToString(idx)).c_str(), block);
  204. };
  205. }
  206. return getres;
  207. }
  208. #endif
  209. private:
  210. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  211. state = ctx.HolderFactory.Create<TState>(ContextType);
  212. }
  213. void RegisterDependencies() const final {
  214. this->FlowDependsOn(Flow);
  215. }
  216. IComputationWideFlowNode* const Flow;
  217. const std::string_view ContextType;
  218. };
  219. }
  220. IComputationNode* WrapWithContext(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  221. const auto contextTypeData = AS_VALUE(TDataLiteral, callable.GetInput(0));
  222. const auto contextType = contextTypeData->AsValue().AsStringRef();
  223. const auto arg = LocateNode(ctx.NodeLocator, callable, 1);
  224. if (const auto type = callable.GetType()->GetReturnType(); type->IsFlow()) {
  225. if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(arg)) {
  226. return new TWithContextWideFlowWrapper(ctx.Mutables, wide, contextType);
  227. } else {
  228. return new TWithContextFlowWrapper(ctx.Mutables, contextType, GetValueRepresentation(type), arg);
  229. }
  230. } else {
  231. MKQL_ENSURE(!callable.GetInput(1).GetStaticType()->IsStream(), "Stream is not expected here");
  232. return new TWithContextWrapper(ctx.Mutables, contextType, arg);
  233. }
  234. }
  235. }
  236. }