mkql_tooptional.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. #include "mkql_tooptional.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/mkql_node_cast.h>
  4. namespace NKikimr {
  5. namespace NMiniKQL {
  6. namespace {
  7. template <bool IsOptional>
  8. class THeadWrapper : public TMutableCodegeneratorPtrNode<THeadWrapper<IsOptional>> {
  9. typedef TMutableCodegeneratorPtrNode<THeadWrapper<IsOptional>> TBaseComputation;
  10. public:
  11. THeadWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* list)
  12. : TBaseComputation(mutables, kind), List(list)
  13. {}
  14. NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const {
  15. const auto& value = List->GetValue(ctx);
  16. if (const auto ptr = value.GetElements()) {
  17. if (value.GetListLength() > 0ULL) {
  18. return NUdf::TUnboxedValuePod(*ptr).MakeOptionalIf<IsOptional>();
  19. }
  20. } else if (const auto iter = value.GetListIterator()) {
  21. NUdf::TUnboxedValue result;
  22. if (iter.Next(result)) {
  23. return result.Release().MakeOptionalIf<IsOptional>();
  24. }
  25. }
  26. return NUdf::TUnboxedValue();
  27. }
  28. #ifndef MKQL_DISABLE_CODEGEN
  29. void DoGenerateGetValue(const TCodegenContext& ctx, Value* result, BasicBlock*& block) const {
  30. auto& context = ctx.Codegen.GetContext();
  31. const auto valueType = Type::getInt128Ty(context);
  32. const auto ptrType = PointerType::getUnqual(valueType);
  33. const auto list = GetNodeValue(List, ctx, block);
  34. const auto elements = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElements>(ptrType, list, ctx.Codegen, block);
  35. const auto null = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, elements, ConstantPointerNull::get(ptrType), "null", block);
  36. const auto fast = BasicBlock::Create(context, "fast", ctx.Func);
  37. const auto slow = BasicBlock::Create(context, "slow", ctx.Func);
  38. const auto many = BasicBlock::Create(context, "many", ctx.Func);
  39. const auto none = BasicBlock::Create(context, "none", ctx.Func);
  40. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  41. const auto good = IsOptional ? BasicBlock::Create(context, "good", ctx.Func) : done;
  42. BranchInst::Create(slow, fast, null, block);
  43. {
  44. block = fast;
  45. const auto size = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListLength>(Type::getInt64Ty(context), list, ctx.Codegen, block);
  46. const auto test = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, size, ConstantInt::get(size->getType(), 0ULL), "test", block);
  47. BranchInst::Create(many, none, test, block);
  48. block = many;
  49. const auto item = new LoadInst(valueType, elements, "item", block);
  50. ValueAddRef(this->GetRepresentation(), item, ctx, block);
  51. new StoreInst(IsOptional ? MakeOptional(context, item, block) : item, result, block);
  52. BranchInst::Create(done, block);
  53. }
  54. {
  55. block = slow;
  56. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListIterator>(result, list, ctx.Codegen, block);
  57. const auto iter = new LoadInst(valueType, result, "iter", block);
  58. new StoreInst(ConstantInt::get(valueType, 0ULL), result, block);
  59. const auto status = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(Type::getInt1Ty(context), iter, ctx.Codegen, block, result);
  60. UnRefBoxed(iter, ctx, block);
  61. BranchInst::Create(good, none, status, block);
  62. if constexpr (IsOptional) {
  63. block = good;
  64. const auto item = new LoadInst(valueType, result, "item", block);
  65. new StoreInst(MakeOptional(context, item, block), result, block);
  66. BranchInst::Create(done, block);
  67. }
  68. }
  69. block = none;
  70. new StoreInst(ConstantInt::get(valueType, 0ULL), result, block);
  71. BranchInst::Create(done, block);
  72. block = done;
  73. if (List->IsTemporaryValue())
  74. CleanupBoxed(list, ctx, block);
  75. }
  76. #endif
  77. private:
  78. void RegisterDependencies() const final {
  79. this->DependsOn(List);
  80. }
  81. IComputationNode* const List;
  82. };
  83. template <bool IsOptional>
  84. class TLastWrapper : public TMutableCodegeneratorPtrNode<TLastWrapper<IsOptional>> {
  85. typedef TMutableCodegeneratorPtrNode<TLastWrapper<IsOptional>> TBaseComputation;
  86. public:
  87. TLastWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* list)
  88. : TBaseComputation(mutables, kind), List(list)
  89. {}
  90. NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const {
  91. const auto& value = List->GetValue(ctx);
  92. if (const auto ptr = value.GetElements()) {
  93. if (const auto size = value.GetListLength()) {
  94. return NUdf::TUnboxedValuePod(ptr[size - 1U]).MakeOptionalIf<IsOptional>();
  95. }
  96. } else if (const auto iter = value.GetListIterator()) {
  97. NUdf::TUnboxedValue result;
  98. if (iter.Next(result)) {
  99. while (iter.Next(result)) continue;
  100. return result.Release().MakeOptionalIf<IsOptional>();
  101. }
  102. }
  103. return NUdf::TUnboxedValue();
  104. }
  105. #ifndef MKQL_DISABLE_CODEGEN
  106. void DoGenerateGetValue(const TCodegenContext& ctx, Value* result, BasicBlock*& block) const {
  107. auto& context = ctx.Codegen.GetContext();
  108. const auto valueType = Type::getInt128Ty(context);
  109. const auto ptrType = PointerType::getUnqual(valueType);
  110. const auto list = GetNodeValue(List, ctx, block);
  111. const auto elements = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElements>(ptrType, list, ctx.Codegen, block);
  112. const auto null = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, elements, ConstantPointerNull::get(ptrType), "null", block);
  113. const auto fast = BasicBlock::Create(context, "fast", ctx.Func);
  114. const auto slow = BasicBlock::Create(context, "slow", ctx.Func);
  115. const auto nope = BasicBlock::Create(context, "nope", ctx.Func);
  116. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  117. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  118. const auto many = BasicBlock::Create(context, "many", ctx.Func);
  119. const auto none = BasicBlock::Create(context, "none", ctx.Func);
  120. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  121. BranchInst::Create(slow, fast, null, block);
  122. {
  123. block = fast;
  124. const auto size = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListLength>(Type::getInt64Ty(context), list, ctx.Codegen, block);
  125. const auto test = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, size, ConstantInt::get(size->getType(), 0ULL), "test", block);
  126. BranchInst::Create(many, none, test, block);
  127. block = many;
  128. const auto index = BinaryOperator::CreateSub(size, ConstantInt::get(size->getType(), 1), "index", block);
  129. const auto last = GetElementPtrInst::CreateInBounds(valueType, elements, {index}, "last", block);
  130. const auto item = new LoadInst(valueType, last, "item", block);
  131. ValueAddRef(this->GetRepresentation(), item, ctx, block);
  132. new StoreInst(IsOptional ? MakeOptional(context, item, block) : item, result, block);
  133. BranchInst::Create(done, block);
  134. }
  135. {
  136. block = slow;
  137. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListIterator>(result, list, ctx.Codegen, block);
  138. const auto iter = new LoadInst(valueType, result, "iter", block);
  139. new StoreInst(ConstantInt::get(valueType, 0ULL), result, block);
  140. const auto first = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(Type::getInt1Ty(context), iter, ctx.Codegen, block, result);
  141. BranchInst::Create(loop, nope, first, block);
  142. block = nope;
  143. UnRefBoxed(iter, ctx, block);
  144. BranchInst::Create(none, block);
  145. block = loop;
  146. const auto status = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(Type::getInt1Ty(context), iter, ctx.Codegen, block, result);
  147. BranchInst::Create(loop, good, status, block);
  148. block = good;
  149. UnRefBoxed(iter, ctx, block);
  150. if constexpr (IsOptional) {
  151. const auto item = new LoadInst(valueType, result, "item", block);
  152. new StoreInst(MakeOptional(context, item, block), result, block);
  153. }
  154. BranchInst::Create(done, block);
  155. }
  156. block = none;
  157. new StoreInst(ConstantInt::get(valueType, 0ULL), result, block);
  158. BranchInst::Create(done, block);
  159. block = done;
  160. if (List->IsTemporaryValue())
  161. CleanupBoxed(list, ctx, block);
  162. }
  163. #endif
  164. private:
  165. void RegisterDependencies() const final {
  166. this->DependsOn(List);
  167. }
  168. IComputationNode* const List;
  169. };
  170. }
  171. IComputationNode* WrapHead(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  172. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args");
  173. if (AS_TYPE(TOptionalType, callable.GetType()->GetReturnType())->IsOptional()) {
  174. return new THeadWrapper<true>(ctx.Mutables, GetValueRepresentation(callable.GetType()->GetReturnType()), LocateNode(ctx.NodeLocator, callable, 0));
  175. } else {
  176. return new THeadWrapper<false>(ctx.Mutables, GetValueRepresentation(callable.GetType()->GetReturnType()), LocateNode(ctx.NodeLocator, callable, 0));
  177. }
  178. }
  179. IComputationNode* WrapLast(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  180. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args");
  181. if (AS_TYPE(TOptionalType, callable.GetType()->GetReturnType())->IsOptional()) {
  182. return new TLastWrapper<true>(ctx.Mutables, GetValueRepresentation(callable.GetType()->GetReturnType()), LocateNode(ctx.NodeLocator, callable, 0));
  183. } else {
  184. return new TLastWrapper<false>(ctx.Mutables, GetValueRepresentation(callable.GetType()->GetReturnType()), LocateNode(ctx.NodeLocator, callable, 0));
  185. }
  186. }
  187. }
  188. }