mkql_element.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. #include "mkql_element.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/mkql_node_cast.h>
  4. #include <yql/essentials/minikql/mkql_node_builder.h>
  5. namespace NKikimr {
  6. namespace NMiniKQL {
  7. namespace {
  8. template <bool IsOptional>
  9. class TElementsWrapper : public TMutableCodegeneratorNode<TElementsWrapper<IsOptional>> {
  10. typedef TMutableCodegeneratorNode<TElementsWrapper<IsOptional>> TBaseComputation;
  11. public:
  12. TElementsWrapper(TComputationMutables& mutables, IComputationNode* array)
  13. : TBaseComputation(mutables, EValueRepresentation::Embedded), Array(array)
  14. {}
  15. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
  16. const auto& array = Array->GetValue(compCtx);
  17. if constexpr (IsOptional) {
  18. return array ? NUdf::TUnboxedValuePod(reinterpret_cast<ui64>(array.GetElements())) : NUdf::TUnboxedValuePod();
  19. } else {
  20. return NUdf::TUnboxedValuePod(reinterpret_cast<ui64>(array.GetElements()));
  21. }
  22. }
  23. #ifndef MKQL_DISABLE_CODEGEN
  24. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  25. auto& context = ctx.Codegen.GetContext();
  26. const auto array = GetNodeValue(Array, ctx, block);
  27. const auto elementsType = PointerType::getUnqual(array->getType());
  28. if constexpr (IsOptional) {
  29. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  30. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  31. const auto result = PHINode::Create(array->getType(), 2U, "result", done);
  32. result->addIncoming(ConstantInt::get(array->getType(), 0ULL), block);
  33. BranchInst::Create(done, good, IsEmpty(array, block, context), block);
  34. block = good;
  35. const auto elements = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElements>(elementsType, array, ctx.Codegen, block);
  36. const auto cast = CastInst::Create(Instruction::PtrToInt, elements, Type::getInt64Ty(context), "cast", block);
  37. const auto wide = SetterFor<ui64>(cast, context, block);
  38. result->addIncoming(wide, block);
  39. BranchInst::Create(done, block);
  40. block = done;
  41. return result;
  42. } else {
  43. const auto elements = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElements>(elementsType, array, ctx.Codegen, block);
  44. const auto cast = CastInst::Create(Instruction::PtrToInt, elements, Type::getInt64Ty(context), "cast", block);
  45. return SetterFor<ui64>(cast, context, block);
  46. }
  47. }
  48. #endif
  49. private:
  50. void RegisterDependencies() const final {
  51. this->DependsOn(Array);
  52. }
  53. IComputationNode* const Array;
  54. };
  55. template <bool IsOptional>
  56. class TElementWrapper : public TMutableCodegeneratorPtrNode<TElementWrapper<IsOptional>> {
  57. typedef TMutableCodegeneratorPtrNode<TElementWrapper<IsOptional>> TBaseComputation;
  58. public:
  59. TElementWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* cache, IComputationNode* array, ui32 index)
  60. : TBaseComputation(mutables, kind), Cache(cache), Array(array), Index(index)
  61. {}
  62. NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const {
  63. if (Cache->GetDependencesCount() > 1U) {
  64. const auto cache = Cache->GetValue(ctx);
  65. if (IsOptional && !cache) {
  66. return NUdf::TUnboxedValue();
  67. }
  68. if (const auto elements = cache.Get<ui64>()) {
  69. return reinterpret_cast<const NUdf::TUnboxedValuePod*>(elements)[Index];
  70. }
  71. }
  72. const auto& array = Array->GetValue(ctx);
  73. if constexpr (IsOptional) {
  74. return array ? array.GetElement(Index) : NUdf::TUnboxedValue();
  75. } else {
  76. return array.GetElement(Index);
  77. }
  78. }
  79. #ifndef MKQL_DISABLE_CODEGEN
  80. void DoGenerateGetElement(const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const {
  81. auto& context = ctx.Codegen.GetContext();
  82. const auto array = GetNodeValue(Array, ctx, block);
  83. const auto index = ConstantInt::get(Type::getInt32Ty(context), Index);
  84. if constexpr (IsOptional) {
  85. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  86. const auto zero = BasicBlock::Create(context, "zero", ctx.Func);
  87. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  88. BranchInst::Create(zero, good, IsEmpty(array, block, context), block);
  89. block = zero;
  90. new StoreInst(ConstantInt::get(array->getType(), 0ULL), pointer, block);
  91. BranchInst::Create(exit, block);
  92. block = good;
  93. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElement>(pointer, array, ctx.Codegen, block, index);
  94. if (Array->IsTemporaryValue())
  95. CleanupBoxed(array, ctx, block);
  96. BranchInst::Create(exit, block);
  97. block = exit;
  98. } else {
  99. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElement>(pointer, array, ctx.Codegen, block, index);
  100. if (Array->IsTemporaryValue())
  101. CleanupBoxed(array, ctx, block);
  102. }
  103. }
  104. void DoGenerateGetValue(const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const {
  105. if (Cache->GetDependencesCount() <= 1U) {
  106. return DoGenerateGetElement(ctx, pointer, block);
  107. }
  108. auto& context = ctx.Codegen.GetContext();
  109. const auto cache = GetNodeValue(Cache, ctx, block);
  110. const auto fast = BasicBlock::Create(context, "fast", ctx.Func);
  111. const auto slow = BasicBlock::Create(context, "slow", ctx.Func);
  112. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  113. if constexpr (IsOptional) {
  114. const auto zero = ConstantInt::get(cache->getType(), 0ULL);
  115. const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, cache, zero, "check", block);
  116. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  117. const auto none = BasicBlock::Create(context, "none", ctx.Func);
  118. BranchInst::Create(none, good, check, block);
  119. block = none;
  120. new StoreInst(zero, pointer, block);
  121. BranchInst::Create(done, block);
  122. block = good;
  123. }
  124. const auto trunc = CastInst::Create(Instruction::Trunc, cache, Type::getInt64Ty(context), "trunc", block);
  125. const auto type = PointerType::getUnqual(cache->getType());
  126. const auto elements = CastInst::Create(Instruction::IntToPtr, trunc, type, "elements", block);
  127. const auto fill = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, elements, ConstantPointerNull::get(type), "fill", block);
  128. BranchInst::Create(fast, slow, fill, block);
  129. block = fast;
  130. const auto index = ConstantInt::get(Type::getInt32Ty(context), this->Index);
  131. const auto ptr = GetElementPtrInst::CreateInBounds(cache->getType(), elements, {index}, "ptr", block);
  132. const auto item = new LoadInst(cache->getType(), ptr, "item", block);
  133. ValueAddRef(this->GetRepresentation(), item, ctx, block);
  134. new StoreInst(item, pointer, block);
  135. BranchInst::Create(done, block);
  136. block = slow;
  137. DoGenerateGetElement(ctx, pointer, block);
  138. BranchInst::Create(done, block);
  139. block = done;
  140. }
  141. #endif
  142. private:
  143. void RegisterDependencies() const final {
  144. this->DependsOn(Array);
  145. this->DependsOn(Cache);
  146. }
  147. IComputationNode *const Cache;
  148. IComputationNode *const Array;
  149. const ui32 Index;
  150. };
  151. IComputationNode* WrapElements(IComputationNode* array, const TComputationNodeFactoryContext& ctx, bool isOptional) {
  152. if (isOptional) {
  153. return new TElementsWrapper<true>(ctx.Mutables, array);
  154. } else {
  155. return new TElementsWrapper<false>(ctx.Mutables, array);
  156. }
  157. }
  158. }
  159. IComputationNode* WrapNth(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  160. MKQL_ENSURE(callable.GetInputsCount() == 2U, "Expected two args.");
  161. const auto input = callable.GetInput(0U);
  162. bool isOptional;
  163. const auto tupleType = AS_TYPE(TTupleType, UnpackOptional(input.GetStaticType(), isOptional));
  164. const auto indexData = AS_VALUE(TDataLiteral, callable.GetInput(1U));
  165. const auto index = indexData->AsValue().Get<ui32>();
  166. MKQL_ENSURE(index < tupleType->GetElementsCount(), "Bad tuple index");
  167. const auto tuple = LocateNode(ctx.NodeLocator, callable, 0);
  168. const auto ins = ctx.ElementsCache.emplace(tuple, nullptr);
  169. if (ins.second) {
  170. ctx.NodePushBack(ins.first->second = WrapElements(tuple, ctx, isOptional));
  171. }
  172. if (isOptional) {
  173. return new TElementWrapper<true>(ctx.Mutables, GetValueRepresentation(tupleType->GetElementType(index)), ins.first->second, tuple, index);
  174. } else {
  175. return new TElementWrapper<false>(ctx.Mutables, GetValueRepresentation(tupleType->GetElementType(index)), ins.first->second, tuple, index);
  176. }
  177. }
  178. IComputationNode* WrapMember(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  179. MKQL_ENSURE(callable.GetInputsCount() == 2U, "Expected two args.");
  180. const auto input = callable.GetInput(0U);
  181. bool isOptional;
  182. const auto structType = AS_TYPE(TStructType, UnpackOptional(input.GetStaticType(), isOptional));
  183. const auto indexData = AS_VALUE(TDataLiteral, callable.GetInput(1U));
  184. const auto index = indexData->AsValue().Get<ui32>();
  185. MKQL_ENSURE(index < structType->GetMembersCount(), "Bad member index");
  186. const auto structObj = LocateNode(ctx.NodeLocator, callable, 0U);
  187. const auto ins = ctx.ElementsCache.emplace(structObj, nullptr);
  188. if (ins.second) {
  189. ctx.NodePushBack(ins.first->second = WrapElements(structObj, ctx, isOptional));
  190. }
  191. if (isOptional) {
  192. return new TElementWrapper<true>(ctx.Mutables, GetValueRepresentation(structType->GetMemberType(index)), ins.first->second, structObj, index);
  193. } else {
  194. return new TElementWrapper<false>(ctx.Mutables, GetValueRepresentation(structType->GetMemberType(index)), ins.first->second, structObj, index);
  195. }
  196. }
  197. }
  198. }