#include "mkql_decimal_div.h" #include // Y_IGNORE #include // Y_IGNORE #include #include extern "C" NYql::NDecimal::TInt128 DecimalMulAndDivNormalMultiplier(NYql::NDecimal::TInt128 a, NYql::NDecimal::TInt128 b, NYql::NDecimal::TInt128 c) { return NYql::NDecimal::MulAndDivNormalMultiplier(a, b, c); } extern "C" NYql::NDecimal::TInt128 DecimalDiv(NYql::NDecimal::TInt128 a, NYql::NDecimal::TInt128 b) { return NYql::NDecimal::Div(a, b); } namespace NKikimr { namespace NMiniKQL { namespace { template class TDecimalDivWrapper : public TMutableCodegeneratorNode>, NYql::NDecimal::TDecimalDivisor { typedef TMutableCodegeneratorNode> TBaseComputation; public: TDecimalDivWrapper(TComputationMutables& mutables, IComputationNode* left, IComputationNode* right, ui8 precision, ui8 scale) : TBaseComputation(mutables, EValueRepresentation::Embedded) , NYql::NDecimal::TDecimalDivisor(precision, scale) , Left(left) , Right(right) {} NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const { const auto& left = Left->GetValue(compCtx); const auto& right = Right->GetValue(compCtx); if (IsLeftOptional && !left) return NUdf::TUnboxedValuePod(); if (IsRightOptional && !right) return NUdf::TUnboxedValuePod(); return NUdf::TUnboxedValuePod(Do(left.GetInt128(), right.GetInt128())); } #ifndef MKQL_DISABLE_CODEGEN Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const { auto& context = ctx.Codegen.GetContext(); const auto valType = Type::getInt128Ty(context); const auto name = "DecimalMulAndDivNormalMultiplier"; ctx.Codegen.AddGlobalMapping(name, reinterpret_cast(&DecimalMulAndDivNormalMultiplier)); const auto fnType = FunctionType::get(valType, { valType, valType, valType }, false); const auto func = ctx.Codegen.GetModule().getOrInsertFunction(name, fnType); const auto left = GetNodeValue(Left, ctx, block); const auto right = GetNodeValue(Right, ctx, block); if constexpr (IsLeftOptional || IsRightOptional) { const auto test = IsLeftOptional && IsRightOptional ? BinaryOperator::CreateAnd(left, right, "test", block): IsLeftOptional ? left : right; const auto done = BasicBlock::Create(context, "done", ctx.Func); const auto good = BasicBlock::Create(context, "good", ctx.Func); const auto result = PHINode::Create(valType, 2, "result", done); result->addIncoming(test, block); BranchInst::Create(done, good, IsEmpty(test, block, context), block); block = good; const auto muldiv = CallInst::Create(func, { GetterForInt128(left, block), NDecimal::GenConstant(Divider, context), GetterForInt128(right, block) }, "mul_and_div", block); const auto ok = NDecimal::GenInBounds(muldiv, NDecimal::GenConstant(-Bound, context), NDecimal::GenConstant(+Bound, context), block); const auto nan = NDecimal::GenIsNonComparable(muldiv, context, block); const auto plus = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGT, muldiv, ConstantInt::get(muldiv->getType(), 0), "plus", block); const auto inf = SelectInst::Create(plus, GetDecimalPlusInf(context), GetDecimalMinusInf(context), "inf", block); const auto bad = SelectInst::Create(nan, GetDecimalNan(context), inf, "bad", block); const auto res = SelectInst::Create(ok, muldiv, bad, "res", block); result->addIncoming(SetterForInt128(res, block), block); BranchInst::Create(done, block); block = done; return result; } else { const auto muldiv = CallInst::Create(func, { GetterForInt128(left, block), NDecimal::GenConstant(Divider, context), GetterForInt128(right, block) }, "mul_and_div", block); const auto ok = NDecimal::GenInBounds(muldiv, NDecimal::GenConstant(-Bound, context), NDecimal::GenConstant(+Bound, context), block); const auto nan = NDecimal::GenIsNonComparable(muldiv, context, block); const auto plus = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGT, muldiv, ConstantInt::get(muldiv->getType(), 0), "plus", block); const auto inf = SelectInst::Create(plus, GetDecimalPlusInf(context), GetDecimalMinusInf(context), "inf", block); const auto bad = SelectInst::Create(nan, GetDecimalNan(context), inf, "bad", block); const auto res = SelectInst::Create(ok, muldiv, bad, "res", block); return SetterForInt128(res, block); } } #endif private: void RegisterDependencies() const final { this->DependsOn(Left); this->DependsOn(Right); } IComputationNode* const Left; IComputationNode* const Right; }; template class TDecimalDivIntegralWrapper : public TMutableCodegeneratorNode>, NYql::NDecimal::TDecimalDivisor { typedef TMutableCodegeneratorNode> TBaseComputation; public: TDecimalDivIntegralWrapper(TComputationMutables& mutables, IComputationNode* left, IComputationNode* right) : TBaseComputation(mutables, EValueRepresentation::Embedded) , Left(left) , Right(right) {} NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const { const auto& left = Left->GetValue(compCtx); const auto& right = Right->GetValue(compCtx); if (IsLeftOptional && !left) return NUdf::TUnboxedValuePod(); if (IsRightOptional && !right) return NUdf::TUnboxedValuePod(); return NUdf::TUnboxedValuePod(this->Do(left.GetInt128(), right.Get())); } #ifndef MKQL_DISABLE_CODEGEN Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const { auto& context = ctx.Codegen.GetContext(); const auto valType = Type::getInt128Ty(context); const auto name = "DecimalDiv"; ctx.Codegen.AddGlobalMapping(name, reinterpret_cast(&DecimalDiv)); const auto fnType = FunctionType::get(valType, { valType, valType }, false); const auto func = ctx.Codegen.GetModule().getOrInsertFunction(name, fnType); const auto left = GetNodeValue(Left, ctx, block); const auto right = GetNodeValue(Right, ctx, block); if constexpr (IsLeftOptional || IsRightOptional) { const auto test = IsLeftOptional && IsRightOptional ? BinaryOperator::CreateAnd(left, right, "test", block): IsLeftOptional ? left : right; const auto done = BasicBlock::Create(context, "done", ctx.Func); const auto good = BasicBlock::Create(context, "good", ctx.Func); const auto result = PHINode::Create(valType, 2, "result", done); result->addIncoming(test, block); BranchInst::Create(done, good, IsEmpty(test, block, context), block); block = good; const auto cast = std::is_signed() ? static_cast(new SExtInst(GetterFor(right, context, block), valType, "sext", block)): static_cast(new ZExtInst(GetterFor(right, context, block), valType, "zext", block)); const auto div = CallInst::Create(func, {GetterForInt128(left, block), cast}, "div", block); result->addIncoming(SetterForInt128(div, block), block); BranchInst::Create(done, block); block = done; return result; } else { const auto cast = std::is_signed() ? static_cast(new SExtInst(GetterFor(right, context, block), valType, "sext", block)): static_cast(new ZExtInst(GetterFor(right, context, block), valType, "zext", block)); const auto div = CallInst::Create(func, {GetterForInt128(left, block), cast}, "div", block); return SetterForInt128(div, block); } } #endif private: void RegisterDependencies() const final { this->DependsOn(Left); this->DependsOn(Right); } IComputationNode* const Left; IComputationNode* const Right; }; } IComputationNode* WrapDecimalDiv(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args"); bool isOptionalLeft, isOptionalRight; const auto leftType = static_cast(UnpackOptionalData(callable.GetInput(0), isOptionalLeft)); const auto rightType = UnpackOptionalData(callable.GetInput(1), isOptionalRight); auto left = LocateNode(ctx.NodeLocator, callable, 0); auto right = LocateNode(ctx.NodeLocator, callable, 1); switch (rightType->GetSchemeType()) { case NUdf::TDataType::Id: MKQL_ENSURE(static_cast(rightType)->IsSameType(*leftType), "Operands type mismatch"); if (isOptionalLeft && isOptionalRight) return new TDecimalDivWrapper(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second); else if (isOptionalLeft) return new TDecimalDivWrapper(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second); else if (isOptionalRight) return new TDecimalDivWrapper(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second); else return new TDecimalDivWrapper(ctx.Mutables, left, right, leftType->GetParams().first, leftType->GetParams().second); #define MAKE_PRIMITIVE_TYPE_DIV(type) \ case NUdf::TDataType::Id: \ if (isOptionalLeft && isOptionalRight) \ return new TDecimalDivIntegralWrapper(ctx.Mutables, left, right); \ else if (isOptionalLeft) \ return new TDecimalDivIntegralWrapper(ctx.Mutables, left, right); \ else if (isOptionalRight) \ return new TDecimalDivIntegralWrapper(ctx.Mutables, left, right); \ else \ return new TDecimalDivIntegralWrapper(ctx.Mutables, left, right); INTEGRAL_VALUE_TYPES(MAKE_PRIMITIVE_TYPE_DIV) #undef MAKE_PRIMITIVE_TYPE_DIV default: Y_ABORT("Unupported type."); } } } }