mkql_guess.cpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. #include "mkql_guess.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/mkql_node_cast.h>
  4. #include <yql/essentials/minikql/mkql_node_builder.h>
  5. namespace NKikimr {
  6. namespace NMiniKQL {
  7. namespace {
  8. template <bool IsOptional>
  9. class TGuessWrapper: public TMutableCodegeneratorPtrNode<TGuessWrapper<IsOptional>> {
  10. typedef TMutableCodegeneratorPtrNode<TGuessWrapper<IsOptional>> TBaseComputation;
  11. public:
  12. TGuessWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* varNode, ui32 index)
  13. : TBaseComputation(mutables, kind)
  14. , VarNode(varNode)
  15. , Index(index)
  16. {
  17. }
  18. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
  19. auto var = VarNode->GetValue(compCtx);
  20. if (IsOptional && !var) {
  21. return NUdf::TUnboxedValuePod();
  22. }
  23. const auto currentIndex = var.GetVariantIndex();
  24. if (Index == currentIndex) {
  25. return var.Release().GetVariantItem().MakeOptional();
  26. } else {
  27. return NUdf::TUnboxedValuePod();
  28. }
  29. }
  30. #ifndef MKQL_DISABLE_CODEGEN
  31. void DoGenerateGetValue(const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const {
  32. auto& context = ctx.Codegen.GetContext();
  33. const auto valueType = Type::getInt128Ty(context);
  34. const auto indexType = Type::getInt32Ty(context);
  35. const auto var = GetNodeValue(VarNode, ctx, block);
  36. const auto ind = ConstantInt::get(indexType, Index);
  37. const auto zero = ConstantInt::get(valueType, 0ULL);
  38. const auto none = BasicBlock::Create(context, "none", ctx.Func);
  39. if constexpr (IsOptional) {
  40. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  41. BranchInst::Create(none, good, IsEmpty(var, block, context), block);
  42. block = good;
  43. }
  44. const auto lshr = BinaryOperator::CreateLShr(var, ConstantInt::get(valueType, 122), "lshr", block);
  45. const auto trunc = CastInst::Create(Instruction::Trunc, lshr, indexType, "trunc", block);
  46. const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, trunc, ConstantInt::get(indexType , 0), "check", block);
  47. const auto boxed = BasicBlock::Create(context, "boxed", ctx.Func);
  48. const auto embed = BasicBlock::Create(context, "embed", ctx.Func);
  49. const auto step = BasicBlock::Create(context, "step", ctx.Func);
  50. const auto index = PHINode::Create(indexType, 2, "index", step);
  51. BranchInst::Create(embed, boxed, check, block);
  52. block = embed;
  53. const auto dec = BinaryOperator::CreateSub(trunc, ConstantInt::get(indexType, 1), "dec", block);
  54. index->addIncoming(dec, block);
  55. BranchInst::Create(step, block);
  56. block = boxed;
  57. const auto idx = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetVariantIndex>(indexType, var, ctx.Codegen, block);
  58. index->addIncoming(idx, block);
  59. BranchInst::Create(step, block);
  60. block = step;
  61. const auto equal = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, index, ind, "equal", block);
  62. const auto same = BasicBlock::Create(context, "same", ctx.Func);
  63. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  64. BranchInst::Create(same, none, equal, block);
  65. block = none;
  66. new StoreInst(zero, pointer, block);
  67. BranchInst::Create(done, block);
  68. block = same;
  69. const auto box = BasicBlock::Create(context, "box", ctx.Func);
  70. const auto emb = BasicBlock::Create(context, "emb", ctx.Func);
  71. BranchInst::Create(emb, box, check, block);
  72. block = emb;
  73. const uint64_t init[] = {0xFFFFFFFFFFFFFFFFULL, 0x3FFFFFFFFFFFFFFULL};
  74. const auto mask = ConstantInt::get(valueType, APInt(128, 2, init));
  75. const auto clean = BinaryOperator::CreateAnd(var, mask, "clean", block);
  76. new StoreInst(MakeOptional(context, clean, block), pointer, block);
  77. ValueAddRef(this->RepresentationKind, pointer, ctx, block);
  78. BranchInst::Create(done, block);
  79. block = box;
  80. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetVariantItem>(pointer, var, ctx.Codegen, block);
  81. const auto load = new LoadInst(valueType, pointer, "load", block);
  82. new StoreInst(MakeOptional(context, load, block), pointer, block);
  83. BranchInst::Create(done, block);
  84. block = done;
  85. }
  86. #endif
  87. private:
  88. void RegisterDependencies() const final {
  89. this->DependsOn(VarNode);
  90. }
  91. IComputationNode *const VarNode;
  92. const ui32 Index;
  93. };
  94. }
  95. IComputationNode* WrapGuess(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  96. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 arguments");
  97. bool isOptional;
  98. const auto unpacked = UnpackOptional(callable.GetInput(0), isOptional);
  99. const auto varType = AS_TYPE(TVariantType, unpacked);
  100. const auto indexData = AS_VALUE(TDataLiteral, callable.GetInput(1));
  101. const ui32 index = indexData->AsValue().Get<ui32>();
  102. MKQL_ENSURE(index < varType->GetAlternativesCount(), "Bad alternative index");
  103. const auto variant = LocateNode(ctx.NodeLocator, callable, 0);
  104. if (isOptional) {
  105. return new TGuessWrapper<true>(ctx.Mutables, GetValueRepresentation(varType->GetAlternativeType(index)), variant, index);
  106. } else {
  107. return new TGuessWrapper<false>(ctx.Mutables, GetValueRepresentation(varType->GetAlternativeType(index)), variant, index);
  108. }
  109. }
  110. }
  111. }