123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274 |
- #include "mkql_decimal_div.h"
- #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
- #include <yql/essentials/minikql/invoke_builtins/mkql_builtins_decimal.h> // Y_IGNORE
- #include <yql/essentials/minikql/mkql_node_builder.h>
- #include <yql/essentials/public/decimal/yql_decimal.h>
- namespace NKikimr {
- namespace NMiniKQL {
- namespace {
- template<bool IsLeftOptional, bool IsRightOptional>
- class TDecimalModWrapper : public TMutableCodegeneratorNode<TDecimalModWrapper<IsLeftOptional, IsRightOptional>>, NYql::NDecimal::TDecimalRemainder<NYql::NDecimal::TInt128> {
- typedef TMutableCodegeneratorNode<TDecimalModWrapper<IsLeftOptional, IsRightOptional>> TBaseComputation;
- public:
- TDecimalModWrapper(TComputationMutables& mutables, IComputationNode* left, IComputationNode* right)
- : TBaseComputation(mutables, EValueRepresentation::Embedded)
- , Left(left)
- , Right(right)
- {}
- NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
- const auto& left = Left->GetValue(compCtx);
- const auto& right = Right->GetValue(compCtx);
- if (IsLeftOptional && !left)
- return NUdf::TUnboxedValuePod();
- if (IsRightOptional && !right)
- return NUdf::TUnboxedValuePod();
- return NUdf::TUnboxedValuePod(Do(left.GetInt128(), right.GetInt128()));
- }
- #ifndef MKQL_DISABLE_CODEGEN
- Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
- auto& context = ctx.Codegen.GetContext();
- const auto valType = Type::getInt128Ty(context);
- const auto left = GetNodeValue(Left, ctx, block);
- const auto right = GetNodeValue(Right, ctx, block);
- const auto done = BasicBlock::Create(context, "done", ctx.Func);
- const auto good = BasicBlock::Create(context, "good", ctx.Func);
- const auto zero = ConstantInt::get(valType, 0ULL);
- const auto result = PHINode::Create(valType, IsLeftOptional || IsRightOptional ? 3 : 2, "result", done);
- if constexpr (IsLeftOptional || IsRightOptional) {
- const auto test = IsLeftOptional && IsRightOptional ?
- BinaryOperator::CreateAnd(left, right, "test", block):
- IsLeftOptional ? left : right;
- result->addIncoming(zero, block);
- BranchInst::Create(done, good, IsEmpty(test, block, context), block);
- block = good;
- const auto lv = GetterForInt128(left, block);
- const auto rv = GetterForInt128(right, block);
- const auto lbad = NDecimal::GenIsAbnormal(lv, context, block);
- const auto rbad = NDecimal::GenIsAbnormal(rv, context, block);
- const auto bad = BinaryOperator::CreateOr(lbad, rbad, "bad", block);
- const auto nul = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, rv, zero, "check", block);
- const auto nan = BinaryOperator::CreateOr(bad, nul, "nan", block);
- const auto norm = BasicBlock::Create(context, "norm", ctx.Func);
- result->addIncoming(SetterForInt128(GetDecimalNan(context), block), block);
- BranchInst::Create(done, norm, nan, block);
- block = norm;
- const auto srem = BinaryOperator::CreateSRem(lv, rv, "srem", block);
- result->addIncoming(SetterForInt128(srem, block), block);
- } else {
- const auto lv = GetterForInt128(left, block);
- const auto rv = GetterForInt128(right, block);
- const auto lbad = NDecimal::GenIsAbnormal(lv, context, block);
- const auto rbad = NDecimal::GenIsAbnormal(rv, context, block);
- const auto bad = BinaryOperator::CreateOr(lbad, rbad, "bad", block);
- const auto nul = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, rv, zero, "check", block);
- const auto nan = BinaryOperator::CreateOr(bad, nul, "nan", block);
- result->addIncoming(SetterForInt128(GetDecimalNan(context), block), block);
- BranchInst::Create(done, good, nan, block);
- block = good;
- const auto srem = BinaryOperator::CreateSRem(lv, rv, "srem", block);
- result->addIncoming(SetterForInt128(srem, block), block);
- }
- BranchInst::Create(done, block);
- block = done;
- return result;
- }
- #endif
- private:
- void RegisterDependencies() const final {
- this->DependsOn(Left);
- this->DependsOn(Right);
- }
- IComputationNode* const Left;
- IComputationNode* const Right;
- };
- template<bool IsLeftOptional, bool IsRightOptional, typename TRight>
- class TDecimalModIntegralWrapper : public TMutableCodegeneratorNode<TDecimalModIntegralWrapper<IsLeftOptional, IsRightOptional, TRight>>, NYql::NDecimal::TDecimalRemainder<TRight> {
- typedef TMutableCodegeneratorNode<TDecimalModIntegralWrapper<IsLeftOptional, IsRightOptional, TRight>> TBaseComputation;
- using NYql::NDecimal::TDecimalRemainder<TRight>::Divider;
- using NYql::NDecimal::TDecimalRemainder<TRight>::Bound;
- public:
- TDecimalModIntegralWrapper(TComputationMutables& mutables, IComputationNode* left, IComputationNode* right, ui8 precision, ui8 scale)
- : TBaseComputation(mutables, EValueRepresentation::Embedded)
- , NYql::NDecimal::TDecimalRemainder<TRight>(precision, scale)
- , Left(left)
- , Right(right)
- {}
- NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
- auto left = Left->GetValue(compCtx);
- const auto& right = Right->GetValue(compCtx);
- if (IsLeftOptional && !left)
- return NUdf::TUnboxedValuePod();
- if (IsRightOptional && !right)
- return NUdf::TUnboxedValuePod();
- return NUdf::TUnboxedValuePod(this->Do(left.GetInt128(), right.Get<TRight>()));
- }
- #ifndef MKQL_DISABLE_CODEGEN
- Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
- auto& context = ctx.Codegen.GetContext();
- const auto valType = Type::getInt128Ty(context);
- const auto divider = NDecimal::GenConstant(Divider, context);
- const auto left = GetNodeValue(Left, ctx, block);
- const auto right = GetNodeValue(Right, ctx, block);
- const auto done = BasicBlock::Create(context, "done", ctx.Func);
- const auto good = BasicBlock::Create(context, "good", ctx.Func);
- const auto zero = ConstantInt::get(valType, 0ULL);
- const auto result = PHINode::Create(valType, IsLeftOptional || IsRightOptional ? 3 : 2, "result", done);
- if constexpr (IsLeftOptional || IsRightOptional) {
- const auto test = IsLeftOptional && IsRightOptional ?
- BinaryOperator::CreateAnd(left, right, "test", block):
- IsLeftOptional ? left : right;
- result->addIncoming(zero, block);
- BranchInst::Create(done, good, IsEmpty(test, block, context), block);
- block = good;
- const auto lv = GetterForInt128(left, block);
- const auto cast = std::is_signed<TRight>() ?
- static_cast<CastInst*>(new SExtInst(GetterFor<TRight>(right, context, block), valType, "sext", block)):
- static_cast<CastInst*>(new ZExtInst(GetterFor<TRight>(right, context, block), valType, "zext", block));
- const auto out = std::is_signed<TRight>() ?
- NDecimal::GenOutOfBounds(cast, NDecimal::GenConstant(-Bound, context), NDecimal::GenConstant(+Bound, context), block):
- CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGE, cast, NDecimal::GenConstant(Bound, context), "out", block);
- const auto nul = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, cast, zero, "check", block);
- const auto lbad = NDecimal::GenIsAbnormal(lv, context, block);
- const auto rbad = BinaryOperator::CreateOr(out, nul, "or", block);
- const auto bad = BinaryOperator::CreateOr(lbad, rbad, "bad", block);
- const auto norm = BasicBlock::Create(context, "norm", ctx.Func);
- const auto spec = SelectInst::Create(out, left, SetterForInt128(GetDecimalNan(context), block), "spec", block);
- result->addIncoming(spec, block);
- BranchInst::Create(done, norm, bad, block);
- block = norm;
- const auto mul = BinaryOperator::CreateMul(divider, cast, "mul", block);
- const auto srem = BinaryOperator::CreateSRem(lv, mul, "srem", block);
- result->addIncoming(SetterForInt128(srem, block), block);
- } else {
- const auto lv = GetterForInt128(left, block);
- const auto cast = std::is_signed<TRight>() ?
- static_cast<CastInst*>(new SExtInst(GetterFor<TRight>(right, context, block), valType, "sext", block)):
- static_cast<CastInst*>(new ZExtInst(GetterFor<TRight>(right, context, block), valType, "zext", block));
- const auto out = std::is_signed<TRight>() ?
- NDecimal::GenOutOfBounds(cast, NDecimal::GenConstant(-Bound, context), NDecimal::GenConstant(+Bound, context), block):
- CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGE, cast, NDecimal::GenConstant(Bound, context), "out", block);
- const auto nul = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, cast, zero, "check", block);
- const auto lbad = NDecimal::GenIsAbnormal(lv, context, block);
- const auto rbad = BinaryOperator::CreateOr(out, nul, "or", block);
- const auto bad = BinaryOperator::CreateOr(lbad, rbad, "bad", block);
- const auto spec = SelectInst::Create(out, left, SetterForInt128(GetDecimalNan(context), block), "spec", block);
- result->addIncoming(spec, block);
- BranchInst::Create(done, good, bad, block);
- block = good;
- const auto mul = BinaryOperator::CreateMul(divider, cast, "mul", block);
- const auto srem = BinaryOperator::CreateSRem(lv, mul, "srem", block);
- result->addIncoming(SetterForInt128(srem, block), block);
- }
- BranchInst::Create(done, block);
- block = done;
- return result;
- }
- #endif
- private:
- void RegisterDependencies() const final {
- this->DependsOn(Left);
- this->DependsOn(Right);
- }
- IComputationNode* const Left;
- IComputationNode* const Right;
- };
- }
- IComputationNode* WrapDecimalMod(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
- MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
- bool isOptionalLeft, isOptionalRight;
- const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(callable.GetInput(0), isOptionalLeft));
- const auto rightType = UnpackOptionalData(callable.GetInput(1), isOptionalRight);
- auto left = LocateNode(ctx.NodeLocator, callable, 0);
- auto right = LocateNode(ctx.NodeLocator, callable, 1);
- switch (rightType->GetSchemeType()) {
- case NUdf::TDataType<NUdf::TDecimal>::Id:
- MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch");
- if (isOptionalLeft && isOptionalRight)
- return new TDecimalModWrapper<true, true>(ctx.Mutables, left, right);
- else if (isOptionalLeft)
- return new TDecimalModWrapper<true, false>(ctx.Mutables, left, right);
- else if (isOptionalRight)
- return new TDecimalModWrapper<false, true>(ctx.Mutables, left, right);
- else
- return new TDecimalModWrapper<false, false>(ctx.Mutables, left, right);
- #define MAKE_PRIMITIVE_TYPE_MOD(type) \
- case NUdf::TDataType<type>::Id: \
- if (isOptionalLeft && isOptionalRight) \
- return new TDecimalModIntegralWrapper<true, true, type>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second); \
- else if (isOptionalLeft) \
- return new TDecimalModIntegralWrapper<true, false, type>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second); \
- else if (isOptionalRight) \
- return new TDecimalModIntegralWrapper<false, true, type>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second); \
- else \
- return new TDecimalModIntegralWrapper<false, false, type>(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second);
- INTEGRAL_VALUE_TYPES(MAKE_PRIMITIVE_TYPE_MOD)
- #undef MAKE_PRIMITIVE_TYPE_MOD
- default:
- Y_ABORT("Unupported type.");
- }
- }
- }
- }
|