dsv_udf.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. #include <yql/essentials/public/udf/udf_helpers.h>
  2. #include <yql/essentials/public/udf/udf_value_builder.h>
  3. #include <yql/essentials/public/udf/udf_type_inspection.h>
  4. #include <util/generic/yexception.h>
  5. #include <library/cpp/deprecated/split/split_iterator.h>
  6. #include <util/string/vector.h>
  7. using namespace NKikimr;
  8. using namespace NUdf;
  9. namespace {
  10. struct TKsvIndexes
  11. {
  12. ui32 key;
  13. ui32 subkey;
  14. ui32 value;
  15. };
  16. struct TResultIndexes
  17. {
  18. TType* DictType;
  19. ui32 key;
  20. ui32 subkey;
  21. ui32 dict;
  22. static constexpr ui32 FieldsCount = 3U;
  23. };
  24. void ParseDsv(const TUnboxedValuePod& value,
  25. const std::string_view& separator,
  26. const IValueBuilder* valueBuilder,
  27. IDictValueBuilder* builder) {
  28. const std::string_view input(value.AsStringRef());
  29. const std::vector<std::string_view> parts = StringSplitter(input).SplitByString(separator);
  30. for (const auto& part : parts) {
  31. const auto pos = part.find('=');
  32. if (std::string_view::npos != pos) {
  33. const auto from = std::distance(input.begin(), part.begin());
  34. builder->Add(
  35. valueBuilder->SubString(value, from, pos),
  36. valueBuilder->SubString(value, from + pos + 1U, part.length() - pos - 1U)
  37. );
  38. }
  39. }
  40. }
  41. class TDsvReadRecord: public TBoxedValue
  42. {
  43. public:
  44. class TFactory : public TBoxedValue {
  45. public:
  46. TFactory(const TResultIndexes& fieldIndexes,
  47. const TKsvIndexes& ksvIndexes)
  48. : ResultIndexes_(fieldIndexes)
  49. , KsvIndexes_(ksvIndexes)
  50. {
  51. }
  52. private:
  53. TUnboxedValue Run(
  54. const IValueBuilder* valueBuilder,
  55. const TUnboxedValuePod* args) const final try
  56. {
  57. const auto optRunConfig = args[0];
  58. TUnboxedValue separator;
  59. if (optRunConfig && !optRunConfig.AsStringRef().Empty()) {
  60. separator = optRunConfig;
  61. } else {
  62. separator = valueBuilder->NewString("\t");
  63. }
  64. return TUnboxedValuePod(new TDsvReadRecord(separator, ResultIndexes_, KsvIndexes_));
  65. }
  66. catch (const std::exception& e) {
  67. UdfTerminate(e.what());
  68. }
  69. const TResultIndexes ResultIndexes_;
  70. const TKsvIndexes KsvIndexes_;
  71. };
  72. explicit TDsvReadRecord(const TUnboxedValue& separator,
  73. const TResultIndexes& fieldIndexes,
  74. const TKsvIndexes& ksvIndexes)
  75. : Separator_(std::move(separator))
  76. , ResultIndexes_(fieldIndexes)
  77. , KsvIndexes_(ksvIndexes)
  78. {
  79. }
  80. private:
  81. TUnboxedValue Run(
  82. const IValueBuilder* valueBuilder,
  83. const TUnboxedValuePod* args) const final try
  84. {
  85. auto keyData = args[0].GetElement(KsvIndexes_.key);
  86. auto subkeyData = args[0].GetElement(KsvIndexes_.subkey);
  87. auto valueData = args[0].GetElement(KsvIndexes_.value);
  88. auto dict = valueBuilder->NewDict(ResultIndexes_.DictType, 0);
  89. ParseDsv(valueData, Separator_.AsStringRef(), valueBuilder, dict.Get());
  90. TUnboxedValue* items = nullptr;
  91. const auto result = valueBuilder->NewArray(ResultIndexes_.FieldsCount, items);
  92. items[ResultIndexes_.key] = keyData;
  93. items[ResultIndexes_.subkey] = subkeyData;
  94. items[ResultIndexes_.dict] = dict->Build();
  95. return result;
  96. }
  97. catch (const std::exception& e) {
  98. UdfTerminate(e.what());
  99. }
  100. const TUnboxedValue Separator_;
  101. const TResultIndexes ResultIndexes_;
  102. const TKsvIndexes KsvIndexes_;
  103. };
  104. class TDsvParse: public TBoxedValue
  105. {
  106. public:
  107. explicit TDsvParse(TType* dictType)
  108. : DictType(dictType)
  109. {}
  110. private:
  111. TUnboxedValue Run(
  112. const IValueBuilder* valueBuilder,
  113. const TUnboxedValuePod* args) const final try
  114. {
  115. const std::string_view separator = args[1] ?
  116. std::string_view(args[1].AsStringRef()):
  117. std::string_view("\t");
  118. auto dict = valueBuilder->NewDict(DictType, 0);
  119. ParseDsv(args[0], separator, valueBuilder, dict.Get());
  120. return dict->Build();
  121. }
  122. catch (const std::exception& e) {
  123. UdfTerminate(e.what());
  124. }
  125. const TType* DictType;
  126. };
  127. #define TYPE_TO_STRING(type) \
  128. case TDataType<type>::Id: part += ToString(member.Get<type>()); break;
  129. class TDsvSerialize: public TBoxedValue
  130. {
  131. public:
  132. explicit TDsvSerialize(const TVector<TDataTypeId>& typeIds, TStructTypeInspector* structInspector)
  133. : TypeIds(typeIds)
  134. , StructInspector(structInspector)
  135. {}
  136. private:
  137. TUnboxedValue Run(
  138. const IValueBuilder* valueBuilder,
  139. const TUnboxedValuePod* args) const final try
  140. {
  141. TVector<TString> result;
  142. if (const ui32 structSize = StructInspector->GetMembersCount()) {
  143. result.reserve(structSize);
  144. for (ui32 i = 0; i < structSize; ++i) {
  145. auto part = TString(StructInspector->GetMemberName(i));
  146. part += '=';
  147. const TUnboxedValue& member = args[0].GetElement(i);
  148. switch (TypeIds[i]) {
  149. TYPE_TO_STRING(i32)
  150. TYPE_TO_STRING(ui32)
  151. TYPE_TO_STRING(i64)
  152. TYPE_TO_STRING(ui64)
  153. TYPE_TO_STRING(ui8)
  154. TYPE_TO_STRING(bool)
  155. TYPE_TO_STRING(double)
  156. TYPE_TO_STRING(float)
  157. default:
  158. part += member.AsStringRef();
  159. break;
  160. }
  161. result.emplace_back(std::move(part));
  162. }
  163. }
  164. return valueBuilder->NewString(JoinStrings(result, "\t"));
  165. }
  166. catch (const std::exception& e) {
  167. UdfTerminate(e.what());
  168. }
  169. const TVector<TDataTypeId> TypeIds;
  170. THolder<TStructTypeInspector> StructInspector;
  171. };
  172. class TDsvModule: public IUdfModule
  173. {
  174. public:
  175. TStringRef Name() const {
  176. return TStringRef::Of("Dsv");
  177. }
  178. void CleanupOnTerminate() const final {}
  179. void GetAllFunctions(IFunctionsSink& sink) const final {
  180. sink.Add(TStringRef::Of("ReadRecord"));
  181. sink.Add(TStringRef::Of("Parse"));
  182. sink.Add(TStringRef::Of("Serialize"))->SetTypeAwareness();
  183. }
  184. void BuildFunctionTypeInfo(
  185. const TStringRef& name,
  186. TType* userType,
  187. const TStringRef& typeConfig,
  188. ui32 flags,
  189. IFunctionTypeInfoBuilder& builder) const final try
  190. {
  191. Y_UNUSED(typeConfig);
  192. bool typesOnly = (flags & TFlags::TypesOnly);
  193. if (TStringRef::Of("ReadRecord") == name) {
  194. TKsvIndexes ksvIndexes;
  195. auto recordType = builder.Struct(3U)->
  196. AddField<char*>("key", &ksvIndexes.key)
  197. .AddField<char*>("subkey", &ksvIndexes.subkey)
  198. .AddField<char*>("value", &ksvIndexes.value)
  199. .Build();
  200. TResultIndexes resultIndexes;
  201. resultIndexes.DictType = builder.Dict()->Key<char*>().Value<char*>().Build();
  202. const auto structType = builder.Struct(resultIndexes.FieldsCount)
  203. ->AddField<char*>("key", &resultIndexes.key)
  204. .AddField<char*>("subkey", &resultIndexes.subkey)
  205. .AddField("dict", resultIndexes.DictType, &resultIndexes.dict)
  206. .Build();
  207. builder.Returns(structType)
  208. .Args()->Add(recordType).Done()
  209. .RunConfig<TOptional<char*>>();
  210. if (!typesOnly) {
  211. builder.Implementation(new TDsvReadRecord::TFactory(
  212. resultIndexes, ksvIndexes));
  213. }
  214. builder.IsStrict();
  215. } else if (TStringRef::Of("Parse") == name) {
  216. auto optionalStringType = builder.Optional()->Item<char*>().Build();
  217. auto dictType = builder.Dict()->Key<char*>().Value<char*>().Build();
  218. builder.Returns(dictType)
  219. .Args()->Add<char*>().Flags(ICallablePayload::TArgumentFlags::AutoMap).Add(optionalStringType).Done()
  220. .OptionalArgs(1);
  221. if (!typesOnly) {
  222. builder.Implementation(new TDsvParse(dictType));
  223. }
  224. builder.IsStrict();
  225. } else if (TStringRef::Of("Serialize") == name) {
  226. auto typeHelper = builder.TypeInfoHelper();
  227. auto userTypeInspector = TTupleTypeInspector(*typeHelper, userType);
  228. if (!userTypeInspector || userTypeInspector.GetElementsCount() < 1) {
  229. builder.SetError("Expected user type");
  230. return;
  231. }
  232. auto argsTypeTuple = userTypeInspector.GetElementType(0);
  233. auto argsTypeInspector = TTupleTypeInspector(*typeHelper, argsTypeTuple);
  234. if (!(argsTypeInspector && argsTypeInspector.GetElementsCount() == 1)) {
  235. builder.SetError("Only one argument is expected " + ToString(argsTypeInspector.GetElementsCount()));
  236. return;
  237. }
  238. TVector<TDataTypeId> typeIds;
  239. const auto structType = argsTypeInspector.GetElementType(0);
  240. THolder<TStructTypeInspector> structInspector(new TStructTypeInspector(*typeHelper, structType));
  241. if (structInspector) {
  242. ui32 memberCount = structInspector->GetMembersCount();
  243. typeIds.reserve(memberCount);
  244. if (memberCount) {
  245. for (ui32 i = 0; i < memberCount; ++i) {
  246. const TString memberName(structInspector->GetMemberName(i));
  247. const auto memberType = structInspector->GetMemberType(i);
  248. auto memberInspector = TDataTypeInspector(*typeHelper, memberType);
  249. if (!memberInspector) {
  250. builder.SetError("Only DataType members are supported at the moment, failed at " + memberName);
  251. return;
  252. }
  253. typeIds.push_back(memberInspector.GetTypeId());
  254. }
  255. } else {
  256. builder.SetError("Zero members in input Struct");
  257. return;
  258. }
  259. } else {
  260. builder.SetError("Only Structs are supported at the moment");
  261. return;
  262. }
  263. builder.UserType(userType).Returns<char*>().Args()->Add(structType).Done();
  264. if (!typesOnly) {
  265. builder.Implementation(new TDsvSerialize(typeIds, structInspector.Release()));
  266. }
  267. builder.IsStrict();
  268. }
  269. } catch (const std::exception& e) {
  270. builder.SetError(CurrentExceptionMessage());
  271. }
  272. };
  273. } // namespace
  274. REGISTER_MODULES(TDsvModule)