mkql_decimal_mul.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. #include "mkql_decimal_mul.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/invoke_builtins/mkql_builtins_decimal.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/mkql_node_builder.h>
  5. #include <yql/essentials/public/decimal/yql_decimal.h>
  6. extern "C" NYql::NDecimal::TInt128 DecimalMulAndDivNormalDivider(NYql::NDecimal::TInt128 a, NYql::NDecimal::TInt128 b, NYql::NDecimal::TInt128 c) {
  7. return NYql::NDecimal::MulAndDivNormalDivider(a, b, c);
  8. }
  9. extern "C" NYql::NDecimal::TInt128 DecimalMul(NYql::NDecimal::TInt128 a, NYql::NDecimal::TInt128 b) {
  10. return NYql::NDecimal::Mul(a, b);
  11. }
  12. namespace NKikimr {
  13. namespace NMiniKQL {
  14. namespace {
  15. template<bool IsLeftOptional, bool IsRightOptional>
  16. class TDecimalMulWrapper : public TMutableCodegeneratorNode<TDecimalMulWrapper<IsLeftOptional, IsRightOptional>>, NYql::NDecimal::TDecimalMultiplicator<NYql::NDecimal::TInt128> {
  17. typedef TMutableCodegeneratorNode<TDecimalMulWrapper<IsLeftOptional, IsRightOptional>> TBaseComputation;
  18. public:
  19. TDecimalMulWrapper(TComputationMutables& mutables, IComputationNode* left, IComputationNode* right, ui8 precision, ui8 scale)
  20. : TBaseComputation(mutables, EValueRepresentation::Embedded)
  21. , NYql::NDecimal::TDecimalMultiplicator<NYql::NDecimal::TInt128>(precision, scale)
  22. , Left(left)
  23. , Right(right)
  24. {
  25. }
  26. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
  27. const auto& left = Left->GetValue(compCtx);
  28. const auto& right = Right->GetValue(compCtx);
  29. if (IsLeftOptional && !left)
  30. return NUdf::TUnboxedValuePod();
  31. if (IsRightOptional && !right)
  32. return NUdf::TUnboxedValuePod();
  33. return NUdf::TUnboxedValuePod(Do(left.GetInt128(), right.GetInt128()));
  34. }
  35. #ifndef MKQL_DISABLE_CODEGEN
  36. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  37. auto& context = ctx.Codegen.GetContext();
  38. const auto valType = Type::getInt128Ty(context);
  39. const bool useMulAddDiv = Divider > 1;
  40. const auto name = useMulAddDiv ? "DecimalMulAndDivNormalDivider" : "DecimalMul";
  41. const auto fnType = useMulAddDiv ?
  42. FunctionType::get(valType, { valType, valType, valType }, false):
  43. FunctionType::get(valType, { valType, valType}, false);
  44. ctx.Codegen.AddGlobalMapping(name, useMulAddDiv ? reinterpret_cast<const void*>(&DecimalMulAndDivNormalDivider) : reinterpret_cast<const void*>(&DecimalMul));
  45. const auto func = ctx.Codegen.GetModule().getOrInsertFunction(name, fnType);
  46. const auto left = GetNodeValue(Left, ctx, block);
  47. const auto right = GetNodeValue(Right, ctx, block);
  48. if constexpr (IsLeftOptional || IsRightOptional) {
  49. const auto test = IsLeftOptional && IsRightOptional ?
  50. BinaryOperator::CreateAnd(left, right, "test", block):
  51. IsLeftOptional ? left : right;
  52. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  53. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  54. const auto result = PHINode::Create(valType, 2, "result", done);
  55. result->addIncoming(test, block);
  56. BranchInst::Create(done, good, IsEmpty(test, block, context), block);
  57. block = good;
  58. Value* muldiv;
  59. if (useMulAddDiv) {
  60. muldiv = CallInst::Create(func, { GetterForInt128(left, block), GetterForInt128(right, block), NDecimal::GenConstant(Divider, context) }, "mul_and_div", block);
  61. } else {
  62. muldiv = CallInst::Create(func, { GetterForInt128(left, block), GetterForInt128(right, block) }, "mul", block);
  63. }
  64. const auto ok = NDecimal::GenInBounds(muldiv, NDecimal::GenConstant(-Bound, context), NDecimal::GenConstant(+Bound, context), block);
  65. const auto nan = NDecimal::GenIsNonComparable(muldiv, context, block);
  66. const auto plus = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGT, muldiv, ConstantInt::get(muldiv->getType(), 0), "plus", block);
  67. const auto inf = SelectInst::Create(plus, GetDecimalPlusInf(context), GetDecimalMinusInf(context), "inf", block);
  68. const auto bad = SelectInst::Create(nan, GetDecimalNan(context), inf, "bad", block);
  69. const auto res = SelectInst::Create(ok, muldiv, bad, "res", block);
  70. result->addIncoming(SetterForInt128(res, block), block);
  71. BranchInst::Create(done, block);
  72. block = done;
  73. return result;
  74. } else {
  75. Value* muldiv;
  76. if (useMulAddDiv) {
  77. muldiv = CallInst::Create(func, { GetterForInt128(left, block), GetterForInt128(right, block), NDecimal::GenConstant(Divider, context) }, "mul_and_div", block);
  78. } else {
  79. muldiv = CallInst::Create(func, { GetterForInt128(left, block), GetterForInt128(right, block) }, "mul", block);
  80. }
  81. const auto ok = NDecimal::GenInBounds(muldiv, NDecimal::GenConstant(-Bound, context), NDecimal::GenConstant(+Bound, context), block);
  82. const auto nan = NDecimal::GenIsNonComparable(muldiv, context, block);
  83. const auto plus = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGT, muldiv, ConstantInt::get(muldiv->getType(), 0), "plus", block);
  84. const auto inf = SelectInst::Create(plus, GetDecimalPlusInf(context), GetDecimalMinusInf(context), "inf", block);
  85. const auto bad = SelectInst::Create(nan, GetDecimalNan(context), inf, "bad", block);
  86. const auto res = SelectInst::Create(ok, muldiv, bad, "res", block);
  87. return SetterForInt128(res, block);
  88. }
  89. }
  90. #endif
  91. private:
  92. void RegisterDependencies() const final {
  93. this->DependsOn(Left);
  94. this->DependsOn(Right);
  95. }
  96. IComputationNode* const Left;
  97. IComputationNode* const Right;
  98. };
  99. template<bool IsLeftOptional, bool IsRightOptional, typename TRight>
  100. class TDecimalMulIntegralWrapper : public TMutableCodegeneratorNode<TDecimalMulIntegralWrapper<IsLeftOptional, IsRightOptional, TRight>>, NYql::NDecimal::TDecimalMultiplicator<TRight> {
  101. typedef TMutableCodegeneratorNode<TDecimalMulIntegralWrapper<IsLeftOptional, IsRightOptional, TRight>> TBaseComputation;
  102. using NYql::NDecimal::TDecimalMultiplicator<TRight>::Bound;
  103. public:
  104. TDecimalMulIntegralWrapper(TComputationMutables& mutables, IComputationNode* left, IComputationNode* right, ui8 precision)
  105. : TBaseComputation(mutables, EValueRepresentation::Embedded)
  106. , NYql::NDecimal::TDecimalMultiplicator<TRight>(precision)
  107. , Left(left)
  108. , Right(right)
  109. {}
  110. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
  111. const auto& left = Left->GetValue(compCtx);
  112. const auto& right = Right->GetValue(compCtx);
  113. if (IsLeftOptional && !left)
  114. return NUdf::TUnboxedValuePod();
  115. if (IsRightOptional && !right)
  116. return NUdf::TUnboxedValuePod();
  117. return NUdf::TUnboxedValuePod(this->Do(left.GetInt128(), right.template Get<TRight>()));
  118. }
  119. #ifndef MKQL_DISABLE_CODEGEN
  120. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  121. auto& context = ctx.Codegen.GetContext();
  122. const auto valType = Type::getInt128Ty(context);
  123. const auto name = "DecimalMul";
  124. ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&DecimalMul));
  125. const auto fnType =
  126. FunctionType::get(valType, { valType, valType }, false);
  127. const auto func = ctx.Codegen.GetModule().getOrInsertFunction(name, fnType);
  128. const auto left = GetNodeValue(Left, ctx, block);
  129. const auto right = GetNodeValue(Right, ctx, block);
  130. if constexpr (IsLeftOptional || IsRightOptional) {
  131. const auto test = IsLeftOptional && IsRightOptional ?
  132. BinaryOperator::CreateAnd(left, right, "test", block):
  133. IsLeftOptional ? left : right;
  134. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  135. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  136. const auto result = PHINode::Create(valType, 2, "result", done);
  137. result->addIncoming(test, block);
  138. BranchInst::Create(done, good, IsEmpty(test, block, context), block);
  139. block = good;
  140. const auto cast = std::is_signed<TRight>() ?
  141. static_cast<CastInst*>(new SExtInst(GetterFor<TRight>(right, context, block), valType, "sext", block)):
  142. static_cast<CastInst*>(new ZExtInst(GetterFor<TRight>(right, context, block), valType, "zext", block));
  143. const auto mul = CallInst::Create(func, {GetterForInt128(left, block), cast}, "div", block);
  144. const auto ok = NDecimal::GenInBounds(mul, NDecimal::GenConstant(-Bound, context), NDecimal::GenConstant(+Bound, context), block);
  145. const auto nan = NDecimal::GenIsNonComparable(mul, context, block);
  146. const auto plus = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGT, mul, ConstantInt::get(mul->getType(), 0), "plus", block);
  147. const auto inf = SelectInst::Create(plus, GetDecimalPlusInf(context), GetDecimalMinusInf(context), "inf", block);
  148. const auto bad = SelectInst::Create(nan, GetDecimalNan(context), inf, "bad", block);
  149. const auto res = SelectInst::Create(ok, mul, bad, "res", block);
  150. result->addIncoming(SetterForInt128(res, block), block);
  151. BranchInst::Create(done, block);
  152. block = done;
  153. return result;
  154. } else {
  155. const auto cast = std::is_signed<TRight>() ?
  156. static_cast<CastInst*>(new SExtInst(GetterFor<TRight>(right, context, block), valType, "sext", block)):
  157. static_cast<CastInst*>(new ZExtInst(GetterFor<TRight>(right, context, block), valType, "zext", block));
  158. const auto mul = CallInst::Create(func, {GetterForInt128(left, block), cast}, "div", block);
  159. const auto ok = NDecimal::GenInBounds(mul, NDecimal::GenConstant(-Bound, context), NDecimal::GenConstant(+Bound, context), block);
  160. const auto nan = NDecimal::GenIsNonComparable(mul, context, block);
  161. const auto plus = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGT, mul, ConstantInt::get(mul->getType(), 0), "plus", block);
  162. const auto inf = SelectInst::Create(plus, GetDecimalPlusInf(context), GetDecimalMinusInf(context), "inf", block);
  163. const auto bad = SelectInst::Create(nan, GetDecimalNan(context), inf, "bad", block);
  164. const auto res = SelectInst::Create(ok, mul, bad, "res", block);
  165. return SetterForInt128(res, block);
  166. }
  167. }
  168. #endif
  169. private:
  170. void RegisterDependencies() const final {
  171. this->DependsOn(Left);
  172. this->DependsOn(Right);
  173. }
  174. IComputationNode* const Left;
  175. IComputationNode* const Right;
  176. };
  177. }
  178. IComputationNode* WrapDecimalMul(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  179. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
  180. bool isOptionalLeft, isOptionalRight;
  181. const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(callable.GetInput(0), isOptionalLeft));
  182. const auto rightType = UnpackOptionalData(callable.GetInput(1), isOptionalRight);
  183. auto left = LocateNode(ctx.NodeLocator, callable, 0);
  184. auto right = LocateNode(ctx.NodeLocator, callable, 1);
  185. switch (rightType->GetSchemeType()) {
  186. case NUdf::TDataType<NUdf::TDecimal>::Id:
  187. MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch");
  188. if (isOptionalLeft && isOptionalRight)
  189. return new TDecimalMulWrapper<true, true>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second);
  190. else if (isOptionalLeft)
  191. return new TDecimalMulWrapper<true, false>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second);
  192. else if (isOptionalRight)
  193. return new TDecimalMulWrapper<false, true>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second);
  194. else
  195. return new TDecimalMulWrapper<false, false>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second);
  196. #define MAKE_PRIMITIVE_TYPE_MUL(type) \
  197. case NUdf::TDataType<type>::Id: \
  198. if (isOptionalLeft && isOptionalRight) \
  199. return new TDecimalMulIntegralWrapper<true, true, type>(ctx.Mutables, left, right, leftType->GetParams().first); \
  200. else if (isOptionalLeft) \
  201. return new TDecimalMulIntegralWrapper<true, false, type>(ctx.Mutables, left, right, leftType->GetParams().first); \
  202. else if (isOptionalRight) \
  203. return new TDecimalMulIntegralWrapper<false, true, type>(ctx.Mutables, left, right, leftType->GetParams().first); \
  204. else \
  205. return new TDecimalMulIntegralWrapper<false, false, type>(ctx.Mutables, left, right, leftType->GetParams().first);
  206. INTEGRAL_VALUE_TYPES(MAKE_PRIMITIVE_TYPE_MUL)
  207. #undef MAKE_PRIMITIVE_TYPE_MUL
  208. default:
  209. Y_ABORT("Unupported type.");
  210. }
  211. }
  212. }
  213. }