mkql_simple_codegen.cpp 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  2. #include "mkql_simple_codegen.h"
  3. namespace NKikimr {
  4. namespace NMiniKQL {
  5. #ifndef MKQL_DISABLE_CODEGEN
  6. ICodegeneratorInlineWideNode::TGenerateResult TSimpleStatefulWideFlowCodegeneratorNodeLLVMBase::DoGenGetValues(const NKikimr::NMiniKQL::TCodegenContext &ctx, llvm::Value *statePtrVal, llvm::BasicBlock *&genToBlock) const {
  7. // init stuff (mainly in global entry block)
  8. auto& context = ctx.Codegen.GetContext();
  9. const auto valueType = Type::getInt128Ty(context);
  10. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  11. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  12. const auto loopFetch = BasicBlock::Create(context, "loop_fetch", ctx.Func);
  13. const auto loopCalc = BasicBlock::Create(context, "loop_calc", ctx.Func);
  14. const auto loopTail = BasicBlock::Create(context, "loop_tail", ctx.Func);
  15. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  16. const auto entryPos = &ctx.Func->getEntryBlock().back();
  17. const auto thisType = StructType::get(context)->getPointerTo();
  18. const auto thisRawVal = ConstantInt::get(Type::getInt64Ty(context), PtrTable.ThisPtr);
  19. const auto thisVal = CastInst::Create(Instruction::IntToPtr, thisRawVal, thisType, "this", entryPos);
  20. const auto valuePtrType = PointerType::getUnqual(valueType);
  21. const auto valuePtrsPtrType = PointerType::getUnqual(valuePtrType);
  22. const auto statePtrType = statePtrVal->getType();
  23. const auto ctxType = ctx.Ctx->getType();
  24. const auto i32Type = Type::getInt32Ty(context);
  25. const auto valueNullptrVal = ConstantPointerNull::get(valuePtrType);
  26. const auto valuePtrNullptrVal = ConstantPointerNull::get(valuePtrsPtrType);
  27. const auto oneVal = ConstantInt::get(i32Type, static_cast<i32>(EFetchResult::One));
  28. const auto maybeResType = TMaybeFetchResult::LLVMType(context);
  29. const auto noneVal = TMaybeFetchResult::None().LLVMConst(context);
  30. const auto outputArrayVal = new AllocaInst(valueType, 0, ConstantInt::get(i32Type, OutWidth), "output_array", entryPos);
  31. const auto outputPtrsVal = new AllocaInst(valuePtrType, 0, ConstantInt::get(Type::getInt64Ty(context), OutWidth), "output_ptrs", entryPos);
  32. for (ui32 pos = 0; pos < OutWidth; pos++) {
  33. const auto posVal = ConstantInt::get(i32Type, pos);
  34. const auto arrayPtrVal = GetElementPtrInst::CreateInBounds(valueType, outputArrayVal, {posVal}, "array_ptr", entryPos);
  35. const auto ptrsPtrVal = GetElementPtrInst::CreateInBounds(valuePtrType, outputPtrsVal, {posVal}, "ptrs_ptr", entryPos);
  36. new StoreInst(arrayPtrVal, ptrsPtrVal, &ctx.Func->getEntryBlock().back());
  37. }
  38. auto block = genToBlock; // >>> start of main code chunk
  39. const auto stateVal = new LoadInst(valueType, statePtrVal, "state", block);
  40. BranchInst::Create(init, loop, IsInvalid(stateVal, block, context), block);
  41. block = init; // state initialization block:
  42. const auto initFuncType = FunctionType::get(Type::getVoidTy(context), {thisType, statePtrType, ctxType}, false);
  43. const auto initFuncRawVal = ConstantInt::get(Type::getInt64Ty(context), PtrTable.InitStateMethPtr);
  44. const auto initFuncVal = CastInst::Create(Instruction::IntToPtr, initFuncRawVal, PointerType::getUnqual(initFuncType), "init_func", block);
  45. CallInst::Create(initFuncType, initFuncVal, {thisVal, statePtrVal, ctx.Ctx}, "", block);
  46. BranchInst::Create(loop, block);
  47. block = loop; // loop head block: (prepare inputs and decide whether to calculate row or not)
  48. const auto generated = GenFetchProcess(statePtrVal, ctx, std::bind_front(GetNodeValues, SourceFlow), block);
  49. auto processResVal = generated.first;
  50. if (processResVal == nullptr) {
  51. const auto prepareFuncType = FunctionType::get(valuePtrsPtrType, {thisType, statePtrType, ctxType, valuePtrsPtrType}, false);
  52. const auto prepareFuncRawVal = ConstantInt::get(Type::getInt64Ty(context), PtrTable.PrepareInputMethPtr);
  53. const auto prepareFuncVal = CastInst::Create(Instruction::IntToPtr, prepareFuncRawVal, PointerType::getUnqual(prepareFuncType), "prepare_func", block);
  54. const auto inputPtrsVal = CallInst::Create(prepareFuncType, prepareFuncVal, {thisVal, statePtrVal, ctx.Ctx, outputPtrsVal}, "input_ptrs", block);
  55. const auto skipFetchCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, inputPtrsVal, valuePtrNullptrVal, "skip_fetch", block);
  56. BranchInst::Create(loopTail, loopFetch, skipFetchCond, block);
  57. block = loopFetch; // loop fetch chunk:
  58. const auto [fetchResVal, getters] = GetNodeValues(SourceFlow, ctx, block);
  59. const auto fetchResExtVal = new ZExtInst(fetchResVal, maybeResType, "res_ext", block);
  60. const auto skipCalcCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, fetchResVal, oneVal, "skip_calc", block);
  61. const auto fetchSourceBlock = block;
  62. BranchInst::Create(loopTail, loopCalc, skipCalcCond, block);
  63. block = loopCalc; // loop calc chunk: (calculate needed values in the row)
  64. for (ui32 pos = 0; pos < InWidth; pos++) {
  65. const auto stor = BasicBlock::Create(context, "stor", ctx.Func);
  66. const auto cont = BasicBlock::Create(context, "cont", ctx.Func);
  67. auto innerBlock = block; // >>> start of inner chunk (calculates and stores the value if needed)
  68. const auto posVal = ConstantInt::get(i32Type, pos);
  69. const auto inputPtrPtrVal = GetElementPtrInst::CreateInBounds(valuePtrType, inputPtrsVal, {posVal}, "input_ptr_ptr", innerBlock);
  70. const auto inputPtrVal = new LoadInst(valuePtrType, inputPtrPtrVal, "input_ptr", innerBlock);
  71. const auto isNullCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, inputPtrVal, valueNullptrVal, "is_null", innerBlock);
  72. BranchInst::Create(cont, stor, isNullCond, innerBlock);
  73. innerBlock = stor; // calculate & store chunk:
  74. new StoreInst(getters[pos](ctx, innerBlock), inputPtrVal, innerBlock);
  75. BranchInst::Create(cont, innerBlock);
  76. innerBlock = cont; // skip input value block:
  77. /* nothing here yet */
  78. block = innerBlock; // <<< end of inner chunk
  79. }
  80. const auto calcSourceBlock = block;
  81. BranchInst::Create(loopTail, block);
  82. block = loopTail; // loop tail block: (process row)
  83. const auto maybeFetchResVal = PHINode::Create(maybeResType, 2, "fetch_res", block);
  84. maybeFetchResVal->addIncoming(noneVal, loop);
  85. maybeFetchResVal->addIncoming(fetchResExtVal, fetchSourceBlock);
  86. maybeFetchResVal->addIncoming(fetchResExtVal, calcSourceBlock);
  87. const auto processFuncType = FunctionType::get(maybeResType, {thisType, statePtrType, ctxType, maybeResType, valuePtrsPtrType}, false);
  88. const auto processFuncRawVal = ConstantInt::get(Type::getInt64Ty(context), PtrTable.DoProcessMethPtr);
  89. const auto processFuncVal = CastInst::Create(Instruction::IntToPtr, processFuncRawVal, PointerType::getUnqual(processFuncType), "process_func", block);
  90. processResVal = CallInst::Create(processFuncType, processFuncVal, {thisVal, statePtrVal, ctx.Ctx, maybeFetchResVal, outputPtrsVal}, "process_res", block);
  91. } else {
  92. BranchInst::Create(loopFetch, loopFetch);
  93. BranchInst::Create(loopCalc, loopCalc);
  94. BranchInst::Create(loopTail, loopTail);
  95. }
  96. const auto brkCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, processResVal, noneVal, "brk", block);
  97. BranchInst::Create(done, loop, brkCond, block);
  98. block = done; // finalization block:
  99. const auto processResTruncVal = new TruncInst(processResVal, i32Type, "res_trunc", block);
  100. genToBlock = block; // <<< end of main code chunk
  101. if (generated.first) {
  102. return {processResTruncVal, generated.second};
  103. }
  104. ICodegeneratorInlineWideNode::TGettersList new_getters;
  105. new_getters.reserve(OutWidth);
  106. for (size_t pos = 0; pos < OutWidth; pos++) {
  107. new_getters.push_back([pos, outputArrayVal, i32Type, valueType] (const TCodegenContext&, BasicBlock*& block) -> Value* {
  108. const auto posVal = ConstantInt::get(i32Type, pos);
  109. const auto arrayPtrVal = GetElementPtrInst::CreateInBounds(valueType, outputArrayVal, {posVal}, "array_ptr", block);
  110. const auto valueVal = new LoadInst(valueType, arrayPtrVal, "value", block);
  111. return valueVal;
  112. });
  113. }
  114. return {processResTruncVal, std::move(new_getters)};
  115. }
  116. #endif
  117. }
  118. }