mkql_block_impl.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. #include "mkql_block_impl.h"
  2. #include "mkql_block_builder.h"
  3. #include "mkql_block_reader.h"
  4. #include <yql/essentials/minikql/arrow/mkql_functions.h>
  5. #include <yql/essentials/minikql/mkql_node_builder.h>
  6. #include <yql/essentials/minikql/mkql_node_cast.h>
  7. #include <yql/essentials/minikql/arrow/arrow_util.h>
  8. #include <yql/essentials/minikql/arrow/mkql_bit_utils.h>
  9. #include <yql/essentials/public/udf/arrow/args_dechunker.h>
  10. #include <yql/essentials/parser/pg_wrapper/interface/arrow.h>
  11. #include <arrow/compute/exec_internal.h>
  12. extern "C" uint64_t GetBlockCount(const NYql::NUdf::TUnboxedValuePod data) {
  13. return NKikimr::NMiniKQL::TArrowBlock::From(data).GetDatum().scalar_as<arrow::UInt64Scalar>().value;
  14. }
  15. extern "C" uint64_t GetBitmapPopCountCount(const NYql::NUdf::TUnboxedValuePod data) {
  16. const auto& arr = NKikimr::NMiniKQL::TArrowBlock::From(data).GetDatum().array();
  17. const size_t len = (size_t)arr->length;
  18. MKQL_ENSURE(arr->GetNullCount() == 0, "Bitmap block should not have nulls");
  19. const ui8* src = arr->GetValues<ui8>(1);
  20. return NKikimr::NMiniKQL::GetSparseBitmapPopCount(src, len);
  21. }
  22. extern "C" uint8_t GetBitmapScalarValue(const NYql::NUdf::TUnboxedValuePod data) {
  23. return NKikimr::NMiniKQL::TArrowBlock::From(data).GetDatum().scalar_as<arrow::UInt8Scalar>().value;
  24. }
  25. namespace NKikimr::NMiniKQL {
  26. namespace {
  27. template<typename T>
  28. arrow::Datum DoConvertScalar(TType* type, const T& value, arrow::MemoryPool& pool) {
  29. std::shared_ptr<arrow::DataType> arrowType;
  30. MKQL_ENSURE(ConvertArrowType(type, arrowType), "Unsupported type of scalar " << *type);
  31. if (!value) {
  32. return arrow::MakeNullScalar(arrowType);
  33. }
  34. bool isOptional = false;
  35. if (type->IsOptional()) {
  36. type = AS_TYPE(TOptionalType, type)->GetItemType();
  37. isOptional = true;
  38. }
  39. if (type->IsOptional() || (isOptional && type->IsPg())) {
  40. // nested optionals
  41. std::vector<std::shared_ptr<arrow::Scalar>> arrowValue;
  42. arrowValue.emplace_back(DoConvertScalar(type, value.GetOptionalValue(), pool).scalar());
  43. return arrow::Datum(std::make_shared<arrow::StructScalar>(arrowValue, arrowType));
  44. }
  45. if (type->IsStruct()) {
  46. auto structType = AS_TYPE(TStructType, type);
  47. std::vector<std::shared_ptr<arrow::Scalar>> arrowValue;
  48. for (ui32 i = 0; i < structType->GetMembersCount(); ++i) {
  49. arrowValue.emplace_back(DoConvertScalar(structType->GetMemberType(i), value.GetElement(i), pool).scalar());
  50. }
  51. return arrow::Datum(std::make_shared<arrow::StructScalar>(arrowValue, arrowType));
  52. }
  53. if (type->IsTuple()) {
  54. auto tupleType = AS_TYPE(TTupleType, type);
  55. std::vector<std::shared_ptr<arrow::Scalar>> arrowValue;
  56. for (ui32 i = 0; i < tupleType->GetElementsCount(); ++i) {
  57. arrowValue.emplace_back(DoConvertScalar(tupleType->GetElementType(i), value.GetElement(i), pool).scalar());
  58. }
  59. return arrow::Datum(std::make_shared<arrow::StructScalar>(arrowValue, arrowType));
  60. }
  61. if (type->IsData()) {
  62. auto slot = *AS_TYPE(TDataType, type)->GetDataSlot();
  63. switch (slot) {
  64. case NUdf::EDataSlot::Int8:
  65. return arrow::Datum(static_cast<int8_t>(value.template Get<i8>()));
  66. case NUdf::EDataSlot::Bool:
  67. case NUdf::EDataSlot::Uint8:
  68. return arrow::Datum(static_cast<uint8_t>(value.template Get<ui8>()));
  69. case NUdf::EDataSlot::Int16:
  70. return arrow::Datum(static_cast<int16_t>(value.template Get<i16>()));
  71. case NUdf::EDataSlot::Uint16:
  72. case NUdf::EDataSlot::Date:
  73. return arrow::Datum(static_cast<uint16_t>(value.template Get<ui16>()));
  74. case NUdf::EDataSlot::Int32:
  75. case NUdf::EDataSlot::Date32:
  76. return arrow::Datum(static_cast<int32_t>(value.template Get<i32>()));
  77. case NUdf::EDataSlot::Uint32:
  78. case NUdf::EDataSlot::Datetime:
  79. return arrow::Datum(static_cast<uint32_t>(value.template Get<ui32>()));
  80. case NUdf::EDataSlot::Int64:
  81. case NUdf::EDataSlot::Interval:
  82. case NUdf::EDataSlot::Interval64:
  83. case NUdf::EDataSlot::Datetime64:
  84. case NUdf::EDataSlot::Timestamp64:
  85. return arrow::Datum(static_cast<int64_t>(value.template Get<i64>()));
  86. case NUdf::EDataSlot::Uint64:
  87. case NUdf::EDataSlot::Timestamp:
  88. return arrow::Datum(static_cast<uint64_t>(value.template Get<ui64>()));
  89. case NUdf::EDataSlot::Float:
  90. return arrow::Datum(static_cast<float>(value.template Get<float>()));
  91. case NUdf::EDataSlot::Double:
  92. return arrow::Datum(static_cast<double>(value.template Get<double>()));
  93. case NUdf::EDataSlot::String:
  94. case NUdf::EDataSlot::Utf8:
  95. case NUdf::EDataSlot::Yson:
  96. case NUdf::EDataSlot::Json:
  97. case NUdf::EDataSlot::JsonDocument: {
  98. const auto& str = value.AsStringRef();
  99. std::shared_ptr<arrow::Buffer> buffer(ARROW_RESULT(arrow::AllocateBuffer(str.Size(), &pool)));
  100. std::memcpy(buffer->mutable_data(), str.Data(), str.Size());
  101. auto type = (slot == NUdf::EDataSlot::String || slot == NUdf::EDataSlot::Yson || slot == NUdf::EDataSlot::JsonDocument) ? arrow::binary() : arrow::utf8();
  102. std::shared_ptr<arrow::Scalar> scalar = std::make_shared<arrow::BinaryScalar>(buffer, type);
  103. return arrow::Datum(scalar);
  104. }
  105. case NUdf::EDataSlot::TzDate: {
  106. auto items = arrow::StructScalar::ValueType{
  107. std::make_shared<arrow::UInt16Scalar>(value.template Get<ui16>()),
  108. std::make_shared<arrow::UInt16Scalar>(value.GetTimezoneId())
  109. };
  110. return arrow::Datum(std::make_shared<arrow::StructScalar>(items, MakeTzDateArrowType<NUdf::EDataSlot::TzDate>()));
  111. }
  112. case NUdf::EDataSlot::TzDatetime: {
  113. auto items = arrow::StructScalar::ValueType{
  114. std::make_shared<arrow::UInt32Scalar>(value.template Get<ui32>()),
  115. std::make_shared<arrow::UInt16Scalar>(value.GetTimezoneId())
  116. };
  117. return arrow::Datum(std::make_shared<arrow::StructScalar>(items, MakeTzDateArrowType<NUdf::EDataSlot::TzDatetime>()));
  118. }
  119. case NUdf::EDataSlot::TzTimestamp: {
  120. auto items = arrow::StructScalar::ValueType{
  121. std::make_shared<arrow::UInt64Scalar>(value.template Get<ui64>()),
  122. std::make_shared<arrow::UInt16Scalar>(value.GetTimezoneId())
  123. };
  124. return arrow::Datum(std::make_shared<arrow::StructScalar>(items, MakeTzDateArrowType<NUdf::EDataSlot::TzTimestamp>()));
  125. }
  126. case NUdf::EDataSlot::TzDate32: {
  127. auto items = arrow::StructScalar::ValueType{
  128. std::make_shared<arrow::Int32Scalar>(value.template Get<i32>()),
  129. std::make_shared<arrow::UInt16Scalar>(value.GetTimezoneId())
  130. };
  131. return arrow::Datum(std::make_shared<arrow::StructScalar>(items, MakeTzDateArrowType<NUdf::EDataSlot::TzDate32>()));
  132. }
  133. case NUdf::EDataSlot::TzDatetime64: {
  134. auto items = arrow::StructScalar::ValueType{
  135. std::make_shared<arrow::Int64Scalar>(value.template Get<i64>()),
  136. std::make_shared<arrow::UInt16Scalar>(value.GetTimezoneId())
  137. };
  138. return arrow::Datum(std::make_shared<arrow::StructScalar>(items, MakeTzDateArrowType<NUdf::EDataSlot::TzDatetime64>()));
  139. }
  140. case NUdf::EDataSlot::TzTimestamp64: {
  141. auto items = arrow::StructScalar::ValueType{
  142. std::make_shared<arrow::Int64Scalar>(value.template Get<i64>()),
  143. std::make_shared<arrow::UInt16Scalar>(value.GetTimezoneId())
  144. };
  145. return arrow::Datum(std::make_shared<arrow::StructScalar>(items, MakeTzDateArrowType<NUdf::EDataSlot::TzTimestamp64>()));
  146. }
  147. case NUdf::EDataSlot::Decimal: {
  148. std::shared_ptr<arrow::Buffer> buffer(ARROW_RESULT(arrow::AllocateBuffer(16, &pool)));
  149. *reinterpret_cast<NYql::NDecimal::TInt128*>(buffer->mutable_data()) = value.GetInt128();
  150. return arrow::Datum(std::make_shared<TPrimitiveDataType<NYql::NDecimal::TInt128>::TScalarResult>(buffer));
  151. }
  152. default:
  153. MKQL_ENSURE(false, "Unsupported data slot " << slot);
  154. }
  155. }
  156. if (type->IsPg()) {
  157. return NYql::MakePgScalar(AS_TYPE(TPgType, type), value, pool);
  158. }
  159. MKQL_ENSURE(false, "Unsupported type " << *type);
  160. }
  161. } // namespace
  162. arrow::Datum ConvertScalar(TType* type, const NUdf::TUnboxedValuePod& value, arrow::MemoryPool& pool) {
  163. return DoConvertScalar(type, value, pool);
  164. }
  165. arrow::Datum ConvertScalar(TType* type, const NUdf::TBlockItem& value, arrow::MemoryPool& pool) {
  166. return DoConvertScalar(type, value, pool);
  167. }
  168. arrow::Datum MakeArrayFromScalar(const arrow::Scalar& scalar, size_t len, TType* type, arrow::MemoryPool& pool) {
  169. MKQL_ENSURE(len > 0, "Invalid block size");
  170. auto reader = MakeBlockReader(TTypeInfoHelper(), type);
  171. auto builder = MakeArrayBuilder(TTypeInfoHelper(), type, pool, len, nullptr);
  172. auto scalarItem = reader->GetScalarItem(scalar);
  173. builder->Add(scalarItem, len);
  174. return builder->Build(true);
  175. }
  176. arrow::ValueDescr ToValueDescr(TType* type) {
  177. arrow::ValueDescr ret;
  178. MKQL_ENSURE(ConvertInputArrowType(type, ret), "can't get arrow type");
  179. return ret;
  180. }
  181. std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types) {
  182. std::vector<arrow::ValueDescr> res;
  183. res.reserve(types.size());
  184. for (const auto& type : types) {
  185. res.emplace_back(ToValueDescr(type));
  186. }
  187. return res;
  188. }
  189. std::vector<arrow::compute::InputType> ConvertToInputTypes(const TVector<TType*>& argTypes) {
  190. std::vector<arrow::compute::InputType> result;
  191. result.reserve(argTypes.size());
  192. for (auto& type : argTypes) {
  193. result.emplace_back(ToValueDescr(type));
  194. }
  195. return result;
  196. }
  197. arrow::compute::OutputType ConvertToOutputType(TType* output) {
  198. return arrow::compute::OutputType(ToValueDescr(output));
  199. }
  200. NUdf::TUnboxedValuePod MakeBlockCount(const THolderFactory& holderFactory, const uint64_t count) {
  201. return holderFactory.CreateArrowBlock(arrow::Datum(count));
  202. }
  203. TBlockFuncNode::TBlockFuncNode(TComputationMutables& mutables, TStringBuf name, TComputationNodePtrVector&& argsNodes,
  204. const TVector<TType*>& argsTypes, const arrow::compute::ScalarKernel& kernel,
  205. std::shared_ptr<arrow::compute::ScalarKernel> kernelHolder,
  206. const arrow::compute::FunctionOptions* functionOptions)
  207. : TMutableComputationNode(mutables)
  208. , StateIndex(mutables.CurValueIndex++)
  209. , ArgsNodes(std::move(argsNodes))
  210. , ArgsValuesDescr(ToValueDescr(argsTypes))
  211. , Kernel(kernel)
  212. , KernelHolder(std::move(kernelHolder))
  213. , Options(functionOptions)
  214. , ScalarOutput(GetResultShape(argsTypes) == TBlockType::EShape::Scalar)
  215. , Name(name.starts_with("Block") ? name.substr(5) : name)
  216. {
  217. }
  218. NUdf::TUnboxedValuePod TBlockFuncNode::DoCalculate(TComputationContext& ctx) const {
  219. auto& state = GetState(ctx);
  220. std::vector<arrow::Datum> argDatums;
  221. for (ui32 i = 0; i < ArgsNodes.size(); ++i) {
  222. const auto& value = ArgsNodes[i]->GetValue(ctx);
  223. argDatums.emplace_back(TArrowBlock::From(value).GetDatum());
  224. ARROW_DEBUG_CHECK_DATUM_TYPES(ArgsValuesDescr[i], argDatums.back().descr());
  225. }
  226. if (ScalarOutput) {
  227. auto executor = arrow::compute::detail::KernelExecutor::MakeScalar();
  228. ARROW_OK(executor->Init(&state.KernelContext, { &Kernel, ArgsValuesDescr, Options }));
  229. auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>();
  230. ARROW_OK(executor->Execute(argDatums, listener.get()));
  231. auto output = executor->WrapResults(argDatums, listener->values());
  232. return ctx.HolderFactory.CreateArrowBlock(std::move(output));
  233. }
  234. NYql::NUdf::TArgsDechunker dechunker(std::move(argDatums));
  235. std::vector<arrow::Datum> chunk;
  236. TVector<std::shared_ptr<arrow::ArrayData>> arrays;
  237. while (dechunker.Next(chunk)) {
  238. auto executor = arrow::compute::detail::KernelExecutor::MakeScalar();
  239. ARROW_OK(executor->Init(&state.KernelContext, { &Kernel, ArgsValuesDescr, Options }));
  240. arrow::compute::detail::DatumAccumulator listener;
  241. ARROW_OK(executor->Execute(chunk, &listener));
  242. auto output = executor->WrapResults(chunk, listener.values());
  243. ForEachArrayData(output, [&](const auto& arr) { arrays.push_back(arr); });
  244. }
  245. return ctx.HolderFactory.CreateArrowBlock(MakeArray(arrays));
  246. }
  247. void TBlockFuncNode::RegisterDependencies() const {
  248. for (const auto& arg : ArgsNodes) {
  249. DependsOn(arg);
  250. }
  251. }
  252. TBlockFuncNode::TState& TBlockFuncNode::GetState(TComputationContext& ctx) const {
  253. auto& result = ctx.MutableValues[StateIndex];
  254. if (!result.HasValue()) {
  255. result = ctx.HolderFactory.Create<TState>(Options, Kernel, ArgsValuesDescr, ctx);
  256. }
  257. return *static_cast<TState*>(result.AsBoxed().Get());
  258. }
  259. std::unique_ptr<IArrowKernelComputationNode> TBlockFuncNode::PrepareArrowKernelComputationNode(TComputationContext&) const {
  260. return std::make_unique<TArrowNode>(this);
  261. }
  262. TBlockFuncNode::TArrowNode::TArrowNode(const TBlockFuncNode* parent)
  263. : Parent_(parent)
  264. {}
  265. TStringBuf TBlockFuncNode::TArrowNode::GetKernelName() const {
  266. return Parent_->Name;
  267. }
  268. const arrow::compute::ScalarKernel& TBlockFuncNode::TArrowNode::GetArrowKernel() const {
  269. return Parent_->Kernel;
  270. }
  271. const std::vector<arrow::ValueDescr>& TBlockFuncNode::TArrowNode::GetArgsDesc() const {
  272. return Parent_->ArgsValuesDescr;
  273. }
  274. const IComputationNode* TBlockFuncNode::TArrowNode::GetArgument(ui32 index) const {
  275. MKQL_ENSURE(index < Parent_->ArgsNodes.size(), "Wrong index");
  276. return Parent_->ArgsNodes[index];
  277. }
  278. TBlockState::TBlockState(TMemoryUsageInfo* memInfo, size_t width)
  279. : TBase(memInfo), Values(width), Deques(width - 1ULL), Arrays(width - 1ULL)
  280. {
  281. Pointer_ = Values.data();
  282. }
  283. void TBlockState::ClearValues() {
  284. Values.assign(Values.size(), NUdf::TUnboxedValuePod());
  285. }
  286. void TBlockState::FillArrays() {
  287. MKQL_ENSURE(Count == 0, "All existing arrays have to be processed");
  288. auto& counterDatum = TArrowBlock::From(Values.back()).GetDatum();
  289. MKQL_ENSURE(counterDatum.is_scalar(), "Unexpected block length type (expecting scalar)");
  290. Count = counterDatum.scalar_as<arrow::UInt64Scalar>().value;
  291. if (!Count)
  292. return;
  293. for (size_t i = 0U; i < Deques.size(); ++i) {
  294. Deques[i].clear();
  295. if (const auto& value = Values[i]) {
  296. const auto& datum = TArrowBlock::From(value).GetDatum();
  297. if (datum.is_scalar()) {
  298. return;
  299. }
  300. MKQL_ENSURE(datum.is_arraylike(), "Unexpected block type (expecting array or chunked array)");
  301. ForEachArrayData(datum, [this, i](const auto& arrayData) {
  302. Deques[i].push_back(arrayData);
  303. });
  304. }
  305. }
  306. }
  307. ui64 TBlockState::Slice() {
  308. auto sliceSize = Count;
  309. for (size_t i = 0; i < Deques.size(); ++i) {
  310. const auto& arr = Deques[i];
  311. if (arr.empty())
  312. continue;
  313. Y_ABORT_UNLESS(ui64(arr.front()->length) <= Count);
  314. MKQL_ENSURE(ui64(arr.front()->length) <= Count, "Unexpected array length at column #" << i);
  315. sliceSize = std::min<ui64>(sliceSize, arr.front()->length);
  316. }
  317. for (size_t i = 0; i < Arrays.size(); ++i) {
  318. auto& arr = Deques[i];
  319. if (arr.empty())
  320. continue;
  321. if (auto& array = arr.front(); ui64(array->length) == sliceSize) {
  322. Arrays[i] = std::move(array);
  323. Deques[i].pop_front();
  324. } else
  325. Arrays[i] = Chop(array, sliceSize);
  326. }
  327. Count -= sliceSize;
  328. return sliceSize;
  329. }
  330. NUdf::TUnboxedValuePod TBlockState::Get(const ui64 sliceSize, const THolderFactory& holderFactory, const size_t idx) const {
  331. if (idx >= Deques.size())
  332. return holderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(sliceSize)));
  333. if (auto array = Arrays[idx])
  334. return holderFactory.CreateArrowBlock(std::move(array));
  335. else
  336. return Values[idx];
  337. }
  338. }