mkql_fromstring.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. #include "mkql_fromstring.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. #include <yql/essentials/minikql/mkql_node_builder.h>
  6. #include <yql/essentials/minikql/invoke_builtins/mkql_builtins_decimal.h> // Y_IGNORE
  7. #include <yql/essentials/public/udf/udf_terminator.h>
  8. #ifndef MKQL_DISABLE_CODEGEN
  9. Y_PRAGMA_DIAGNOSTIC_PUSH
  10. Y_PRAGMA("GCC diagnostic ignored \"-Wreturn-type-c-linkage\"")
  11. extern "C" NKikimr::NUdf::TUnboxedValuePod DataFromString(const NKikimr::NUdf::TUnboxedValuePod data, NKikimr::NUdf::EDataSlot slot) {
  12. return NKikimr::NMiniKQL::ValueFromString(slot, data.AsStringRef());
  13. }
  14. extern "C" NYql::NDecimal::TInt128 DecimalFromString(const NKikimr::NUdf::TUnboxedValuePod decimal, ui8 precision, ui8 scale) {
  15. return NYql::NDecimal::FromStringEx(decimal.AsStringRef(), precision, scale);
  16. }
  17. Y_PRAGMA_DIAGNOSTIC_POP
  18. #endif
  19. namespace NKikimr {
  20. namespace NMiniKQL {
  21. namespace {
  22. const unsigned ERROR_FRAGMENT_LIMIT = 5000;
  23. [[noreturn]]
  24. void ThrowConvertError(NYql::NUdf::TStringRef data, TStringBuf type) {
  25. TStringBuilder builder;
  26. builder << "could not convert \"";
  27. if (data.Size() < ERROR_FRAGMENT_LIMIT) {
  28. builder << data << "\"";
  29. } else {
  30. builder << TStringBuf(data.Data(), ERROR_FRAGMENT_LIMIT) << "\" (truncated)";
  31. }
  32. builder << " to " << type;
  33. UdfTerminate(builder.data());
  34. }
  35. template <bool IsStrict, bool IsOptional>
  36. class TDecimalFromStringWrapper : public TMutableCodegeneratorNode<TDecimalFromStringWrapper<IsStrict, IsOptional>> {
  37. typedef TMutableCodegeneratorNode<TDecimalFromStringWrapper<IsStrict, IsOptional>> TBaseComputation;
  38. public:
  39. TDecimalFromStringWrapper(TComputationMutables& mutables, IComputationNode* data, ui8 precision, ui8 scale)
  40. : TBaseComputation(mutables, EValueRepresentation::Embedded)
  41. , Data(data)
  42. , Precision(precision)
  43. , Scale(scale)
  44. {
  45. MKQL_ENSURE(precision > 0 && precision <= NYql::NDecimal::MaxPrecision, "Wrong precision.");
  46. MKQL_ENSURE(scale <= precision, "Wrong scale.");
  47. }
  48. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  49. const auto& data = Data->GetValue(ctx);
  50. if (IsOptional && !data) {
  51. return NUdf::TUnboxedValuePod();
  52. }
  53. if (const auto v = NYql::NDecimal::FromStringEx(data.AsStringRef(), Precision, Scale); !NYql::NDecimal::IsError(v)) {
  54. return NUdf::TUnboxedValuePod(v);
  55. }
  56. if constexpr (IsStrict) {
  57. Throw(data, Precision, Scale);
  58. } else {
  59. return NUdf::TUnboxedValuePod();
  60. }
  61. }
  62. #ifndef MKQL_DISABLE_CODEGEN
  63. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  64. auto& context = ctx.Codegen.GetContext();
  65. const auto valType = Type::getInt128Ty(context);
  66. const auto psType = Type::getInt8Ty(context);
  67. const auto name = "DecimalFromString";
  68. ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&DecimalFromString));
  69. const auto fnType =
  70. FunctionType::get(valType, { valType, psType, psType }, false);
  71. const auto func = ctx.Codegen.GetModule().getOrInsertFunction(name, fnType);
  72. const auto zero = ConstantInt::get(valType, 0ULL);
  73. const auto precision = ConstantInt::get(psType, Precision);
  74. const auto scale = ConstantInt::get(psType, Scale);
  75. const auto value = GetNodeValue(Data, ctx, block);
  76. const auto fail = BasicBlock::Create(context, "fail", ctx.Func);
  77. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  78. const auto ways = (IsOptional ? 1U : 0U) + (IsStrict ? 0U : 1U);
  79. const auto last = ways > 0U ? BasicBlock::Create(context, "last", ctx.Func) : nullptr;
  80. const auto phi = last ? PHINode::Create(valType, ways + 1U, "result", last) : nullptr;
  81. if constexpr (IsOptional) {
  82. phi->addIncoming(zero, block);
  83. const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, value, zero, "check", block);
  84. const auto call = BasicBlock::Create(context, "call", ctx.Func);
  85. BranchInst::Create(last, call, check, block);
  86. block = call;
  87. }
  88. const auto decimal = CallInst::Create(func, { value, precision, scale }, "from_string", block);
  89. if (Data->IsTemporaryValue())
  90. ValueCleanup(Data->GetRepresentation(), value, ctx, block);
  91. const auto test = NDecimal::GenIsError(decimal, context, block);
  92. BranchInst::Create(fail, good, test, block);
  93. {
  94. block = fail;
  95. if constexpr (IsStrict) {
  96. const auto doFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TDecimalFromStringWrapper::Throw));
  97. const auto doFuncType = FunctionType::get(Type::getVoidTy(context), {valType, psType, psType}, false);
  98. const auto doFuncPtr = CastInst::Create(Instruction::IntToPtr, doFunc, PointerType::getUnqual(doFuncType), "thrower", block);
  99. CallInst::Create(doFuncType, doFuncPtr, { value, precision, scale }, "", block);
  100. new UnreachableInst(context, block);
  101. } else {
  102. phi->addIncoming(zero, block);
  103. BranchInst::Create(last, block);
  104. }
  105. }
  106. block = good;
  107. if constexpr (IsOptional || !IsStrict) {
  108. phi->addIncoming(SetterForInt128(decimal, block), block);
  109. BranchInst::Create(last, block);
  110. block = last;
  111. return phi;
  112. } else {
  113. return SetterForInt128(decimal, block);
  114. }
  115. }
  116. #endif
  117. private:
  118. void RegisterDependencies() const final {
  119. this->DependsOn(Data);
  120. }
  121. [[noreturn]] static void Throw(const NUdf::TUnboxedValuePod data, ui8 precision, ui8 scale) {
  122. const TString type = TStringBuilder() << "Decimal(" << unsigned(precision) << ", " << unsigned(scale) << ")";
  123. ThrowConvertError(data.AsStringRef(), type);
  124. }
  125. IComputationNode* const Data;
  126. const ui8 Precision, Scale;
  127. };
  128. template <bool IsStrict, bool IsOptional>
  129. class TFromStringWrapper : public TMutableCodegeneratorNode<TFromStringWrapper<IsStrict, IsOptional>> {
  130. typedef TMutableCodegeneratorNode<TFromStringWrapper<IsStrict, IsOptional>> TBaseComputation;
  131. public:
  132. TFromStringWrapper(TComputationMutables& mutables, IComputationNode* data, NUdf::TDataTypeId schemeType)
  133. : TBaseComputation(mutables, GetValueRepresentation(schemeType))
  134. , Data(data)
  135. , SchemeType(NUdf::GetDataSlot(schemeType))
  136. {}
  137. NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const {
  138. const auto& data = Data->GetValue(ctx);
  139. if (IsOptional && !data) {
  140. return NUdf::TUnboxedValuePod();
  141. }
  142. if (const auto out = ValueFromString(SchemeType, data.AsStringRef())) {
  143. return out;
  144. }
  145. if constexpr (IsStrict) {
  146. Throw(data, SchemeType);
  147. } else {
  148. return NUdf::TUnboxedValuePod();
  149. }
  150. }
  151. #ifndef MKQL_DISABLE_CODEGEN
  152. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  153. auto& context = ctx.Codegen.GetContext();
  154. const auto valType = Type::getInt128Ty(context);
  155. const auto slotType = Type::getInt32Ty(context);
  156. const auto name = "DataFromString";
  157. ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&DataFromString));
  158. const auto fnType =
  159. FunctionType::get(valType, { valType, slotType }, false);
  160. const auto func = ctx.Codegen.GetModule().getOrInsertFunction(name, fnType);
  161. const auto zero = ConstantInt::get(valType, 0ULL);
  162. const auto slot = ConstantInt::get(slotType, static_cast<ui32>(SchemeType));
  163. const auto value = GetNodeValue(Data, ctx, block);
  164. const auto fail = IsStrict ? BasicBlock::Create(context, "fail", ctx.Func) : nullptr;
  165. const auto last = IsOptional || fail ? BasicBlock::Create(context, "last", ctx.Func) : nullptr;
  166. const auto phi = IsOptional ? PHINode::Create(valType, 2U, "result", last) : nullptr;
  167. if constexpr (IsOptional) {
  168. phi->addIncoming(zero, block);
  169. const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, value, zero, "check", block);
  170. const auto call = BasicBlock::Create(context, "call", ctx.Func);
  171. BranchInst::Create(last, call, check, block);
  172. block = call;
  173. }
  174. Value* data = CallInst::Create(func, { value, slot }, "from_string", block);
  175. if (Data->IsTemporaryValue())
  176. ValueCleanup(Data->GetRepresentation(), value, ctx, block);
  177. if constexpr (IsOptional) {
  178. phi->addIncoming(data, block);
  179. }
  180. if constexpr (IsStrict) {
  181. const auto test = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, data, zero, "test", block);
  182. BranchInst::Create(fail, last, test, block);
  183. block = fail;
  184. const auto doFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TFromStringWrapper::Throw));
  185. const auto doFuncType = FunctionType::get(Type::getVoidTy(context), {valType, slotType}, false);
  186. const auto doFuncPtr = CastInst::Create(Instruction::IntToPtr, doFunc, PointerType::getUnqual(doFuncType), "thrower", block);
  187. CallInst::Create(doFuncType, doFuncPtr, { value, slot }, "", block);
  188. new UnreachableInst(context, block);
  189. } else if constexpr (IsOptional) {
  190. BranchInst::Create(last, block);
  191. }
  192. if constexpr (IsOptional || IsStrict) {
  193. block = last;
  194. }
  195. return IsOptional ? phi : data;
  196. }
  197. #endif
  198. private:
  199. void RegisterDependencies() const final {
  200. this->DependsOn(Data);
  201. }
  202. [[noreturn]] static void Throw(const NUdf::TUnboxedValuePod data, NUdf::EDataSlot slot) {
  203. ThrowConvertError(data.AsStringRef(), NUdf::GetDataTypeInfo(slot).Name);
  204. }
  205. IComputationNode* const Data;
  206. const NUdf::EDataSlot SchemeType;
  207. };
  208. }
  209. IComputationNode* WrapFromString(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  210. MKQL_ENSURE(callable.GetInputsCount() >= 2, "Expected 2 args");
  211. bool isOptional;
  212. const auto dataType = UnpackOptionalData(callable.GetInput(0), isOptional);
  213. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<char*>::Id || dataType->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected String");
  214. const auto schemeTypeData = AS_VALUE(TDataLiteral, callable.GetInput(1));
  215. const auto schemeType = schemeTypeData->AsValue().Get<ui32>();
  216. const auto data = LocateNode(ctx.NodeLocator, callable, 0);
  217. if (NUdf::TDataType<NUdf::TDecimal>::Id == schemeType) {
  218. MKQL_ENSURE(callable.GetInputsCount() == 4, "Expected 4 args");
  219. const auto precision = AS_VALUE(TDataLiteral, callable.GetInput(2))->AsValue().Get<ui8>();
  220. const auto scale = AS_VALUE(TDataLiteral, callable.GetInput(3))->AsValue().Get<ui8>();
  221. if (isOptional) {
  222. return new TDecimalFromStringWrapper<false, true>(ctx.Mutables, data, precision, scale);
  223. } else {
  224. return new TDecimalFromStringWrapper<false, false>(ctx.Mutables, data, precision, scale);
  225. }
  226. } else {
  227. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
  228. if (isOptional) {
  229. return new TFromStringWrapper<false, true>(ctx.Mutables, data, static_cast<NUdf::TDataTypeId>(schemeType));
  230. } else {
  231. return new TFromStringWrapper<false, false>(ctx.Mutables, data, static_cast<NUdf::TDataTypeId>(schemeType));
  232. }
  233. }
  234. }
  235. IComputationNode* WrapStrictFromString(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  236. MKQL_ENSURE(callable.GetInputsCount() >= 2, "Expected 2 args");
  237. bool isOptional;
  238. const auto dataType = UnpackOptionalData(callable.GetInput(0), isOptional);
  239. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<char*>::Id || dataType->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected String");
  240. const auto schemeTypeData = AS_VALUE(TDataLiteral, callable.GetInput(1));
  241. const auto schemeType = schemeTypeData->AsValue().Get<ui32>();
  242. const auto data = LocateNode(ctx.NodeLocator, callable, 0);
  243. if (NUdf::TDataType<NUdf::TDecimal>::Id == schemeType) {
  244. MKQL_ENSURE(callable.GetInputsCount() == 4, "Expected 4 args");
  245. const auto precision = AS_VALUE(TDataLiteral, callable.GetInput(2))->AsValue().Get<ui8>();
  246. const auto scale = AS_VALUE(TDataLiteral, callable.GetInput(3))->AsValue().Get<ui8>();
  247. if (isOptional) {
  248. return new TDecimalFromStringWrapper<true, true>(ctx.Mutables, data, precision, scale);
  249. } else {
  250. return new TDecimalFromStringWrapper<true, false>(ctx.Mutables, data, precision, scale);
  251. }
  252. } else {
  253. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
  254. if (isOptional) {
  255. return new TFromStringWrapper<true, true>(ctx.Mutables, data, static_cast<NUdf::TDataTypeId>(schemeType));
  256. } else {
  257. return new TFromStringWrapper<true, false>(ctx.Mutables, data, static_cast<NUdf::TDataTypeId>(schemeType));
  258. }
  259. }
  260. }
  261. }
  262. }