mkql_round.cpp 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. #include "mkql_round.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/presort.h>
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. #include <yql/essentials/minikql/mkql_node_builder.h>
  6. #include <yql/essentials/minikql/mkql_type_builder.h>
  7. #include <yql/essentials/minikql/mkql_string_util.h>
  8. #include <yql/essentials/public/udf/udf_data_type.h>
  9. #include <yql/essentials/utils/utf8.h>
  10. #include <algorithm>
  11. namespace NKikimr {
  12. namespace NMiniKQL {
  13. using namespace NYql::NUdf;
  14. namespace {
  15. template<typename From, typename To>
  16. class TRoundIntegralWrapper : public TMutableComputationNode<TRoundIntegralWrapper<From, To>> {
  17. using TSelf = TRoundIntegralWrapper<From, To>;
  18. using TBase = TMutableComputationNode<TSelf>;
  19. typedef TBase TBaseComputation;
  20. public:
  21. TRoundIntegralWrapper(TComputationMutables& mutables, IComputationNode* source, bool down)
  22. : TBaseComputation(mutables)
  23. , Source(source)
  24. , Down(down)
  25. {
  26. }
  27. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  28. const auto value = Source->GetValue(ctx).Get<From>();
  29. constexpr auto toMin = std::numeric_limits<To>::min();
  30. constexpr auto toMax = std::numeric_limits<To>::max();
  31. if constexpr (std::is_signed<From>::value && std::is_unsigned<To>::value) {
  32. if (value < 0) {
  33. return Down ? TUnboxedValuePod() : TUnboxedValuePod(toMin);
  34. }
  35. if (static_cast<std::make_unsigned_t<From>>(value) > toMax) {
  36. return Down ? TUnboxedValuePod(toMax) : TUnboxedValuePod();
  37. }
  38. return TUnboxedValuePod(static_cast<To>(value));
  39. }
  40. if constexpr (std::is_unsigned<From>::value && std::is_signed<To>::value) {
  41. if (value > static_cast<std::make_unsigned_t<To>>(toMax)) {
  42. return Down ? TUnboxedValuePod(toMax) : TUnboxedValuePod();
  43. }
  44. return TUnboxedValuePod(static_cast<To>(value));
  45. }
  46. if (value < toMin) {
  47. return Down ? TUnboxedValuePod() : TUnboxedValuePod(toMin);
  48. }
  49. if (value > toMax) {
  50. return Down ? TUnboxedValuePod(toMax) : TUnboxedValuePod();
  51. }
  52. return TUnboxedValuePod(static_cast<To>(value));
  53. }
  54. private:
  55. void RegisterDependencies() const final {
  56. this->DependsOn(Source);
  57. }
  58. IComputationNode* const Source;
  59. const bool Down;
  60. };
  61. class TRoundDateTypeWrapper : public TMutableComputationNode<TRoundDateTypeWrapper> {
  62. using TSelf = TRoundDateTypeWrapper;
  63. using TBase = TMutableComputationNode<TSelf>;
  64. typedef TBase TBaseComputation;
  65. public:
  66. TRoundDateTypeWrapper(TComputationMutables& mutables, IComputationNode* source, bool down, EDataSlot from, EDataSlot to)
  67. : TBaseComputation(mutables)
  68. , Source(source)
  69. , Down(down)
  70. , From(from)
  71. , To(to)
  72. {
  73. }
  74. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  75. constexpr i64 usInDay = 86400'000'000ll;
  76. constexpr i64 usInSec = 1000'000ll;
  77. i64 us;
  78. if (From == EDataSlot::Timestamp64) {
  79. us = Source->GetValue(ctx).Get<i64>();
  80. } else if (From == EDataSlot::Datetime64) {
  81. us = usInSec * Source->GetValue(ctx).Get<i64>();
  82. } else if (From == EDataSlot::Timestamp) {
  83. us = static_cast<i64>(Source->GetValue(ctx).Get<ui64>());
  84. } else if (From == EDataSlot::Datetime) {
  85. us = usInSec * static_cast<i64>(Source->GetValue(ctx).Get<ui32>());
  86. } else {
  87. Y_ENSURE(From == EDataSlot::Date32);
  88. us = usInDay * static_cast<i64>(Source->GetValue(ctx).Get<i32>());
  89. }
  90. if (To == EDataSlot::Date || To == EDataSlot::Date32) {
  91. i64 rounded = us / usInDay;
  92. i64 rem = us % usInDay;
  93. if (rem > 0 && !Down) {
  94. rounded += 1;
  95. } else if (rem < 0 && Down) {
  96. rounded -= 1;
  97. }
  98. if (To == EDataSlot::Date32 && rounded <= MAX_DATE32) {
  99. // lower bound check is not needed as RoundDown(MinTimestamp64) is valid value
  100. return TUnboxedValuePod(static_cast<i32>(rounded));
  101. } else if (To == EDataSlot::Date && rounded >= 0 && rounded < MAX_DATE) {
  102. return TUnboxedValuePod(static_cast<ui16>(rounded));
  103. }
  104. } else if (To == EDataSlot::Datetime || To == EDataSlot::Datetime64) {
  105. i64 rounded = us / usInSec;
  106. i64 rem = us % usInSec;
  107. if (rem > 0 && !Down) {
  108. rounded += 1;
  109. } else if (rem < 0 && Down) {
  110. rounded -= 1;
  111. }
  112. if (To == EDataSlot::Datetime64 && rounded <= MAX_DATETIME64) {
  113. // lower bound check is not needed as RoundDown(MinTimestamp64) is valid value
  114. return TUnboxedValuePod(rounded);
  115. } else if (To == EDataSlot::Datetime && rounded >= 0 && rounded < MAX_DATETIME) {
  116. return TUnboxedValuePod(static_cast<ui32>(rounded));
  117. }
  118. } else {
  119. Y_ENSURE(To == EDataSlot::Timestamp);
  120. if (0 <= us && us < static_cast<i64>(MAX_TIMESTAMP)) {
  121. return TUnboxedValuePod(static_cast<ui64>(us));
  122. }
  123. }
  124. return {};
  125. }
  126. private:
  127. void RegisterDependencies() const final {
  128. this->DependsOn(Source);
  129. }
  130. IComputationNode* const Source;
  131. const bool Down;
  132. const EDataSlot From;
  133. const EDataSlot To;
  134. };
  135. class TRoundStringWrapper : public TMutableComputationNode<TRoundStringWrapper> {
  136. using TSelf = TRoundStringWrapper;
  137. using TBase = TMutableComputationNode<TSelf>;
  138. typedef TBase TBaseComputation;
  139. public:
  140. TRoundStringWrapper(TComputationMutables& mutables, IComputationNode* source, bool down)
  141. : TBaseComputation(mutables)
  142. , Source(source)
  143. , Down(down)
  144. {
  145. }
  146. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  147. TUnboxedValue input = Source->GetValue(ctx);
  148. auto output = NYql::RoundToNearestValidUtf8(input.AsStringRef(), Down);
  149. if (!output) {
  150. return {};
  151. }
  152. return MakeString(*output);
  153. }
  154. private:
  155. void RegisterDependencies() const final {
  156. this->DependsOn(Source);
  157. }
  158. IComputationNode* const Source;
  159. const bool Down;
  160. };
  161. template<typename From>
  162. IComputationNode* FromIntegral(TComputationMutables& mutables, IComputationNode* source, bool down, EDataSlot target) {
  163. switch (target) {
  164. case EDataSlot::Int8: return new TRoundIntegralWrapper<From, i8>(mutables, source, down);
  165. case EDataSlot::Uint8: return new TRoundIntegralWrapper<From, ui8>(mutables, source, down);
  166. case EDataSlot::Int16: return new TRoundIntegralWrapper<From, i16>(mutables, source, down);
  167. case EDataSlot::Uint16: return new TRoundIntegralWrapper<From, ui16>(mutables, source, down);
  168. case EDataSlot::Int32: return new TRoundIntegralWrapper<From, i32>(mutables, source, down);
  169. case EDataSlot::Uint32: return new TRoundIntegralWrapper<From, ui32>(mutables, source, down);
  170. case EDataSlot::Int64: return new TRoundIntegralWrapper<From, i64>(mutables, source, down);
  171. case EDataSlot::Uint64: return new TRoundIntegralWrapper<From, ui64>(mutables, source, down);
  172. default: Y_ENSURE(false, "Unsupported integral rounding");
  173. }
  174. return nullptr;
  175. }
  176. } // namespace
  177. IComputationNode* WrapRound(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  178. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expecting exactly one argument");
  179. auto type = callable.GetInput(0).GetStaticType();
  180. MKQL_ENSURE(type->IsData(), "Expecting data as argument");
  181. auto returnType = callable.GetType()->GetReturnType();
  182. MKQL_ENSURE(returnType->IsOptional(), "Expecting optional as return type");
  183. auto targetType = static_cast<TOptionalType*>(returnType)->GetItemType();
  184. MKQL_ENSURE(targetType->IsData(), "Expecting Data as target type");
  185. auto from = GetDataSlot(static_cast<TDataType*>(type)->GetSchemeType());
  186. auto to = GetDataSlot(static_cast<TDataType*>(targetType)->GetSchemeType());
  187. bool down = callable.GetType()->GetName() == "RoundDown";
  188. auto source = LocateNode(ctx.NodeLocator, callable, 0);
  189. switch (from) {
  190. case EDataSlot::Int8: return FromIntegral<i8>(ctx.Mutables, source, down, to);
  191. case EDataSlot::Uint8: return FromIntegral<ui8>(ctx.Mutables, source, down, to);
  192. case EDataSlot::Int16: return FromIntegral<i16>(ctx.Mutables, source, down, to);
  193. case EDataSlot::Uint16: return FromIntegral<ui16>(ctx.Mutables, source, down, to);
  194. case EDataSlot::Int32: return FromIntegral<i32>(ctx.Mutables, source, down, to);
  195. case EDataSlot::Uint32: return FromIntegral<ui32>(ctx.Mutables, source, down, to);
  196. case EDataSlot::Int64: return FromIntegral<i64>(ctx.Mutables, source, down, to);
  197. case EDataSlot::Uint64: return FromIntegral<ui64>(ctx.Mutables, source, down, to);
  198. case EDataSlot::Datetime:
  199. case EDataSlot::Timestamp:
  200. case EDataSlot::Date32: // From Date cases are covered in NYql::NTypeAnnImpl::RoundWrapper
  201. case EDataSlot::Datetime64:
  202. case EDataSlot::Timestamp64:
  203. Y_ENSURE(GetDataTypeInfo(to).Features & DateType);
  204. return new TRoundDateTypeWrapper(ctx.Mutables, source, down, from, to);
  205. case EDataSlot::String:
  206. Y_ENSURE(to == EDataSlot::Utf8);
  207. return new TRoundStringWrapper(ctx.Mutables, source, down);
  208. default:
  209. Y_ENSURE(false,
  210. "Unsupported rounding from " << GetDataTypeInfo(from).Name << " to " << GetDataTypeInfo(to).Name);
  211. }
  212. return nullptr;
  213. }
  214. }
  215. }