mkql_removemember.cpp 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. #include "mkql_removemember.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  4. #include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
  5. #include <yql/essentials/minikql/mkql_node_cast.h>
  6. namespace NKikimr {
  7. namespace NMiniKQL {
  8. namespace {
  9. class TRemoveMemberWrapper : public TMutableCodegeneratorFallbackNode<TRemoveMemberWrapper> {
  10. typedef TMutableCodegeneratorFallbackNode<TRemoveMemberWrapper> TBaseComputation;
  11. public:
  12. TRemoveMemberWrapper(TComputationMutables& mutables, IComputationNode* structObj, ui32 index, std::vector<EValueRepresentation>&& representations)
  13. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  14. , StructObj(structObj)
  15. , Index(index)
  16. , Representations(std::move(representations))
  17. , Cache(mutables)
  18. {}
  19. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  20. const auto& baseStruct = StructObj->GetValue(ctx);
  21. NUdf::TUnboxedValue* itemsPtr = nullptr;
  22. const auto result = Cache.NewArray(ctx, Representations.size() - 1U, itemsPtr);
  23. if (Representations.size() > 1) {
  24. Y_ABORT_UNLESS(itemsPtr);
  25. if (const auto ptr = baseStruct.GetElements()) {
  26. for (ui32 i = 0; i < Index; ++i) {
  27. *itemsPtr++ = ptr[i];
  28. }
  29. for (ui32 i = Index + 1; i < Representations.size(); ++i) {
  30. *itemsPtr++ = ptr[i];
  31. }
  32. } else {
  33. for (ui32 i = 0; i < Index; ++i) {
  34. *itemsPtr++ = baseStruct.GetElement(i);
  35. }
  36. for (ui32 i = Index + 1; i < Representations.size(); ++i) {
  37. *itemsPtr++ = baseStruct.GetElement(i);
  38. }
  39. }
  40. }
  41. return result;
  42. }
  43. #ifndef MKQL_DISABLE_CODEGEN
  44. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  45. if (Representations.size() > CodegenArraysFallbackLimit)
  46. return TBaseComputation::DoGenerateGetValue(ctx, block);
  47. auto& context = ctx.Codegen.GetContext();
  48. const auto newSize = Representations.size() - 1U;
  49. const auto valType = Type::getInt128Ty(context);
  50. const auto ptrType = PointerType::getUnqual(valType);
  51. const auto idxType = Type::getInt32Ty(context);
  52. const auto type = ArrayType::get(valType, newSize);
  53. const auto itmsType = PointerType::getUnqual(type);
  54. const auto itms = *Stateless || ctx.AlwaysInline ?
  55. new AllocaInst(itmsType, 0U, "itms", &ctx.Func->getEntryBlock().back()):
  56. new AllocaInst(itmsType, 0U, "itms", block);
  57. const auto result = Cache.GenNewArray(newSize, itms, ctx, block);
  58. const auto itemsPtr = new LoadInst(itmsType, itms, "items", block);
  59. const auto array = GetNodeValue(StructObj, ctx, block);
  60. const auto zero = ConstantInt::get(idxType, 0);
  61. const auto elements = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElements>(ptrType, array, ctx.Codegen, block);
  62. const auto null = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, elements, ConstantPointerNull::get(ptrType), "null", block);
  63. const auto fast = BasicBlock::Create(context, "fast", ctx.Func);
  64. const auto slow = BasicBlock::Create(context, "slow", ctx.Func);
  65. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  66. BranchInst::Create(slow, fast, null, block);
  67. {
  68. block = fast;
  69. for (ui32 i = 0; i < Index; ++i) {
  70. const auto index = ConstantInt::get(idxType, i);
  71. const auto srcPtr = GetElementPtrInst::CreateInBounds(valType, elements, {index}, "src", block);
  72. const auto dstPtr = GetElementPtrInst::CreateInBounds(type, itemsPtr, {zero, index}, "dst", block);
  73. const auto item = new LoadInst(valType, srcPtr, "item", block);
  74. new StoreInst(item, dstPtr, block);
  75. ValueAddRef(Representations[i], dstPtr, ctx, block);
  76. }
  77. for (ui32 i = Index + 1U; i < Representations.size(); ++i) {
  78. const auto oldIndex = ConstantInt::get(idxType, i);
  79. const auto newIndex = ConstantInt::get(idxType, i - 1U);
  80. const auto srcPtr = GetElementPtrInst::CreateInBounds(valType, elements, {oldIndex}, "src", block);
  81. const auto dstPtr = GetElementPtrInst::CreateInBounds(type, itemsPtr, {zero, newIndex}, "dst", block);
  82. const auto item = new LoadInst(valType, srcPtr, "item", block);
  83. new StoreInst(item, dstPtr, block);
  84. ValueAddRef(Representations[i], dstPtr, ctx, block);
  85. }
  86. BranchInst::Create(done, block);
  87. }
  88. {
  89. block = slow;
  90. for (ui32 i = 0; i < Index; ++i) {
  91. const auto index = ConstantInt::get(idxType, i);
  92. const auto itemPtr = GetElementPtrInst::CreateInBounds(type, itemsPtr, {zero, index}, "item", block);
  93. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElement>(itemPtr, array, ctx.Codegen, block, index);
  94. }
  95. for (ui32 i = Index + 1U; i < Representations.size(); ++i) {
  96. const auto oldIndex = ConstantInt::get(idxType, i);
  97. const auto newIndex = ConstantInt::get(idxType, i - 1U);
  98. const auto itemPtr = GetElementPtrInst::CreateInBounds(type, itemsPtr, {zero, newIndex}, "item", block);
  99. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElement>(itemPtr, array, ctx.Codegen, block, oldIndex);
  100. }
  101. BranchInst::Create(done, block);
  102. }
  103. block = done;
  104. if (StructObj->IsTemporaryValue())
  105. CleanupBoxed(array, ctx, block);
  106. return result;
  107. }
  108. #endif
  109. private:
  110. void RegisterDependencies() const final {
  111. DependsOn(StructObj);
  112. }
  113. IComputationNode* const StructObj;
  114. const ui32 Index;
  115. const std::vector<EValueRepresentation> Representations;
  116. const TContainerCacheOnContext Cache;
  117. };
  118. }
  119. IComputationNode* WrapRemoveMember(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  120. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
  121. const auto structType = AS_TYPE(TStructType, callable.GetInput(0));
  122. const auto indexData = AS_VALUE(TDataLiteral, callable.GetInput(1));
  123. const ui32 index = indexData->AsValue().Get<ui32>();
  124. MKQL_ENSURE(index < structType->GetMembersCount(), "Bad member index");
  125. std::vector<EValueRepresentation> representations;
  126. representations.reserve(structType->GetMembersCount());
  127. for (ui32 i = 0; i < structType->GetMembersCount(); ++i) {
  128. representations.emplace_back(GetValueRepresentation(structType->GetMemberType(i)));
  129. }
  130. const auto structObj = LocateNode(ctx.NodeLocator, callable, 0);
  131. return new TRemoveMemberWrapper(ctx.Mutables, structObj, index, std::move(representations));
  132. }
  133. }
  134. }