arrow.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. #include "pg_compat.h"
  2. #include "arrow.h"
  3. #include "arrow_impl.h"
  4. #include <yql/essentials/minikql/defs.h>
  5. #include <yql/essentials/parser/pg_wrapper/interface/arrow.h>
  6. #include <yql/essentials/parser/pg_wrapper/interface/utils.h>
  7. #include <yql/essentials/minikql/mkql_node_cast.h>
  8. #include <yql/essentials/minikql/arrow/arrow_util.h>
  9. #include <yql/essentials/types/dynumber/dynumber.h>
  10. #include <yql/essentials/public/decimal/yql_decimal.h>
  11. #include <util/generic/singleton.h>
  12. #include <arrow/compute/cast.h>
  13. #include <arrow/array.h>
  14. #include <arrow/array/builder_binary.h>
  15. #include <util/system/mutex.h>
  16. extern "C" {
  17. #include "utils/date.h"
  18. #include "utils/timestamp.h"
  19. #include "utils/fmgrprotos.h"
  20. }
  21. namespace NYql {
  22. extern "C" {
  23. Y_PRAGMA_DIAGNOSTIC_PUSH
  24. Y_PRAGMA("GCC diagnostic ignored \"-Wreturn-type-c-linkage\"")
  25. #include "pg_kernels_fwd.inc"
  26. Y_PRAGMA_DIAGNOSTIC_POP
  27. }
  28. struct TExecs {
  29. static TExecs& Instance() {
  30. return *Singleton<TExecs>();
  31. }
  32. TExecs();
  33. THashMap<Oid, TExecFunc> Table;
  34. };
  35. TExecFunc FindExec(Oid oid) {
  36. const auto& table = TExecs::Instance().Table;
  37. auto it = table.find(oid);
  38. if (it == table.end()) {
  39. return nullptr;
  40. }
  41. return it->second;
  42. }
  43. bool HasPgKernel(ui32 procOid) {
  44. return FindExec(procOid) != nullptr;
  45. }
  46. TExecs::TExecs()
  47. {
  48. #define RegisterExec(oid, func) Table[oid] = func
  49. #include "pg_kernels_register.all.inc"
  50. #undef RegisterExec
  51. }
  52. const NPg::TAggregateDesc& ResolveAggregation(const TString& name, NKikimr::NMiniKQL::TTupleType* tupleType,
  53. const std::vector<ui32>& argsColumns, NKikimr::NMiniKQL::TType* returnType, ui32 hint) {
  54. using namespace NKikimr::NMiniKQL;
  55. if (returnType) {
  56. MKQL_ENSURE(argsColumns.size() == 1, "Expected one column");
  57. TType* stateType = AS_TYPE(TBlockType, tupleType->GetElementType(argsColumns[0]))->GetItemType();
  58. TType* returnItemType = AS_TYPE(TBlockType, returnType)->GetItemType();
  59. return NPg::LookupAggregation(name + "#" + ToString(hint), AS_TYPE(TPgType, stateType)->GetTypeId(), AS_TYPE(TPgType, returnItemType)->GetTypeId());
  60. } else {
  61. TVector<ui32> argTypeIds;
  62. for (const auto col : argsColumns) {
  63. argTypeIds.push_back(AS_TYPE(TPgType, AS_TYPE(TBlockType, tupleType->GetElementType(col))->GetItemType())->GetTypeId());
  64. }
  65. return NPg::LookupAggregation(name, argTypeIds);
  66. }
  67. }
  68. std::shared_ptr<arrow::Array> PgConvertBool(const std::shared_ptr<arrow::Array>& value) {
  69. const auto& data = value->data();
  70. size_t length = data->length;
  71. NUdf::TFixedSizeArrayBuilder<ui64, false> builder(NKikimr::NMiniKQL::TTypeInfoHelper(), arrow::uint64(), *arrow::default_memory_pool(), length);
  72. auto input = data->GetValues<ui8>(1, 0);
  73. builder.UnsafeReserve(length);
  74. auto output = builder.MutableData();
  75. for (size_t i = 0; i < length; ++i) {
  76. auto fullIndex = i + data->offset;
  77. output[i] = BoolGetDatum(arrow::BitUtil::GetBit(input, fullIndex));
  78. }
  79. auto dataBuffer = builder.Build(true).array()->buffers[1];
  80. return arrow::MakeArray(arrow::ArrayData::Make(arrow::uint64(), length, { data->buffers[0], dataBuffer }));
  81. }
  82. template <typename T, typename F>
  83. std::shared_ptr<arrow::Array> PgConvertFixed(const std::shared_ptr<arrow::Array>& value, const F& f) {
  84. const auto& data = value->data();
  85. size_t length = data->length;
  86. NUdf::TFixedSizeArrayBuilder<ui64, false> builder(NKikimr::NMiniKQL::TTypeInfoHelper(), arrow::uint64(), *arrow::default_memory_pool(), length);
  87. auto input = data->GetValues<T>(1);
  88. builder.UnsafeReserve(length);
  89. auto output = builder.MutableData();
  90. for (size_t i = 0; i < length; ++i) {
  91. output[i] = f(input[i]);
  92. }
  93. auto dataBuffer = builder.Build(true).array()->buffers[1];
  94. return arrow::MakeArray(arrow::ArrayData::Make(arrow::uint64(), length, { data->buffers[0], dataBuffer }));
  95. }
  96. template <bool IsCString>
  97. std::shared_ptr<arrow::Array> PgConvertString(const std::shared_ptr<arrow::Array>& value) {
  98. const auto& data = value->data();
  99. size_t length = data->length;
  100. arrow::BinaryBuilder builder;
  101. ARROW_OK(builder.Reserve(length));
  102. auto inputDataSize = arrow::BinaryArray(data).total_values_length();
  103. ARROW_OK(builder.ReserveData(inputDataSize + length * (sizeof(void*) + (IsCString ? 1 : VARHDRSZ))));
  104. NUdf::TStringBlockReader<arrow::BinaryType, true> reader;
  105. std::vector<char> tmp;
  106. for (size_t i = 0; i < length; ++i) {
  107. auto item = reader.GetItem(*data, i);
  108. if (!item) {
  109. ARROW_OK(builder.AppendNull());
  110. continue;
  111. }
  112. auto originalLen = item.AsStringRef().Size();
  113. ui32 len;
  114. if constexpr (IsCString) {
  115. len = sizeof(void*) + 1 + originalLen;
  116. } else {
  117. len = sizeof(void*) + VARHDRSZ + originalLen;
  118. }
  119. if (Y_UNLIKELY(len < originalLen)) {
  120. ythrow yexception() << "Too long string";
  121. }
  122. if (tmp.capacity() < len) {
  123. tmp.reserve(Max<ui64>(len, tmp.capacity() * 2));
  124. }
  125. tmp.resize(len);
  126. NUdf::ZeroMemoryContext(tmp.data() + sizeof(void*));
  127. if constexpr (IsCString) {
  128. memcpy(tmp.data() + sizeof(void*), item.AsStringRef().Data(), originalLen);
  129. tmp[len - 1] = 0;
  130. } else {
  131. memcpy(tmp.data() + sizeof(void*) + VARHDRSZ, item.AsStringRef().Data(), originalLen);
  132. UpdateCleanVarSize((text*)(tmp.data() + sizeof(void*)), originalLen);
  133. }
  134. ARROW_OK(builder.Append(tmp.data(), len));
  135. }
  136. std::shared_ptr<arrow::BinaryArray> ret;
  137. ARROW_OK(builder.Finish(&ret));
  138. return ret;
  139. }
  140. Numeric Uint64ToPgNumeric(ui64 value) {
  141. if (value <= (ui64)Max<i64>()) {
  142. return int64_to_numeric((i64)value);
  143. }
  144. auto ret1 = int64_to_numeric((i64)(value & ~(1ull << 63)));
  145. auto bit = int64_to_numeric(Min<i64>());
  146. bool haveError = false;
  147. auto ret2 = numeric_sub_opt_error(ret1, bit, &haveError);
  148. Y_ENSURE(!haveError);
  149. pfree(ret1);
  150. pfree(bit);
  151. return ret2;
  152. }
  153. Numeric DecimalToPgNumeric(const NUdf::TUnboxedValuePod& value, ui8 precision, ui8 scale) {
  154. const auto str = NYql::NDecimal::ToString(value.GetInt128(), precision, scale);
  155. Y_ENSURE(str);
  156. return (Numeric)DirectFunctionCall3Coll(numeric_in, DEFAULT_COLLATION_OID,
  157. PointerGetDatum(str), Int32GetDatum(0), Int32GetDatum(-1));
  158. }
  159. Numeric DyNumberToPgNumeric(const NUdf::TUnboxedValuePod& value) {
  160. auto str = NKikimr::NDyNumber::DyNumberToString(value.AsStringRef());
  161. Y_ENSURE(str);
  162. return (Numeric)DirectFunctionCall3Coll(numeric_in, DEFAULT_COLLATION_OID,
  163. PointerGetDatum(str->c_str()), Int32GetDatum(0), Int32GetDatum(-1));
  164. }
  165. Numeric PgFloatToNumeric(double item, ui64 scale, int digits) {
  166. double intPart, fracPart;
  167. bool error;
  168. fracPart = modf(item, &intPart);
  169. i64 fracInt = round(fracPart * scale);
  170. // scale compaction: represent 711.56000 as 711.56
  171. while (digits > 0 && fracInt % 10 == 0) {
  172. fracInt /= 10;
  173. digits -= 1;
  174. }
  175. if (digits == 0) {
  176. return int64_to_numeric(intPart);
  177. } else {
  178. return numeric_add_opt_error(
  179. int64_to_numeric(intPart),
  180. int64_div_fast_to_numeric(fracInt, digits),
  181. &error);
  182. }
  183. }
  184. std::shared_ptr<arrow::Array> PgDecimal128ConvertNumeric(const std::shared_ptr<arrow::Array>& value, int32_t precision, int32_t scale) {
  185. TArenaMemoryContext arena;
  186. const auto& data = value->data();
  187. size_t length = data->length;
  188. arrow::BinaryBuilder builder;
  189. bool error;
  190. Numeric high_bits_mul = numeric_mul_opt_error(int64_to_numeric(int64_t(1) << 62), int64_to_numeric(4), &error);
  191. auto input = data->GetValues<arrow::Decimal128>(1);
  192. for (size_t i = 0; i < length; ++i) {
  193. if (value->IsNull(i)) {
  194. ARROW_OK(builder.AppendNull());
  195. continue;
  196. }
  197. Numeric v = PgDecimal128ToNumeric(input[i], precision, scale, high_bits_mul);
  198. auto datum = NumericGetDatum(v);
  199. auto ptr = (char*)datum;
  200. auto len = GetFullVarSize((const text*)datum);
  201. NUdf::ZeroMemoryContext(ptr);
  202. ARROW_OK(builder.Append(ptr - sizeof(void*), len + sizeof(void*)));
  203. }
  204. std::shared_ptr<arrow::BinaryArray> ret;
  205. ARROW_OK(builder.Finish(&ret));
  206. return ret;
  207. }
  208. Numeric PgDecimal128ToNumeric(arrow::Decimal128 value, int32_t precision, int32_t scale, Numeric high_bits_mul) {
  209. uint64_t low_bits = value.low_bits();
  210. int64 high_bits = value.high_bits();
  211. if (low_bits > INT64_MAX){
  212. high_bits += 1;
  213. }
  214. bool error;
  215. Numeric low_bits_res = int64_div_fast_to_numeric(low_bits, scale);
  216. Numeric high_bits_res = numeric_mul_opt_error(int64_div_fast_to_numeric(high_bits, scale), high_bits_mul, &error);
  217. MKQL_ENSURE(error == false, "Bad numeric multiplication.");
  218. Numeric res = numeric_add_opt_error(high_bits_res, low_bits_res, &error);
  219. MKQL_ENSURE(error == false, "Bad numeric addition.");
  220. return res;
  221. }
  222. TColumnConverter BuildPgNumericColumnConverter(const std::shared_ptr<arrow::DataType>& originalType) {
  223. switch (originalType->id()) {
  224. case arrow::Type::INT16:
  225. return [](const std::shared_ptr<arrow::Array>& value) {
  226. return PgConvertNumeric<i16>(value);
  227. };
  228. case arrow::Type::INT32:
  229. return [](const std::shared_ptr<arrow::Array>& value) {
  230. return PgConvertNumeric<i32>(value);
  231. };
  232. case arrow::Type::INT64:
  233. return [](const std::shared_ptr<arrow::Array>& value) {
  234. return PgConvertNumeric<i64>(value);
  235. };
  236. case arrow::Type::FLOAT:
  237. return [](const std::shared_ptr<arrow::Array>& value) {
  238. return PgConvertNumeric<float>(value);
  239. };
  240. case arrow::Type::DOUBLE:
  241. return [](const std::shared_ptr<arrow::Array>& value) {
  242. return PgConvertNumeric<double>(value);
  243. };
  244. case arrow::Type::DECIMAL128: {
  245. auto decimal128Ptr = std::static_pointer_cast<arrow::Decimal128Type>(originalType);
  246. int32_t precision = decimal128Ptr->precision();
  247. int32_t scale = decimal128Ptr->scale();
  248. return [precision, scale](const std::shared_ptr<arrow::Array>& value) {
  249. return PgDecimal128ConvertNumeric(value, precision, scale);
  250. };
  251. }
  252. default:
  253. return {};
  254. }
  255. }
  256. template <typename T, typename F>
  257. TColumnConverter BuildPgFixedColumnConverter(const std::shared_ptr<arrow::DataType>& originalType, const F& f) {
  258. auto primaryType = NKikimr::NMiniKQL::GetPrimitiveDataType<T>();
  259. if (!originalType->Equals(*primaryType) && !arrow::compute::CanCast(*originalType, *primaryType)) {
  260. return {};
  261. }
  262. return [primaryType, originalType, f](const std::shared_ptr<arrow::Array>& value) {
  263. auto res = originalType->Equals(*primaryType) ? value : ARROW_RESULT(arrow::compute::Cast(*value, primaryType));
  264. return PgConvertFixed<T, F>(res, f);
  265. };
  266. }
  267. Datum MakePgDateFromUint16(ui16 value) {
  268. return DatumGetDateADT(UNIX_EPOCH_JDATE - POSTGRES_EPOCH_JDATE + value);
  269. }
  270. Datum MakePgTimestampFromInt64(i64 value) {
  271. return DatumGetTimestamp(USECS_PER_SEC * ((UNIX_EPOCH_JDATE - POSTGRES_EPOCH_JDATE) * SECS_PER_DAY + value));
  272. }
  273. TColumnConverter BuildPgColumnConverter(const std::shared_ptr<arrow::DataType>& originalType, NKikimr::NMiniKQL::TPgType* targetType) {
  274. switch (targetType->GetTypeId()) {
  275. case BOOLOID: {
  276. auto primaryType = arrow::boolean();
  277. if (!originalType->Equals(*primaryType) && !arrow::compute::CanCast(*originalType, *primaryType)) {
  278. return {};
  279. }
  280. return [primaryType, originalType](const std::shared_ptr<arrow::Array>& value) {
  281. auto res = originalType->Equals(*primaryType) ? value : ARROW_RESULT(arrow::compute::Cast(*value, primaryType));
  282. return PgConvertBool(res);
  283. };
  284. }
  285. case INT2OID: {
  286. return BuildPgFixedColumnConverter<i16>(originalType, [](auto value){ return Int16GetDatum(value); });
  287. }
  288. case INT4OID: {
  289. return BuildPgFixedColumnConverter<i32>(originalType, [](auto value){ return Int32GetDatum(value); });
  290. }
  291. case INT8OID: {
  292. return BuildPgFixedColumnConverter<i64>(originalType, [](auto value){ return Int64GetDatum(value); });
  293. }
  294. case FLOAT4OID: {
  295. return BuildPgFixedColumnConverter<float>(originalType, [](auto value){ return Float4GetDatum(value); });
  296. }
  297. case FLOAT8OID: {
  298. return BuildPgFixedColumnConverter<double>(originalType, [](auto value){ return Float8GetDatum(value); });
  299. }
  300. case NUMERICOID: {
  301. return BuildPgNumericColumnConverter(originalType);
  302. }
  303. case BYTEAOID:
  304. case VARCHAROID:
  305. case TEXTOID:
  306. case CSTRINGOID: {
  307. auto primaryType = (targetType->GetTypeId() == BYTEAOID) ? arrow::binary() : arrow::utf8();
  308. if (!arrow::compute::CanCast(*originalType, *primaryType)) {
  309. return {};
  310. }
  311. return [primaryType, originalType, isCString = NPg::LookupType(targetType->GetTypeId()).TypeLen == -2](const std::shared_ptr<arrow::Array>& value) {
  312. auto res = originalType->Equals(*primaryType) ? value : ARROW_RESULT(arrow::compute::Cast(*value, primaryType));
  313. if (isCString) {
  314. return PgConvertString<true>(res);
  315. } else {
  316. return PgConvertString<false>(res);
  317. }
  318. };
  319. }
  320. case DATEOID: {
  321. if (originalType->Equals(arrow::uint16())) {
  322. return [](const std::shared_ptr<arrow::Array>& value) {
  323. return PgConvertFixed<ui16>(value, [](auto value){ return MakePgDateFromUint16(value); });
  324. };
  325. } else if (originalType->Equals(arrow::date32())) {
  326. return [](const std::shared_ptr<arrow::Array>& value) {
  327. return PgConvertFixed<i32>(value, [](auto value){ return MakePgDateFromUint16(value); });
  328. };
  329. } else {
  330. return {};
  331. }
  332. }
  333. case TIMESTAMPOID: {
  334. if (originalType->Equals(arrow::int64())) {
  335. return [](const std::shared_ptr<arrow::Array>& value) {
  336. return PgConvertFixed<i64>(value, [](auto value){ return MakePgTimestampFromInt64(value); });
  337. };
  338. } else {
  339. return {};
  340. }
  341. }
  342. }
  343. return {};
  344. }
  345. }