mkql_builtins_equals.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. #include "mkql_builtins_compare.h"
  2. #include "mkql_builtins_datetime.h"
  3. #include "mkql_builtins_decimal.h" // Y_IGNORE
  4. #include "mkql_builtins_string_kernels.h"
  5. #include <yql/essentials/minikql/mkql_type_ops.h>
  6. namespace NKikimr {
  7. namespace NMiniKQL {
  8. namespace {
  9. template <typename T1, typename T2,
  10. std::enable_if_t<std::is_integral<T1>::value && std::is_integral<T2>::value && std::is_signed<T1>::value == std::is_signed<T2>::value, bool> Aggr>
  11. Y_FORCE_INLINE bool Equals(T1 x, T2 y) {
  12. return x == y;
  13. }
  14. template <typename T1, typename T2,
  15. std::enable_if_t<std::is_integral<T1>::value && std::is_integral<T2>::value && std::is_signed<T1>::value && std::is_unsigned<T2>::value, bool> Aggr>
  16. Y_FORCE_INLINE bool Equals(T1 x, T2 y) {
  17. return x >= T1(0) && static_cast<std::make_unsigned_t<T1>>(x) == y;
  18. }
  19. template <typename T1, typename T2,
  20. std::enable_if_t<std::is_integral<T1>::value && std::is_integral<T2>::value && std::is_unsigned<T1>::value && std::is_signed<T2>::value, bool> Aggr>
  21. Y_FORCE_INLINE bool Equals(T1 x, T2 y) {
  22. return y >= T2(0) && x == static_cast<std::make_unsigned_t<T2>>(y);
  23. }
  24. template <typename T1, typename T2,
  25. std::enable_if_t<std::is_floating_point<T1>::value || std::is_floating_point<T2>::value, bool> Aggr>
  26. Y_FORCE_INLINE bool Equals(T1 x, T2 y) {
  27. using F1 = std::conditional_t<std::is_floating_point<T1>::value, T1, T2>;
  28. using F2 = std::conditional_t<std::is_floating_point<T2>::value, T2, T1>;
  29. using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>;
  30. const auto l = static_cast<FT>(x);
  31. const auto r = static_cast<FT>(y);
  32. if constexpr (Aggr) {
  33. if (std::isunordered(l, r))
  34. return std::isnan(l) == std::isnan(r);
  35. }
  36. return l == r;
  37. }
  38. #ifndef MKQL_DISABLE_CODEGEN
  39. Value* GenEqualsIntegral(Value* lhs, Value* rhs, BasicBlock* block) {
  40. return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, lhs, rhs, "equals", block);
  41. }
  42. template <bool Aggr>
  43. Value* GenEqualsFloats(Value* lhs, Value* rhs, BasicBlock* block);
  44. template <>
  45. Value* GenEqualsFloats<false>(Value* lhs, Value* rhs, BasicBlock* block) {
  46. return CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_OEQ, lhs, rhs, "equals", block);
  47. }
  48. template <>
  49. Value* GenEqualsFloats<true>(Value* lhs, Value* rhs, BasicBlock* block) {
  50. const auto ueq = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UEQ, lhs, rhs, "equals", block);
  51. const auto lord = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ORD, ConstantFP::get(lhs->getType(), 0.0), lhs, "lord", block);
  52. const auto runo = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UNO, ConstantFP::get(rhs->getType(), 0.0), rhs, "runo", block);
  53. const auto once = BinaryOperator::CreateXor(lord, runo, "xor", block);
  54. return BinaryOperator::CreateAnd(ueq, once, "and", block);
  55. }
  56. template <typename T1, typename T2>
  57. Value* GenEqualsIntegralLeftSigned(Value* x, Value* y, LLVMContext &context, BasicBlock* block) {
  58. const auto zero = ConstantInt::get(x->getType(), 0);
  59. const auto neg = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLT, x, zero, "negative", block);
  60. using T = std::conditional_t<(sizeof(std::make_unsigned_t<T1>) > sizeof(T2)), std::make_unsigned_t<T1>, T2>;
  61. const auto comp = GenEqualsIntegral(StaticCast<T1, T>(x, context, block), StaticCast<T2, T>(y, context, block), block);
  62. return SelectInst::Create(neg, ConstantInt::getFalse(context), comp, "result", block);
  63. }
  64. template <typename T1, typename T2>
  65. Value* GenEqualsIntegralRightSigned(Value* x, Value* y, LLVMContext &context, BasicBlock* block) {
  66. const auto zero = ConstantInt::get(y->getType(), 0);
  67. const auto neg = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLT, y, zero, "negative", block);
  68. using T = std::conditional_t<(sizeof(T1) > sizeof(std::make_unsigned_t<T2>)), T1, std::make_unsigned_t<T2>>;
  69. const auto comp = GenEqualsIntegral(StaticCast<T1, T>(x, context, block), StaticCast<T2, T>(y, context, block), block);
  70. return SelectInst::Create(neg, ConstantInt::getFalse(context), comp, "result", block);
  71. }
  72. template <typename T1, typename T2,
  73. std::enable_if_t<std::is_unsigned<T1>::value && std::is_unsigned<T2>::value, bool> Aggr>
  74. inline Value* GenEquals(Value* x, Value* y, LLVMContext &context, BasicBlock* block) {
  75. using T = std::conditional_t<(sizeof(T1) > sizeof(T2)), T1, T2>;
  76. return GenEqualsIntegral(StaticCast<T1, T>(x, context, block), StaticCast<T2, T>(y, context, block), block);
  77. }
  78. template <typename T1, typename T2,
  79. std::enable_if_t<std::is_signed<T1>::value && std::is_signed<T2>::value &&
  80. std::is_integral<T1>::value && std::is_integral<T2>::value, bool> Aggr>
  81. inline Value* GenEquals(Value* x, Value* y, LLVMContext &context, BasicBlock* block) {
  82. using T = std::conditional_t<(sizeof(T1) > sizeof(T2)), T1, T2>;
  83. return GenEqualsIntegral(StaticCast<T1, T>(x, context, block), StaticCast<T2, T>(y, context, block), block);
  84. }
  85. template <typename T1, typename T2,
  86. std::enable_if_t<std::is_integral<T1>::value && std::is_integral<T2>::value
  87. && std::is_signed<T1>::value && std::is_unsigned<T2>::value, bool> Aggr>
  88. inline Value* GenEquals(Value* x, Value* y, LLVMContext &context, BasicBlock* block) {
  89. return GenEqualsIntegralLeftSigned<T1, T2>(x, y, context, block);
  90. }
  91. template <typename T1, typename T2,
  92. std::enable_if_t<std::is_integral<T1>::value && std::is_integral<T2>::value
  93. && std::is_unsigned<T1>::value && std::is_signed<T2>::value, bool> Aggr>
  94. inline Value* GenEquals(Value* x, Value* y, LLVMContext &context, BasicBlock* block) {
  95. return GenEqualsIntegralRightSigned<T1, T2>(x, y, context, block);
  96. }
  97. template <typename T1, typename T2,
  98. std::enable_if_t<std::is_floating_point<T1>::value || std::is_floating_point<T2>::value, bool> Aggr>
  99. inline Value* GenEquals(Value* x, Value* y, LLVMContext &context, BasicBlock* block) {
  100. using F1 = std::conditional_t<std::is_floating_point<T1>::value, T1, T2>;
  101. using F2 = std::conditional_t<std::is_floating_point<T2>::value, T2, T1>;
  102. using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>;
  103. return GenEqualsFloats<Aggr>(StaticCast<T1, FT>(x, context, block), StaticCast<T2, FT>(y, context, block), block);
  104. }
  105. #endif
  106. struct TAggrEquals {
  107. static bool Simple(bool left, bool right)
  108. {
  109. return left == right;
  110. }
  111. static bool Join(bool one, bool two)
  112. {
  113. return one && two;
  114. }
  115. #ifndef MKQL_DISABLE_CODEGEN
  116. static constexpr CmpInst::Predicate SimplePredicate = ICmpInst::ICMP_EQ;
  117. static Value* GenJoin(Value* one, Value* two, BasicBlock* block)
  118. {
  119. return BinaryOperator::CreateAnd(one, two, "and", block);
  120. }
  121. #endif
  122. };
  123. template<typename TLeft, typename TRight, bool Aggr>
  124. struct TEquals : public TCompareArithmeticBinary<TLeft, TRight, TEquals<TLeft, TRight, Aggr>>, public TAggrEquals {
  125. static bool Do(TLeft left, TRight right)
  126. {
  127. return Equals<TLeft, TRight, Aggr>(left, right);
  128. }
  129. #ifndef MKQL_DISABLE_CODEGEN
  130. static Value* Gen(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block)
  131. {
  132. return GenEquals<TLeft, TRight, Aggr>(left, right, ctx.Codegen.GetContext(), block);
  133. }
  134. #endif
  135. };
  136. template<typename TLeft, typename TRight, typename TOutput>
  137. struct TEqualsOp;
  138. template<typename TLeft, typename TRight>
  139. struct TEqualsOp<TLeft, TRight, bool> : public TEquals<TLeft, TRight, false> {
  140. static constexpr auto NullMode = TKernel::ENullMode::Default;
  141. };
  142. template<typename TLeft, typename TRight, bool Aggr>
  143. struct TDiffDateEquals : public TCompareArithmeticBinary<typename TLeft::TLayout, typename TRight::TLayout, TDiffDateEquals<TLeft, TRight, Aggr>>, public TAggrEquals {
  144. static bool Do(typename TLeft::TLayout left, typename TRight::TLayout right)
  145. {
  146. return std::is_same<TLeft, TRight>::value ?
  147. Equals<typename TLeft::TLayout, typename TRight::TLayout, Aggr>(left, right):
  148. Equals<TScaledDate, TScaledDate, Aggr>(ToScaledDate<TLeft>(left), ToScaledDate<TRight>(right));
  149. }
  150. #ifndef MKQL_DISABLE_CODEGEN
  151. static Value* Gen(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block)
  152. {
  153. auto& context = ctx.Codegen.GetContext();
  154. return std::is_same<TLeft, TRight>::value ?
  155. GenEquals<typename TLeft::TLayout, typename TRight::TLayout, Aggr>(left, right, context, block):
  156. GenEquals<TScaledDate, TScaledDate, Aggr>(GenToScaledDate<TLeft>(left, context, block), GenToScaledDate<TRight>(right, context, block), context, block);
  157. }
  158. #endif
  159. };
  160. template<typename TLeft, typename TRight, typename TOutput>
  161. struct TDiffDateEqualsOp;
  162. template<typename TLeft, typename TRight>
  163. struct TDiffDateEqualsOp<TLeft, TRight, NUdf::TDataType<bool>> : public TDiffDateEquals<TLeft, TRight, false> {
  164. static constexpr auto NullMode = TKernel::ENullMode::Default;
  165. };
  166. template <typename TLeft, typename TRight, bool Aggr>
  167. struct TAggrTzDateEquals : public TArithmeticConstraintsBinary<TLeft, TRight, bool>, public TAggrEquals {
  168. static_assert(std::is_same<TLeft, TRight>::value, "Must be same type.");
  169. static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right) {
  170. return NUdf::TUnboxedValuePod(Join(Equals<TLeft, TRight, Aggr>(left.template Get<TLeft>(), right.template Get<TRight>()), Equals<ui16, ui16, Aggr>(left.GetTimezoneId(), right.GetTimezoneId())));
  171. }
  172. #ifndef MKQL_DISABLE_CODEGEN
  173. static Value* Generate(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block)
  174. {
  175. auto& context = ctx.Codegen.GetContext();
  176. const auto lhs = GetterFor<TLeft>(left, context, block);
  177. const auto rhs = GetterFor<TRight>(right, context, block);
  178. const auto ltz = GetterForTimezone(context, left, block);
  179. const auto rtz = GetterForTimezone(context, right, block);
  180. const auto result = GenJoin(GenEquals<TLeft, TRight, Aggr>(lhs, rhs, context, block), GenEquals<ui16, ui16, Aggr>(ltz, rtz, context, block), block);
  181. const auto wide = MakeBoolean(result, context, block);
  182. return wide;
  183. }
  184. #endif
  185. };
  186. template<NUdf::EDataSlot Slot>
  187. struct TCustomEquals : public TAggrEquals {
  188. static NUdf::TUnboxedValuePod Execute(NUdf::TUnboxedValuePod left, NUdf::TUnboxedValuePod right) {
  189. return NUdf::TUnboxedValuePod(CompareCustomsWithCleanup<Slot>(left, right) == 0);
  190. }
  191. #ifndef MKQL_DISABLE_CODEGEN
  192. static Value* Generate(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block)
  193. {
  194. auto& context = ctx.Codegen.GetContext();
  195. const auto res = CallBinaryUnboxedValueFunction(&CompareCustoms<Slot>, Type::getInt32Ty(context), left, right, ctx.Codegen, block);
  196. const auto comp = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, res, ConstantInt::get(res->getType(), 0), "equals", block);
  197. ValueCleanup(EValueRepresentation::String, left, ctx, block);
  198. ValueCleanup(EValueRepresentation::String, right, ctx, block);
  199. return MakeBoolean(comp, context, block);
  200. }
  201. #endif
  202. };
  203. struct TDecimalEquals {
  204. static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right) {
  205. const auto l = left.GetInt128();
  206. const auto r = right.GetInt128();
  207. return NUdf::TUnboxedValuePod(NYql::NDecimal::IsComparable(l) && l == r);
  208. }
  209. #ifndef MKQL_DISABLE_CODEGEN
  210. static Value* Generate(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block)
  211. {
  212. auto& context = ctx.Codegen.GetContext();
  213. const auto l = GetterForInt128(left, block);
  214. const auto r = GetterForInt128(right, block);
  215. const auto good = NDecimal::GenIsComparable(l, context, block);
  216. const auto eq = GenEqualsIntegral(l, r, block);
  217. const auto res = BinaryOperator::CreateAnd(good, eq, "res", block);
  218. return MakeBoolean(res, context, block);
  219. }
  220. #endif
  221. };
  222. struct TDecimalAggrEquals : public TAggrEquals {
  223. static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right) {
  224. const auto l = left.GetInt128();
  225. const auto r = right.GetInt128();
  226. return NUdf::TUnboxedValuePod(l == r);
  227. }
  228. #ifndef MKQL_DISABLE_CODEGEN
  229. static Value* Generate(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block)
  230. {
  231. auto& context = ctx.Codegen.GetContext();
  232. const auto l = GetterForInt128(left, block);
  233. const auto r = GetterForInt128(right, block);
  234. const auto eq = GenEqualsIntegral(l, r, block);
  235. return MakeBoolean(eq, context, block);
  236. }
  237. #endif
  238. };
  239. }
  240. void RegisterEquals(IBuiltinFunctionRegistry& registry) {
  241. const auto name = "Equals";
  242. RegisterComparePrimitive<TEquals, TCompareArgsOpt>(registry, name);
  243. RegisterCompareDatetime<TDiffDateEquals, TCompareArgsOpt>(registry, name);
  244. RegisterCompareBigDatetime<TDiffDateEquals, TCompareArgsOpt>(registry, name);
  245. RegisterCompareStrings<TCustomEquals, TCompareArgsOpt>(registry, name);
  246. RegisterCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, NUdf::TDataType<NUdf::TDecimal>, TDecimalEquals, TCompareArgsOpt>(registry, name);
  247. const auto aggrName = "AggrEquals";
  248. RegisterAggrComparePrimitive<TEquals, TCompareArgsOpt>(registry, aggrName);
  249. RegisterAggrCompareDatetime<TDiffDateEquals, TCompareArgsOpt>(registry, aggrName);
  250. RegisterAggrCompareTzDatetime<TAggrTzDateEquals, TCompareArgsOpt>(registry, aggrName);
  251. RegisterAggrCompareBigDatetime<TDiffDateEquals, TCompareArgsOpt>(registry, aggrName);
  252. RegisterAggrCompareBigTzDatetime<TAggrTzDateEquals, TCompareArgsOpt>(registry, aggrName);
  253. RegisterAggrCompareStrings<TCustomEquals, TCompareArgsOpt>(registry, aggrName);
  254. RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrEquals, TCompareArgsOpt>(registry, aggrName);
  255. }
  256. void RegisterEquals(TKernelFamilyMap& kernelFamilyMap) {
  257. auto family = std::make_unique<TKernelFamilyBase>();
  258. AddNumericComparisonKernels<TEqualsOp>(*family);
  259. AddDateComparisonKernels<TDiffDateEqualsOp>(*family);
  260. AddDecimalComparisonKernels<TDecimalEquals>(*family);
  261. RegisterStringKernelEquals(*family);
  262. kernelFamilyMap["Equals"] = std::move(family);
  263. }
  264. } // namespace NMiniKQL
  265. } // namespace NKikimr