mkql_fromstring.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. #include "mkql_fromstring.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. #include <yql/essentials/minikql/mkql_node_builder.h>
  6. #include <yql/essentials/minikql/invoke_builtins/mkql_builtins_decimal.h> // Y_IGNORE
  7. #include <yql/essentials/public/udf/udf_terminator.h>
  8. #ifndef MKQL_DISABLE_CODEGEN
  9. Y_PRAGMA_DIAGNOSTIC_PUSH
  10. Y_PRAGMA("GCC diagnostic ignored \"-Wreturn-type-c-linkage\"")
  11. extern "C" NKikimr::NUdf::TUnboxedValuePod DataFromString(const NKikimr::NUdf::TUnboxedValuePod data, NKikimr::NUdf::EDataSlot slot) {
  12. return NKikimr::NMiniKQL::ValueFromString(slot, data.AsStringRef());
  13. }
  14. extern "C" NYql::NDecimal::TInt128 DecimalFromString(const NKikimr::NUdf::TUnboxedValuePod decimal, ui8 precision, ui8 scale) {
  15. return NYql::NDecimal::FromStringEx(decimal.AsStringRef(), precision, scale);
  16. }
  17. Y_PRAGMA_DIAGNOSTIC_POP
  18. #endif
  19. namespace NKikimr {
  20. namespace NMiniKQL {
  21. namespace {
  22. const unsigned ERROR_FRAGMENT_LIMIT = 5000;
  23. [[noreturn]]
  24. void ThrowConvertError(NYql::NUdf::TStringRef data, TStringBuf type) {
  25. TStringBuilder builder;
  26. builder << "could not convert \"";
  27. if (data.Size() < ERROR_FRAGMENT_LIMIT) {
  28. builder << data << "\"";
  29. } else {
  30. builder << TStringBuf(data.Data(), ERROR_FRAGMENT_LIMIT) << "\" (truncated)";
  31. }
  32. builder << " to " << type;
  33. UdfTerminate(builder.data());
  34. }
  35. template <bool IsStrict, bool IsOptional>
  36. class TDecimalFromStringWrapper : public TMutableCodegeneratorNode<TDecimalFromStringWrapper<IsStrict, IsOptional>> {
  37. typedef TMutableCodegeneratorNode<TDecimalFromStringWrapper<IsStrict, IsOptional>> TBaseComputation;
  38. public:
  39. TDecimalFromStringWrapper(TComputationMutables& mutables, IComputationNode* data, ui8 precision, ui8 scale)
  40. : TBaseComputation(mutables, EValueRepresentation::Embedded)
  41. , Data(data)
  42. , Precision(precision)
  43. , Scale(scale)
  44. {
  45. MKQL_ENSURE(precision > 0 && precision <= NYql::NDecimal::MaxPrecision, "Wrong precision.");
  46. MKQL_ENSURE(scale <= precision, "Wrong scale.");
  47. }
  48. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  49. const auto& data = Data->GetValue(ctx);
  50. if (IsOptional && !data) {
  51. return NUdf::TUnboxedValuePod();
  52. }
  53. if (const auto v = NYql::NDecimal::FromStringEx(data.AsStringRef(), Precision, Scale); !NYql::NDecimal::IsError(v)) {
  54. return NUdf::TUnboxedValuePod(v);
  55. }
  56. if constexpr (IsStrict) {
  57. Throw(data, Precision, Scale);
  58. } else {
  59. return NUdf::TUnboxedValuePod();
  60. }
  61. }
  62. #ifndef MKQL_DISABLE_CODEGEN
  63. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  64. auto& context = ctx.Codegen.GetContext();
  65. const auto valType = Type::getInt128Ty(context);
  66. const auto psType = Type::getInt8Ty(context);
  67. const auto valTypePtr = PointerType::getUnqual(valType);
  68. const auto name = "DecimalFromString";
  69. ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&DecimalFromString));
  70. const auto fnType = NYql::NCodegen::ETarget::Windows != ctx.Codegen.GetEffectiveTarget() ?
  71. FunctionType::get(valType, { valType, psType, psType }, false):
  72. FunctionType::get(Type::getVoidTy(context), { valTypePtr, valTypePtr, psType, psType }, false);
  73. const auto func = ctx.Codegen.GetModule().getOrInsertFunction(name, fnType);
  74. const auto zero = ConstantInt::get(valType, 0ULL);
  75. const auto precision = ConstantInt::get(psType, Precision);
  76. const auto scale = ConstantInt::get(psType, Scale);
  77. const auto value = GetNodeValue(Data, ctx, block);
  78. const auto fail = BasicBlock::Create(context, "fail", ctx.Func);
  79. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  80. const auto ways = (IsOptional ? 1U : 0U) + (IsStrict ? 0U : 1U);
  81. const auto last = ways > 0U ? BasicBlock::Create(context, "last", ctx.Func) : nullptr;
  82. const auto phi = last ? PHINode::Create(valType, ways + 1U, "result", last) : nullptr;
  83. if constexpr (IsOptional) {
  84. phi->addIncoming(zero, block);
  85. const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, value, zero, "check", block);
  86. const auto call = BasicBlock::Create(context, "call", ctx.Func);
  87. BranchInst::Create(last, call, check, block);
  88. block = call;
  89. }
  90. Value* decimal;
  91. if (NYql::NCodegen::ETarget::Windows != ctx.Codegen.GetEffectiveTarget()) {
  92. decimal = CallInst::Create(func, { value, precision, scale }, "from_string", block);
  93. } else {
  94. const auto retPtr = new AllocaInst(valType, 0U, "ret_ptr", block);
  95. new StoreInst(value, retPtr, block);
  96. CallInst::Create(func, { retPtr, retPtr, precision, scale }, "", block);
  97. decimal = new LoadInst(valType, retPtr, "res", block);
  98. }
  99. if (Data->IsTemporaryValue())
  100. ValueCleanup(Data->GetRepresentation(), value, ctx, block);
  101. const auto test = NDecimal::GenIsError(decimal, context, block);
  102. BranchInst::Create(fail, good, test, block);
  103. {
  104. block = fail;
  105. if constexpr (IsStrict) {
  106. const auto doFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TDecimalFromStringWrapper::Throw));
  107. const auto doFuncType = FunctionType::get(Type::getVoidTy(context), {valType, psType, psType}, false);
  108. const auto doFuncPtr = CastInst::Create(Instruction::IntToPtr, doFunc, PointerType::getUnqual(doFuncType), "thrower", block);
  109. CallInst::Create(doFuncType, doFuncPtr, { value, precision, scale }, "", block);
  110. new UnreachableInst(context, block);
  111. } else {
  112. phi->addIncoming(zero, block);
  113. BranchInst::Create(last, block);
  114. }
  115. }
  116. block = good;
  117. if constexpr (IsOptional || !IsStrict) {
  118. phi->addIncoming(SetterForInt128(decimal, block), block);
  119. BranchInst::Create(last, block);
  120. block = last;
  121. return phi;
  122. } else {
  123. return SetterForInt128(decimal, block);
  124. }
  125. }
  126. #endif
  127. private:
  128. void RegisterDependencies() const final {
  129. this->DependsOn(Data);
  130. }
  131. [[noreturn]] static void Throw(const NUdf::TUnboxedValuePod data, ui8 precision, ui8 scale) {
  132. const TString type = TStringBuilder() << "Decimal(" << unsigned(precision) << ", " << unsigned(scale) << ")";
  133. ThrowConvertError(data.AsStringRef(), type);
  134. }
  135. IComputationNode* const Data;
  136. const ui8 Precision, Scale;
  137. };
  138. template <bool IsStrict, bool IsOptional>
  139. class TFromStringWrapper : public TMutableCodegeneratorNode<TFromStringWrapper<IsStrict, IsOptional>> {
  140. typedef TMutableCodegeneratorNode<TFromStringWrapper<IsStrict, IsOptional>> TBaseComputation;
  141. public:
  142. TFromStringWrapper(TComputationMutables& mutables, IComputationNode* data, NUdf::TDataTypeId schemeType)
  143. : TBaseComputation(mutables, GetValueRepresentation(schemeType))
  144. , Data(data)
  145. , SchemeType(NUdf::GetDataSlot(schemeType))
  146. {}
  147. NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const {
  148. const auto& data = Data->GetValue(ctx);
  149. if (IsOptional && !data) {
  150. return NUdf::TUnboxedValuePod();
  151. }
  152. if (const auto out = ValueFromString(SchemeType, data.AsStringRef())) {
  153. return out;
  154. }
  155. if constexpr (IsStrict) {
  156. Throw(data, SchemeType);
  157. } else {
  158. return NUdf::TUnboxedValuePod();
  159. }
  160. }
  161. #ifndef MKQL_DISABLE_CODEGEN
  162. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  163. auto& context = ctx.Codegen.GetContext();
  164. const auto valType = Type::getInt128Ty(context);
  165. const auto slotType = Type::getInt32Ty(context);
  166. const auto valTypePtr = PointerType::getUnqual(valType);
  167. const auto name = "DataFromString";
  168. ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&DataFromString));
  169. const auto fnType = NYql::NCodegen::ETarget::Windows != ctx.Codegen.GetEffectiveTarget() ?
  170. FunctionType::get(valType, { valType, slotType }, false):
  171. FunctionType::get(Type::getVoidTy(context), { valTypePtr, valTypePtr, slotType }, false);
  172. const auto func = ctx.Codegen.GetModule().getOrInsertFunction(name, fnType);
  173. const auto zero = ConstantInt::get(valType, 0ULL);
  174. const auto slot = ConstantInt::get(slotType, static_cast<ui32>(SchemeType));
  175. const auto value = GetNodeValue(Data, ctx, block);
  176. const auto fail = IsStrict ? BasicBlock::Create(context, "fail", ctx.Func) : nullptr;
  177. const auto last = IsOptional || fail ? BasicBlock::Create(context, "last", ctx.Func) : nullptr;
  178. const auto phi = IsOptional ? PHINode::Create(valType, 2U, "result", last) : nullptr;
  179. if constexpr (IsOptional) {
  180. phi->addIncoming(zero, block);
  181. const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, value, zero, "check", block);
  182. const auto call = BasicBlock::Create(context, "call", ctx.Func);
  183. BranchInst::Create(last, call, check, block);
  184. block = call;
  185. }
  186. Value* data;
  187. if (NYql::NCodegen::ETarget::Windows != ctx.Codegen.GetEffectiveTarget()) {
  188. data = CallInst::Create(func, { value, slot }, "from_string", block);
  189. } else {
  190. const auto retPtr = new AllocaInst(valType, 0U, "ret_ptr", block);
  191. new StoreInst(value, retPtr, block);
  192. CallInst::Create(func, { retPtr, retPtr, slot }, "", block);
  193. data = new LoadInst(valType, retPtr, "res", block);
  194. }
  195. if (Data->IsTemporaryValue())
  196. ValueCleanup(Data->GetRepresentation(), value, ctx, block);
  197. if constexpr (IsOptional) {
  198. phi->addIncoming(data, block);
  199. }
  200. if constexpr (IsStrict) {
  201. const auto test = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, data, zero, "test", block);
  202. BranchInst::Create(fail, last, test, block);
  203. block = fail;
  204. const auto doFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TFromStringWrapper::Throw));
  205. const auto doFuncType = FunctionType::get(Type::getVoidTy(context), {valType, slotType}, false);
  206. const auto doFuncPtr = CastInst::Create(Instruction::IntToPtr, doFunc, PointerType::getUnqual(doFuncType), "thrower", block);
  207. CallInst::Create(doFuncType, doFuncPtr, { value, slot }, "", block);
  208. new UnreachableInst(context, block);
  209. } else if constexpr (IsOptional) {
  210. BranchInst::Create(last, block);
  211. }
  212. if constexpr (IsOptional || IsStrict) {
  213. block = last;
  214. }
  215. return IsOptional ? phi : data;
  216. }
  217. #endif
  218. private:
  219. void RegisterDependencies() const final {
  220. this->DependsOn(Data);
  221. }
  222. [[noreturn]] static void Throw(const NUdf::TUnboxedValuePod data, NUdf::EDataSlot slot) {
  223. ThrowConvertError(data.AsStringRef(), NUdf::GetDataTypeInfo(slot).Name);
  224. }
  225. IComputationNode* const Data;
  226. const NUdf::EDataSlot SchemeType;
  227. };
  228. }
  229. IComputationNode* WrapFromString(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  230. MKQL_ENSURE(callable.GetInputsCount() >= 2, "Expected 2 args");
  231. bool isOptional;
  232. const auto dataType = UnpackOptionalData(callable.GetInput(0), isOptional);
  233. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<char*>::Id || dataType->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected String");
  234. const auto schemeTypeData = AS_VALUE(TDataLiteral, callable.GetInput(1));
  235. const auto schemeType = schemeTypeData->AsValue().Get<ui32>();
  236. const auto data = LocateNode(ctx.NodeLocator, callable, 0);
  237. if (NUdf::TDataType<NUdf::TDecimal>::Id == schemeType) {
  238. MKQL_ENSURE(callable.GetInputsCount() == 4, "Expected 4 args");
  239. const auto precision = AS_VALUE(TDataLiteral, callable.GetInput(2))->AsValue().Get<ui8>();
  240. const auto scale = AS_VALUE(TDataLiteral, callable.GetInput(3))->AsValue().Get<ui8>();
  241. if (isOptional) {
  242. return new TDecimalFromStringWrapper<false, true>(ctx.Mutables, data, precision, scale);
  243. } else {
  244. return new TDecimalFromStringWrapper<false, false>(ctx.Mutables, data, precision, scale);
  245. }
  246. } else {
  247. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
  248. if (isOptional) {
  249. return new TFromStringWrapper<false, true>(ctx.Mutables, data, static_cast<NUdf::TDataTypeId>(schemeType));
  250. } else {
  251. return new TFromStringWrapper<false, false>(ctx.Mutables, data, static_cast<NUdf::TDataTypeId>(schemeType));
  252. }
  253. }
  254. }
  255. IComputationNode* WrapStrictFromString(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  256. MKQL_ENSURE(callable.GetInputsCount() >= 2, "Expected 2 args");
  257. bool isOptional;
  258. const auto dataType = UnpackOptionalData(callable.GetInput(0), isOptional);
  259. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<char*>::Id || dataType->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected String");
  260. const auto schemeTypeData = AS_VALUE(TDataLiteral, callable.GetInput(1));
  261. const auto schemeType = schemeTypeData->AsValue().Get<ui32>();
  262. const auto data = LocateNode(ctx.NodeLocator, callable, 0);
  263. if (NUdf::TDataType<NUdf::TDecimal>::Id == schemeType) {
  264. MKQL_ENSURE(callable.GetInputsCount() == 4, "Expected 4 args");
  265. const auto precision = AS_VALUE(TDataLiteral, callable.GetInput(2))->AsValue().Get<ui8>();
  266. const auto scale = AS_VALUE(TDataLiteral, callable.GetInput(3))->AsValue().Get<ui8>();
  267. if (isOptional) {
  268. return new TDecimalFromStringWrapper<true, true>(ctx.Mutables, data, precision, scale);
  269. } else {
  270. return new TDecimalFromStringWrapper<true, false>(ctx.Mutables, data, precision, scale);
  271. }
  272. } else {
  273. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
  274. if (isOptional) {
  275. return new TFromStringWrapper<true, true>(ctx.Mutables, data, static_cast<NUdf::TDataTypeId>(schemeType));
  276. } else {
  277. return new TFromStringWrapper<true, false>(ctx.Mutables, data, static_cast<NUdf::TDataTypeId>(schemeType));
  278. }
  279. }
  280. }
  281. }
  282. }