mkql_builtins_div.cpp 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. #include "mkql_builtins_impl.h" // Y_IGNORE
  2. #include "mkql_builtins_datetime.h"
  3. #include <yql/essentials/minikql/mkql_type_ops.h>
  4. namespace NKikimr {
  5. namespace NMiniKQL {
  6. namespace {
  7. template<typename TLeft, typename TRight, typename TOutput>
  8. struct TDiv : public TSimpleArithmeticBinary<TLeft, TRight, TOutput, TDiv<TLeft, TRight, TOutput>> {
  9. static_assert(std::is_floating_point<TOutput>::value, "expected floating point");
  10. static constexpr auto NullMode = TKernel::ENullMode::Default;
  11. static TOutput Do(TOutput left, TOutput right)
  12. {
  13. return left / right;
  14. }
  15. #ifndef MKQL_DISABLE_CODEGEN
  16. static Value* Gen(Value* left, Value* right, const TCodegenContext&, BasicBlock*& block)
  17. {
  18. return BinaryOperator::CreateFDiv(left, right, "div", block);
  19. }
  20. #endif
  21. };
  22. template <typename TLeft, typename TRight, typename TOutput>
  23. struct TIntegralDiv {
  24. static_assert(std::is_integral<TOutput>::value, "integral type expected");
  25. static constexpr auto NullMode = TKernel::ENullMode::AlwaysNull;
  26. static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right)
  27. {
  28. const auto lv = static_cast<TOutput>(left.template Get<TLeft>());
  29. const auto rv = static_cast<TOutput>(right.template Get<TRight>());
  30. if (rv == 0 ||
  31. (std::is_signed<TOutput>::value && sizeof(TOutput) <= sizeof(TLeft) && rv == TOutput(-1) && lv == Min<TOutput>()))
  32. {
  33. return NUdf::TUnboxedValuePod();
  34. }
  35. return NUdf::TUnboxedValuePod(lv / rv);
  36. }
  37. #ifndef MKQL_DISABLE_CODEGEN
  38. static Value* Generate(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block)
  39. {
  40. auto& context = ctx.Codegen.GetContext();
  41. const auto lv = StaticCast<TLeft, TOutput>(GetterFor<TLeft>(left, context, block), context, block);
  42. const auto rv = StaticCast<TRight, TOutput>(GetterFor<TRight>(right, context, block), context, block);
  43. const auto type = Type::getInt128Ty(context);
  44. const auto zero = ConstantInt::get(type, 0);
  45. const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, rv, ConstantInt::get(rv->getType(), 0), "check", block);
  46. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  47. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  48. const auto result = PHINode::Create(type, 2, "result", done);
  49. result->addIncoming(zero, block);
  50. if constexpr (std::is_signed<TOutput>() && sizeof(TOutput) <= sizeof(TLeft)) {
  51. const auto min = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, lv, ConstantInt::get(lv->getType(), Min<TOutput>()), "min", block);
  52. const auto one = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, rv, ConstantInt::get(rv->getType(), -1), "one", block);
  53. const auto two = BinaryOperator::CreateAnd(min, one, "two", block);
  54. const auto all = BinaryOperator::CreateOr(check, two, "all", block);
  55. BranchInst::Create(done, good, all, block);
  56. } else {
  57. BranchInst::Create(done, good, check, block);
  58. }
  59. block = good;
  60. const auto div = std::is_signed<TOutput>() ? BinaryOperator::CreateSDiv(lv, rv, "div", block) : BinaryOperator::CreateUDiv(lv, rv, "div", block);
  61. const auto full = SetterFor<TOutput>(div, context, block);
  62. result->addIncoming(full, block);
  63. BranchInst::Create(done, block);
  64. block = done;
  65. return result;
  66. }
  67. #endif
  68. };
  69. template <typename TLeft, typename TRight, typename TOutput>
  70. struct TNumDivInterval {
  71. static_assert(TLeft::Features & NYql::NUdf::TimeIntervalType, "Left must be interval type");
  72. static_assert(TRight::Features & NYql::NUdf::IntegralType, "Right must be integral type");
  73. static_assert(TOutput::Features & NYql::NUdf::TimeIntervalType, "Output must be interval type");
  74. static_assert(std::is_same_v<typename TOutput::TLayout, i64>, "Output layout type must be i64");
  75. static constexpr auto NullMode = TKernel::ENullMode::AlwaysNull;
  76. static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right)
  77. {
  78. if constexpr (std::is_same_v<ui64, typename TRight::TLayout>) {
  79. if (right.Get<ui64>() > static_cast<ui64>(std::numeric_limits<i64>::max())) {
  80. return NUdf::TUnboxedValuePod(i64(0));
  81. }
  82. }
  83. const auto lv = static_cast<typename TOutput::TLayout>(left.template Get<typename TLeft::TLayout>());
  84. const auto rv = static_cast<typename TOutput::TLayout>(right.template Get<typename TRight::TLayout>());
  85. if (rv == 0) {
  86. return NUdf::TUnboxedValuePod();
  87. }
  88. return NUdf::TUnboxedValuePod(lv / rv);
  89. }
  90. #ifndef MKQL_DISABLE_CODEGEN
  91. static Value* Generate(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block)
  92. {
  93. auto& context = ctx.Codegen.GetContext();
  94. const auto bbMain = BasicBlock::Create(context, "bbMain", ctx.Func);
  95. const auto bbDone = BasicBlock::Create(context, "bbDone", ctx.Func);
  96. const auto resultType = Type::getInt128Ty(context);
  97. const auto null = ConstantInt::get(resultType, 0);
  98. const auto result = PHINode::Create(resultType, 3, "result", bbDone);
  99. const auto rv = GetterFor<typename TRight::TLayout>(right, context, block);
  100. const auto rvZero = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ,
  101. rv, ConstantInt::get(rv->getType(), 0), "rvZero", block);
  102. BranchInst::Create(bbDone, bbMain, rvZero, block);
  103. result->addIncoming(null, block);
  104. block = bbMain;
  105. const auto rvOverflow = GenIsInt64Overflow<typename TRight::TLayout>(rv, context, block);
  106. const auto zero = SetterFor<typename TOutput::TLayout>(
  107. ConstantInt::get(Type::getInt64Ty(context), 0), context, block);
  108. const auto lval = StaticCast<typename TLeft::TLayout, typename TOutput::TLayout>(
  109. GetterFor<typename TLeft::TLayout>(left, context, block), context, block);
  110. const auto rval = StaticCast<typename TRight::TLayout, typename TOutput::TLayout>(
  111. rv, context, block);
  112. const auto div = BinaryOperator::CreateSDiv(lval, rval, "div", block);
  113. const auto divResult = SetterFor<typename TOutput::TLayout>(div, context, block);
  114. const auto res = SelectInst::Create(rvOverflow, zero, divResult, "res", block);
  115. result->addIncoming(res, block);
  116. BranchInst::Create(bbDone, block);
  117. block = bbDone;
  118. return result;
  119. }
  120. #endif
  121. };
  122. }
  123. template <typename TInterval>
  124. void RegisterIntegralDiv(IBuiltinFunctionRegistry& registry) {
  125. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<ui8>,
  126. TInterval, TNumDivInterval, TBinaryArgsOptWithNullableResult>(registry, "Div");
  127. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<i8>,
  128. TInterval, TNumDivInterval, TBinaryArgsOptWithNullableResult>(registry, "Div");
  129. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<ui16>,
  130. TInterval, TNumDivInterval, TBinaryArgsOptWithNullableResult>(registry, "Div");
  131. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<i16>,
  132. TInterval, TNumDivInterval, TBinaryArgsOptWithNullableResult>(registry, "Div");
  133. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<ui32>,
  134. TInterval, TNumDivInterval, TBinaryArgsOptWithNullableResult>(registry, "Div");
  135. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<i32>,
  136. TInterval, TNumDivInterval, TBinaryArgsOptWithNullableResult>(registry, "Div");
  137. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<ui64>,
  138. TInterval, TNumDivInterval, TBinaryArgsOptWithNullableResult>(registry, "Div");
  139. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<i64>,
  140. TInterval, TNumDivInterval, TBinaryArgsOptWithNullableResult>(registry, "Div");
  141. }
  142. void RegisterDiv(IBuiltinFunctionRegistry& registry) {
  143. RegisterBinaryRealFunctionOpt<TDiv, TBinaryArgsOpt>(registry, "Div");
  144. RegisterBinaryIntegralFunctionOpt<TIntegralDiv, TBinaryArgsOptWithNullableResult>(registry, "Div");
  145. RegisterIntegralDiv<NUdf::TDataType<NUdf::TInterval>>(registry);
  146. RegisterIntegralDiv<NUdf::TDataType<NUdf::TInterval64>>(registry);
  147. }
  148. template <typename TInterval>
  149. void RegisterIntervalDiv(TKernelFamilyBase& owner) {
  150. AddBinaryKernelPoly<TInterval, NUdf::TDataType<i8>, TInterval, TNumDivInterval>(owner);
  151. AddBinaryKernelPoly<TInterval, NUdf::TDataType<ui8>, TInterval, TNumDivInterval>(owner);
  152. AddBinaryKernelPoly<TInterval, NUdf::TDataType<i16>, TInterval, TNumDivInterval>(owner);
  153. AddBinaryKernelPoly<TInterval, NUdf::TDataType<ui16>, TInterval, TNumDivInterval>(owner);
  154. AddBinaryKernelPoly<TInterval, NUdf::TDataType<i32>, TInterval, TNumDivInterval>(owner);
  155. AddBinaryKernelPoly<TInterval, NUdf::TDataType<ui32>, TInterval, TNumDivInterval>(owner);
  156. AddBinaryKernelPoly<TInterval, NUdf::TDataType<i64>, TInterval, TNumDivInterval>(owner);
  157. AddBinaryKernelPoly<TInterval, NUdf::TDataType<ui64>, TInterval, TNumDivInterval>(owner);
  158. }
  159. void RegisterDiv(TKernelFamilyMap& kernelFamilyMap) {
  160. auto family = std::make_unique<TKernelFamilyBase>();
  161. AddBinaryIntegralKernels<TIntegralDiv>(*family);
  162. AddBinaryRealKernels<TDiv>(*family);
  163. RegisterIntervalDiv<NUdf::TDataType<NUdf::TInterval>>(*family);
  164. RegisterIntervalDiv<NUdf::TDataType<NUdf::TInterval64>>(*family);
  165. kernelFamilyMap["Div"] = std::move(family);
  166. }
  167. } // namespace NMiniKQL
  168. } // namespace NKikimr