mkql_builtins_abs.cpp 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. #include "mkql_builtins_decimal.h" // Y_IGNORE
  2. #include <cmath>
  3. namespace NKikimr {
  4. namespace NMiniKQL {
  5. namespace {
  6. template <typename T, std::enable_if_t<std::is_unsigned<T>::value>* = nullptr>
  7. inline T Abs(T v) {
  8. return v;
  9. }
  10. template <typename T, std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
  11. inline T Abs(T v) {
  12. return std::fabs(v);
  13. }
  14. template <typename T, std::enable_if_t<std::is_signed<T>::value && std::is_integral<T>::value>* = nullptr>
  15. inline T Abs(T v) {
  16. return std::abs(v);
  17. }
  18. template<typename TInput, typename TOutput>
  19. struct TAbs : public TSimpleArithmeticUnary<TInput, TOutput, TAbs<TInput, TOutput>> {
  20. static constexpr auto NullMode = TKernel::ENullMode::Default;
  21. static TOutput Do(TInput val)
  22. {
  23. return Abs<TInput>(val);
  24. }
  25. #ifndef MKQL_DISABLE_CODEGEN
  26. static Value* Gen(Value* arg, const TCodegenContext& ctx, BasicBlock*& block)
  27. {
  28. if (std::is_unsigned<TInput>())
  29. return arg;
  30. if (std::is_floating_point<TInput>()) {
  31. auto& module = ctx.Codegen.GetModule();
  32. const auto fnType = FunctionType::get(arg->getType(), {arg->getType()}, false);
  33. const auto& name = GetFuncNameForType<TInput>("llvm.fabs");
  34. const auto func = module.getOrInsertFunction(name, fnType).getCallee();
  35. const auto res = CallInst::Create(fnType, func, {arg}, "fabs", block);
  36. return res;
  37. } else {
  38. const auto zero = ConstantInt::get(arg->getType(), 0);
  39. const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLT, arg, zero, "check", block);
  40. const auto neg = BinaryOperator::CreateNeg(arg, "neg", block);
  41. const auto res = SelectInst::Create(check, neg, arg, "result", block);
  42. return res;
  43. }
  44. }
  45. #endif
  46. };
  47. struct TDecimalAbs : public TDecimalUnary<TDecimalAbs> {
  48. static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& arg) {
  49. const auto a = arg.GetInt128();
  50. return a < 0 ? NUdf::TUnboxedValuePod(-a) : arg;
  51. }
  52. #ifndef MKQL_DISABLE_CODEGEN
  53. static Value* Generate(Value* arg, const TCodegenContext&, BasicBlock*& block)
  54. {
  55. const auto val = GetterForInt128(arg, block);
  56. const auto zero = ConstantInt::get(val->getType(), 0);
  57. const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLT, val, zero, "check", block);
  58. const auto neg = BinaryOperator::CreateNeg(val, "neg", block);
  59. const auto res = SelectInst::Create(check, SetterForInt128(neg, block), arg, "result", block);
  60. return res;
  61. }
  62. #endif
  63. };
  64. }
  65. void RegisterAbs(IBuiltinFunctionRegistry& registry) {
  66. RegisterUnaryNumericFunctionOpt<TAbs, TUnaryArgsOpt>(registry, "Abs");
  67. RegisterFunctionUnOpt<NUdf::TDataType<NUdf::TInterval>, NUdf::TDataType<NUdf::TInterval>, TAbs, TUnaryArgsOpt>(registry, "Abs");
  68. RegisterFunctionUnOpt<NUdf::TDataType<NUdf::TInterval64>, NUdf::TDataType<NUdf::TInterval64>, TAbs, TUnaryArgsOpt>(registry, "Abs");
  69. NDecimal::RegisterUnaryFunction<TDecimalAbs, TUnaryArgsOpt>(registry, "Abs");
  70. }
  71. void RegisterAbs(TKernelFamilyMap& kernelFamilyMap) {
  72. auto family = std::make_unique<TUnaryNumericKernelFamily<TAbs>>();
  73. AddUnaryDecimalKernels<TDecimalAbs>(*family);
  74. kernelFamilyMap["Abs"] = std::move(family);
  75. }
  76. } // namespace NMiniKQL
  77. } // namespace NKikimr