mkql_decimal_div.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. #include "mkql_decimal_div.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/invoke_builtins/mkql_builtins_decimal.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/mkql_node_builder.h>
  5. #include <yql/essentials/public/decimal/yql_decimal.h>
  6. extern "C" NYql::NDecimal::TInt128 DecimalMulAndDivNormalMultiplier(NYql::NDecimal::TInt128 a, NYql::NDecimal::TInt128 b, NYql::NDecimal::TInt128 c) {
  7. return NYql::NDecimal::MulAndDivNormalMultiplier(a, b, c);
  8. }
  9. extern "C" NYql::NDecimal::TInt128 DecimalDiv(NYql::NDecimal::TInt128 a, NYql::NDecimal::TInt128 b) {
  10. return NYql::NDecimal::Div(a, b);
  11. }
  12. namespace NKikimr {
  13. namespace NMiniKQL {
  14. namespace {
  15. template<bool IsLeftOptional, bool IsRightOptional>
  16. class TDecimalDivWrapper : public TMutableCodegeneratorNode<TDecimalDivWrapper<IsLeftOptional, IsRightOptional>>, NYql::NDecimal::TDecimalDivisor<NYql::NDecimal::TInt128> {
  17. typedef TMutableCodegeneratorNode<TDecimalDivWrapper<IsLeftOptional, IsRightOptional>> TBaseComputation;
  18. public:
  19. TDecimalDivWrapper(TComputationMutables& mutables, IComputationNode* left, IComputationNode* right, ui8 precision, ui8 scale)
  20. : TBaseComputation(mutables, EValueRepresentation::Embedded)
  21. , NYql::NDecimal::TDecimalDivisor<NYql::NDecimal::TInt128>(precision, scale)
  22. , Left(left)
  23. , Right(right)
  24. {}
  25. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
  26. const auto& left = Left->GetValue(compCtx);
  27. const auto& right = Right->GetValue(compCtx);
  28. if (IsLeftOptional && !left)
  29. return NUdf::TUnboxedValuePod();
  30. if (IsRightOptional && !right)
  31. return NUdf::TUnboxedValuePod();
  32. return NUdf::TUnboxedValuePod(Do(left.GetInt128(), right.GetInt128()));
  33. }
  34. #ifndef MKQL_DISABLE_CODEGEN
  35. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  36. auto& context = ctx.Codegen.GetContext();
  37. const auto valType = Type::getInt128Ty(context);
  38. const auto name = "DecimalMulAndDivNormalMultiplier";
  39. ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&DecimalMulAndDivNormalMultiplier));
  40. const auto fnType =
  41. FunctionType::get(valType, { valType, valType, valType }, false);
  42. const auto func = ctx.Codegen.GetModule().getOrInsertFunction(name, fnType);
  43. const auto left = GetNodeValue(Left, ctx, block);
  44. const auto right = GetNodeValue(Right, ctx, block);
  45. if constexpr (IsLeftOptional || IsRightOptional) {
  46. const auto test = IsLeftOptional && IsRightOptional ?
  47. BinaryOperator::CreateAnd(left, right, "test", block):
  48. IsLeftOptional ? left : right;
  49. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  50. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  51. const auto result = PHINode::Create(valType, 2, "result", done);
  52. result->addIncoming(test, block);
  53. BranchInst::Create(done, good, IsEmpty(test, block, context), block);
  54. block = good;
  55. const auto muldiv = CallInst::Create(func, { GetterForInt128(left, block), NDecimal::GenConstant(Divider, context), GetterForInt128(right, block) }, "mul_and_div", block);
  56. const auto ok = NDecimal::GenInBounds(muldiv, NDecimal::GenConstant(-Bound, context), NDecimal::GenConstant(+Bound, context), block);
  57. const auto nan = NDecimal::GenIsNonComparable(muldiv, context, block);
  58. const auto plus = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGT, muldiv, ConstantInt::get(muldiv->getType(), 0), "plus", block);
  59. const auto inf = SelectInst::Create(plus, GetDecimalPlusInf(context), GetDecimalMinusInf(context), "inf", block);
  60. const auto bad = SelectInst::Create(nan, GetDecimalNan(context), inf, "bad", block);
  61. const auto res = SelectInst::Create(ok, muldiv, bad, "res", block);
  62. result->addIncoming(SetterForInt128(res, block), block);
  63. BranchInst::Create(done, block);
  64. block = done;
  65. return result;
  66. } else {
  67. const auto muldiv = CallInst::Create(func, { GetterForInt128(left, block), NDecimal::GenConstant(Divider, context), GetterForInt128(right, block) }, "mul_and_div", block);
  68. const auto ok = NDecimal::GenInBounds(muldiv, NDecimal::GenConstant(-Bound, context), NDecimal::GenConstant(+Bound, context), block);
  69. const auto nan = NDecimal::GenIsNonComparable(muldiv, context, block);
  70. const auto plus = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGT, muldiv, ConstantInt::get(muldiv->getType(), 0), "plus", block);
  71. const auto inf = SelectInst::Create(plus, GetDecimalPlusInf(context), GetDecimalMinusInf(context), "inf", block);
  72. const auto bad = SelectInst::Create(nan, GetDecimalNan(context), inf, "bad", block);
  73. const auto res = SelectInst::Create(ok, muldiv, bad, "res", block);
  74. return SetterForInt128(res, block);
  75. }
  76. }
  77. #endif
  78. private:
  79. void RegisterDependencies() const final {
  80. this->DependsOn(Left);
  81. this->DependsOn(Right);
  82. }
  83. IComputationNode* const Left;
  84. IComputationNode* const Right;
  85. };
  86. template<bool IsLeftOptional, bool IsRightOptional, typename TRight>
  87. class TDecimalDivIntegralWrapper : public TMutableCodegeneratorNode<TDecimalDivIntegralWrapper<IsLeftOptional, IsRightOptional, TRight>>, NYql::NDecimal::TDecimalDivisor<TRight> {
  88. typedef TMutableCodegeneratorNode<TDecimalDivIntegralWrapper<IsLeftOptional, IsRightOptional, TRight>> TBaseComputation;
  89. public:
  90. TDecimalDivIntegralWrapper(TComputationMutables& mutables, IComputationNode* left, IComputationNode* right)
  91. : TBaseComputation(mutables, EValueRepresentation::Embedded)
  92. , Left(left)
  93. , Right(right)
  94. {}
  95. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
  96. const auto& left = Left->GetValue(compCtx);
  97. const auto& right = Right->GetValue(compCtx);
  98. if (IsLeftOptional && !left)
  99. return NUdf::TUnboxedValuePod();
  100. if (IsRightOptional && !right)
  101. return NUdf::TUnboxedValuePod();
  102. return NUdf::TUnboxedValuePod(this->Do(left.GetInt128(), right.Get<TRight>()));
  103. }
  104. #ifndef MKQL_DISABLE_CODEGEN
  105. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  106. auto& context = ctx.Codegen.GetContext();
  107. const auto valType = Type::getInt128Ty(context);
  108. const auto name = "DecimalDiv";
  109. ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&DecimalDiv));
  110. const auto fnType =
  111. FunctionType::get(valType, { valType, valType }, false);
  112. const auto func = ctx.Codegen.GetModule().getOrInsertFunction(name, fnType);
  113. const auto left = GetNodeValue(Left, ctx, block);
  114. const auto right = GetNodeValue(Right, ctx, block);
  115. if constexpr (IsLeftOptional || IsRightOptional) {
  116. const auto test = IsLeftOptional && IsRightOptional ?
  117. BinaryOperator::CreateAnd(left, right, "test", block):
  118. IsLeftOptional ? left : right;
  119. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  120. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  121. const auto result = PHINode::Create(valType, 2, "result", done);
  122. result->addIncoming(test, block);
  123. BranchInst::Create(done, good, IsEmpty(test, block, context), block);
  124. block = good;
  125. const auto cast = std::is_signed<TRight>() ?
  126. static_cast<CastInst*>(new SExtInst(GetterFor<TRight>(right, context, block), valType, "sext", block)):
  127. static_cast<CastInst*>(new ZExtInst(GetterFor<TRight>(right, context, block), valType, "zext", block));
  128. const auto div = CallInst::Create(func, {GetterForInt128(left, block), cast}, "div", block);
  129. result->addIncoming(SetterForInt128(div, block), block);
  130. BranchInst::Create(done, block);
  131. block = done;
  132. return result;
  133. } else {
  134. const auto cast = std::is_signed<TRight>() ?
  135. static_cast<CastInst*>(new SExtInst(GetterFor<TRight>(right, context, block), valType, "sext", block)):
  136. static_cast<CastInst*>(new ZExtInst(GetterFor<TRight>(right, context, block), valType, "zext", block));
  137. const auto div = CallInst::Create(func, {GetterForInt128(left, block), cast}, "div", block);
  138. return SetterForInt128(div, block);
  139. }
  140. }
  141. #endif
  142. private:
  143. void RegisterDependencies() const final {
  144. this->DependsOn(Left);
  145. this->DependsOn(Right);
  146. }
  147. IComputationNode* const Left;
  148. IComputationNode* const Right;
  149. };
  150. }
  151. IComputationNode* WrapDecimalDiv(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  152. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
  153. bool isOptionalLeft, isOptionalRight;
  154. const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(callable.GetInput(0), isOptionalLeft));
  155. const auto rightType = UnpackOptionalData(callable.GetInput(1), isOptionalRight);
  156. auto left = LocateNode(ctx.NodeLocator, callable, 0);
  157. auto right = LocateNode(ctx.NodeLocator, callable, 1);
  158. switch (rightType->GetSchemeType()) {
  159. case NUdf::TDataType<NUdf::TDecimal>::Id:
  160. MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch");
  161. if (isOptionalLeft && isOptionalRight)
  162. return new TDecimalDivWrapper<true, true>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second);
  163. else if (isOptionalLeft)
  164. return new TDecimalDivWrapper<true, false>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second);
  165. else if (isOptionalRight)
  166. return new TDecimalDivWrapper<false, true>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second);
  167. else
  168. return new TDecimalDivWrapper<false, false>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second);
  169. #define MAKE_PRIMITIVE_TYPE_DIV(type) \
  170. case NUdf::TDataType<type>::Id: \
  171. if (isOptionalLeft && isOptionalRight) \
  172. return new TDecimalDivIntegralWrapper<true, true, type>(ctx.Mutables, left, right); \
  173. else if (isOptionalLeft) \
  174. return new TDecimalDivIntegralWrapper<true, false, type>(ctx.Mutables, left, right); \
  175. else if (isOptionalRight) \
  176. return new TDecimalDivIntegralWrapper<false, true, type>(ctx.Mutables, left, right); \
  177. else \
  178. return new TDecimalDivIntegralWrapper<false, false, type>(ctx.Mutables, left, right);
  179. INTEGRAL_VALUE_TYPES(MAKE_PRIMITIVE_TYPE_DIV)
  180. #undef MAKE_PRIMITIVE_TYPE_DIV
  181. default:
  182. Y_ABORT("Unupported type.");
  183. }
  184. }
  185. }
  186. }