#include "mkql_builtins_compare.h" #include "mkql_builtins_datetime.h" #include "mkql_builtins_decimal.h" // Y_IGNORE #include "mkql_builtins_string_kernels.h" #include namespace NKikimr { namespace NMiniKQL { namespace { template ::value && std::is_integral::value && std::is_signed::value == std::is_signed::value, bool> Aggr> Y_FORCE_INLINE bool Equals(T1 x, T2 y) { return x == y; } template ::value && std::is_integral::value && std::is_signed::value && std::is_unsigned::value, bool> Aggr> Y_FORCE_INLINE bool Equals(T1 x, T2 y) { return x >= T1(0) && static_cast>(x) == y; } template ::value && std::is_integral::value && std::is_unsigned::value && std::is_signed::value, bool> Aggr> Y_FORCE_INLINE bool Equals(T1 x, T2 y) { return y >= T2(0) && x == static_cast>(y); } template ::value || std::is_floating_point::value, bool> Aggr> Y_FORCE_INLINE bool Equals(T1 x, T2 y) { using F1 = std::conditional_t::value, T1, T2>; using F2 = std::conditional_t::value, T2, T1>; using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>; const auto l = static_cast(x); const auto r = static_cast(y); if constexpr (Aggr) { if (std::isunordered(l, r)) return std::isnan(l) == std::isnan(r); } return l == r; } #ifndef MKQL_DISABLE_CODEGEN Value* GenEqualsIntegral(Value* lhs, Value* rhs, BasicBlock* block) { return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, lhs, rhs, "equals", block); } template Value* GenEqualsFloats(Value* lhs, Value* rhs, BasicBlock* block); template <> Value* GenEqualsFloats(Value* lhs, Value* rhs, BasicBlock* block) { return CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_OEQ, lhs, rhs, "equals", block); } template <> Value* GenEqualsFloats(Value* lhs, Value* rhs, BasicBlock* block) { const auto ueq = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UEQ, lhs, rhs, "equals", block); const auto lord = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ORD, ConstantFP::get(lhs->getType(), 0.0), lhs, "lord", block); const auto runo = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UNO, ConstantFP::get(rhs->getType(), 0.0), rhs, "runo", block); const auto once = BinaryOperator::CreateXor(lord, runo, "xor", block); return BinaryOperator::CreateAnd(ueq, once, "and", block); } template Value* GenEqualsIntegralLeftSigned(Value* x, Value* y, LLVMContext &context, BasicBlock* block) { const auto zero = ConstantInt::get(x->getType(), 0); const auto neg = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLT, x, zero, "negative", block); using T = std::conditional_t<(sizeof(std::make_unsigned_t) > sizeof(T2)), std::make_unsigned_t, T2>; const auto comp = GenEqualsIntegral(StaticCast(x, context, block), StaticCast(y, context, block), block); return SelectInst::Create(neg, ConstantInt::getFalse(context), comp, "result", block); } template Value* GenEqualsIntegralRightSigned(Value* x, Value* y, LLVMContext &context, BasicBlock* block) { const auto zero = ConstantInt::get(y->getType(), 0); const auto neg = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLT, y, zero, "negative", block); using T = std::conditional_t<(sizeof(T1) > sizeof(std::make_unsigned_t)), T1, std::make_unsigned_t>; const auto comp = GenEqualsIntegral(StaticCast(x, context, block), StaticCast(y, context, block), block); return SelectInst::Create(neg, ConstantInt::getFalse(context), comp, "result", block); } template ::value && std::is_unsigned::value, bool> Aggr> inline Value* GenEquals(Value* x, Value* y, LLVMContext &context, BasicBlock* block) { using T = std::conditional_t<(sizeof(T1) > sizeof(T2)), T1, T2>; return GenEqualsIntegral(StaticCast(x, context, block), StaticCast(y, context, block), block); } template ::value && std::is_signed::value && std::is_integral::value && std::is_integral::value, bool> Aggr> inline Value* GenEquals(Value* x, Value* y, LLVMContext &context, BasicBlock* block) { using T = std::conditional_t<(sizeof(T1) > sizeof(T2)), T1, T2>; return GenEqualsIntegral(StaticCast(x, context, block), StaticCast(y, context, block), block); } template ::value && std::is_integral::value && std::is_signed::value && std::is_unsigned::value, bool> Aggr> inline Value* GenEquals(Value* x, Value* y, LLVMContext &context, BasicBlock* block) { return GenEqualsIntegralLeftSigned(x, y, context, block); } template ::value && std::is_integral::value && std::is_unsigned::value && std::is_signed::value, bool> Aggr> inline Value* GenEquals(Value* x, Value* y, LLVMContext &context, BasicBlock* block) { return GenEqualsIntegralRightSigned(x, y, context, block); } template ::value || std::is_floating_point::value, bool> Aggr> inline Value* GenEquals(Value* x, Value* y, LLVMContext &context, BasicBlock* block) { using F1 = std::conditional_t::value, T1, T2>; using F2 = std::conditional_t::value, T2, T1>; using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>; return GenEqualsFloats(StaticCast(x, context, block), StaticCast(y, context, block), block); } #endif struct TAggrEquals { static bool Simple(bool left, bool right) { return left == right; } static bool Join(bool one, bool two) { return one && two; } #ifndef MKQL_DISABLE_CODEGEN static constexpr CmpInst::Predicate SimplePredicate = ICmpInst::ICMP_EQ; static Value* GenJoin(Value* one, Value* two, BasicBlock* block) { return BinaryOperator::CreateAnd(one, two, "and", block); } #endif }; template struct TEquals : public TCompareArithmeticBinary>, public TAggrEquals { static bool Do(TLeft left, TRight right) { return Equals(left, right); } #ifndef MKQL_DISABLE_CODEGEN static Value* Gen(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block) { return GenEquals(left, right, ctx.Codegen.GetContext(), block); } #endif }; template struct TEqualsOp; template struct TEqualsOp : public TEquals { static constexpr auto NullMode = TKernel::ENullMode::Default; }; template struct TDiffDateEquals : public TCompareArithmeticBinary>, public TAggrEquals { static bool Do(typename TLeft::TLayout left, typename TRight::TLayout right) { return std::is_same::value ? Equals(left, right): Equals(ToScaledDate(left), ToScaledDate(right)); } #ifndef MKQL_DISABLE_CODEGEN static Value* Gen(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block) { auto& context = ctx.Codegen.GetContext(); return std::is_same::value ? GenEquals(left, right, context, block): GenEquals(GenToScaledDate(left, context, block), GenToScaledDate(right, context, block), context, block); } #endif }; template struct TDiffDateEqualsOp; template struct TDiffDateEqualsOp> : public TDiffDateEquals { static constexpr auto NullMode = TKernel::ENullMode::Default; }; template struct TAggrTzDateEquals : public TArithmeticConstraintsBinary, public TAggrEquals { static_assert(std::is_same::value, "Must be same type."); static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right) { return NUdf::TUnboxedValuePod(Join(Equals(left.template Get(), right.template Get()), Equals(left.GetTimezoneId(), right.GetTimezoneId()))); } #ifndef MKQL_DISABLE_CODEGEN static Value* Generate(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block) { auto& context = ctx.Codegen.GetContext(); const auto lhs = GetterFor(left, context, block); const auto rhs = GetterFor(right, context, block); const auto ltz = GetterForTimezone(context, left, block); const auto rtz = GetterForTimezone(context, right, block); const auto result = GenJoin(GenEquals(lhs, rhs, context, block), GenEquals(ltz, rtz, context, block), block); const auto wide = MakeBoolean(result, context, block); return wide; } #endif }; template struct TCustomEquals : public TAggrEquals { static NUdf::TUnboxedValuePod Execute(NUdf::TUnboxedValuePod left, NUdf::TUnboxedValuePod right) { return NUdf::TUnboxedValuePod(CompareCustomsWithCleanup(left, right) == 0); } #ifndef MKQL_DISABLE_CODEGEN static Value* Generate(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block) { auto& context = ctx.Codegen.GetContext(); const auto res = CallBinaryUnboxedValueFunction(&CompareCustoms, Type::getInt32Ty(context), left, right, ctx.Codegen, block); const auto comp = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, res, ConstantInt::get(res->getType(), 0), "equals", block); ValueCleanup(EValueRepresentation::String, left, ctx, block); ValueCleanup(EValueRepresentation::String, right, ctx, block); return MakeBoolean(comp, context, block); } #endif }; struct TDecimalEquals { static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right) { const auto l = left.GetInt128(); const auto r = right.GetInt128(); return NUdf::TUnboxedValuePod(NYql::NDecimal::IsComparable(l) && l == r); } #ifndef MKQL_DISABLE_CODEGEN static Value* Generate(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block) { auto& context = ctx.Codegen.GetContext(); const auto l = GetterForInt128(left, block); const auto r = GetterForInt128(right, block); const auto good = NDecimal::GenIsComparable(l, context, block); const auto eq = GenEqualsIntegral(l, r, block); const auto res = BinaryOperator::CreateAnd(good, eq, "res", block); return MakeBoolean(res, context, block); } #endif }; struct TDecimalAggrEquals : public TAggrEquals { static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right) { const auto l = left.GetInt128(); const auto r = right.GetInt128(); return NUdf::TUnboxedValuePod(l == r); } #ifndef MKQL_DISABLE_CODEGEN static Value* Generate(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block) { auto& context = ctx.Codegen.GetContext(); const auto l = GetterForInt128(left, block); const auto r = GetterForInt128(right, block); const auto eq = GenEqualsIntegral(l, r, block); return MakeBoolean(eq, context, block); } #endif }; } void RegisterEquals(IBuiltinFunctionRegistry& registry) { const auto name = "Equals"; RegisterComparePrimitive(registry, name); RegisterCompareDatetime(registry, name); RegisterCompareBigDatetime(registry, name); RegisterCompareStrings(registry, name); RegisterCompareCustomOpt, NUdf::TDataType, TDecimalEquals, TCompareArgsOpt>(registry, name); const auto aggrName = "AggrEquals"; RegisterAggrComparePrimitive(registry, aggrName); RegisterAggrCompareDatetime(registry, aggrName); RegisterAggrCompareTzDatetime(registry, aggrName); RegisterAggrCompareBigDatetime(registry, aggrName); RegisterAggrCompareBigTzDatetime(registry, aggrName); RegisterAggrCompareStrings(registry, aggrName); RegisterAggrCompareCustomOpt, TDecimalAggrEquals, TCompareArgsOpt>(registry, aggrName); } void RegisterEquals(TKernelFamilyMap& kernelFamilyMap) { auto family = std::make_unique(); AddNumericComparisonKernels(*family); AddDateComparisonKernels(*family); AddDecimalComparisonKernels(*family); RegisterStringKernelEquals(*family); kernelFamilyMap["Equals"] = std::move(family); } } // namespace NMiniKQL } // namespace NKikimr