mkql_decimal_mod.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. #include "mkql_decimal_div.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/invoke_builtins/mkql_builtins_decimal.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/mkql_node_builder.h>
  5. #include <yql/essentials/public/decimal/yql_decimal.h>
  6. namespace NKikimr {
  7. namespace NMiniKQL {
  8. namespace {
  9. template<bool IsLeftOptional, bool IsRightOptional>
  10. class TDecimalModWrapper : public TMutableCodegeneratorNode<TDecimalModWrapper<IsLeftOptional, IsRightOptional>>, NYql::NDecimal::TDecimalRemainder<NYql::NDecimal::TInt128> {
  11. typedef TMutableCodegeneratorNode<TDecimalModWrapper<IsLeftOptional, IsRightOptional>> TBaseComputation;
  12. public:
  13. TDecimalModWrapper(TComputationMutables& mutables, IComputationNode* left, IComputationNode* right)
  14. : TBaseComputation(mutables, EValueRepresentation::Embedded)
  15. , Left(left)
  16. , Right(right)
  17. {}
  18. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
  19. const auto& left = Left->GetValue(compCtx);
  20. const auto& right = Right->GetValue(compCtx);
  21. if (IsLeftOptional && !left)
  22. return NUdf::TUnboxedValuePod();
  23. if (IsRightOptional && !right)
  24. return NUdf::TUnboxedValuePod();
  25. return NUdf::TUnboxedValuePod(Do(left.GetInt128(), right.GetInt128()));
  26. }
  27. #ifndef MKQL_DISABLE_CODEGEN
  28. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  29. auto& context = ctx.Codegen.GetContext();
  30. const auto valType = Type::getInt128Ty(context);
  31. const auto left = GetNodeValue(Left, ctx, block);
  32. const auto right = GetNodeValue(Right, ctx, block);
  33. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  34. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  35. const auto zero = ConstantInt::get(valType, 0ULL);
  36. const auto result = PHINode::Create(valType, IsLeftOptional || IsRightOptional ? 3 : 2, "result", done);
  37. if constexpr (IsLeftOptional || IsRightOptional) {
  38. const auto test = IsLeftOptional && IsRightOptional ?
  39. BinaryOperator::CreateAnd(left, right, "test", block):
  40. IsLeftOptional ? left : right;
  41. result->addIncoming(zero, block);
  42. BranchInst::Create(done, good, IsEmpty(test, block, context), block);
  43. block = good;
  44. const auto lv = GetterForInt128(left, block);
  45. const auto rv = GetterForInt128(right, block);
  46. const auto lbad = NDecimal::GenIsAbnormal(lv, context, block);
  47. const auto rbad = NDecimal::GenIsAbnormal(rv, context, block);
  48. const auto bad = BinaryOperator::CreateOr(lbad, rbad, "bad", block);
  49. const auto nul = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, rv, zero, "check", block);
  50. const auto nan = BinaryOperator::CreateOr(bad, nul, "nan", block);
  51. const auto norm = BasicBlock::Create(context, "norm", ctx.Func);
  52. result->addIncoming(SetterForInt128(GetDecimalNan(context), block), block);
  53. BranchInst::Create(done, norm, nan, block);
  54. block = norm;
  55. const auto srem = BinaryOperator::CreateSRem(lv, rv, "srem", block);
  56. result->addIncoming(SetterForInt128(srem, block), block);
  57. } else {
  58. const auto lv = GetterForInt128(left, block);
  59. const auto rv = GetterForInt128(right, block);
  60. const auto lbad = NDecimal::GenIsAbnormal(lv, context, block);
  61. const auto rbad = NDecimal::GenIsAbnormal(rv, context, block);
  62. const auto bad = BinaryOperator::CreateOr(lbad, rbad, "bad", block);
  63. const auto nul = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, rv, zero, "check", block);
  64. const auto nan = BinaryOperator::CreateOr(bad, nul, "nan", block);
  65. result->addIncoming(SetterForInt128(GetDecimalNan(context), block), block);
  66. BranchInst::Create(done, good, nan, block);
  67. block = good;
  68. const auto srem = BinaryOperator::CreateSRem(lv, rv, "srem", block);
  69. result->addIncoming(SetterForInt128(srem, block), block);
  70. }
  71. BranchInst::Create(done, block);
  72. block = done;
  73. return result;
  74. }
  75. #endif
  76. private:
  77. void RegisterDependencies() const final {
  78. this->DependsOn(Left);
  79. this->DependsOn(Right);
  80. }
  81. IComputationNode* const Left;
  82. IComputationNode* const Right;
  83. };
  84. template<bool IsLeftOptional, bool IsRightOptional, typename TRight>
  85. class TDecimalModIntegralWrapper : public TMutableCodegeneratorNode<TDecimalModIntegralWrapper<IsLeftOptional, IsRightOptional, TRight>>, NYql::NDecimal::TDecimalRemainder<TRight> {
  86. typedef TMutableCodegeneratorNode<TDecimalModIntegralWrapper<IsLeftOptional, IsRightOptional, TRight>> TBaseComputation;
  87. using NYql::NDecimal::TDecimalRemainder<TRight>::Divider;
  88. using NYql::NDecimal::TDecimalRemainder<TRight>::Bound;
  89. public:
  90. TDecimalModIntegralWrapper(TComputationMutables& mutables, IComputationNode* left, IComputationNode* right, ui8 precision, ui8 scale)
  91. : TBaseComputation(mutables, EValueRepresentation::Embedded)
  92. , NYql::NDecimal::TDecimalRemainder<TRight>(precision, scale)
  93. , Left(left)
  94. , Right(right)
  95. {}
  96. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
  97. auto left = Left->GetValue(compCtx);
  98. const auto& right = Right->GetValue(compCtx);
  99. if (IsLeftOptional && !left)
  100. return NUdf::TUnboxedValuePod();
  101. if (IsRightOptional && !right)
  102. return NUdf::TUnboxedValuePod();
  103. return NUdf::TUnboxedValuePod(this->Do(left.GetInt128(), right.Get<TRight>()));
  104. }
  105. #ifndef MKQL_DISABLE_CODEGEN
  106. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  107. auto& context = ctx.Codegen.GetContext();
  108. const auto valType = Type::getInt128Ty(context);
  109. const auto divider = NDecimal::GenConstant(Divider, context);
  110. const auto left = GetNodeValue(Left, ctx, block);
  111. const auto right = GetNodeValue(Right, ctx, block);
  112. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  113. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  114. const auto zero = ConstantInt::get(valType, 0ULL);
  115. const auto result = PHINode::Create(valType, IsLeftOptional || IsRightOptional ? 3 : 2, "result", done);
  116. if constexpr (IsLeftOptional || IsRightOptional) {
  117. const auto test = IsLeftOptional && IsRightOptional ?
  118. BinaryOperator::CreateAnd(left, right, "test", block):
  119. IsLeftOptional ? left : right;
  120. result->addIncoming(zero, block);
  121. BranchInst::Create(done, good, IsEmpty(test, block, context), block);
  122. block = good;
  123. const auto lv = GetterForInt128(left, block);
  124. const auto cast = std::is_signed<TRight>() ?
  125. static_cast<CastInst*>(new SExtInst(GetterFor<TRight>(right, context, block), valType, "sext", block)):
  126. static_cast<CastInst*>(new ZExtInst(GetterFor<TRight>(right, context, block), valType, "zext", block));
  127. const auto out = std::is_signed<TRight>() ?
  128. NDecimal::GenOutOfBounds(cast, NDecimal::GenConstant(-Bound, context), NDecimal::GenConstant(+Bound, context), block):
  129. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGE, cast, NDecimal::GenConstant(Bound, context), "out", block);
  130. const auto nul = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, cast, zero, "check", block);
  131. const auto lbad = NDecimal::GenIsAbnormal(lv, context, block);
  132. const auto rbad = BinaryOperator::CreateOr(out, nul, "or", block);
  133. const auto bad = BinaryOperator::CreateOr(lbad, rbad, "bad", block);
  134. const auto norm = BasicBlock::Create(context, "norm", ctx.Func);
  135. const auto spec = SelectInst::Create(out, left, SetterForInt128(GetDecimalNan(context), block), "spec", block);
  136. result->addIncoming(spec, block);
  137. BranchInst::Create(done, norm, bad, block);
  138. block = norm;
  139. const auto mul = BinaryOperator::CreateMul(divider, cast, "mul", block);
  140. const auto srem = BinaryOperator::CreateSRem(lv, mul, "srem", block);
  141. result->addIncoming(SetterForInt128(srem, block), block);
  142. } else {
  143. const auto lv = GetterForInt128(left, block);
  144. const auto cast = std::is_signed<TRight>() ?
  145. static_cast<CastInst*>(new SExtInst(GetterFor<TRight>(right, context, block), valType, "sext", block)):
  146. static_cast<CastInst*>(new ZExtInst(GetterFor<TRight>(right, context, block), valType, "zext", block));
  147. const auto out = std::is_signed<TRight>() ?
  148. NDecimal::GenOutOfBounds(cast, NDecimal::GenConstant(-Bound, context), NDecimal::GenConstant(+Bound, context), block):
  149. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGE, cast, NDecimal::GenConstant(Bound, context), "out", block);
  150. const auto nul = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, cast, zero, "check", block);
  151. const auto lbad = NDecimal::GenIsAbnormal(lv, context, block);
  152. const auto rbad = BinaryOperator::CreateOr(out, nul, "or", block);
  153. const auto bad = BinaryOperator::CreateOr(lbad, rbad, "bad", block);
  154. const auto spec = SelectInst::Create(out, left, SetterForInt128(GetDecimalNan(context), block), "spec", block);
  155. result->addIncoming(spec, block);
  156. BranchInst::Create(done, good, bad, block);
  157. block = good;
  158. const auto mul = BinaryOperator::CreateMul(divider, cast, "mul", block);
  159. const auto srem = BinaryOperator::CreateSRem(lv, mul, "srem", block);
  160. result->addIncoming(SetterForInt128(srem, block), block);
  161. }
  162. BranchInst::Create(done, block);
  163. block = done;
  164. return result;
  165. }
  166. #endif
  167. private:
  168. void RegisterDependencies() const final {
  169. this->DependsOn(Left);
  170. this->DependsOn(Right);
  171. }
  172. IComputationNode* const Left;
  173. IComputationNode* const Right;
  174. };
  175. }
  176. IComputationNode* WrapDecimalMod(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  177. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
  178. bool isOptionalLeft, isOptionalRight;
  179. const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(callable.GetInput(0), isOptionalLeft));
  180. const auto rightType = UnpackOptionalData(callable.GetInput(1), isOptionalRight);
  181. auto left = LocateNode(ctx.NodeLocator, callable, 0);
  182. auto right = LocateNode(ctx.NodeLocator, callable, 1);
  183. switch (rightType->GetSchemeType()) {
  184. case NUdf::TDataType<NUdf::TDecimal>::Id:
  185. MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch");
  186. if (isOptionalLeft && isOptionalRight)
  187. return new TDecimalModWrapper<true, true>(ctx.Mutables, left, right);
  188. else if (isOptionalLeft)
  189. return new TDecimalModWrapper<true, false>(ctx.Mutables, left, right);
  190. else if (isOptionalRight)
  191. return new TDecimalModWrapper<false, true>(ctx.Mutables, left, right);
  192. else
  193. return new TDecimalModWrapper<false, false>(ctx.Mutables, left, right);
  194. #define MAKE_PRIMITIVE_TYPE_MOD(type) \
  195. case NUdf::TDataType<type>::Id: \
  196. if (isOptionalLeft && isOptionalRight) \
  197. return new TDecimalModIntegralWrapper<true, true, type>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second); \
  198. else if (isOptionalLeft) \
  199. return new TDecimalModIntegralWrapper<true, false, type>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second); \
  200. else if (isOptionalRight) \
  201. return new TDecimalModIntegralWrapper<false, true, type>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second); \
  202. else \
  203. return new TDecimalModIntegralWrapper<false, false, type>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second);
  204. INTEGRAL_VALUE_TYPES(MAKE_PRIMITIVE_TYPE_MOD)
  205. #undef MAKE_PRIMITIVE_TYPE_MOD
  206. default:
  207. Y_ABORT("Unupported type.");
  208. }
  209. }
  210. }
  211. }