mkql_way.cpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #include "mkql_way.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  4. #include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
  5. #include <yql/essentials/minikql/mkql_node_cast.h>
  6. #include <yql/essentials/minikql/mkql_node_builder.h>
  7. #include <yql/essentials/minikql/mkql_string_util.h>
  8. #include <util/string/cast.h>
  9. namespace NKikimr {
  10. namespace NMiniKQL {
  11. namespace {
  12. template <bool IsOptional>
  13. class TWayWrapper: public TMutableCodegeneratorNode<TWayWrapper<IsOptional>> {
  14. typedef TMutableCodegeneratorNode<TWayWrapper<IsOptional>> TBaseComputation;
  15. public:
  16. TWayWrapper(TComputationMutables& mutables, IComputationNode* varNode, EValueRepresentation kind, TComputationNodePtrVector&& literals)
  17. : TBaseComputation(mutables, kind)
  18. , VarNode(varNode)
  19. , Literals(std::move(literals))
  20. {}
  21. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  22. const auto& var = VarNode->GetValue(ctx);
  23. if (IsOptional && !var) {
  24. return NUdf::TUnboxedValuePod();
  25. }
  26. const ui32 index = var.GetVariantIndex();
  27. return Literals[index]->GetValue(ctx).Release();
  28. }
  29. #ifndef MKQL_DISABLE_CODEGEN
  30. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  31. auto& context = ctx.Codegen.GetContext();
  32. const auto valueType = Type::getInt128Ty(context);
  33. const auto indexType = Type::getInt32Ty(context);
  34. const auto var = GetNodeValue(VarNode, ctx, block);
  35. const auto zero = ConstantInt::get(valueType, 0ULL);
  36. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  37. const auto result = PHINode::Create(valueType, Literals.size() + IsOptional ? 2U : 1U, "result", done);
  38. if (IsOptional) {
  39. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  40. BranchInst::Create(done, good, IsEmpty(var, block, context), block);
  41. result->addIncoming(zero, block);
  42. block = good;
  43. }
  44. const auto lshr = BinaryOperator::CreateLShr(var, ConstantInt::get(valueType, 122), "lshr", block);
  45. const auto trunc = CastInst::Create(Instruction::Trunc, lshr, indexType, "trunc", block);
  46. const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, trunc, ConstantInt::get(indexType , 0), "check", block);
  47. const auto boxed = BasicBlock::Create(context, "boxed", ctx.Func);
  48. const auto embed = BasicBlock::Create(context, "embed", ctx.Func);
  49. const auto step = BasicBlock::Create(context, "step", ctx.Func);
  50. const auto index = PHINode::Create(indexType, 2, "index", step);
  51. BranchInst::Create(embed, boxed, check, block);
  52. block = embed;
  53. const auto dec = BinaryOperator::CreateSub(trunc, ConstantInt::get(indexType, 1), "dec", block);
  54. index->addIncoming(dec, block);
  55. BranchInst::Create(step, block);
  56. block = boxed;
  57. const auto idx = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetVariantIndex>(indexType, var, ctx.Codegen, block);
  58. index->addIncoming(idx, block);
  59. BranchInst::Create(step, block);
  60. block = step;
  61. const auto choise = SwitchInst::Create(index, done, Literals.size(), block);
  62. result->addIncoming(zero, block);
  63. for (ui32 i = 0; i < Literals.size(); ++i) {
  64. const auto var = BasicBlock::Create(context, (TString("case_") += ToString(i)).c_str(), ctx.Func);
  65. choise->addCase(ConstantInt::get(Type::getInt32Ty(context), i), var);
  66. block = var;
  67. const auto way = GetNodeValue(Literals[i], ctx, block);
  68. result->addIncoming(way, block);
  69. BranchInst::Create(done, block);
  70. }
  71. block = done;
  72. return result;
  73. }
  74. #endif
  75. private:
  76. void RegisterDependencies() const final {
  77. this->DependsOn(VarNode);
  78. std::for_each(Literals.cbegin(), Literals.cend(),std::bind(&TWayWrapper<IsOptional>::DependsOn, this, std::placeholders::_1));
  79. }
  80. IComputationNode *const VarNode;
  81. const TComputationNodePtrVector Literals;
  82. };
  83. }
  84. IComputationNode* WrapWay(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  85. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 argument");
  86. bool isOptional;
  87. const auto unpacked = UnpackOptional(callable.GetInput(0), isOptional);
  88. const auto varType = AS_TYPE(TVariantType, unpacked);
  89. const auto structType = varType->GetUnderlyingType()->IsTuple() ? nullptr : AS_TYPE(TStructType, varType->GetUnderlyingType());
  90. const auto size = varType->GetAlternativesCount();
  91. TComputationNodePtrVector literals(size);
  92. EValueRepresentation kind = EValueRepresentation::Embedded;
  93. for (ui32 idx = 0U; idx < size; ++idx) {
  94. const auto node = literals[idx] = ctx.NodeFactory.CreateImmutableNode(structType ? MakeString(structType->GetMemberName(idx)) : NUdf::TUnboxedValuePod(idx));
  95. ctx.NodePushBack(node);
  96. if (node->GetRepresentation() != EValueRepresentation::Embedded) {
  97. kind = EValueRepresentation::Any;
  98. }
  99. }
  100. const auto variant = LocateNode(ctx.NodeLocator, callable, 0);
  101. if (isOptional) {
  102. return new TWayWrapper<true>(ctx.Mutables, variant, kind, std::move(literals));
  103. } else {
  104. return new TWayWrapper<false>(ctx.Mutables, variant, kind, std::move(literals));
  105. }
  106. }
  107. }
  108. }