#include "mkql_decimal_div.h" #include // Y_IGNORE #include // Y_IGNORE #include #include namespace NKikimr { namespace NMiniKQL { namespace { template class TDecimalModWrapper : public TMutableCodegeneratorNode>, NYql::NDecimal::TDecimalRemainder { typedef TMutableCodegeneratorNode> TBaseComputation; public: TDecimalModWrapper(TComputationMutables& mutables, IComputationNode* left, IComputationNode* right) : TBaseComputation(mutables, EValueRepresentation::Embedded) , Left(left) , Right(right) {} NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const { const auto& left = Left->GetValue(compCtx); const auto& right = Right->GetValue(compCtx); if (IsLeftOptional && !left) return NUdf::TUnboxedValuePod(); if (IsRightOptional && !right) return NUdf::TUnboxedValuePod(); return NUdf::TUnboxedValuePod(Do(left.GetInt128(), right.GetInt128())); } #ifndef MKQL_DISABLE_CODEGEN Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const { auto& context = ctx.Codegen.GetContext(); const auto valType = Type::getInt128Ty(context); const auto left = GetNodeValue(Left, ctx, block); const auto right = GetNodeValue(Right, ctx, block); const auto done = BasicBlock::Create(context, "done", ctx.Func); const auto good = BasicBlock::Create(context, "good", ctx.Func); const auto zero = ConstantInt::get(valType, 0ULL); const auto result = PHINode::Create(valType, IsLeftOptional || IsRightOptional ? 3 : 2, "result", done); if constexpr (IsLeftOptional || IsRightOptional) { const auto test = IsLeftOptional && IsRightOptional ? BinaryOperator::CreateAnd(left, right, "test", block): IsLeftOptional ? left : right; result->addIncoming(zero, block); BranchInst::Create(done, good, IsEmpty(test, block, context), block); block = good; const auto lv = GetterForInt128(left, block); const auto rv = GetterForInt128(right, block); const auto lbad = NDecimal::GenIsAbnormal(lv, context, block); const auto rbad = NDecimal::GenIsAbnormal(rv, context, block); const auto bad = BinaryOperator::CreateOr(lbad, rbad, "bad", block); const auto nul = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, rv, zero, "check", block); const auto nan = BinaryOperator::CreateOr(bad, nul, "nan", block); const auto norm = BasicBlock::Create(context, "norm", ctx.Func); result->addIncoming(SetterForInt128(GetDecimalNan(context), block), block); BranchInst::Create(done, norm, nan, block); block = norm; const auto srem = BinaryOperator::CreateSRem(lv, rv, "srem", block); result->addIncoming(SetterForInt128(srem, block), block); } else { const auto lv = GetterForInt128(left, block); const auto rv = GetterForInt128(right, block); const auto lbad = NDecimal::GenIsAbnormal(lv, context, block); const auto rbad = NDecimal::GenIsAbnormal(rv, context, block); const auto bad = BinaryOperator::CreateOr(lbad, rbad, "bad", block); const auto nul = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, rv, zero, "check", block); const auto nan = BinaryOperator::CreateOr(bad, nul, "nan", block); result->addIncoming(SetterForInt128(GetDecimalNan(context), block), block); BranchInst::Create(done, good, nan, block); block = good; const auto srem = BinaryOperator::CreateSRem(lv, rv, "srem", block); result->addIncoming(SetterForInt128(srem, block), block); } BranchInst::Create(done, block); block = done; return result; } #endif private: void RegisterDependencies() const final { this->DependsOn(Left); this->DependsOn(Right); } IComputationNode* const Left; IComputationNode* const Right; }; template class TDecimalModIntegralWrapper : public TMutableCodegeneratorNode>, NYql::NDecimal::TDecimalRemainder { typedef TMutableCodegeneratorNode> TBaseComputation; using NYql::NDecimal::TDecimalRemainder::Divider; using NYql::NDecimal::TDecimalRemainder::Bound; public: TDecimalModIntegralWrapper(TComputationMutables& mutables, IComputationNode* left, IComputationNode* right, ui8 precision, ui8 scale) : TBaseComputation(mutables, EValueRepresentation::Embedded) , NYql::NDecimal::TDecimalRemainder(precision, scale) , Left(left) , Right(right) {} NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const { auto left = Left->GetValue(compCtx); const auto& right = Right->GetValue(compCtx); if (IsLeftOptional && !left) return NUdf::TUnboxedValuePod(); if (IsRightOptional && !right) return NUdf::TUnboxedValuePod(); return NUdf::TUnboxedValuePod(this->Do(left.GetInt128(), right.Get())); } #ifndef MKQL_DISABLE_CODEGEN Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const { auto& context = ctx.Codegen.GetContext(); const auto valType = Type::getInt128Ty(context); const auto divider = NDecimal::GenConstant(Divider, context); const auto left = GetNodeValue(Left, ctx, block); const auto right = GetNodeValue(Right, ctx, block); const auto done = BasicBlock::Create(context, "done", ctx.Func); const auto good = BasicBlock::Create(context, "good", ctx.Func); const auto zero = ConstantInt::get(valType, 0ULL); const auto result = PHINode::Create(valType, IsLeftOptional || IsRightOptional ? 3 : 2, "result", done); if constexpr (IsLeftOptional || IsRightOptional) { const auto test = IsLeftOptional && IsRightOptional ? BinaryOperator::CreateAnd(left, right, "test", block): IsLeftOptional ? left : right; result->addIncoming(zero, block); BranchInst::Create(done, good, IsEmpty(test, block, context), block); block = good; const auto lv = GetterForInt128(left, block); const auto cast = std::is_signed() ? static_cast(new SExtInst(GetterFor(right, context, block), valType, "sext", block)): static_cast(new ZExtInst(GetterFor(right, context, block), valType, "zext", block)); const auto out = std::is_signed() ? NDecimal::GenOutOfBounds(cast, NDecimal::GenConstant(-Bound, context), NDecimal::GenConstant(+Bound, context), block): CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGE, cast, NDecimal::GenConstant(Bound, context), "out", block); const auto nul = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, cast, zero, "check", block); const auto lbad = NDecimal::GenIsAbnormal(lv, context, block); const auto rbad = BinaryOperator::CreateOr(out, nul, "or", block); const auto bad = BinaryOperator::CreateOr(lbad, rbad, "bad", block); const auto norm = BasicBlock::Create(context, "norm", ctx.Func); const auto spec = SelectInst::Create(out, left, SetterForInt128(GetDecimalNan(context), block), "spec", block); result->addIncoming(spec, block); BranchInst::Create(done, norm, bad, block); block = norm; const auto mul = BinaryOperator::CreateMul(divider, cast, "mul", block); const auto srem = BinaryOperator::CreateSRem(lv, mul, "srem", block); result->addIncoming(SetterForInt128(srem, block), block); } else { const auto lv = GetterForInt128(left, block); const auto cast = std::is_signed() ? static_cast(new SExtInst(GetterFor(right, context, block), valType, "sext", block)): static_cast(new ZExtInst(GetterFor(right, context, block), valType, "zext", block)); const auto out = std::is_signed() ? NDecimal::GenOutOfBounds(cast, NDecimal::GenConstant(-Bound, context), NDecimal::GenConstant(+Bound, context), block): CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGE, cast, NDecimal::GenConstant(Bound, context), "out", block); const auto nul = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, cast, zero, "check", block); const auto lbad = NDecimal::GenIsAbnormal(lv, context, block); const auto rbad = BinaryOperator::CreateOr(out, nul, "or", block); const auto bad = BinaryOperator::CreateOr(lbad, rbad, "bad", block); const auto spec = SelectInst::Create(out, left, SetterForInt128(GetDecimalNan(context), block), "spec", block); result->addIncoming(spec, block); BranchInst::Create(done, good, bad, block); block = good; const auto mul = BinaryOperator::CreateMul(divider, cast, "mul", block); const auto srem = BinaryOperator::CreateSRem(lv, mul, "srem", block); result->addIncoming(SetterForInt128(srem, block), block); } BranchInst::Create(done, block); block = done; return result; } #endif private: void RegisterDependencies() const final { this->DependsOn(Left); this->DependsOn(Right); } IComputationNode* const Left; IComputationNode* const Right; }; } IComputationNode* WrapDecimalMod(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args"); bool isOptionalLeft, isOptionalRight; const auto leftType = static_cast(UnpackOptionalData(callable.GetInput(0), isOptionalLeft)); const auto rightType = UnpackOptionalData(callable.GetInput(1), isOptionalRight); auto left = LocateNode(ctx.NodeLocator, callable, 0); auto right = LocateNode(ctx.NodeLocator, callable, 1); switch (rightType->GetSchemeType()) { case NUdf::TDataType::Id: MKQL_ENSURE(static_cast(rightType)->IsSameType(*leftType), "Operands type mismatch"); if (isOptionalLeft && isOptionalRight) return new TDecimalModWrapper(ctx.Mutables, left, right); else if (isOptionalLeft) return new TDecimalModWrapper(ctx.Mutables, left, right); else if (isOptionalRight) return new TDecimalModWrapper(ctx.Mutables, left, right); else return new TDecimalModWrapper(ctx.Mutables, left, right); #define MAKE_PRIMITIVE_TYPE_MOD(type) \ case NUdf::TDataType::Id: \ if (isOptionalLeft && isOptionalRight) \ return new TDecimalModIntegralWrapper(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second); \ else if (isOptionalLeft) \ return new TDecimalModIntegralWrapper(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second); \ else if (isOptionalRight) \ return new TDecimalModIntegralWrapper(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second); \ else \ return new TDecimalModIntegralWrapper(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second); INTEGRAL_VALUE_TYPES(MAKE_PRIMITIVE_TYPE_MOD) #undef MAKE_PRIMITIVE_TYPE_MOD default: Y_ABORT("Unupported type."); } } } }