test_spec.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. #include <library/cpp/testing/unittest/registar.h>
  2. #include <yql/essentials/public/purecalc/common/interface.h>
  3. #include <yql/essentials/public/purecalc/io_specs/arrow/spec.h>
  4. #include <yql/essentials/public/purecalc/ut/lib/helpers.h>
  5. #include <yql/essentials/core/yql_type_annotation.h>
  6. #include <yql/essentials/public/udf/arrow/udf_arrow_helpers.h>
  7. #include <arrow/array/builder_primitive.h>
  8. namespace {
  9. #define Y_UNIT_TEST_ADD_BLOCK_TEST(N, MODE) \
  10. TCurrentTest::AddTest(#N ":BlockEngineMode=" #MODE, \
  11. static_cast<void (*)(NUnitTest::TTestContext&)>(&N<NYql::EBlockEngineMode::MODE>), false);
  12. #define Y_UNIT_TEST_BLOCKS(N) \
  13. template<NYql::EBlockEngineMode BlockEngineMode> \
  14. void N(NUnitTest::TTestContext&); \
  15. struct TTestRegistration##N { \
  16. TTestRegistration##N() { \
  17. Y_UNIT_TEST_ADD_BLOCK_TEST(N, Disable) \
  18. Y_UNIT_TEST_ADD_BLOCK_TEST(N, Auto) \
  19. Y_UNIT_TEST_ADD_BLOCK_TEST(N, Force) \
  20. } \
  21. }; \
  22. static TTestRegistration##N testRegistration##N; \
  23. template<NYql::EBlockEngineMode BlockEngineMode> \
  24. void N(NUnitTest::TTestContext&)
  25. NYql::NPureCalc::TProgramFactoryOptions TestOptions(NYql::EBlockEngineMode mode) {
  26. static const TMap<NYql::EBlockEngineMode, const TString> mode2settings = {
  27. {NYql::EBlockEngineMode::Disable, "disable"},
  28. {NYql::EBlockEngineMode::Auto, "auto"},
  29. {NYql::EBlockEngineMode::Force, "force"},
  30. };
  31. auto options = NYql::NPureCalc::TProgramFactoryOptions();
  32. options.SetBlockEngineSettings(mode2settings.at(mode));
  33. return options;
  34. }
  35. template <typename T>
  36. struct TVectorStream: public NYql::NPureCalc::IStream<T*> {
  37. TVector<T> Data_;
  38. size_t Index_ = 0;
  39. public:
  40. TVectorStream(TVector<T> items)
  41. : Data_(std::move(items))
  42. {
  43. }
  44. T* Fetch() override {
  45. return Index_ < Data_.size() ? &Data_[Index_++] : nullptr;
  46. }
  47. };
  48. template<typename T>
  49. struct TVectorConsumer: public NYql::NPureCalc::IConsumer<T*> {
  50. TVector<T>& Data_;
  51. size_t Index_ = 0;
  52. public:
  53. TVectorConsumer(TVector<T>& items)
  54. : Data_(items)
  55. {
  56. }
  57. void OnObject(T* t) override {
  58. Index_++;
  59. Data_.push_back(*t);
  60. }
  61. void OnFinish() override {
  62. UNIT_ASSERT_GT(Index_, 0);
  63. }
  64. };
  65. using ExecBatchStreamImpl = TVectorStream<arrow::compute::ExecBatch>;
  66. using ExecBatchConsumerImpl = TVectorConsumer<arrow::compute::ExecBatch>;
  67. template <typename TBuilder>
  68. arrow::Datum MakeArrayDatumFromVector(
  69. const TVector<typename TBuilder::value_type>& data,
  70. const TVector<bool>& valid
  71. ) {
  72. TBuilder builder;
  73. ARROW_OK(builder.Reserve(data.size()));
  74. ARROW_OK(builder.AppendValues(data, valid));
  75. return arrow::Datum(ARROW_RESULT(builder.Finish()));
  76. }
  77. template <typename TValue>
  78. TVector<TValue> MakeVectorFromArrayDatum(
  79. const arrow::Datum& datum,
  80. const int64_t dsize
  81. ) {
  82. Y_ENSURE(datum.is_array(), "ExecBatch layout doesn't respect the schema");
  83. const auto& array = *datum.array();
  84. Y_ENSURE(array.length == dsize,
  85. "Array Datum size differs from the given ExecBatch size");
  86. Y_ENSURE(array.GetNullCount() == 0,
  87. "Null values conversion is not supported");
  88. Y_ENSURE(array.buffers.size() == 2,
  89. "Array Datum layout doesn't respect the schema");
  90. const TValue* adata1 = array.GetValuesSafe<TValue>(1);
  91. return TVector<TValue>(adata1, adata1 + dsize);
  92. }
  93. arrow::compute::ExecBatch MakeBatch(ui64 bsize, i64 value, ui64 init = 1) {
  94. TVector<uint64_t> data1(bsize);
  95. TVector<int64_t> data2(bsize);
  96. TVector<bool> valid(bsize);
  97. std::iota(data1.begin(), data1.end(), init);
  98. std::fill(data2.begin(), data2.end(), value);
  99. std::fill(valid.begin(), valid.end(), true);
  100. TVector<arrow::Datum> batchArgs = {
  101. MakeArrayDatumFromVector<arrow::UInt64Builder>(data1, valid),
  102. MakeArrayDatumFromVector<arrow::Int64Builder>(data2, valid)
  103. };
  104. return arrow::compute::ExecBatch(std::move(batchArgs), bsize);
  105. }
  106. TVector<std::tuple<ui64, i64>> CanonBatches(const TVector<arrow::compute::ExecBatch>& batches) {
  107. TVector<std::tuple<ui64, i64>> result;
  108. for (const auto& batch : batches) {
  109. const auto bsize = batch.length;
  110. const auto& avec1 = MakeVectorFromArrayDatum<ui64>(batch.values[0], bsize);
  111. const auto& avec2 = MakeVectorFromArrayDatum<i64>(batch.values[1], bsize);
  112. for (auto i = 0; i < bsize; i++) {
  113. result.push_back(std::make_tuple(avec1[i], avec2[i]));
  114. }
  115. }
  116. std::sort(result.begin(), result.end());
  117. return result;
  118. }
  119. } // namespace
  120. Y_UNIT_TEST_SUITE(TestSimplePullListArrowIO) {
  121. Y_UNIT_TEST_BLOCKS(TestSingleInput) {
  122. using namespace NYql::NPureCalc;
  123. TVector<TString> fields = {"uint64", "int64"};
  124. auto schema = NYql::NPureCalc::NPrivate::GetSchema(fields);
  125. auto factory = MakeProgramFactory(TestOptions(BlockEngineMode));
  126. try {
  127. auto program = factory->MakePullListProgram(
  128. TArrowInputSpec({schema}),
  129. TArrowOutputSpec(schema),
  130. "SELECT * FROM Input",
  131. ETranslationMode::SQL
  132. );
  133. const TVector<arrow::compute::ExecBatch> input({MakeBatch(9, 19)});
  134. const auto canonInput = CanonBatches(input);
  135. ExecBatchStreamImpl items(input);
  136. auto stream = program->Apply(&items);
  137. TVector<arrow::compute::ExecBatch> output;
  138. while (arrow::compute::ExecBatch* batch = stream->Fetch()) {
  139. output.push_back(*batch);
  140. }
  141. const auto canonOutput = CanonBatches(output);
  142. UNIT_ASSERT_EQUAL(canonInput, canonOutput);
  143. } catch (const TCompileError& error) {
  144. UNIT_FAIL(error.GetIssues());
  145. }
  146. }
  147. Y_UNIT_TEST_BLOCKS(TestMultiInput) {
  148. using namespace NYql::NPureCalc;
  149. TVector<TString> fields = {"uint64", "int64"};
  150. auto schema = NYql::NPureCalc::NPrivate::GetSchema(fields);
  151. auto factory = MakeProgramFactory(TestOptions(BlockEngineMode));
  152. try {
  153. auto program = factory->MakePullListProgram(
  154. TArrowInputSpec({schema, schema}),
  155. TArrowOutputSpec(schema),
  156. R"(
  157. SELECT * FROM Input0
  158. UNION ALL
  159. SELECT * FROM Input1
  160. )",
  161. ETranslationMode::SQL
  162. );
  163. TVector<arrow::compute::ExecBatch> inputs = {
  164. MakeBatch(9, 19),
  165. MakeBatch(7, 17)
  166. };
  167. const auto canonInputs = CanonBatches(inputs);
  168. ExecBatchStreamImpl items0({inputs[0]});
  169. ExecBatchStreamImpl items1({inputs[1]});
  170. const TVector<IStream<arrow::compute::ExecBatch*>*> items({&items0, &items1});
  171. auto stream = program->Apply(items);
  172. TVector<arrow::compute::ExecBatch> output;
  173. while (arrow::compute::ExecBatch* batch = stream->Fetch()) {
  174. output.push_back(*batch);
  175. }
  176. const auto canonOutput = CanonBatches(output);
  177. UNIT_ASSERT_EQUAL(canonInputs, canonOutput);
  178. } catch (const TCompileError& error) {
  179. UNIT_FAIL(error.GetIssues());
  180. }
  181. }
  182. }
  183. Y_UNIT_TEST_SUITE(TestMorePullListArrowIO) {
  184. Y_UNIT_TEST_BLOCKS(TestInc) {
  185. using namespace NYql::NPureCalc;
  186. TVector<TString> fields = {"uint64", "int64"};
  187. auto schema = NYql::NPureCalc::NPrivate::GetSchema(fields);
  188. auto factory = MakeProgramFactory(TestOptions(BlockEngineMode));
  189. try {
  190. auto program = factory->MakePullListProgram(
  191. TArrowInputSpec({schema}),
  192. TArrowOutputSpec(schema),
  193. R"(SELECT
  194. uint64 + 1 as uint64,
  195. int64 - 2 as int64,
  196. FROM Input)",
  197. ETranslationMode::SQL
  198. );
  199. const TVector<arrow::compute::ExecBatch> input({MakeBatch(9, 19)});
  200. const auto canonInput = CanonBatches(input);
  201. ExecBatchStreamImpl items(input);
  202. auto stream = program->Apply(&items);
  203. TVector<arrow::compute::ExecBatch> output;
  204. while (arrow::compute::ExecBatch* batch = stream->Fetch()) {
  205. output.push_back(*batch);
  206. }
  207. const auto canonOutput = CanonBatches(output);
  208. const TVector<arrow::compute::ExecBatch> check({MakeBatch(9, 17, 2)});
  209. const auto canonCheck = CanonBatches(check);
  210. UNIT_ASSERT_EQUAL(canonCheck, canonOutput);
  211. } catch (const TCompileError& error) {
  212. UNIT_FAIL(error.GetIssues());
  213. }
  214. }
  215. }
  216. Y_UNIT_TEST_SUITE(TestSimplePullStreamArrowIO) {
  217. Y_UNIT_TEST_BLOCKS(TestSingleInput) {
  218. using namespace NYql::NPureCalc;
  219. TVector<TString> fields = {"uint64", "int64"};
  220. auto schema = NYql::NPureCalc::NPrivate::GetSchema(fields);
  221. auto factory = MakeProgramFactory(TestOptions(BlockEngineMode));
  222. try {
  223. auto program = factory->MakePullStreamProgram(
  224. TArrowInputSpec({schema}),
  225. TArrowOutputSpec(schema),
  226. "SELECT * FROM Input",
  227. ETranslationMode::SQL
  228. );
  229. const TVector<arrow::compute::ExecBatch> input({MakeBatch(9, 19)});
  230. const auto canonInput = CanonBatches(input);
  231. ExecBatchStreamImpl items(input);
  232. auto stream = program->Apply(&items);
  233. TVector<arrow::compute::ExecBatch> output;
  234. while (arrow::compute::ExecBatch* batch = stream->Fetch()) {
  235. output.push_back(*batch);
  236. }
  237. const auto canonOutput = CanonBatches(output);
  238. UNIT_ASSERT_EQUAL(canonInput, canonOutput);
  239. } catch (const TCompileError& error) {
  240. UNIT_FAIL(error.GetIssues());
  241. }
  242. }
  243. }
  244. Y_UNIT_TEST_SUITE(TestMorePullStreamArrowIO) {
  245. Y_UNIT_TEST_BLOCKS(TestInc) {
  246. using namespace NYql::NPureCalc;
  247. TVector<TString> fields = {"uint64", "int64"};
  248. auto schema = NYql::NPureCalc::NPrivate::GetSchema(fields);
  249. auto factory = MakeProgramFactory(TestOptions(BlockEngineMode));
  250. try {
  251. auto program = factory->MakePullStreamProgram(
  252. TArrowInputSpec({schema}),
  253. TArrowOutputSpec(schema),
  254. R"(SELECT
  255. uint64 + 1 as uint64,
  256. int64 - 2 as int64,
  257. FROM Input)",
  258. ETranslationMode::SQL
  259. );
  260. const TVector<arrow::compute::ExecBatch> input({MakeBatch(9, 19)});
  261. const auto canonInput = CanonBatches(input);
  262. ExecBatchStreamImpl items(input);
  263. auto stream = program->Apply(&items);
  264. TVector<arrow::compute::ExecBatch> output;
  265. while (arrow::compute::ExecBatch* batch = stream->Fetch()) {
  266. output.push_back(*batch);
  267. }
  268. const auto canonOutput = CanonBatches(output);
  269. const TVector<arrow::compute::ExecBatch> check({MakeBatch(9, 17, 2)});
  270. const auto canonCheck = CanonBatches(check);
  271. UNIT_ASSERT_EQUAL(canonCheck, canonOutput);
  272. } catch (const TCompileError& error) {
  273. UNIT_FAIL(error.GetIssues());
  274. }
  275. }
  276. }
  277. Y_UNIT_TEST_SUITE(TestPushStreamArrowIO) {
  278. Y_UNIT_TEST_BLOCKS(TestAllColumns) {
  279. using namespace NYql::NPureCalc;
  280. TVector<TString> fields = {"uint64", "int64"};
  281. auto schema = NYql::NPureCalc::NPrivate::GetSchema(fields);
  282. auto factory = MakeProgramFactory(TestOptions(BlockEngineMode));
  283. try {
  284. auto program = factory->MakePushStreamProgram(
  285. TArrowInputSpec({schema}),
  286. TArrowOutputSpec(schema),
  287. "SELECT * FROM Input",
  288. ETranslationMode::SQL
  289. );
  290. arrow::compute::ExecBatch input = MakeBatch(9, 19);
  291. const auto canonInput = CanonBatches({input});
  292. TVector<arrow::compute::ExecBatch> output;
  293. auto consumer = program->Apply(MakeHolder<ExecBatchConsumerImpl>(output));
  294. UNIT_ASSERT_NO_EXCEPTION([&](){ consumer->OnObject(&input); }());
  295. UNIT_ASSERT_NO_EXCEPTION([&](){ consumer->OnFinish(); }());
  296. const auto canonOutput = CanonBatches(output);
  297. UNIT_ASSERT_EQUAL(canonInput, canonOutput);
  298. } catch (const TCompileError& error) {
  299. UNIT_FAIL(error.GetIssues());
  300. }
  301. }
  302. }
  303. Y_UNIT_TEST_SUITE(TestMorePushStreamArrowIO) {
  304. Y_UNIT_TEST_BLOCKS(TestInc) {
  305. using namespace NYql::NPureCalc;
  306. TVector<TString> fields = {"uint64", "int64"};
  307. auto schema = NYql::NPureCalc::NPrivate::GetSchema(fields);
  308. auto factory = MakeProgramFactory(TestOptions(BlockEngineMode));
  309. try {
  310. auto program = factory->MakePushStreamProgram(
  311. TArrowInputSpec({schema}),
  312. TArrowOutputSpec(schema),
  313. R"(SELECT
  314. uint64 + 1 as uint64,
  315. int64 - 2 as int64,
  316. FROM Input)",
  317. ETranslationMode::SQL
  318. );
  319. arrow::compute::ExecBatch input = MakeBatch(9, 19);
  320. const auto canonInput = CanonBatches({input});
  321. TVector<arrow::compute::ExecBatch> output;
  322. auto consumer = program->Apply(MakeHolder<ExecBatchConsumerImpl>(output));
  323. UNIT_ASSERT_NO_EXCEPTION([&](){ consumer->OnObject(&input); }());
  324. UNIT_ASSERT_NO_EXCEPTION([&](){ consumer->OnFinish(); }());
  325. const auto canonOutput = CanonBatches(output);
  326. const TVector<arrow::compute::ExecBatch> check({MakeBatch(9, 17, 2)});
  327. const auto canonCheck = CanonBatches(check);
  328. UNIT_ASSERT_EQUAL(canonCheck, canonOutput);
  329. } catch (const TCompileError& error) {
  330. UNIT_FAIL(error.GetIssues());
  331. }
  332. }
  333. }