mkql_builtins_mul.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. #include "mkql_builtins_impl.h" // Y_IGNORE
  2. #include "mkql_builtins_datetime.h"
  3. #include <yql/essentials/minikql/mkql_type_ops.h>
  4. namespace NKikimr {
  5. namespace NMiniKQL {
  6. namespace {
  7. template<typename TLeft, typename TRight, typename TOutput>
  8. struct TMul : public TSimpleArithmeticBinary<TLeft, TRight, TOutput, TMul<TLeft, TRight, TOutput>> {
  9. static constexpr auto NullMode = TKernel::ENullMode::Default;
  10. static TOutput Do(TOutput left, TOutput right)
  11. {
  12. return left * right;
  13. }
  14. #ifndef MKQL_DISABLE_CODEGEN
  15. static Value* Gen(Value* left, Value* right, const TCodegenContext&, BasicBlock*& block)
  16. {
  17. return std::is_floating_point<TOutput>() ?
  18. BinaryOperator::CreateFMul(left, right, "mul", block):
  19. BinaryOperator::CreateMul(left, right, "mul", block);
  20. }
  21. #endif
  22. };
  23. template<typename TLeft, typename TRight, typename TOutput>
  24. struct TNumMulInterval {
  25. static_assert(TOutput::Features & NYql::NUdf::TimeIntervalType, "Output must be interval type");
  26. static_assert(std::is_integral_v<typename TLeft::TLayout>, "Left must be integral");
  27. static_assert(std::is_integral_v<typename TRight::TLayout>, "Right must be integral");
  28. static constexpr auto NullMode = TKernel::ENullMode::AlwaysNull;
  29. static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right)
  30. {
  31. const auto lv = static_cast<typename TOutput::TLayout>(left.template Get<typename TLeft::TLayout>());
  32. const auto rv = static_cast<typename TOutput::TLayout>(right.template Get<typename TRight::TLayout>());
  33. const auto ret = lv * rv;
  34. if (rv == 0 || lv == 0) {
  35. return NUdf::TUnboxedValuePod(ret);
  36. }
  37. i64 i64Max = std::numeric_limits<i64>::max();
  38. if constexpr (std::is_same_v<ui64, typename TLeft::TLayout>) {
  39. if (left.Get<ui64>() >= static_cast<ui64>(i64Max)) {
  40. return NUdf::TUnboxedValuePod();
  41. }
  42. }
  43. if constexpr (std::is_same_v<ui64, typename TRight::TLayout>) {
  44. if (right.Get<ui64>() >= static_cast<ui64>(i64Max)) {
  45. return NUdf::TUnboxedValuePod();
  46. }
  47. }
  48. auto div = i64Max / rv;
  49. auto divAbs = (div >= 0) ? div : -div;
  50. if ((lv >= 0) ? (lv > divAbs) : (lv < -divAbs)) {
  51. return NUdf::TUnboxedValuePod();
  52. }
  53. return IsBadInterval<TOutput>(ret) ? NUdf::TUnboxedValuePod() : NUdf::TUnboxedValuePod(ret);
  54. }
  55. #ifndef MKQL_DISABLE_CODEGEN
  56. static Value* Generate(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block)
  57. {
  58. auto& context = ctx.Codegen.GetContext();
  59. const auto bbMain = BasicBlock::Create(context, "bbMain", ctx.Func);
  60. const auto bbDone = BasicBlock::Create(context, "bbDone", ctx.Func);
  61. const auto resultType = Type::getInt128Ty(context);
  62. const auto result = PHINode::Create(resultType, 2, "result", bbDone);
  63. const auto lv = GetterFor<typename TLeft::TLayout>(left, context, block);
  64. const auto lhs = StaticCast<typename TLeft::TLayout, i64>(lv, context, block);
  65. const auto rv = GetterFor<typename TRight::TLayout>(right, context, block);
  66. const auto rhs = StaticCast<typename TRight::TLayout, i64>(rv, context, block);
  67. const auto mul = BinaryOperator::CreateMul(lhs, rhs, "mul", block);
  68. const auto zero = ConstantInt::get(Type::getInt64Ty(context), 0);
  69. const auto lhsZero = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, lhs, zero, "lhsZero", block);
  70. const auto rhsZero = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, rhs, zero, "rhsZero", block);
  71. const auto res = SetterFor<typename TOutput::TLayout>(mul, context, block);
  72. BranchInst::Create(bbDone, bbMain, BinaryOperator::CreateOr(lhsZero, rhsZero, "mulZero", block), block);
  73. result->addIncoming(res, block);
  74. block = bbMain;
  75. const auto i64Max = ConstantInt::get(Type::getInt64Ty(context), std::numeric_limits<i64>::max());
  76. const auto div = BinaryOperator::CreateSDiv(i64Max, rhs, "div", block);
  77. const auto divAbs = SelectInst::Create(
  78. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGE, div, zero, "divPos", block),
  79. div,
  80. BinaryOperator::CreateNeg(div, "divNeg", block),
  81. "divAbs", block);
  82. const auto divAbsNeg = BinaryOperator::CreateNeg(divAbs, "divAbsNeg", block);
  83. const auto mulOverflow = SelectInst::Create(
  84. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGE, lhs, zero, "lhsPos", block),
  85. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGT, lhs, divAbs, "lhsDiv", block),
  86. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLT, lhs, divAbsNeg, "lhsDivAbsNeg", block),
  87. "mulOverflow", block);
  88. const auto i64Overflow = BinaryOperator::CreateOr(
  89. GenIsInt64Overflow<typename TLeft::TLayout>(lv, context, block),
  90. GenIsInt64Overflow<typename TRight::TLayout>(rv, context, block),
  91. "i64Overflow", block);
  92. const auto bad = BinaryOperator::CreateOr(
  93. BinaryOperator::CreateOr(i64Overflow, mulOverflow, "overflow", block),
  94. GenIsBadInterval<TOutput>(mul, context, block),
  95. "bad", block);
  96. const auto null = ConstantInt::get(resultType, 0);
  97. const auto sel = SelectInst::Create(bad, null, res, "sel", block);
  98. result->addIncoming(sel, block);
  99. BranchInst::Create(bbDone, block);
  100. block = bbDone;
  101. return result;
  102. }
  103. #endif
  104. };
  105. }
  106. template <typename TInterval>
  107. void RegisterIntervalMul(IBuiltinFunctionRegistry& registry) {
  108. RegisterFunctionBinPolyOpt<NUdf::TDataType<ui8>, TInterval,
  109. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  110. RegisterFunctionBinPolyOpt<NUdf::TDataType<i8>, TInterval,
  111. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  112. RegisterFunctionBinPolyOpt<NUdf::TDataType<ui16>, TInterval,
  113. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  114. RegisterFunctionBinPolyOpt<NUdf::TDataType<i16>, TInterval,
  115. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  116. RegisterFunctionBinPolyOpt<NUdf::TDataType<ui32>, TInterval,
  117. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  118. RegisterFunctionBinPolyOpt<NUdf::TDataType<i32>, TInterval,
  119. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  120. RegisterFunctionBinPolyOpt<NUdf::TDataType<ui64>, TInterval,
  121. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  122. RegisterFunctionBinPolyOpt<NUdf::TDataType<i64>, TInterval,
  123. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  124. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<ui8>,
  125. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  126. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<i8>,
  127. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  128. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<ui16>,
  129. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  130. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<i16>,
  131. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  132. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<ui32>,
  133. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  134. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<i32>,
  135. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  136. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<ui64>,
  137. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  138. RegisterFunctionBinPolyOpt<TInterval, NUdf::TDataType<i64>,
  139. TInterval, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
  140. }
  141. void RegisterMul(IBuiltinFunctionRegistry& registry) {
  142. RegisterBinaryNumericFunctionOpt<TMul, TBinaryArgsOpt>(registry, "Mul");
  143. RegisterIntervalMul<NUdf::TDataType<NUdf::TInterval>>(registry);
  144. RegisterIntervalMul<NUdf::TDataType<NUdf::TInterval64>>(registry);
  145. }
  146. template <typename TInterval>
  147. void RegisterIntervalMul(TKernelFamilyBase& owner) {
  148. AddBinaryKernelPoly<NUdf::TDataType<i8>, TInterval, TInterval, TNumMulInterval>(owner);
  149. AddBinaryKernelPoly<NUdf::TDataType<ui8>, TInterval, TInterval, TNumMulInterval>(owner);
  150. AddBinaryKernelPoly<NUdf::TDataType<i16>, TInterval, TInterval, TNumMulInterval>(owner);
  151. AddBinaryKernelPoly<NUdf::TDataType<ui16>, TInterval, TInterval, TNumMulInterval>(owner);
  152. AddBinaryKernelPoly<NUdf::TDataType<i32>, TInterval, TInterval, TNumMulInterval>(owner);
  153. AddBinaryKernelPoly<NUdf::TDataType<ui32>, TInterval, TInterval, TNumMulInterval>(owner);
  154. AddBinaryKernelPoly<NUdf::TDataType<i64>, TInterval, TInterval, TNumMulInterval>(owner);
  155. AddBinaryKernelPoly<NUdf::TDataType<ui64>, TInterval, TInterval, TNumMulInterval>(owner);
  156. AddBinaryKernelPoly<TInterval, NUdf::TDataType<i8>, TInterval, TNumMulInterval>(owner);
  157. AddBinaryKernelPoly<TInterval, NUdf::TDataType<ui8>, TInterval, TNumMulInterval>(owner);
  158. AddBinaryKernelPoly<TInterval, NUdf::TDataType<i16>, TInterval, TNumMulInterval>(owner);
  159. AddBinaryKernelPoly<TInterval, NUdf::TDataType<ui16>, TInterval, TNumMulInterval>(owner);
  160. AddBinaryKernelPoly<TInterval, NUdf::TDataType<i32>, TInterval, TNumMulInterval>(owner);
  161. AddBinaryKernelPoly<TInterval, NUdf::TDataType<ui32>, TInterval, TNumMulInterval>(owner);
  162. AddBinaryKernelPoly<TInterval, NUdf::TDataType<i64>, TInterval, TNumMulInterval>(owner);
  163. AddBinaryKernelPoly<TInterval, NUdf::TDataType<ui64>, TInterval, TNumMulInterval>(owner);
  164. }
  165. void RegisterMul(TKernelFamilyMap& kernelFamilyMap) {
  166. auto family = std::make_unique<TKernelFamilyBase>();
  167. AddBinaryIntegralKernels<TMul>(*family);
  168. AddBinaryRealKernels<TMul>(*family);
  169. RegisterIntervalMul<NUdf::TDataType<NUdf::TInterval>>(*family);
  170. RegisterIntervalMul<NUdf::TDataType<NUdf::TInterval64>>(*family);
  171. kernelFamilyMap["Mul"] = std::move(family);
  172. }
  173. } // namespace NMiniKQL
  174. } // namespace NKikimr