mkql_addmember.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #include "mkql_addmember.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  4. #include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
  5. #include <yql/essentials/minikql/mkql_node_cast.h>
  6. namespace NKikimr {
  7. namespace NMiniKQL {
  8. namespace {
  9. class TAddMemberWrapper : public TMutableCodegeneratorFallbackNode<TAddMemberWrapper> {
  10. typedef TMutableCodegeneratorFallbackNode<TAddMemberWrapper> TBaseComputation;
  11. public:
  12. TAddMemberWrapper(TComputationMutables& mutables, IComputationNode* structObj, IComputationNode* member, ui32 index,
  13. std::vector<EValueRepresentation>&& representations)
  14. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  15. , StructObj_(structObj)
  16. , Member_(member)
  17. , Index_(index)
  18. , Representations_(std::move(representations))
  19. , Cache_(mutables)
  20. {}
  21. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  22. const auto& baseStruct = StructObj_->GetValue(ctx);
  23. NUdf::TUnboxedValue* itemsPtr = nullptr;
  24. const auto result = Cache_.NewArray(ctx, Representations_.size() + 1U, itemsPtr);
  25. if (const auto ptr = baseStruct.GetElements()) {
  26. for (ui32 i = 0; i < Index_; ++i) {
  27. *itemsPtr++ = ptr[i];
  28. }
  29. *itemsPtr++ = Member_->GetValue(ctx);
  30. for (ui32 i = Index_; i < Representations_.size(); ++i) {
  31. *itemsPtr++ = ptr[i];
  32. }
  33. } else {
  34. for (ui32 i = 0; i < Index_; ++i) {
  35. *itemsPtr++ = baseStruct.GetElement(i);
  36. }
  37. *itemsPtr++ = Member_->GetValue(ctx);
  38. for (ui32 i = Index_; i < Representations_.size(); ++i) {
  39. *itemsPtr++ = baseStruct.GetElement(i);
  40. }
  41. }
  42. return result;
  43. }
  44. #ifndef MKQL_DISABLE_CODEGEN
  45. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  46. if (Representations_.size() > CodegenArraysFallbackLimit)
  47. return TBaseComputation::DoGenerateGetValue(ctx, block);
  48. auto& context = ctx.Codegen.GetContext();
  49. const auto newSize = Representations_.size() + 1U;
  50. const auto valType = Type::getInt128Ty(context);
  51. const auto ptrType = PointerType::getUnqual(valType);
  52. const auto idxType = Type::getInt32Ty(context);
  53. const auto type = ArrayType::get(valType, newSize);
  54. const auto itmsType = PointerType::getUnqual(type);
  55. const auto itms = *Stateless || ctx.AlwaysInline ?
  56. new AllocaInst(itmsType, 0U, "itms", &ctx.Func->getEntryBlock().back()):
  57. new AllocaInst(itmsType, 0U, "itms", block);
  58. const auto result = Cache_.GenNewArray(newSize, itms, ctx, block);
  59. const auto itemsPtr = new LoadInst(itmsType, itms, "items", block);
  60. const auto array = GetNodeValue(StructObj_, ctx, block);
  61. const auto zero = ConstantInt::get(idxType, 0);
  62. const auto itemPtr = GetElementPtrInst::CreateInBounds(type, itemsPtr, {zero, ConstantInt::get(idxType, Index_)}, "item", block);
  63. GetNodeValue(itemPtr, Member_, ctx, block);
  64. const auto elements = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElements>(ptrType, array, ctx.Codegen, block);
  65. const auto null = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, elements, ConstantPointerNull::get(ptrType), "null", block);
  66. const auto fast = BasicBlock::Create(context, "fast", ctx.Func);
  67. const auto slow = BasicBlock::Create(context, "slow", ctx.Func);
  68. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  69. BranchInst::Create(slow, fast, null, block);
  70. {
  71. block = fast;
  72. for (ui32 i = 0; i < Index_; ++i) {
  73. const auto index = ConstantInt::get(idxType, i);
  74. const auto srcPtr = GetElementPtrInst::CreateInBounds(valType, elements, {index}, "src", block);
  75. const auto dstPtr = GetElementPtrInst::CreateInBounds(type, itemsPtr, {zero, index}, "dst", block);
  76. const auto item = new LoadInst(valType, srcPtr, "item", block);
  77. new StoreInst(item, dstPtr, block);
  78. ValueAddRef(Representations_[i], dstPtr, ctx, block);
  79. }
  80. for (ui32 i = Index_ + 1U; i < newSize; ++i) {
  81. const auto oldIndex = ConstantInt::get(idxType, --i);
  82. const auto newIndex = ConstantInt::get(idxType, ++i);
  83. const auto srcPtr = GetElementPtrInst::CreateInBounds(valType, elements, {oldIndex}, "src", block);
  84. const auto dstPtr = GetElementPtrInst::CreateInBounds(type, itemsPtr, {zero, newIndex}, "dst", block);
  85. const auto item = new LoadInst(valType, srcPtr, "item", block);
  86. new StoreInst(item, dstPtr, block);
  87. ValueAddRef(Representations_[i - 1U], dstPtr, ctx, block);
  88. }
  89. BranchInst::Create(done, block);
  90. }
  91. {
  92. block = slow;
  93. for (ui32 i = 0; i < Index_; ++i) {
  94. const auto index = ConstantInt::get(idxType, i);
  95. const auto itemPtr = GetElementPtrInst::CreateInBounds(type, itemsPtr, {zero, index}, "item", block);
  96. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElement>(itemPtr, array, ctx.Codegen, block, index);
  97. }
  98. for (ui32 i = Index_ + 1U; i < newSize; ++i) {
  99. const auto oldIndex = ConstantInt::get(idxType, --i);
  100. const auto newIndex = ConstantInt::get(idxType, ++i);
  101. const auto itemPtr = GetElementPtrInst::CreateInBounds(type, itemsPtr, {zero, newIndex}, "item", block);
  102. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElement>(itemPtr, array, ctx.Codegen, block, oldIndex);
  103. }
  104. BranchInst::Create(done, block);
  105. }
  106. block = done;
  107. if (StructObj_->IsTemporaryValue())
  108. CleanupBoxed(array, ctx, block);
  109. return result;
  110. }
  111. #endif
  112. private:
  113. void RegisterDependencies() const final {
  114. DependsOn(StructObj_);
  115. DependsOn(Member_);
  116. }
  117. IComputationNode* const StructObj_;
  118. IComputationNode* const Member_;
  119. const ui32 Index_;
  120. const std::vector<EValueRepresentation> Representations_;
  121. const TContainerCacheOnContext Cache_;
  122. };
  123. }
  124. IComputationNode* AddMember(const TComputationNodeFactoryContext& ctx, TRuntimeNode structData, TRuntimeNode memberData, TRuntimeNode indexData) {
  125. const auto structType = AS_TYPE(TStructType, structData);
  126. const ui32 index = AS_VALUE(TDataLiteral, indexData)->AsValue().Get<ui32>();
  127. MKQL_ENSURE(index <= structType->GetMembersCount(), "Bad member index");
  128. std::vector<EValueRepresentation> representations;
  129. representations.reserve(structType->GetMembersCount());
  130. for (ui32 i = 0; i < structType->GetMembersCount(); ++i) {
  131. representations.emplace_back(GetValueRepresentation(structType->GetMemberType(i)));
  132. }
  133. const auto structObj = LocateNode(ctx.NodeLocator, *structData.GetNode());
  134. const auto member = LocateNode(ctx.NodeLocator, *memberData.GetNode());
  135. return new TAddMemberWrapper(ctx.Mutables, structObj, member, index, std::move(representations));
  136. }
  137. }
  138. }