mkql_program_builder.cpp 275 KB


  1. #include "mkql_program_builder.h"
  2. #include "mkql_node_visitor.h"
  3. #include "mkql_node_cast.h"
  4. #include "mkql_runtime_version.h"
  5. #include "yql/essentials/minikql/mkql_node_printer.h"
  6. #include "yql/essentials/minikql/mkql_function_registry.h"
  7. #include "yql/essentials/minikql/mkql_utils.h"
  8. #include "yql/essentials/minikql/mkql_type_builder.h"
  9. #include "yql/essentials/core/sql_types/match_recognize.h"
  10. #include "yql/essentials/core/sql_types/time_order_recover.h"
  11. #include <yql/essentials/parser/pg_catalog/catalog.h>
  12. #include <util/generic/overloaded.h>
  13. #include <util/string/cast.h>
  14. #include <util/string/printf.h>
  15. #include <array>
  16. using namespace std::string_view_literals;
  17. namespace NKikimr {
  18. namespace NMiniKQL {
  19. namespace {
  20. struct TDataFunctionFlags {
  21. enum {
  22. HasBooleanResult = 0x01,
  23. RequiresBooleanArgs = 0x02,
  24. HasOptionalResult = 0x04,
  25. AllowOptionalArgs = 0x08,
  26. HasUi32Result = 0x10,
  27. RequiresCompare = 0x20,
  28. HasStringResult = 0x40,
  29. RequiresStringArgs = 0x80,
  30. RequiresHash = 0x100,
  31. RequiresEquals = 0x200,
  32. AllowNull = 0x400,
  33. CommonOptionalResult = 0x800,
  34. SupportsTuple = 0x1000,
  35. SameOptionalArgs = 0x2000,
  36. Default = 0x00
  37. };
  38. };
  39. #define MKQL_BAD_TYPE_VISIT(NodeType, ScriptName) \
  40. void Visit(NodeType& node) override { \
  41. Y_UNUSED(node); \
  42. MKQL_ENSURE(false, "Can't convert " #NodeType " to " ScriptName " object"); \
  43. }
  44. class TPythonTypeChecker : public TExploringNodeVisitor {
  45. using TExploringNodeVisitor::Visit;
  46. MKQL_BAD_TYPE_VISIT(TAnyType, "Python");
  47. };
  48. class TLuaTypeChecker : public TExploringNodeVisitor {
  49. using TExploringNodeVisitor::Visit;
  50. MKQL_BAD_TYPE_VISIT(TVoidType, "Lua");
  51. MKQL_BAD_TYPE_VISIT(TAnyType, "Lua");
  52. MKQL_BAD_TYPE_VISIT(TVariantType, "Lua");
  53. };
  54. class TJavascriptTypeChecker : public TExploringNodeVisitor {
  55. using TExploringNodeVisitor::Visit;
  56. MKQL_BAD_TYPE_VISIT(TAnyType, "Javascript");
  57. };
  58. #undef MKQL_BAD_TYPE_VISIT
  59. void EnsureScriptSpecificTypes(
  60. EScriptType scriptType,
  61. TCallableType* funcType,
  62. const TTypeEnvironment& env)
  63. {
  64. switch (scriptType) {
  65. case EScriptType::Lua:
  66. return TLuaTypeChecker().Walk(funcType, env);
  67. case EScriptType::Python:
  68. case EScriptType::Python2:
  69. case EScriptType::Python3:
  70. case EScriptType::ArcPython:
  71. case EScriptType::ArcPython2:
  72. case EScriptType::ArcPython3:
  73. case EScriptType::CustomPython:
  74. case EScriptType::CustomPython2:
  75. case EScriptType::CustomPython3:
  76. case EScriptType::SystemPython2:
  77. case EScriptType::SystemPython3:
  78. case EScriptType::SystemPython3_8:
  79. case EScriptType::SystemPython3_9:
  80. case EScriptType::SystemPython3_10:
  81. case EScriptType::SystemPython3_11:
  82. case EScriptType::SystemPython3_12:
  83. case EScriptType::SystemPython3_13:
  84. return TPythonTypeChecker().Walk(funcType, env);
  85. case EScriptType::Javascript:
  86. return TJavascriptTypeChecker().Walk(funcType, env);
  87. default:
  88. MKQL_ENSURE(false, "Unknown script type " << static_cast<ui32>(scriptType));
  89. }
  90. }
  91. ui32 GetNumericSchemeTypeLevel(NUdf::TDataTypeId typeId) {
  92. switch (typeId) {
  93. case NUdf::TDataType<ui8>::Id:
  94. return 0;
  95. case NUdf::TDataType<i8>::Id:
  96. return 1;
  97. case NUdf::TDataType<ui16>::Id:
  98. return 2;
  99. case NUdf::TDataType<i16>::Id:
  100. return 3;
  101. case NUdf::TDataType<ui32>::Id:
  102. return 4;
  103. case NUdf::TDataType<i32>::Id:
  104. return 5;
  105. case NUdf::TDataType<ui64>::Id:
  106. return 6;
  107. case NUdf::TDataType<i64>::Id:
  108. return 7;
  109. case NUdf::TDataType<float>::Id:
  110. return 8;
  111. case NUdf::TDataType<double>::Id:
  112. return 9;
  113. default:
  114. ythrow yexception() << "Unknown numeric type: " << typeId;
  115. }
  116. }
  117. NUdf::TDataTypeId GetNumericSchemeTypeByLevel(ui32 level) {
  118. switch (level) {
  119. case 0:
  120. return NUdf::TDataType<ui8>::Id;
  121. case 1:
  122. return NUdf::TDataType<i8>::Id;
  123. case 2:
  124. return NUdf::TDataType<ui16>::Id;
  125. case 3:
  126. return NUdf::TDataType<i16>::Id;
  127. case 4:
  128. return NUdf::TDataType<ui32>::Id;
  129. case 5:
  130. return NUdf::TDataType<i32>::Id;
  131. case 6:
  132. return NUdf::TDataType<ui64>::Id;
  133. case 7:
  134. return NUdf::TDataType<i64>::Id;
  135. case 8:
  136. return NUdf::TDataType<float>::Id;
  137. case 9:
  138. return NUdf::TDataType<double>::Id;
  139. default:
  140. ythrow yexception() << "Unknown numeric level: " << level;
  141. }
  142. }
  143. NUdf::TDataTypeId MakeNumericDataSuperType(NUdf::TDataTypeId typeId1, NUdf::TDataTypeId typeId2) {
  144. return typeId1 == typeId2 ? typeId1 :
  145. GetNumericSchemeTypeByLevel(std::max(GetNumericSchemeTypeLevel(typeId1), GetNumericSchemeTypeLevel(typeId2)));
  146. }
  147. template<bool IsFilter>
  148. bool CollectOptionalElements(const TType* type, std::vector<std::string_view>& test, std::vector<std::pair<std::string_view, TType*>>& output) {
  149. const auto structType = AS_TYPE(TStructType, type);
  150. test.reserve(structType->GetMembersCount());
  151. output.reserve(structType->GetMembersCount());
  152. bool multiOptional = false;
  153. for (ui32 i = 0; i < structType->GetMembersCount(); ++i) {
  154. output.emplace_back(structType->GetMemberName(i), structType->GetMemberType(i));
  155. auto& memberType = output.back().second;
  156. if (memberType->IsOptional()) {
  157. test.emplace_back(output.back().first);
  158. if constexpr (IsFilter) {
  159. memberType = AS_TYPE(TOptionalType, memberType)->GetItemType();
  160. multiOptional = multiOptional || memberType->IsOptional();
  161. }
  162. }
  163. }
  164. return multiOptional;
  165. }
  166. template<bool IsFilter>
  167. bool CollectOptionalElements(const TType* type, std::vector<ui32>& test, std::vector<TType*>& output) {
  168. const auto typleType = AS_TYPE(TTupleType, type);
  169. test.reserve(typleType->GetElementsCount());
  170. output.reserve(typleType->GetElementsCount());
  171. bool multiOptional = false;
  172. for (ui32 i = 0; i < typleType->GetElementsCount(); ++i) {
  173. output.emplace_back(typleType->GetElementType(i));
  174. auto& elementType = output.back();
  175. if (elementType->IsOptional()) {
  176. test.emplace_back(i);
  177. if constexpr (IsFilter) {
  178. elementType = AS_TYPE(TOptionalType, elementType)->GetItemType();
  179. multiOptional = multiOptional || elementType->IsOptional();
  180. }
  181. }
  182. }
  183. return multiOptional;
  184. }
  185. bool ReduceOptionalElements(const TType* type, const TArrayRef<const std::string_view>& test, std::vector<std::pair<std::string_view, TType*>>& output) {
  186. const auto structType = AS_TYPE(TStructType, type);
  187. output.reserve(structType->GetMembersCount());
  188. for (ui32 i = 0U; i < structType->GetMembersCount(); ++i) {
  189. output.emplace_back(structType->GetMemberName(i), structType->GetMemberType(i));
  190. }
  191. bool multiOptional = false;
  192. for (const auto& member : test) {
  193. auto& memberType = output[structType->GetMemberIndex(member)].second;
  194. MKQL_ENSURE(memberType->IsOptional(), "Required optional column type");
  195. memberType = AS_TYPE(TOptionalType, memberType)->GetItemType();
  196. multiOptional = multiOptional || memberType->IsOptional();
  197. }
  198. return multiOptional;
  199. }
  200. bool ReduceOptionalElements(const TType* type, const TArrayRef<const ui32>& test, std::vector<TType*>& output) {
  201. const auto typleType = AS_TYPE(TTupleType, type);
  202. output.reserve(typleType->GetElementsCount());
  203. for (ui32 i = 0U; i < typleType->GetElementsCount(); ++i) {
  204. output.emplace_back(typleType->GetElementType(i));
  205. }
  206. bool multiOptional = false;
  207. for (const auto& member : test) {
  208. auto& memberType = output[member];
  209. MKQL_ENSURE(memberType->IsOptional(), "Required optional column type");
  210. memberType = AS_TYPE(TOptionalType, memberType)->GetItemType();
  211. multiOptional = multiOptional || memberType->IsOptional();
  212. }
  213. return multiOptional;
  214. }
  215. static std::vector<TType*> ValidateBlockItems(const TArrayRef<TType* const>& wideComponents, bool unwrap) {
  216. MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column");
  217. std::vector<TType*> items;
  218. items.reserve(wideComponents.size());
  219. // XXX: Declare these variables outside the loop body to use for the last
  220. // item (i.e. block length column) in the assertions below.
  221. bool isScalar;
  222. TType* itemType;
  223. for (const auto& wideComponent : wideComponents) {
  224. auto blockType = AS_TYPE(TBlockType, wideComponent);
  225. isScalar = blockType->GetShape() == TBlockType::EShape::Scalar;
  226. itemType = blockType->GetItemType();
  227. items.push_back(unwrap ? itemType : blockType);
  228. }
  229. MKQL_ENSURE(isScalar, "Last column should be scalar");
  230. MKQL_ENSURE(AS_TYPE(TDataType, itemType)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64");
  231. return items;
  232. }
  233. } // namespace
  234. std::string_view ScriptTypeAsStr(EScriptType type) {
  235. switch (type) {
  236. #define MKQL_SCRIPT_TYPE_CASE(name, value, ...) \
  237. case EScriptType::name: return std::string_view(#name);
  238. MKQL_SCRIPT_TYPES(MKQL_SCRIPT_TYPE_CASE)
  239. #undef MKQL_SCRIPT_TYPE_CASE
  240. } // switch
  241. return std::string_view("Unknown");
  242. }
  243. EScriptType ScriptTypeFromStr(std::string_view str) {
  244. TString lowerStr = TString(str);
  245. lowerStr.to_lower();
  246. #define MKQL_SCRIPT_TYPE_FROM_STR(name, value, lowerName, allowSuffix) \
  247. if ((allowSuffix && lowerStr.StartsWith(#lowerName)) || lowerStr == #lowerName) return EScriptType::name;
  248. MKQL_SCRIPT_TYPES(MKQL_SCRIPT_TYPE_FROM_STR)
  249. #undef MKQL_SCRIPT_TYPE_FROM_STR
  250. return EScriptType::Unknown;
  251. }
  252. bool IsCustomPython(EScriptType type) {
  253. return type == EScriptType::CustomPython ||
  254. type == EScriptType::CustomPython2 ||
  255. type == EScriptType::CustomPython3;
  256. }
  257. bool IsSystemPython(EScriptType type) {
  258. return type == EScriptType::SystemPython2
  259. || type == EScriptType::SystemPython3
  260. || type == EScriptType::SystemPython3_8
  261. || type == EScriptType::SystemPython3_9
  262. || type == EScriptType::SystemPython3_10
  263. || type == EScriptType::SystemPython3_11
  264. || type == EScriptType::SystemPython3_12
  265. || type == EScriptType::SystemPython3_13
  266. || type == EScriptType::Python
  267. || type == EScriptType::Python2;
  268. }
  269. EScriptType CanonizeScriptType(EScriptType type) {
  270. if (type == EScriptType::Python) {
  271. return EScriptType::Python2;
  272. }
  273. if (type == EScriptType::ArcPython) {
  274. return EScriptType::ArcPython2;
  275. }
  276. return type;
  277. }
  278. void EnsureDataOrOptionalOfData(TRuntimeNode node) {
  279. MKQL_ENSURE(node.GetStaticType()->IsData() ||
  280. node.GetStaticType()->IsOptional() && AS_TYPE(TOptionalType, node.GetStaticType())
  281. ->GetItemType()->IsData(), "Expected data or optional of data");
  282. }
  283. std::vector<TType*> ValidateBlockStreamType(const TType* streamType, bool unwrap) {
  284. const auto wideComponents = GetWideComponents(AS_TYPE(TStreamType, streamType));
  285. return ValidateBlockItems(wideComponents, unwrap);
  286. }
  287. std::vector<TType*> ValidateBlockFlowType(const TType* flowType, bool unwrap) {
  288. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flowType));
  289. return ValidateBlockItems(wideComponents, unwrap);
  290. }
  291. TProgramBuilder::TProgramBuilder(const TTypeEnvironment& env, const IFunctionRegistry& functionRegistry, bool voidWithEffects)
  292. : TTypeBuilder(env)
  293. , FunctionRegistry(functionRegistry)
  294. , VoidWithEffects(voidWithEffects)
  295. {}
  296. const TTypeEnvironment& TProgramBuilder::GetTypeEnvironment() const {
  297. return Env;
  298. }
  299. const IFunctionRegistry& TProgramBuilder::GetFunctionRegistry() const {
  300. return FunctionRegistry;
  301. }
  302. TType* TProgramBuilder::ChooseCommonType(TType* type1, TType* type2) {
  303. bool isOptional1, isOptional2;
  304. const auto data1 = UnpackOptionalData(type1, isOptional1);
  305. const auto data2 = UnpackOptionalData(type2, isOptional2);
  306. if (data1->IsSameType(*data2)) {
  307. return isOptional1 ? type1 : type2;
  308. }
  309. MKQL_ENSURE(!
  310. ((NUdf::GetDataTypeInfo(*data1->GetDataSlot()).Features | NUdf::GetDataTypeInfo(*data2->GetDataSlot()).Features) & (NUdf::EDataTypeFeatures::DateType | NUdf::EDataTypeFeatures::TzDateType)),
  311. "Not same date types: " << *type1 << " and " << *type2
  312. );
  313. const auto data = NewDataType(MakeNumericDataSuperType(data1->GetSchemeType(), data2->GetSchemeType()));
  314. return isOptional1 || isOptional2 ? NewOptionalType(data) : data;
  315. }
  316. TType* TProgramBuilder::BuildArithmeticCommonType(TType* type1, TType* type2) {
  317. bool isOptional1, isOptional2;
  318. const auto data1 = UnpackOptionalData(type1, isOptional1);
  319. const auto data2 = UnpackOptionalData(type2, isOptional2);
  320. const auto features1 = NUdf::GetDataTypeInfo(*data1->GetDataSlot()).Features;
  321. const auto features2 = NUdf::GetDataTypeInfo(*data2->GetDataSlot()).Features;
  322. const bool isOptional = isOptional1 || isOptional2;
  323. if (features1 & features2 & NUdf::EDataTypeFeatures::TimeIntervalType) {
  324. return NewOptionalType(features1 & NUdf::EDataTypeFeatures::BigDateType ? data1 : data2);
  325. } else if (features1 & NUdf::EDataTypeFeatures::TimeIntervalType) {
  326. return NewOptionalType(features2 & NUdf::EDataTypeFeatures::IntegralType ? data1 : data2);
  327. } else if (features2 & NUdf::EDataTypeFeatures::TimeIntervalType) {
  328. return NewOptionalType(features1 & NUdf::EDataTypeFeatures::IntegralType ? data2 : data1);
  329. } else if (
  330. features1 & (NUdf::EDataTypeFeatures::DateType | NUdf::EDataTypeFeatures::TzDateType) &&
  331. features2 & (NUdf::EDataTypeFeatures::DateType | NUdf::EDataTypeFeatures::TzDateType)
  332. ) {
  333. const auto used = ((features1 | features2) & NUdf::EDataTypeFeatures::BigDateType)
  334. ? NewDataType(NUdf::EDataSlot::Interval64)
  335. : NewDataType(NUdf::EDataSlot::Interval);
  336. return isOptional ? NewOptionalType(used) : used;
  337. } else if (data1->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  338. MKQL_ENSURE(data1->IsSameType(*data2), "Must be same type.");
  339. return isOptional ? NewOptionalType(data1) : data2;
  340. }
  341. const auto data = NewDataType(MakeNumericDataSuperType(data1->GetSchemeType(), data2->GetSchemeType()));
  342. return isOptional ? NewOptionalType(data) : data;
  343. }
  344. TRuntimeNode TProgramBuilder::Arg(TType* type) const {
  345. TCallableBuilder builder(Env, __func__, type, true);
  346. return TRuntimeNode(builder.Build(), false);
  347. }
  348. TRuntimeNode TProgramBuilder::WideFlowArg(TType* type) const {
  349. if constexpr (RuntimeVersion < 18U) {
  350. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  351. }
  352. TCallableBuilder builder(Env, __func__, type, true);
  353. return TRuntimeNode(builder.Build(), false);
  354. }
  355. TRuntimeNode TProgramBuilder::Member(TRuntimeNode structObj, const std::string_view& memberName) {
  356. bool isOptional;
  357. const auto type = AS_TYPE(TStructType, UnpackOptional(structObj.GetStaticType(), isOptional));
  358. const auto memberIndex = type->GetMemberIndex(memberName);
  359. auto memberType = type->GetMemberType(memberIndex);
  360. if (isOptional && !memberType->IsOptional() && !memberType->IsNull() && !memberType->IsPg()) {
  361. memberType = NewOptionalType(memberType);
  362. }
  363. TCallableBuilder callableBuilder(Env, __func__, memberType);
  364. callableBuilder.Add(structObj);
  365. callableBuilder.Add(NewDataLiteral<ui32>(memberIndex));
  366. return TRuntimeNode(callableBuilder.Build(), false);
  367. }
  368. TRuntimeNode TProgramBuilder::Element(TRuntimeNode structObj, const std::string_view& memberName) {
  369. return Member(structObj, memberName);
  370. }
  371. TRuntimeNode TProgramBuilder::AddMember(TRuntimeNode structObj, const std::string_view& memberName, TRuntimeNode memberValue) {
  372. auto oldType = structObj.GetStaticType();
  373. MKQL_ENSURE(oldType->IsStruct(), "Expected struct");
  374. const auto& oldTypeDetailed = static_cast<const TStructType&>(*oldType);
  375. TStructTypeBuilder newTypeBuilder(Env);
  376. newTypeBuilder.Reserve(oldTypeDetailed.GetMembersCount() + 1);
  377. for (ui32 i = 0, e = oldTypeDetailed.GetMembersCount(); i < e; ++i) {
  378. newTypeBuilder.Add(oldTypeDetailed.GetMemberName(i), oldTypeDetailed.GetMemberType(i));
  379. }
  380. newTypeBuilder.Add(memberName, memberValue.GetStaticType());
  381. auto newType = newTypeBuilder.Build();
  382. for (ui32 i = 0, e = newType->GetMembersCount(); i < e; ++i) {
  383. if (newType->GetMemberName(i) == memberName) {
  384. // insert at position i in the struct
  385. TCallableBuilder callableBuilder(Env, __func__, newType);
  386. callableBuilder.Add(structObj);
  387. callableBuilder.Add(memberValue);
  388. callableBuilder.Add(NewDataLiteral<ui32>(i));
  389. return TRuntimeNode(callableBuilder.Build(), false);
  390. }
  391. }
  392. Y_ABORT();
  393. }
  394. TRuntimeNode TProgramBuilder::RemoveMember(TRuntimeNode structObj, const std::string_view& memberName, bool forced) {
  395. auto oldType = structObj.GetStaticType();
  396. MKQL_ENSURE(oldType->IsStruct(), "Expected struct");
  397. const auto& oldTypeDetailed = static_cast<const TStructType&>(*oldType);
  398. MKQL_ENSURE(oldTypeDetailed.GetMembersCount() > 0, "Expected non-empty struct");
  399. TStructTypeBuilder newTypeBuilder(Env);
  400. newTypeBuilder.Reserve(oldTypeDetailed.GetMembersCount() - 1);
  401. std::optional<ui32> memberIndex;
  402. for (ui32 i = 0, e = oldTypeDetailed.GetMembersCount(); i < e; ++i) {
  403. if (oldTypeDetailed.GetMemberName(i) != memberName) {
  404. newTypeBuilder.Add(oldTypeDetailed.GetMemberName(i), oldTypeDetailed.GetMemberType(i));
  405. }
  406. else {
  407. memberIndex = i;
  408. }
  409. }
  410. if (!memberIndex && forced) {
  411. return structObj;
  412. }
  413. MKQL_ENSURE(memberIndex, "Unknown member name: " << memberName);
  414. // remove at position i in the struct
  415. auto newType = newTypeBuilder.Build();
  416. TCallableBuilder callableBuilder(Env, __func__, newType);
  417. callableBuilder.Add(structObj);
  418. callableBuilder.Add(NewDataLiteral<ui32>(*memberIndex));
  419. return TRuntimeNode(callableBuilder.Build(), false);
  420. }
  421. TRuntimeNode TProgramBuilder::Zip(const TArrayRef<const TRuntimeNode>& lists) {
  422. if (lists.empty()) {
  423. return NewEmptyList(Env.GetEmptyTupleLazy()->GetGenericType());
  424. }
  425. std::vector<TType*> tupleTypes;
  426. tupleTypes.reserve(lists.size());
  427. for (auto& list : lists) {
  428. if (list.GetStaticType()->IsEmptyList()) {
  429. tupleTypes.push_back(Env.GetTypeOfVoidLazy());
  430. continue;
  431. }
  432. AS_TYPE(TListType, list.GetStaticType());
  433. auto itemType = static_cast<const TListType&>(*list.GetStaticType()).GetItemType();
  434. tupleTypes.push_back(itemType);
  435. }
  436. auto returnType = TListType::Create(TTupleType::Create(tupleTypes.size(), tupleTypes.data(), Env), Env);
  437. TCallableBuilder callableBuilder(Env, __func__, returnType);
  438. for (auto& list : lists) {
  439. callableBuilder.Add(list);
  440. }
  441. return TRuntimeNode(callableBuilder.Build(), false);
  442. }
  443. TRuntimeNode TProgramBuilder::ZipAll(const TArrayRef<const TRuntimeNode>& lists) {
  444. if (lists.empty()) {
  445. return NewEmptyList(Env.GetEmptyTupleLazy()->GetGenericType());
  446. }
  447. std::vector<TType*> tupleTypes;
  448. tupleTypes.reserve(lists.size());
  449. for (auto& list : lists) {
  450. if (list.GetStaticType()->IsEmptyList()) {
  451. tupleTypes.push_back(TOptionalType::Create(Env.GetTypeOfVoidLazy(), Env));
  452. continue;
  453. }
  454. AS_TYPE(TListType, list.GetStaticType());
  455. auto itemType = static_cast<const TListType&>(*list.GetStaticType()).GetItemType();
  456. tupleTypes.push_back(TOptionalType::Create(itemType, Env));
  457. }
  458. auto returnType = TListType::Create(TTupleType::Create(tupleTypes.size(), tupleTypes.data(), Env), Env);
  459. TCallableBuilder callableBuilder(Env, __func__, returnType);
  460. for (auto& list : lists) {
  461. callableBuilder.Add(list);
  462. }
  463. return TRuntimeNode(callableBuilder.Build(), false);
  464. }
  465. TRuntimeNode TProgramBuilder::Enumerate(TRuntimeNode list, TRuntimeNode start, TRuntimeNode step) {
  466. const auto itemType = AS_TYPE(TListType, list.GetStaticType())->GetItemType();
  467. ThrowIfListOfVoid(itemType);
  468. MKQL_ENSURE(AS_TYPE(TDataType, start)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64 as start");
  469. MKQL_ENSURE(AS_TYPE(TDataType, step)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64 as step");
  470. const std::array<TType*, 2U> tupleTypes = {{ NewDataType(NUdf::EDataSlot::Uint64), itemType }};
  471. const auto returnType = NewListType(NewTupleType(tupleTypes));
  472. TCallableBuilder callableBuilder(Env, __func__, returnType);
  473. callableBuilder.Add(list);
  474. callableBuilder.Add(start);
  475. callableBuilder.Add(step);
  476. return TRuntimeNode(callableBuilder.Build(), false);
  477. }
  478. TRuntimeNode TProgramBuilder::Enumerate(TRuntimeNode list) {
  479. return TProgramBuilder::Enumerate(list, NewDataLiteral<ui64>(0), NewDataLiteral<ui64>(1));
  480. }
  481. TRuntimeNode TProgramBuilder::Fold(TRuntimeNode list, TRuntimeNode state, const TBinaryLambda& handler) {
  482. const auto itemType = AS_TYPE(TListType, list.GetStaticType())->GetItemType();
  483. ThrowIfListOfVoid(itemType);
  484. const auto stateNodeArg = Arg(state.GetStaticType());
  485. const auto itemArg = Arg(itemType);
  486. const auto newState = handler(itemArg, stateNodeArg);
  487. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*state.GetStaticType()), "State type is changed by the handler");
  488. TCallableBuilder callableBuilder(Env, __func__, state.GetStaticType());
  489. callableBuilder.Add(list);
  490. callableBuilder.Add(state);
  491. callableBuilder.Add(itemArg);
  492. callableBuilder.Add(stateNodeArg);
  493. callableBuilder.Add(newState);
  494. return TRuntimeNode(callableBuilder.Build(), false);
  495. }
  496. TRuntimeNode TProgramBuilder::Fold1(TRuntimeNode list, const TUnaryLambda& init, const TBinaryLambda& handler) {
  497. const auto itemType = AS_TYPE(TListType, list.GetStaticType())->GetItemType();
  498. ThrowIfListOfVoid(itemType);
  499. const auto itemArg = Arg(itemType);
  500. const auto initState = init(itemArg);
  501. const auto stateNodeArg = Arg(initState.GetStaticType());
  502. const auto newState = handler(itemArg, stateNodeArg);
  503. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*initState.GetStaticType()), "State type is changed by the handler");
  504. TCallableBuilder callableBuilder(Env, __func__, NewOptionalType(newState.GetStaticType()));
  505. callableBuilder.Add(list);
  506. callableBuilder.Add(itemArg);
  507. callableBuilder.Add(initState);
  508. callableBuilder.Add(stateNodeArg);
  509. callableBuilder.Add(newState);
  510. return TRuntimeNode(callableBuilder.Build(), false);
  511. }
  512. TRuntimeNode TProgramBuilder::Reduce(TRuntimeNode list, TRuntimeNode state1,
  513. const TBinaryLambda& handler1,
  514. const TUnaryLambda& handler2,
  515. TRuntimeNode state3,
  516. const TBinaryLambda& handler3) {
  517. const auto listType = list.GetStaticType();
  518. MKQL_ENSURE(listType->IsList() || listType->IsStream(), "Expected list or stream");
  519. const auto itemType = listType->IsList()?
  520. static_cast<const TListType&>(*listType).GetItemType():
  521. static_cast<const TStreamType&>(*listType).GetItemType();
  522. ThrowIfListOfVoid(itemType);
  523. const auto state1NodeArg = Arg(state1.GetStaticType());
  524. const auto state3NodeArg = Arg(state3.GetStaticType());
  525. const auto itemArg = Arg(itemType);
  526. const auto newState1 = handler1(itemArg, state1NodeArg);
  527. MKQL_ENSURE(newState1.GetStaticType()->IsSameType(*state1.GetStaticType()), "State 1 type is changed by the handler");
  528. const auto newState2 = handler2(state1NodeArg);
  529. TRuntimeNode itemState2Arg = Arg(newState2.GetStaticType());
  530. const auto newState3 = handler3(itemState2Arg, state3NodeArg);
  531. MKQL_ENSURE(newState3.GetStaticType()->IsSameType(*state3.GetStaticType()), "State 3 type is changed by the handler");
  532. TCallableBuilder callableBuilder(Env, __func__, newState3.GetStaticType());
  533. callableBuilder.Add(list);
  534. callableBuilder.Add(state1);
  535. callableBuilder.Add(state3);
  536. callableBuilder.Add(itemArg);
  537. callableBuilder.Add(state1NodeArg);
  538. callableBuilder.Add(newState1);
  539. callableBuilder.Add(newState2);
  540. callableBuilder.Add(itemState2Arg);
  541. callableBuilder.Add(state3NodeArg);
  542. callableBuilder.Add(newState3);
  543. return TRuntimeNode(callableBuilder.Build(), false);
  544. }
  545. TRuntimeNode TProgramBuilder::Condense(TRuntimeNode flow, TRuntimeNode state,
  546. const TBinaryLambda& switcher,
  547. const TBinaryLambda& handler, bool useCtx) {
  548. const auto flowType = flow.GetStaticType();
  549. if (flowType->IsList()) {
  550. // TODO: Native implementation for list.
  551. return Collect(Condense(ToFlow(flow), state, switcher, handler));
  552. }
  553. MKQL_ENSURE(flowType->IsFlow() || flowType->IsStream(), "Expected flow or stream.");
  554. const auto itemType = flowType->IsFlow() ?
  555. static_cast<const TFlowType&>(*flowType).GetItemType():
  556. static_cast<const TStreamType&>(*flowType).GetItemType();
  557. const auto itemArg = Arg(itemType);
  558. const auto stateArg = Arg(state.GetStaticType());
  559. const auto outSwitch = switcher(itemArg, stateArg);
  560. const auto newState = handler(itemArg, stateArg);
  561. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*state.GetStaticType()), "State type is changed by the handler");
  562. TCallableBuilder callableBuilder(Env, __func__, flowType->IsFlow() ? NewFlowType(state.GetStaticType()) : NewStreamType(state.GetStaticType()));
  563. callableBuilder.Add(flow);
  564. callableBuilder.Add(state);
  565. callableBuilder.Add(itemArg);
  566. callableBuilder.Add(stateArg);
  567. callableBuilder.Add(outSwitch);
  568. callableBuilder.Add(newState);
  569. if (useCtx) {
  570. MKQL_ENSURE(RuntimeVersion >= 30U, "Too old runtime version");
  571. callableBuilder.Add(NewDataLiteral<bool>(useCtx));
  572. }
  573. return TRuntimeNode(callableBuilder.Build(), false);
  574. }
  575. TRuntimeNode TProgramBuilder::Condense1(TRuntimeNode flow, const TUnaryLambda& init,
  576. const TBinaryLambda& switcher,
  577. const TBinaryLambda& handler, bool useCtx) {
  578. const auto flowType = flow.GetStaticType();
  579. if (flowType->IsList()) {
  580. // TODO: Native implementation for list.
  581. return Collect(Condense1(ToFlow(flow), init, switcher, handler));
  582. }
  583. MKQL_ENSURE(flowType->IsFlow() || flowType->IsStream(), "Expected flow or stream.");
  584. const auto itemType = flowType->IsFlow() ?
  585. static_cast<const TFlowType&>(*flowType).GetItemType():
  586. static_cast<const TStreamType&>(*flowType).GetItemType();
  587. const auto itemArg = Arg(itemType);
  588. const auto initState = init(itemArg);
  589. const auto stateArg = Arg(initState.GetStaticType());
  590. const auto outSwitch = switcher(itemArg, stateArg);
  591. const auto newState = handler(itemArg, stateArg);
  592. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*initState.GetStaticType()), "State type is changed by the handler");
  593. TCallableBuilder callableBuilder(Env, __func__, flowType->IsFlow() ? NewFlowType(newState.GetStaticType()) : NewStreamType(newState.GetStaticType()));
  594. callableBuilder.Add(flow);
  595. callableBuilder.Add(itemArg);
  596. callableBuilder.Add(initState);
  597. callableBuilder.Add(stateArg);
  598. callableBuilder.Add(outSwitch);
  599. callableBuilder.Add(newState);
  600. if (useCtx) {
  601. MKQL_ENSURE(RuntimeVersion >= 30U, "Too old runtime version");
  602. callableBuilder.Add(NewDataLiteral<bool>(useCtx));
  603. }
  604. return TRuntimeNode(callableBuilder.Build(), false);
  605. }
  606. TRuntimeNode TProgramBuilder::Squeeze(TRuntimeNode stream, TRuntimeNode state,
  607. const TBinaryLambda& handler,
  608. const TUnaryLambda& save,
  609. const TUnaryLambda& load) {
  610. const auto streamType = stream.GetStaticType();
  611. MKQL_ENSURE(streamType->IsStream(), "Expected stream");
  612. const auto& streamDetailedType = static_cast<const TStreamType&>(*streamType);
  613. const auto itemType = streamDetailedType.GetItemType();
  614. ThrowIfListOfVoid(itemType);
  615. const auto stateNodeArg = Arg(state.GetStaticType());
  616. const auto itemArg = Arg(itemType);
  617. const auto newState = handler(itemArg, stateNodeArg);
  618. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*state.GetStaticType()), "State type is changed by the handler");
  619. TRuntimeNode saveArg, outSave, loadArg, outLoad;
  620. if (save && load) {
  621. outSave = save(saveArg = Arg(state.GetStaticType()));
  622. outLoad = load(loadArg = Arg(outSave.GetStaticType()));
  623. MKQL_ENSURE(outLoad.GetStaticType()->IsSameType(*state.GetStaticType()), "Loaded type is changed by the load handler");
  624. } else {
  625. saveArg = outSave = loadArg = outLoad = NewVoid();
  626. }
  627. TCallableBuilder callableBuilder(Env, __func__, TStreamType::Create(state.GetStaticType(), Env));
  628. callableBuilder.Add(stream);
  629. callableBuilder.Add(state);
  630. callableBuilder.Add(itemArg);
  631. callableBuilder.Add(stateNodeArg);
  632. callableBuilder.Add(newState);
  633. callableBuilder.Add(saveArg);
  634. callableBuilder.Add(outSave);
  635. callableBuilder.Add(loadArg);
  636. callableBuilder.Add(outLoad);
  637. return TRuntimeNode(callableBuilder.Build(), false);
  638. }
  639. TRuntimeNode TProgramBuilder::Squeeze1(TRuntimeNode stream, const TUnaryLambda& init,
  640. const TBinaryLambda& handler,
  641. const TUnaryLambda& save,
  642. const TUnaryLambda& load) {
  643. const auto streamType = stream.GetStaticType();
  644. MKQL_ENSURE(streamType->IsStream(), "Expected stream");
  645. const auto& streamDetailedType = static_cast<const TStreamType&>(*streamType);
  646. const auto itemType = streamDetailedType.GetItemType();
  647. ThrowIfListOfVoid(itemType);
  648. const auto itemArg = Arg(itemType);
  649. const auto initState = init(itemArg);
  650. const auto stateNodeArg = Arg(initState.GetStaticType());
  651. const auto newState = handler(itemArg, stateNodeArg);
  652. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*initState.GetStaticType()), "State type is changed by the handler");
  653. TRuntimeNode saveArg, outSave, loadArg, outLoad;
  654. if (save && load) {
  655. outSave = save(saveArg = Arg(initState.GetStaticType()));
  656. outLoad = load(loadArg = Arg(outSave.GetStaticType()));
  657. MKQL_ENSURE(outLoad.GetStaticType()->IsSameType(*initState.GetStaticType()), "Loaded type is changed by the load handler");
  658. } else {
  659. saveArg = outSave = loadArg = outLoad = NewVoid();
  660. }
  661. TCallableBuilder callableBuilder(Env, __func__, NewStreamType(newState.GetStaticType()));
  662. callableBuilder.Add(stream);
  663. callableBuilder.Add(itemArg);
  664. callableBuilder.Add(initState);
  665. callableBuilder.Add(stateNodeArg);
  666. callableBuilder.Add(newState);
  667. callableBuilder.Add(saveArg);
  668. callableBuilder.Add(outSave);
  669. callableBuilder.Add(loadArg);
  670. callableBuilder.Add(outLoad);
  671. return TRuntimeNode(callableBuilder.Build(), false);
  672. }
  673. TRuntimeNode TProgramBuilder::Discard(TRuntimeNode stream) {
  674. const auto streamType = stream.GetStaticType();
  675. MKQL_ENSURE(streamType->IsStream() || streamType->IsFlow(), "Expected stream or flow.");
  676. TCallableBuilder callableBuilder(Env, __func__, streamType);
  677. callableBuilder.Add(stream);
  678. return TRuntimeNode(callableBuilder.Build(), false);
  679. }
  680. TRuntimeNode TProgramBuilder::Map(TRuntimeNode list, const TUnaryLambda& handler) {
  681. return BuildMap(__func__, list, handler);
  682. }
  683. TRuntimeNode TProgramBuilder::OrderedMap(TRuntimeNode list, const TUnaryLambda& handler) {
  684. return BuildMap(__func__, list, handler);
  685. }
  686. TRuntimeNode TProgramBuilder::MapNext(TRuntimeNode list, const TBinaryLambda& handler) {
  687. const auto listType = list.GetStaticType();
  688. MKQL_ENSURE(listType->IsStream() || listType->IsFlow(), "Expected stream or flow");
  689. const auto itemType = listType->IsFlow() ?
  690. AS_TYPE(TFlowType, listType)->GetItemType():
  691. AS_TYPE(TStreamType, listType)->GetItemType();
  692. ThrowIfListOfVoid(itemType);
  693. TType* nextItemType = TOptionalType::Create(itemType, Env);
  694. const auto itemArg = Arg(itemType);
  695. const auto nextItemArg = Arg(nextItemType);
  696. const auto newItem = handler(itemArg, nextItemArg);
  697. const auto resultListType = listType->IsFlow() ?
  698. (TType*)TFlowType::Create(newItem.GetStaticType(), Env):
  699. (TType*)TStreamType::Create(newItem.GetStaticType(), Env);
  700. TCallableBuilder callableBuilder(Env, __func__, resultListType);
  701. callableBuilder.Add(list);
  702. callableBuilder.Add(itemArg);
  703. callableBuilder.Add(nextItemArg);
  704. callableBuilder.Add(newItem);
  705. return TRuntimeNode(callableBuilder.Build(), false);
  706. }
  707. template <bool Ordered>
  708. TRuntimeNode TProgramBuilder::BuildExtract(TRuntimeNode list, const std::string_view& name) {
  709. const auto listType = list.GetStaticType();
  710. MKQL_ENSURE(listType->IsList() || listType->IsOptional(), "Expected list or optional.");
  711. const auto itemType = listType->IsList() ?
  712. AS_TYPE(TListType, listType)->GetItemType():
  713. AS_TYPE(TOptionalType, listType)->GetItemType();
  714. const auto lambda = [&](TRuntimeNode item) {
  715. return itemType->IsStruct() ? Member(item, name) : Nth(item, ::FromString<ui32>(name));
  716. };
  717. return Ordered ? OrderedMap(list, lambda) : Map(list, lambda);
  718. }
  719. TRuntimeNode TProgramBuilder::Extract(TRuntimeNode list, const std::string_view& name) {
  720. return BuildExtract<false>(list, name);
  721. }
  722. TRuntimeNode TProgramBuilder::OrderedExtract(TRuntimeNode list, const std::string_view& name) {
  723. return BuildExtract<true>(list, name);
  724. }
  725. TRuntimeNode TProgramBuilder::ChainMap(TRuntimeNode list, TRuntimeNode state, const TBinaryLambda& handler) {
  726. return ChainMap(list, state, [&](TRuntimeNode item, TRuntimeNode state) -> TRuntimeNodePair {
  727. const auto result = handler(item, state);
  728. return {result, result};
  729. });
  730. }
  731. TRuntimeNode TProgramBuilder::ChainMap(TRuntimeNode list, TRuntimeNode state, const TBinarySplitLambda& handler) {
  732. const auto listType = list.GetStaticType();
  733. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream");
  734. const auto itemType = listType->IsFlow() ?
  735. AS_TYPE(TFlowType, listType)->GetItemType():
  736. listType->IsList() ?
  737. AS_TYPE(TListType, listType)->GetItemType():
  738. AS_TYPE(TStreamType, listType)->GetItemType();
  739. ThrowIfListOfVoid(itemType);
  740. const auto stateNodeArg = Arg(state.GetStaticType());
  741. const auto itemArg = Arg(itemType);
  742. const auto newItemAndState = handler(itemArg, stateNodeArg);
  743. MKQL_ENSURE(std::get<1U>(newItemAndState).GetStaticType()->IsSameType(*state.GetStaticType()), "State type is changed by the handler");
  744. const auto resultItemType = std::get<0U>(newItemAndState).GetStaticType();
  745. TType* resultListType = nullptr;
  746. if (listType->IsFlow()) {
  747. resultListType = TFlowType::Create(resultItemType, Env);
  748. } else if (listType->IsList()) {
  749. resultListType = TListType::Create(resultItemType, Env);
  750. } else if (listType->IsStream()) {
  751. resultListType = TStreamType::Create(resultItemType, Env);
  752. }
  753. TCallableBuilder callableBuilder(Env, __func__, resultListType);
  754. callableBuilder.Add(list);
  755. callableBuilder.Add(state);
  756. callableBuilder.Add(itemArg);
  757. callableBuilder.Add(stateNodeArg);
  758. callableBuilder.Add(std::get<0U>(newItemAndState));
  759. callableBuilder.Add(std::get<1U>(newItemAndState));
  760. return TRuntimeNode(callableBuilder.Build(), false);
  761. }
  762. TRuntimeNode TProgramBuilder::Chain1Map(TRuntimeNode list, const TUnaryLambda& init, const TBinaryLambda& handler) {
  763. return Chain1Map(list,
  764. [&](TRuntimeNode item) -> TRuntimeNodePair {
  765. const auto result = init(item);
  766. return {result, result};
  767. },
  768. [&](TRuntimeNode item, TRuntimeNode state) -> TRuntimeNodePair {
  769. const auto result = handler(item, state);
  770. return {result, result};
  771. }
  772. );
  773. }
  774. TRuntimeNode TProgramBuilder::Chain1Map(TRuntimeNode list, const TUnarySplitLambda& init, const TBinarySplitLambda& handler) {
  775. const auto listType = list.GetStaticType();
  776. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream");
  777. const auto itemType = listType->IsFlow() ?
  778. AS_TYPE(TFlowType, listType)->GetItemType():
  779. listType->IsList() ?
  780. AS_TYPE(TListType, listType)->GetItemType():
  781. AS_TYPE(TStreamType, listType)->GetItemType();
  782. ThrowIfListOfVoid(itemType);
  783. const auto itemArg = Arg(itemType);
  784. const auto initItemAndState = init(itemArg);
  785. const auto resultItemType = std::get<0U>(initItemAndState).GetStaticType();
  786. const auto stateType = std::get<1U>(initItemAndState).GetStaticType();;
  787. TType* resultListType = nullptr;
  788. if (listType->IsFlow()) {
  789. resultListType = TFlowType::Create(resultItemType, Env);
  790. } else if (listType->IsList()) {
  791. resultListType = TListType::Create(resultItemType, Env);
  792. } else if (listType->IsStream()) {
  793. resultListType = TStreamType::Create(resultItemType, Env);
  794. }
  795. const auto stateArg = Arg(stateType);
  796. const auto updateItemAndState = handler(itemArg, stateArg);
  797. MKQL_ENSURE(std::get<0U>(updateItemAndState).GetStaticType()->IsSameType(*resultItemType), "Item type is changed by the handler");
  798. MKQL_ENSURE(std::get<1U>(updateItemAndState).GetStaticType()->IsSameType(*stateType), "State type is changed by the handler");
  799. TCallableBuilder callableBuilder(Env, __func__, resultListType);
  800. callableBuilder.Add(list);
  801. callableBuilder.Add(itemArg);
  802. callableBuilder.Add(std::get<0U>(initItemAndState));
  803. callableBuilder.Add(std::get<1U>(initItemAndState));
  804. callableBuilder.Add(stateArg);
  805. callableBuilder.Add(std::get<0U>(updateItemAndState));
  806. callableBuilder.Add(std::get<1U>(updateItemAndState));
  807. return TRuntimeNode(callableBuilder.Build(), false);
  808. }
  809. TRuntimeNode TProgramBuilder::ToList(TRuntimeNode optional) {
  810. const auto optionalType = optional.GetStaticType();
  811. MKQL_ENSURE(optionalType->IsOptional(), "Expected optional");
  812. const auto& optionalDetailedType = static_cast<const TOptionalType&>(*optionalType);
  813. const auto itemType = optionalDetailedType.GetItemType();
  814. return IfPresent(optional, [&](TRuntimeNode item) { return AsList(item); }, NewEmptyList(itemType));
  815. }
  816. TRuntimeNode TProgramBuilder::Iterable(TZeroLambda lambda) {
  817. if constexpr (RuntimeVersion < 19U) {
  818. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  819. }
  820. const auto itemArg = Arg(NewNull().GetStaticType());
  821. auto lambdaRes = lambda();
  822. const auto resultType = NewListType(AS_TYPE(TStreamType, lambdaRes.GetStaticType())->GetItemType());
  823. TCallableBuilder callableBuilder(Env, __func__, resultType);
  824. callableBuilder.Add(lambdaRes);
  825. callableBuilder.Add(itemArg);
  826. return TRuntimeNode(callableBuilder.Build(), false);
  827. }
  828. TRuntimeNode TProgramBuilder::ToOptional(TRuntimeNode list) {
  829. return Head(list);
  830. }
  831. TRuntimeNode TProgramBuilder::Head(TRuntimeNode list) {
  832. const auto resultType = NewOptionalType(AS_TYPE(TListType, list.GetStaticType())->GetItemType());
  833. TCallableBuilder callableBuilder(Env, __func__, resultType);
  834. callableBuilder.Add(list);
  835. return TRuntimeNode(callableBuilder.Build(), false);
  836. }
  837. TRuntimeNode TProgramBuilder::Last(TRuntimeNode list) {
  838. const auto resultType = NewOptionalType(AS_TYPE(TListType, list.GetStaticType())->GetItemType());
  839. TCallableBuilder callableBuilder(Env, __func__, resultType);
  840. callableBuilder.Add(list);
  841. return TRuntimeNode(callableBuilder.Build(), false);
  842. }
  843. TRuntimeNode TProgramBuilder::Nanvl(TRuntimeNode data, TRuntimeNode dataIfNaN) {
  844. const std::array<TRuntimeNode, 2> args = {{ data, dataIfNaN }};
  845. return Invoke(__func__, BuildArithmeticCommonType(data.GetStaticType(), dataIfNaN.GetStaticType()), args);
  846. }
  847. TRuntimeNode TProgramBuilder::FlatMap(TRuntimeNode list, const TUnaryLambda& handler)
  848. {
  849. return BuildFlatMap(__func__, list, handler);
  850. }
  851. TRuntimeNode TProgramBuilder::OrderedFlatMap(TRuntimeNode list, const TUnaryLambda& handler)
  852. {
  853. return BuildFlatMap(__func__, list, handler);
  854. }
  855. TRuntimeNode TProgramBuilder::Filter(TRuntimeNode list, const TUnaryLambda& handler)
  856. {
  857. return BuildFilter(__func__, list, handler);
  858. }
  859. TRuntimeNode TProgramBuilder::Filter(TRuntimeNode list, TRuntimeNode limit, const TUnaryLambda& handler)
  860. {
  861. return BuildFilter(__func__, list, limit, handler);
  862. }
  863. TRuntimeNode TProgramBuilder::OrderedFilter(TRuntimeNode list, const TUnaryLambda& handler)
  864. {
  865. return BuildFilter(__func__, list, handler);
  866. }
  867. TRuntimeNode TProgramBuilder::OrderedFilter(TRuntimeNode list, TRuntimeNode limit, const TUnaryLambda& handler)
  868. {
  869. return BuildFilter(__func__, list, limit, handler);
  870. }
  871. TRuntimeNode TProgramBuilder::TakeWhile(TRuntimeNode list, const TUnaryLambda& handler)
  872. {
  873. return BuildFilter(__func__, list, handler);
  874. }
  875. TRuntimeNode TProgramBuilder::SkipWhile(TRuntimeNode list, const TUnaryLambda& handler)
  876. {
  877. return BuildFilter(__func__, list, handler);
  878. }
  879. TRuntimeNode TProgramBuilder::TakeWhileInclusive(TRuntimeNode list, const TUnaryLambda& handler)
  880. {
  881. return BuildFilter(__func__, list, handler);
  882. }
  883. TRuntimeNode TProgramBuilder::SkipWhileInclusive(TRuntimeNode list, const TUnaryLambda& handler)
  884. {
  885. return BuildFilter(__func__, list, handler);
  886. }
  887. TRuntimeNode TProgramBuilder::BuildListSort(const std::string_view& callableName, TRuntimeNode list, TRuntimeNode ascending,
  888. const TUnaryLambda& keyExtractor)
  889. {
  890. const auto listType = list.GetStaticType();
  891. MKQL_ENSURE(listType->IsList(), "Expected list.");
  892. const auto itemType = static_cast<const TListType&>(*listType).GetItemType();
  893. ThrowIfListOfVoid(itemType);
  894. const auto ascendingType = ascending.GetStaticType();
  895. const auto itemArg = Arg(itemType);
  896. auto key = keyExtractor(itemArg);
  897. if (ascendingType->IsTuple()) {
  898. const auto ascendingTuple = AS_TYPE(TTupleType, ascendingType);
  899. if (ascendingTuple->GetElementsCount() == 0) {
  900. return list;
  901. }
  902. if (ascendingTuple->GetElementsCount() == 1) {
  903. ascending = Nth(ascending, 0);
  904. key = Nth(key, 0);
  905. }
  906. }
  907. TCallableBuilder callableBuilder(Env, callableName, listType);
  908. callableBuilder.Add(list);
  909. callableBuilder.Add(itemArg);
  910. callableBuilder.Add(key);
  911. callableBuilder.Add(ascending);
  912. return TRuntimeNode(callableBuilder.Build(), false);
  913. }
  914. TRuntimeNode TProgramBuilder::BuildListNth(const std::string_view& callableName, TRuntimeNode list, TRuntimeNode n, TRuntimeNode ascending,
  915. const TUnaryLambda& keyExtractor)
  916. {
  917. const auto listType = list.GetStaticType();
  918. MKQL_ENSURE(listType->IsList(), "Expected list.");
  919. const auto itemType = static_cast<const TListType&>(*listType).GetItemType();
  920. ThrowIfListOfVoid(itemType);
  921. MKQL_ENSURE(n.GetStaticType()->IsData(), "Expected data");
  922. MKQL_ENSURE(static_cast<const TDataType&>(*n.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  923. const auto ascendingType = ascending.GetStaticType();
  924. const auto itemArg = Arg(itemType);
  925. auto key = keyExtractor(itemArg);
  926. if (ascendingType->IsTuple()) {
  927. const auto ascendingTuple = AS_TYPE(TTupleType, ascendingType);
  928. if (ascendingTuple->GetElementsCount() == 0) {
  929. return Take(list, n);
  930. }
  931. if (ascendingTuple->GetElementsCount() == 1) {
  932. ascending = Nth(ascending, 0);
  933. key = Nth(key, 0);
  934. }
  935. }
  936. TCallableBuilder callableBuilder(Env, callableName, listType);
  937. callableBuilder.Add(list);
  938. callableBuilder.Add(n);
  939. callableBuilder.Add(itemArg);
  940. callableBuilder.Add(key);
  941. callableBuilder.Add(ascending);
  942. return TRuntimeNode(callableBuilder.Build(), false);
  943. }
  944. TRuntimeNode TProgramBuilder::BuildSort(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode ascending,
  945. const TUnaryLambda& keyExtractor)
  946. {
  947. if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
  948. const bool newVersion = RuntimeVersion >= 25U && flowType->IsFlow();
  949. const auto condense = newVersion ?
  950. SqueezeToList(Map(flow, [&](TRuntimeNode item) { return Pickle(item); }), NewEmptyOptionalDataLiteral(NUdf::TDataType<ui64>::Id)) :
  951. Condense1(flow,
  952. [this](TRuntimeNode item) { return AsList(item); },
  953. [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
  954. [this](TRuntimeNode item, TRuntimeNode state) { return Append(state, item); }
  955. );
  956. const auto finalKeyExtractor = newVersion ? [&](TRuntimeNode item) {
  957. auto itemType = AS_TYPE(TFlowType, flowType)->GetItemType();
  958. return keyExtractor(Unpickle(itemType, item));
  959. } : keyExtractor;
  960. return FlatMap(condense, [&](TRuntimeNode list) {
  961. auto stealed = RuntimeVersion >= 27U ? Steal(list) : list;
  962. auto sorted = BuildSort(RuntimeVersion >= 26U ? "UnstableSort" : callableName, stealed, ascending, finalKeyExtractor);
  963. return newVersion ? Map(LazyList(sorted), [&](TRuntimeNode item) {
  964. auto itemType = AS_TYPE(TFlowType, flowType)->GetItemType();
  965. return Unpickle(itemType, item);
  966. }) : sorted;
  967. });
  968. }
  969. return BuildListSort(callableName, flow, ascending, keyExtractor);
  970. }
  971. TRuntimeNode TProgramBuilder::BuildNth(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode n, TRuntimeNode ascending,
  972. const TUnaryLambda& keyExtractor)
  973. {
  974. if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
  975. return FlatMap(Condense1(flow,
  976. [this](TRuntimeNode item) { return AsList(item); },
  977. [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
  978. [this](TRuntimeNode item, TRuntimeNode state) { return Append(state, item); }
  979. ),
  980. [&](TRuntimeNode list) { return BuildNth(callableName, list, n, ascending, keyExtractor); }
  981. );
  982. }
  983. return BuildListNth(callableName, flow, n, ascending, keyExtractor);
  984. }
  985. TRuntimeNode TProgramBuilder::BuildTake(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count) {
  986. const auto listType = flow.GetStaticType();
  987. TType* itemType = nullptr;
  988. if (listType->IsFlow()) {
  989. itemType = AS_TYPE(TFlowType, listType)->GetItemType();
  990. } else if (listType->IsList()) {
  991. itemType = AS_TYPE(TListType, listType)->GetItemType();
  992. } else if (listType->IsStream()) {
  993. itemType = AS_TYPE(TStreamType, listType)->GetItemType();
  994. }
  995. MKQL_ENSURE(itemType, "Expected flow, list or stream.");
  996. ThrowIfListOfVoid(itemType);
  997. MKQL_ENSURE(count.GetStaticType()->IsData(), "Expected data");
  998. MKQL_ENSURE(static_cast<const TDataType&>(*count.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  999. TCallableBuilder callableBuilder(Env, callableName, listType);
  1000. callableBuilder.Add(flow);
  1001. callableBuilder.Add(count);
  1002. return TRuntimeNode(callableBuilder.Build(), false);
  1003. }
  1004. template<bool IsFilter, bool OnStruct>
  1005. TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list) {
  1006. const auto listType = list.GetStaticType();
  1007. TType* itemType;
  1008. if (listType->IsFlow()) {
  1009. itemType = AS_TYPE(TFlowType, listType)->GetItemType();
  1010. } else if (listType->IsList()) {
  1011. itemType = AS_TYPE(TListType, listType)->GetItemType();
  1012. } else if (listType->IsStream()) {
  1013. itemType = AS_TYPE(TStreamType, listType)->GetItemType();
  1014. } else if (listType->IsOptional()) {
  1015. itemType = AS_TYPE(TOptionalType, listType)->GetItemType();
  1016. } else {
  1017. THROW yexception() << "Expected flow or list or stream or optional of " << (OnStruct ? "struct." : "tuple.");
  1018. }
  1019. std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TType*>>, std::vector<TType*>> filteredItems;
  1020. std::vector<std::conditional_t<OnStruct, std::string_view, ui32>> members;
  1021. const bool multiOptional = CollectOptionalElements<IsFilter>(itemType, members, filteredItems);
  1022. const auto predicate = [=](TRuntimeNode item) {
  1023. std::vector<TRuntimeNode> checkMembers;
  1024. checkMembers.reserve(members.size());
  1025. std::transform(members.cbegin(), members.cend(), std::back_inserter(checkMembers),
  1026. [=](const auto& i){ return Exists(Element(item, i)); });
  1027. return And(checkMembers);
  1028. };
  1029. auto resultType = listType;
  1030. if constexpr (IsFilter) {
  1031. if (const auto filteredItemType = NewArrayType(filteredItems); multiOptional) {
  1032. return BuildFilterNulls<OnStruct>(list, members, filteredItems);
  1033. } else {
  1034. resultType = listType->IsFlow() ?
  1035. NewFlowType(filteredItemType):
  1036. listType->IsList() ?
  1037. NewListType(filteredItemType):
  1038. listType->IsStream() ? NewStreamType(filteredItemType) : NewOptionalType(filteredItemType);
  1039. }
  1040. }
  1041. return Filter(list, predicate, resultType);
  1042. }
  1043. template<bool IsFilter, bool OnStruct>
  1044. TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list, const TArrayRef<std::conditional_t<OnStruct, const std::string_view, const ui32>>& members) {
  1045. if (members.empty()) {
  1046. return list;
  1047. }
  1048. const auto listType = list.GetStaticType();
  1049. TType* itemType;
  1050. if (listType->IsFlow()) {
  1051. itemType = AS_TYPE(TFlowType, listType)->GetItemType();
  1052. } else if (listType->IsList()) {
  1053. itemType = AS_TYPE(TListType, listType)->GetItemType();
  1054. } else if (listType->IsStream()) {
  1055. itemType = AS_TYPE(TStreamType, listType)->GetItemType();
  1056. } else if (listType->IsOptional()) {
  1057. itemType = AS_TYPE(TOptionalType, listType)->GetItemType();
  1058. } else {
  1059. THROW yexception() << "Expected flow or list or stream or optional of struct.";
  1060. }
  1061. const auto predicate = [=](TRuntimeNode item) {
  1062. TRuntimeNode::TList checkMembers;
  1063. checkMembers.reserve(members.size());
  1064. std::transform(members.cbegin(), members.cend(), std::back_inserter(checkMembers),
  1065. [=](const auto& i){ return Exists(Element(item, i)); });
  1066. return And(checkMembers);
  1067. };
  1068. auto resultType = listType;
  1069. if constexpr (IsFilter) {
  1070. if (std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TType*>>, std::vector<TType*>> filteredItems;
  1071. ReduceOptionalElements(itemType, members, filteredItems)) {
  1072. return BuildFilterNulls<OnStruct>(list, members, filteredItems);
  1073. } else {
  1074. const auto filteredItemType = NewArrayType(filteredItems);
  1075. resultType = listType->IsFlow() ?
  1076. NewFlowType(filteredItemType):
  1077. listType->IsList() ?
  1078. NewListType(filteredItemType):
  1079. listType->IsStream() ? NewStreamType(filteredItemType) : NewOptionalType(filteredItemType);
  1080. }
  1081. }
  1082. return Filter(list, predicate, resultType);
  1083. }
  1084. template<bool OnStruct>
  1085. TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list, const TArrayRef<std::conditional_t<OnStruct, const std::string_view, const ui32>>& members,
  1086. const std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TType*>>, std::vector<TType*>>& filteredItems) {
  1087. return FlatMap(list, [&](TRuntimeNode item) {
  1088. TRuntimeNode::TList checkMembers;
  1089. checkMembers.reserve(members.size());
  1090. std::transform(members.cbegin(), members.cend(), std::back_inserter(checkMembers),
  1091. [=](const auto& i){ return Element(item, i); });
  1092. return IfPresent(checkMembers, [&](TRuntimeNode::TList items) {
  1093. std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TRuntimeNode>>, TRuntimeNode::TList> row;
  1094. row.reserve(filteredItems.size());
  1095. auto j = 0U;
  1096. if constexpr (OnStruct) {
  1097. std::transform(filteredItems.cbegin(), filteredItems.cend(), std::back_inserter(row),
  1098. [&](const std::pair<std::string_view, TType*>& i) {
  1099. const auto& member = i.first;
  1100. const bool passtrought = members.cend() == std::find(members.cbegin(), members.cend(), member);
  1101. return std::make_pair(member, passtrought ? Element(item, member) : items[j++]);
  1102. }
  1103. );
  1104. return NewOptional(NewStruct(row));
  1105. } else {
  1106. auto i = 0U;
  1107. std::generate_n(std::back_inserter(row), filteredItems.size(),
  1108. [&]() {
  1109. const auto index = i++;
  1110. const bool passtrought = members.cend() == std::find(members.cbegin(), members.cend(), index);
  1111. return passtrought ? Element(item, index) : items[j++];
  1112. }
  1113. );
  1114. return NewOptional(NewTuple(row));
  1115. }
  1116. }, NewEmptyOptional(NewOptionalType(NewArrayType(filteredItems))));
  1117. });
  1118. }
  1119. TRuntimeNode TProgramBuilder::SkipNullMembers(TRuntimeNode list) {
  1120. return BuildFilterNulls<false, true>(list);
  1121. }
  1122. TRuntimeNode TProgramBuilder::FilterNullMembers(TRuntimeNode list) {
  1123. return BuildFilterNulls<true, true>(list);
  1124. }
  1125. TRuntimeNode TProgramBuilder::SkipNullMembers(TRuntimeNode list, const TArrayRef<const std::string_view>& members) {
  1126. return BuildFilterNulls<false, true>(list, members);
  1127. }
  1128. TRuntimeNode TProgramBuilder::FilterNullMembers(TRuntimeNode list, const TArrayRef<const std::string_view>& members) {
  1129. return BuildFilterNulls<true, true>(list, members);
  1130. }
  1131. TRuntimeNode TProgramBuilder::FilterNullElements(TRuntimeNode list) {
  1132. return BuildFilterNulls<true, false>(list);
  1133. }
  1134. TRuntimeNode TProgramBuilder::SkipNullElements(TRuntimeNode list) {
  1135. return BuildFilterNulls<false, false>(list);
  1136. }
  1137. TRuntimeNode TProgramBuilder::FilterNullElements(TRuntimeNode list, const TArrayRef<const ui32>& elements) {
  1138. return BuildFilterNulls<true, false>(list, elements);
  1139. }
  1140. TRuntimeNode TProgramBuilder::SkipNullElements(TRuntimeNode list, const TArrayRef<const ui32>& elements) {
  1141. return BuildFilterNulls<false, false>(list, elements);
  1142. }
  1143. template <typename ResultType>
  1144. TRuntimeNode TProgramBuilder::BuildContainerProperty(const std::string_view& callableName, TRuntimeNode listOrDict) {
  1145. const auto type = listOrDict.GetStaticType();
  1146. MKQL_ENSURE(type->IsList() || type->IsDict() || type->IsEmptyList() || type->IsEmptyDict(), "Expected list or dict.");
  1147. if (type->IsList()) {
  1148. const auto itemType = AS_TYPE(TListType, type)->GetItemType();
  1149. ThrowIfListOfVoid(itemType);
  1150. }
  1151. TCallableBuilder callableBuilder(Env, callableName, NewDataType(NUdf::TDataType<ResultType>::Id));
  1152. callableBuilder.Add(listOrDict);
  1153. return TRuntimeNode(callableBuilder.Build(), false);
  1154. }
  1155. TRuntimeNode TProgramBuilder::Length(TRuntimeNode listOrDict) {
  1156. return BuildContainerProperty<ui64>(__func__, listOrDict);
  1157. }
  1158. TRuntimeNode TProgramBuilder::Iterator(TRuntimeNode list, const TArrayRef<const TRuntimeNode>& dependentNodes) {
  1159. const auto streamType = NewStreamType(AS_TYPE(TListType, list.GetStaticType())->GetItemType());
  1160. TCallableBuilder callableBuilder(Env, __func__, streamType);
  1161. callableBuilder.Add(list);
  1162. for (auto node : dependentNodes) {
  1163. callableBuilder.Add(node);
  1164. }
  1165. return TRuntimeNode(callableBuilder.Build(), false);
  1166. }
  1167. TRuntimeNode TProgramBuilder::EmptyIterator(TType* streamType) {
  1168. MKQL_ENSURE(streamType->IsStream() || streamType->IsFlow(), "Expected stream or flow.");
  1169. if (RuntimeVersion < 7U && streamType->IsFlow()) {
  1170. return ToFlow(EmptyIterator(NewStreamType(AS_TYPE(TFlowType, streamType)->GetItemType())));
  1171. }
  1172. TCallableBuilder callableBuilder(Env, __func__, streamType);
  1173. return TRuntimeNode(callableBuilder.Build(), false);
  1174. }
  1175. TRuntimeNode TProgramBuilder::Collect(TRuntimeNode flow) {
  1176. const auto seqType = flow.GetStaticType();
  1177. TType* itemType = nullptr;
  1178. if (seqType->IsFlow()) {
  1179. itemType = AS_TYPE(TFlowType, seqType)->GetItemType();
  1180. } else if (seqType->IsList()) {
  1181. itemType = AS_TYPE(TListType, seqType)->GetItemType();
  1182. } else if (seqType->IsStream()) {
  1183. itemType = AS_TYPE(TStreamType, seqType)->GetItemType();
  1184. } else {
  1185. THROW yexception() << "Expected flow, list or stream.";
  1186. }
  1187. TCallableBuilder callableBuilder(Env, __func__, NewListType(itemType));
  1188. callableBuilder.Add(flow);
  1189. return TRuntimeNode(callableBuilder.Build(), false);
  1190. }
  1191. TRuntimeNode TProgramBuilder::LazyList(TRuntimeNode list) {
  1192. const auto type = list.GetStaticType();
  1193. bool isOptional;
  1194. const auto listType = UnpackOptional(type, isOptional);
  1195. MKQL_ENSURE(listType->IsList(), "Expected list");
  1196. TCallableBuilder callableBuilder(Env, __func__, type);
  1197. callableBuilder.Add(list);
  1198. return TRuntimeNode(callableBuilder.Build(), false);
  1199. }
  1200. TRuntimeNode TProgramBuilder::ForwardList(TRuntimeNode stream) {
  1201. const auto type = stream.GetStaticType();
  1202. MKQL_ENSURE(type->IsStream() || type->IsFlow(), "Expected flow or stream.");
  1203. if constexpr (RuntimeVersion < 10U) {
  1204. if (type->IsFlow()) {
  1205. return ForwardList(FromFlow(stream));
  1206. }
  1207. }
  1208. TCallableBuilder callableBuilder(Env, __func__, NewListType(type->IsFlow() ? AS_TYPE(TFlowType, stream)->GetItemType() : AS_TYPE(TStreamType, stream)->GetItemType()));
  1209. callableBuilder.Add(stream);
  1210. return TRuntimeNode(callableBuilder.Build(), false);
  1211. }
  1212. TRuntimeNode TProgramBuilder::ToFlow(TRuntimeNode stream) {
  1213. const auto type = stream.GetStaticType();
  1214. MKQL_ENSURE(type->IsStream() || type->IsList() || type->IsOptional(), "Expected stream, list or optional.");
  1215. const auto itemType = type->IsStream() ? AS_TYPE(TStreamType, stream)->GetItemType() :
  1216. type->IsList() ? AS_TYPE(TListType, stream)->GetItemType() : AS_TYPE(TOptionalType, stream)->GetItemType();
  1217. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(itemType));
  1218. callableBuilder.Add(stream);
  1219. return TRuntimeNode(callableBuilder.Build(), false);
  1220. }
  1221. TRuntimeNode TProgramBuilder::FromFlow(TRuntimeNode flow) {
  1222. MKQL_ENSURE(flow.GetStaticType()->IsFlow(), "Expected flow.");
  1223. TCallableBuilder callableBuilder(Env, __func__, NewStreamType(AS_TYPE(TFlowType, flow)->GetItemType()));
  1224. callableBuilder.Add(flow);
  1225. return TRuntimeNode(callableBuilder.Build(), false);
  1226. }
  1227. TRuntimeNode TProgramBuilder::Steal(TRuntimeNode input) {
  1228. if constexpr (RuntimeVersion < 27U) {
  1229. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1230. }
  1231. TCallableBuilder callableBuilder(Env, __func__, input.GetStaticType(), true);
  1232. callableBuilder.Add(input);
  1233. return TRuntimeNode(callableBuilder.Build(), false);
  1234. }
  1235. TRuntimeNode TProgramBuilder::ToBlocks(TRuntimeNode flow) {
  1236. auto* flowType = AS_TYPE(TFlowType, flow.GetStaticType());
  1237. auto* blockType = NewBlockType(flowType->GetItemType(), TBlockType::EShape::Many);
  1238. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(blockType));
  1239. callableBuilder.Add(flow);
  1240. return TRuntimeNode(callableBuilder.Build(), false);
  1241. }
  1242. TRuntimeNode TProgramBuilder::WideToBlocks(TRuntimeNode flow) {
  1243. TType* outputItemType;
  1244. {
  1245. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  1246. std::vector<TType*> outputItems;
  1247. outputItems.reserve(wideComponents.size());
  1248. for (size_t i = 0; i < wideComponents.size(); ++i) {
  1249. outputItems.push_back(NewBlockType(wideComponents[i], TBlockType::EShape::Many));
  1250. }
  1251. outputItems.push_back(NewBlockType(NewDataType(NUdf::TDataType<ui64>::Id), TBlockType::EShape::Scalar));
  1252. outputItemType = NewMultiType(outputItems);
  1253. }
  1254. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(outputItemType));
  1255. callableBuilder.Add(flow);
  1256. return TRuntimeNode(callableBuilder.Build(), false);
  1257. }
  1258. TRuntimeNode TProgramBuilder::FromBlocks(TRuntimeNode flow) {
  1259. auto* flowType = AS_TYPE(TFlowType, flow.GetStaticType());
  1260. auto* blockType = AS_TYPE(TBlockType, flowType->GetItemType());
  1261. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(blockType->GetItemType()));
  1262. callableBuilder.Add(flow);
  1263. return TRuntimeNode(callableBuilder.Build(), false);
  1264. }
  1265. TRuntimeNode TProgramBuilder::WideFromBlocks(TRuntimeNode stream) {
  1266. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected WideStream as input type");
  1267. if constexpr (RuntimeVersion < 55U) {
  1268. // Preserve the old behaviour for ABI compatibility.
  1269. // Emit (FromFlow (WideFromBlocks (ToFlow (<stream>)))) to
  1270. // process the flow in favor to the given stream following
  1271. // the older MKQL ABI.
  1272. // FIXME: Drop the branch below, when the time comes.
  1273. const auto inputFlow = ToFlow(stream);
  1274. auto outputItems = ValidateBlockFlowType(inputFlow.GetStaticType());
  1275. outputItems.pop_back();
  1276. TType* outputMultiType = NewMultiType(outputItems);
  1277. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(outputMultiType));
  1278. callableBuilder.Add(inputFlow);
  1279. const auto outputFlow = TRuntimeNode(callableBuilder.Build(), false);
  1280. return FromFlow(outputFlow);
  1281. }
  1282. auto outputItems = ValidateBlockStreamType(stream.GetStaticType());
  1283. outputItems.pop_back();
  1284. TType* outputMultiType = NewMultiType(outputItems);
  1285. TCallableBuilder callableBuilder(Env, __func__, NewStreamType(outputMultiType));
  1286. callableBuilder.Add(stream);
  1287. return TRuntimeNode(callableBuilder.Build(), false);
  1288. }
  1289. TRuntimeNode TProgramBuilder::WideSkipBlocks(TRuntimeNode flow, TRuntimeNode count) {
  1290. return BuildWideSkipTakeBlocks(__func__, flow, count);
  1291. }
  1292. TRuntimeNode TProgramBuilder::WideTakeBlocks(TRuntimeNode flow, TRuntimeNode count) {
  1293. return BuildWideSkipTakeBlocks(__func__, flow, count);
  1294. }
  1295. TRuntimeNode TProgramBuilder::WideTopBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
  1296. return BuildWideTopOrSort(__func__, flow, count, keys);
  1297. }
  1298. TRuntimeNode TProgramBuilder::WideTopSortBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
  1299. return BuildWideTopOrSort(__func__, flow, count, keys);
  1300. }
  1301. TRuntimeNode TProgramBuilder::WideSortBlocks(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
  1302. return BuildWideTopOrSort(__func__, flow, Nothing(), keys);
  1303. }
  1304. TRuntimeNode TProgramBuilder::AsScalar(TRuntimeNode value) {
  1305. TCallableBuilder callableBuilder(Env, __func__, NewBlockType(value.GetStaticType(), TBlockType::EShape::Scalar));
  1306. callableBuilder.Add(value);
  1307. return TRuntimeNode(callableBuilder.Build(), false);
  1308. }
  1309. TRuntimeNode TProgramBuilder::ReplicateScalar(TRuntimeNode value, TRuntimeNode count) {
  1310. if constexpr (RuntimeVersion < 43U) {
  1311. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1312. }
  1313. auto valueType = AS_TYPE(TBlockType, value.GetStaticType());
  1314. auto countType = AS_TYPE(TBlockType, count.GetStaticType());
  1315. MKQL_ENSURE(valueType->GetShape() == TBlockType::EShape::Scalar, "Expecting scalar as first arguemnt");
  1316. MKQL_ENSURE(countType->GetShape() == TBlockType::EShape::Scalar, "Expecting scalar as second arguemnt");
  1317. MKQL_ENSURE(countType->GetItemType()->IsData(), "Expected scalar data as second argument");
  1318. MKQL_ENSURE(AS_TYPE(TDataType, countType->GetItemType())->GetSchemeType() ==
  1319. NUdf::TDataType<ui64>::Id, "Expected scalar ui64 as second argument");
  1320. auto outputType = NewBlockType(valueType->GetItemType(), TBlockType::EShape::Many);
  1321. TCallableBuilder callableBuilder(Env, __func__, outputType);
  1322. callableBuilder.Add(value);
  1323. callableBuilder.Add(count);
  1324. return TRuntimeNode(callableBuilder.Build(), false);
  1325. }
  1326. TRuntimeNode TProgramBuilder::BlockCompress(TRuntimeNode flow, ui32 bitmapIndex) {
  1327. auto blockItemTypes = ValidateBlockFlowType(flow.GetStaticType());
  1328. MKQL_ENSURE(blockItemTypes.size() >= 2, "Expected at least two input columns");
  1329. MKQL_ENSURE(bitmapIndex < blockItemTypes.size() - 1, "Invalid bitmap index");
  1330. MKQL_ENSURE(AS_TYPE(TDataType, blockItemTypes[bitmapIndex])->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected Bool as bitmap column type");
  1331. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  1332. MKQL_ENSURE(wideComponents.size() == blockItemTypes.size(), "Unexpected tuple size");
  1333. std::vector<TType*> flowItems;
  1334. for (size_t i = 0; i < wideComponents.size(); ++i) {
  1335. if (i == bitmapIndex) {
  1336. continue;
  1337. }
  1338. flowItems.push_back(wideComponents[i]);
  1339. }
  1340. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(flowItems)));
  1341. callableBuilder.Add(flow);
  1342. callableBuilder.Add(NewDataLiteral<ui32>(bitmapIndex));
  1343. return TRuntimeNode(callableBuilder.Build(), false);
  1344. }
  1345. TRuntimeNode TProgramBuilder::BlockExpandChunked(TRuntimeNode comp) {
  1346. if (comp.GetStaticType()->IsStream()) {
  1347. ValidateBlockStreamType(comp.GetStaticType());
  1348. } else {
  1349. ValidateBlockFlowType(comp.GetStaticType());
  1350. }
  1351. TCallableBuilder callableBuilder(Env, __func__, comp.GetStaticType());
  1352. callableBuilder.Add(comp);
  1353. return TRuntimeNode(callableBuilder.Build(), false);
  1354. }
  1355. TRuntimeNode TProgramBuilder::BlockCoalesce(TRuntimeNode first, TRuntimeNode second) {
  1356. auto firstType = AS_TYPE(TBlockType, first.GetStaticType());
  1357. auto secondType = AS_TYPE(TBlockType, second.GetStaticType());
  1358. auto firstItemType = firstType->GetItemType();
  1359. auto secondItemType = secondType->GetItemType();
  1360. MKQL_ENSURE(firstItemType->IsOptional() || firstItemType->IsPg(), "Expecting Optional or Pg type as first argument");
  1361. if (!firstItemType->IsSameType(*secondItemType)) {
  1362. bool firstOptional;
  1363. firstItemType = UnpackOptional(firstItemType, firstOptional);
  1364. MKQL_ENSURE(firstItemType->IsSameType(*secondItemType), "Uncompatible arguemnt types");
  1365. }
  1366. auto outputType = NewBlockType(secondType->GetItemType(), GetResultShape({firstType, secondType}));
  1367. TCallableBuilder callableBuilder(Env, __func__, outputType);
  1368. callableBuilder.Add(first);
  1369. callableBuilder.Add(second);
  1370. return TRuntimeNode(callableBuilder.Build(), false);
  1371. }
  1372. TRuntimeNode TProgramBuilder::BlockExists(TRuntimeNode data) {
  1373. auto dataType = AS_TYPE(TBlockType, data.GetStaticType());
  1374. auto outputType = NewBlockType(NewDataType(NUdf::TDataType<bool>::Id), dataType->GetShape());
  1375. TCallableBuilder callableBuilder(Env, __func__, outputType);
  1376. callableBuilder.Add(data);
  1377. return TRuntimeNode(callableBuilder.Build(), false);
  1378. }
  1379. TRuntimeNode TProgramBuilder::BlockMember(TRuntimeNode structObj, const std::string_view& memberName) {
  1380. auto blockType = AS_TYPE(TBlockType, structObj.GetStaticType());
  1381. bool isOptional;
  1382. const auto type = AS_TYPE(TStructType, UnpackOptional(blockType->GetItemType(), isOptional));
  1383. const auto memberIndex = type->GetMemberIndex(memberName);
  1384. auto memberType = type->GetMemberType(memberIndex);
  1385. if (isOptional && !memberType->IsOptional() && !memberType->IsNull() && !memberType->IsPg()) {
  1386. memberType = NewOptionalType(memberType);
  1387. }
  1388. auto returnType = NewBlockType(memberType, blockType->GetShape());
  1389. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1390. callableBuilder.Add(structObj);
  1391. callableBuilder.Add(NewDataLiteral<ui32>(memberIndex));
  1392. return TRuntimeNode(callableBuilder.Build(), false);
  1393. }
  1394. TRuntimeNode TProgramBuilder::BlockNth(TRuntimeNode tuple, ui32 index) {
  1395. auto blockType = AS_TYPE(TBlockType, tuple.GetStaticType());
  1396. bool isOptional;
  1397. const auto type = AS_TYPE(TTupleType, UnpackOptional(blockType->GetItemType(), isOptional));
  1398. MKQL_ENSURE(index < type->GetElementsCount(), "Index out of range: " << index <<
  1399. " is not less than " << type->GetElementsCount());
  1400. auto itemType = type->GetElementType(index);
  1401. if (isOptional && !itemType->IsOptional() && !itemType->IsNull() && !itemType->IsPg()) {
  1402. itemType = TOptionalType::Create(itemType, Env);
  1403. }
  1404. auto returnType = NewBlockType(itemType, blockType->GetShape());
  1405. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1406. callableBuilder.Add(tuple);
  1407. callableBuilder.Add(NewDataLiteral<ui32>(index));
  1408. return TRuntimeNode(callableBuilder.Build(), false);
  1409. }
  1410. TRuntimeNode TProgramBuilder::BlockAsStruct(const TArrayRef<std::pair<std::string_view, TRuntimeNode>>& args) {
  1411. MKQL_ENSURE(!args.empty(), "Expected at least one argument");
  1412. TBlockType::EShape resultShape = TBlockType::EShape::Scalar;
  1413. TVector<std::pair<std::string_view, TType*>> members;
  1414. for (const auto& x : args) {
  1415. auto blockType = AS_TYPE(TBlockType, x.second.GetStaticType());
  1416. members.emplace_back(x.first, blockType->GetItemType());
  1417. if (blockType->GetShape() == TBlockType::EShape::Many) {
  1418. resultShape = TBlockType::EShape::Many;
  1419. }
  1420. }
  1421. auto returnType = NewBlockType(NewStructType(members), resultShape);
  1422. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1423. for (const auto& x : args) {
  1424. callableBuilder.Add(x.second);
  1425. }
  1426. return TRuntimeNode(callableBuilder.Build(), false);
  1427. }
  1428. TRuntimeNode TProgramBuilder::BlockAsTuple(const TArrayRef<const TRuntimeNode>& args) {
  1429. MKQL_ENSURE(!args.empty(), "Expected at least one argument");
  1430. TBlockType::EShape resultShape = TBlockType::EShape::Scalar;
  1431. TVector<TType*> types;
  1432. for (const auto& x : args) {
  1433. auto blockType = AS_TYPE(TBlockType, x.GetStaticType());
  1434. types.push_back(blockType->GetItemType());
  1435. if (blockType->GetShape() == TBlockType::EShape::Many) {
  1436. resultShape = TBlockType::EShape::Many;
  1437. }
  1438. }
  1439. auto tupleType = NewTupleType(types);
  1440. auto returnType = NewBlockType(tupleType, resultShape);
  1441. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1442. for (const auto& x : args) {
  1443. callableBuilder.Add(x);
  1444. }
  1445. return TRuntimeNode(callableBuilder.Build(), false);
  1446. }
  1447. TRuntimeNode TProgramBuilder::BlockToPg(TRuntimeNode input, TType* returnType) {
  1448. if constexpr (RuntimeVersion < 37U) {
  1449. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1450. }
  1451. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1452. callableBuilder.Add(input);
  1453. return TRuntimeNode(callableBuilder.Build(), false);
  1454. }
  1455. TRuntimeNode TProgramBuilder::BlockFromPg(TRuntimeNode input, TType* returnType) {
  1456. if constexpr (RuntimeVersion < 37U) {
  1457. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1458. }
  1459. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1460. callableBuilder.Add(input);
  1461. return TRuntimeNode(callableBuilder.Build(), false);
  1462. }
  1463. TRuntimeNode TProgramBuilder::BlockNot(TRuntimeNode data) {
  1464. auto dataType = AS_TYPE(TBlockType, data.GetStaticType());
  1465. bool isOpt;
  1466. MKQL_ENSURE(UnpackOptionalData(dataType->GetItemType(), isOpt)->GetSchemeType() == NUdf::TDataType<bool>::Id, "Requires boolean args.");
  1467. TCallableBuilder callableBuilder(Env, __func__, data.GetStaticType());
  1468. callableBuilder.Add(data);
  1469. return TRuntimeNode(callableBuilder.Build(), false);
  1470. }
  1471. TRuntimeNode TProgramBuilder::BlockAnd(TRuntimeNode first, TRuntimeNode second) {
  1472. return BuildBlockLogical(__func__, first, second);
  1473. }
  1474. TRuntimeNode TProgramBuilder::BlockOr(TRuntimeNode first, TRuntimeNode second) {
  1475. return BuildBlockLogical(__func__, first, second);
  1476. }
  1477. TRuntimeNode TProgramBuilder::BlockXor(TRuntimeNode first, TRuntimeNode second) {
  1478. return BuildBlockLogical(__func__, first, second);
  1479. }
  1480. TRuntimeNode TProgramBuilder::BlockDecimalDiv(TRuntimeNode first, TRuntimeNode second) {
  1481. return BuildBlockDecimalBinary(__func__, first, second);
  1482. }
  1483. TRuntimeNode TProgramBuilder::BlockDecimalMod(TRuntimeNode first, TRuntimeNode second) {
  1484. return BuildBlockDecimalBinary(__func__, first, second);
  1485. }
  1486. TRuntimeNode TProgramBuilder::BlockDecimalMul(TRuntimeNode first, TRuntimeNode second) {
  1487. return BuildBlockDecimalBinary(__func__, first, second);
  1488. }
  1489. TRuntimeNode TProgramBuilder::ListFromRange(TRuntimeNode start, TRuntimeNode end, TRuntimeNode step) {
  1490. MKQL_ENSURE(start.GetStaticType()->IsData(), "Expected data");
  1491. MKQL_ENSURE(end.GetStaticType()->IsSameType(*start.GetStaticType()), "Mismatch type");
  1492. if constexpr (RuntimeVersion < 24U) {
  1493. MKQL_ENSURE(IsNumericType(AS_TYPE(TDataType, start)->GetSchemeType()), "Expected numeric");
  1494. } else {
  1495. MKQL_ENSURE(IsNumericType(AS_TYPE(TDataType, start)->GetSchemeType()) ||
  1496. IsDateType(AS_TYPE(TDataType, start)->GetSchemeType()) ||
  1497. IsTzDateType(AS_TYPE(TDataType, start)->GetSchemeType()) ||
  1498. IsIntervalType(AS_TYPE(TDataType, start)->GetSchemeType()),
  1499. "Expected numeric, date or tzdate");
  1500. if (IsNumericType(AS_TYPE(TDataType, start)->GetSchemeType())) {
  1501. MKQL_ENSURE(IsNumericType(AS_TYPE(TDataType, step)->GetSchemeType()), "Expected numeric");
  1502. } else {
  1503. MKQL_ENSURE(IsIntervalType(AS_TYPE(TDataType, step)->GetSchemeType()), "Expected interval");
  1504. }
  1505. }
  1506. TCallableBuilder callableBuilder(Env, __func__, TListType::Create(start.GetStaticType(), Env));
  1507. callableBuilder.Add(start);
  1508. callableBuilder.Add(end);
  1509. callableBuilder.Add(step);
  1510. return TRuntimeNode(callableBuilder.Build(), false);
  1511. }
  1512. TRuntimeNode TProgramBuilder::Switch(TRuntimeNode stream,
  1513. const TArrayRef<const TSwitchInput>& handlerInputs,
  1514. std::function<TRuntimeNode(ui32 index, TRuntimeNode item)> handler,
  1515. ui64 memoryLimitBytes, TType* returnType) {
  1516. MKQL_ENSURE(stream.GetStaticType()->IsStream() || stream.GetStaticType()->IsFlow(), "Expected stream or flow.");
  1517. std::vector<TRuntimeNode> argNodes(handlerInputs.size());
  1518. std::vector<TRuntimeNode> outputNodes(handlerInputs.size());
  1519. for (ui32 i = 0; i < handlerInputs.size(); ++i) {
  1520. TRuntimeNode arg = Arg(handlerInputs[i].InputType);
  1521. argNodes[i] = arg;
  1522. outputNodes[i] = handler(i, arg);
  1523. }
  1524. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1525. callableBuilder.Add(stream);
  1526. callableBuilder.Add(NewDataLiteral<ui64>(memoryLimitBytes));
  1527. for (ui32 i = 0; i < handlerInputs.size(); ++i) {
  1528. std::vector<TRuntimeNode> tupleElems;
  1529. for (auto index : handlerInputs[i].Indicies) {
  1530. tupleElems.push_back(NewDataLiteral<ui32>(index));
  1531. }
  1532. auto indiciesTuple = NewTuple(tupleElems);
  1533. callableBuilder.Add(indiciesTuple);
  1534. callableBuilder.Add(argNodes[i]);
  1535. callableBuilder.Add(outputNodes[i]);
  1536. if (!handlerInputs[i].ResultVariantOffset) {
  1537. callableBuilder.Add(NewVoid());
  1538. } else {
  1539. callableBuilder.Add(NewDataLiteral<ui32>(*handlerInputs[i].ResultVariantOffset));
  1540. }
  1541. }
  1542. return TRuntimeNode(callableBuilder.Build(), false);
  1543. }
  1544. TRuntimeNode TProgramBuilder::HasItems(TRuntimeNode listOrDict) {
  1545. return BuildContainerProperty<bool>(__func__, listOrDict);
  1546. }
  1547. TRuntimeNode TProgramBuilder::Reverse(TRuntimeNode list) {
  1548. bool isOptional = false;
  1549. const auto listType = UnpackOptional(list, isOptional);
  1550. if (isOptional) {
  1551. return Map(list, [&](TRuntimeNode unpacked) { return Reverse(unpacked); } );
  1552. }
  1553. const auto listDetailedType = AS_TYPE(TListType, listType);
  1554. const auto itemType = listDetailedType->GetItemType();
  1555. ThrowIfListOfVoid(itemType);
  1556. TCallableBuilder callableBuilder(Env, __func__, listType);
  1557. callableBuilder.Add(list);
  1558. return TRuntimeNode(callableBuilder.Build(), false);
  1559. }
  1560. TRuntimeNode TProgramBuilder::Skip(TRuntimeNode list, TRuntimeNode count) {
  1561. return BuildTake(__func__, list, count);
  1562. }
  1563. TRuntimeNode TProgramBuilder::Take(TRuntimeNode list, TRuntimeNode count) {
  1564. return BuildTake(__func__, list, count);
  1565. }
  1566. TRuntimeNode TProgramBuilder::Sort(TRuntimeNode list, TRuntimeNode ascending, const TUnaryLambda& keyExtractor)
  1567. {
  1568. return BuildSort(__func__, list, ascending, keyExtractor);
  1569. }
  1570. TRuntimeNode TProgramBuilder::WideTop(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys)
  1571. {
  1572. return BuildWideTopOrSort(__func__, flow, count, keys);
  1573. }
  1574. TRuntimeNode TProgramBuilder::WideTopSort(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys)
  1575. {
  1576. return BuildWideTopOrSort(__func__, flow, count, keys);
  1577. }
  1578. TRuntimeNode TProgramBuilder::WideSort(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys)
  1579. {
  1580. return BuildWideTopOrSort(__func__, flow, Nothing(), keys);
  1581. }
  1582. TRuntimeNode TProgramBuilder::BuildWideTopOrSort(const std::string_view& callableName, TRuntimeNode flow, TMaybe<TRuntimeNode> count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
  1583. if (count) {
  1584. if constexpr (RuntimeVersion < 33U) {
  1585. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << callableName;
  1586. }
  1587. } else {
  1588. if constexpr (RuntimeVersion < 34U) {
  1589. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << callableName;
  1590. }
  1591. }
  1592. const auto width = GetWideComponentsCount(AS_TYPE(TFlowType, flow.GetStaticType()));
  1593. MKQL_ENSURE(!keys.empty() && keys.size() <= width, "Unexpected keys count: " << keys.size());
  1594. TCallableBuilder callableBuilder(Env, callableName, flow.GetStaticType());
  1595. callableBuilder.Add(flow);
  1596. if (count) {
  1597. callableBuilder.Add(*count);
  1598. }
  1599. std::for_each(keys.cbegin(), keys.cend(), [&](const std::pair<ui32, TRuntimeNode>& key) {
  1600. MKQL_ENSURE(key.first < width, "Key index too large: " << key.first);
  1601. callableBuilder.Add(NewDataLiteral(key.first));
  1602. callableBuilder.Add(key.second);
  1603. });
  1604. return TRuntimeNode(callableBuilder.Build(), false);
  1605. }
  1606. TRuntimeNode TProgramBuilder::Top(TRuntimeNode flow, TRuntimeNode count, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) {
  1607. if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
  1608. const TUnaryLambda getKey = [&](TRuntimeNode item) { return Nth(item, 0U); };
  1609. const TUnaryLambda getItem = [&](TRuntimeNode item) { return Nth(item, 1U); };
  1610. const TUnaryLambda cacheKeyExtractor = [&](TRuntimeNode item) {
  1611. return NewTuple({keyExtractor(item), item});
  1612. };
  1613. return FlatMap(Condense1(Map(flow, cacheKeyExtractor),
  1614. [&](TRuntimeNode item) { return AsList(item); },
  1615. [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
  1616. [&](TRuntimeNode item, TRuntimeNode state) {
  1617. return KeepTop(count, state, item, ascending, getKey);
  1618. }
  1619. ),
  1620. [&](TRuntimeNode list) { return Map(Top(list, count, ascending, getKey), getItem); }
  1621. );
  1622. }
  1623. return BuildListNth(__func__, flow, count, ascending, keyExtractor);
  1624. }
  1625. TRuntimeNode TProgramBuilder::TopSort(TRuntimeNode flow, TRuntimeNode count, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) {
  1626. if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
  1627. const TUnaryLambda getKey = [&](TRuntimeNode item) { return Nth(item, 0U); };
  1628. const TUnaryLambda getItem = [&](TRuntimeNode item) { return Nth(item, 1U); };
  1629. const TUnaryLambda cacheKeyExtractor = [&](TRuntimeNode item) {
  1630. return NewTuple({keyExtractor(item), item});
  1631. };
  1632. return FlatMap(Condense1(Map(flow, cacheKeyExtractor),
  1633. [&](TRuntimeNode item) { return AsList(item); },
  1634. [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
  1635. [&](TRuntimeNode item, TRuntimeNode state) {
  1636. return KeepTop(count, state, item, ascending, getKey);
  1637. }
  1638. ),
  1639. [&](TRuntimeNode list) { return Map(TopSort(list, count, ascending, getKey), getItem); }
  1640. );
  1641. }
  1642. if constexpr (RuntimeVersion >= 25U)
  1643. return BuildListNth(__func__, flow, count, ascending, keyExtractor);
  1644. else
  1645. return BuildListSort("Sort", BuildListNth("Top", flow, count, ascending, keyExtractor), ascending, keyExtractor);
  1646. }
  1647. TRuntimeNode TProgramBuilder::KeepTop(TRuntimeNode count, TRuntimeNode list, TRuntimeNode item, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) {
  1648. const auto listType = list.GetStaticType();
  1649. MKQL_ENSURE(listType->IsList(), "Expected list.");
  1650. const auto itemType = static_cast<const TListType&>(*listType).GetItemType();
  1651. ThrowIfListOfVoid(itemType);
  1652. MKQL_ENSURE(count.GetStaticType()->IsData(), "Expected data");
  1653. MKQL_ENSURE(static_cast<const TDataType&>(*count.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  1654. MKQL_ENSURE(itemType->IsSameType(*item.GetStaticType()), "Types of list and item are different.");
  1655. const auto ascendingType = ascending.GetStaticType();
  1656. const auto itemArg = Arg(itemType);
  1657. auto key = keyExtractor(itemArg);
  1658. const auto hotkey = Arg(key.GetStaticType());
  1659. if (ascendingType->IsTuple()) {
  1660. const auto ascendingTuple = AS_TYPE(TTupleType, ascendingType);
  1661. if (ascendingTuple->GetElementsCount() == 0) {
  1662. return If(AggrLess(Length(list), count), Append(list, item), list);
  1663. }
  1664. if (ascendingTuple->GetElementsCount() == 1) {
  1665. ascending = Nth(ascending, 0);
  1666. key = Nth(key, 0);
  1667. }
  1668. }
  1669. TCallableBuilder callableBuilder(Env, __func__, listType);
  1670. callableBuilder.Add(count);
  1671. callableBuilder.Add(list);
  1672. callableBuilder.Add(item);
  1673. callableBuilder.Add(itemArg);
  1674. callableBuilder.Add(key);
  1675. callableBuilder.Add(ascending);
  1676. callableBuilder.Add(hotkey);
  1677. return TRuntimeNode(callableBuilder.Build(), false);
  1678. }
  1679. TRuntimeNode TProgramBuilder::Contains(TRuntimeNode dict, TRuntimeNode key) {
  1680. if constexpr (RuntimeVersion >= 25U)
  1681. if (!dict.GetStaticType()->IsDict())
  1682. return DataCompare(__func__, dict, key);
  1683. const auto keyType = AS_TYPE(TDictType, dict.GetStaticType())->GetKeyType();
  1684. MKQL_ENSURE(keyType->IsSameType(*key.GetStaticType()), "Key type mismatch. Requred: " << *keyType << ", but got: " << *key.GetStaticType());
  1685. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<bool>::Id));
  1686. callableBuilder.Add(dict);
  1687. callableBuilder.Add(key);
  1688. return TRuntimeNode(callableBuilder.Build(), false);
  1689. }
  1690. TRuntimeNode TProgramBuilder::Lookup(TRuntimeNode dict, TRuntimeNode key) {
  1691. const auto dictType = AS_TYPE(TDictType, dict.GetStaticType());
  1692. const auto keyType = dictType->GetKeyType();
  1693. MKQL_ENSURE(keyType->IsSameType(*key.GetStaticType()), "Key type mismatch. Requred: " << *keyType << ", but got: " << *key.GetStaticType());
  1694. TCallableBuilder callableBuilder(Env, __func__, NewOptionalType(dictType->GetPayloadType()));
  1695. callableBuilder.Add(dict);
  1696. callableBuilder.Add(key);
  1697. return TRuntimeNode(callableBuilder.Build(), false);
  1698. }
  1699. TRuntimeNode TProgramBuilder::DictItems(TRuntimeNode dict, EDictItems mode) {
  1700. const auto dictTypeChecked = AS_TYPE(TDictType, dict.GetStaticType());
  1701. TType* itemType;
  1702. switch (mode) {
  1703. case EDictItems::Both: {
  1704. const std::array<TType*, 2U> tupleTypes = {{ dictTypeChecked->GetKeyType(), dictTypeChecked->GetPayloadType() }};
  1705. itemType = NewTupleType(tupleTypes);
  1706. break;
  1707. }
  1708. case EDictItems::Keys: itemType = dictTypeChecked->GetKeyType(); break;
  1709. case EDictItems::Payloads: itemType = dictTypeChecked->GetPayloadType(); break;
  1710. }
  1711. TCallableBuilder callableBuilder(Env, __func__, NewListType(itemType));
  1712. callableBuilder.Add(dict);
  1713. callableBuilder.Add(NewDataLiteral((ui32)mode));
  1714. return TRuntimeNode(callableBuilder.Build(), false);
  1715. }
  1716. TRuntimeNode TProgramBuilder::DictItems(TRuntimeNode dict) {
  1717. if constexpr (RuntimeVersion < 6U) {
  1718. return DictItems(dict, EDictItems::Both);
  1719. }
  1720. const auto dictTypeChecked = AS_TYPE(TDictType, dict.GetStaticType());
  1721. const auto itemType = NewTupleType({ dictTypeChecked->GetKeyType(), dictTypeChecked->GetPayloadType() });
  1722. TCallableBuilder callableBuilder(Env, __func__, NewListType(itemType));
  1723. callableBuilder.Add(dict);
  1724. return TRuntimeNode(callableBuilder.Build(), false);
  1725. }
  1726. TRuntimeNode TProgramBuilder::DictKeys(TRuntimeNode dict) {
  1727. if constexpr (RuntimeVersion < 6U) {
  1728. return DictItems(dict, EDictItems::Keys);
  1729. }
  1730. const auto dictTypeChecked = AS_TYPE(TDictType, dict.GetStaticType());
  1731. TCallableBuilder callableBuilder(Env, __func__, NewListType(dictTypeChecked->GetKeyType()));
  1732. callableBuilder.Add(dict);
  1733. return TRuntimeNode(callableBuilder.Build(), false);
  1734. }
  1735. TRuntimeNode TProgramBuilder::DictPayloads(TRuntimeNode dict) {
  1736. if constexpr (RuntimeVersion < 6U) {
  1737. return DictItems(dict, EDictItems::Payloads);
  1738. }
  1739. const auto dictTypeChecked = AS_TYPE(TDictType, dict.GetStaticType());
  1740. TCallableBuilder callableBuilder(Env, __func__, NewListType(dictTypeChecked->GetPayloadType()));
  1741. callableBuilder.Add(dict);
  1742. return TRuntimeNode(callableBuilder.Build(), false);
  1743. }
  1744. TRuntimeNode TProgramBuilder::ToIndexDict(TRuntimeNode list) {
  1745. const auto itemType = AS_TYPE(TListType, list.GetStaticType())->GetItemType();
  1746. ThrowIfListOfVoid(itemType);
  1747. const auto keyType = NewDataType(NUdf::TDataType<ui64>::Id);
  1748. const auto dictType = NewDictType(keyType, itemType, false);
  1749. TCallableBuilder callableBuilder(Env, __func__, dictType);
  1750. callableBuilder.Add(list);
  1751. return TRuntimeNode(callableBuilder.Build(), false);
  1752. }
  1753. TRuntimeNode TProgramBuilder::JoinDict(TRuntimeNode dict1, bool isMulti1, TRuntimeNode dict2, bool isMulti2, EJoinKind joinKind) {
  1754. const auto dict1type = AS_TYPE(TDictType, dict1);
  1755. const auto dict2type = AS_TYPE(TDictType, dict2);
  1756. MKQL_ENSURE(dict1type->GetKeyType()->IsSameType(*dict2type->GetKeyType()), "Dict key types must be the same");
  1757. if (joinKind == EJoinKind::RightOnly || joinKind == EJoinKind::RightSemi)
  1758. MKQL_ENSURE(dict1type->GetPayloadType()->IsVoid(), "Void required for first dict payload.");
  1759. else if (isMulti1)
  1760. MKQL_ENSURE(dict1type->GetPayloadType()->IsList(), "List required for first dict payload.");
  1761. if (joinKind == EJoinKind::LeftOnly || joinKind == EJoinKind::LeftSemi)
  1762. MKQL_ENSURE(dict2type->GetPayloadType()->IsVoid(), "Void required for second dict payload.");
  1763. else if (isMulti2)
  1764. MKQL_ENSURE(dict2type->GetPayloadType()->IsList(), "List required for second dict payload.");
  1765. std::array<TType*, 2> tupleItems = {{ dict1type->GetPayloadType(), dict2type->GetPayloadType() }};
  1766. if (isMulti1 && tupleItems.front()->IsList())
  1767. tupleItems.front() = AS_TYPE(TListType, tupleItems.front())->GetItemType();
  1768. if (isMulti2 && tupleItems.back()->IsList())
  1769. tupleItems.back() = AS_TYPE(TListType, tupleItems.back())->GetItemType();
  1770. if (IsLeftOptional(joinKind))
  1771. tupleItems.front() = NewOptionalType(tupleItems.front());
  1772. if (IsRightOptional(joinKind))
  1773. tupleItems.back() = NewOptionalType(tupleItems.back());
  1774. TType* itemType;
  1775. if (joinKind == EJoinKind::LeftOnly || joinKind == EJoinKind::LeftSemi)
  1776. itemType = tupleItems.front();
  1777. else if (joinKind == EJoinKind::RightOnly || joinKind == EJoinKind::RightSemi)
  1778. itemType = tupleItems.back();
  1779. else
  1780. itemType = NewTupleType(tupleItems);
  1781. const auto returnType = NewListType(itemType);
  1782. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1783. callableBuilder.Add(dict1);
  1784. callableBuilder.Add(dict2);
  1785. callableBuilder.Add(NewDataLiteral(isMulti1));
  1786. callableBuilder.Add(NewDataLiteral(isMulti2));
  1787. callableBuilder.Add(NewDataLiteral(ui32(joinKind)));
  1788. return TRuntimeNode(callableBuilder.Build(), false);
  1789. }
  1790. TRuntimeNode TProgramBuilder::GraceJoinCommon(const TStringBuf& funcName, TRuntimeNode flowLeft, TRuntimeNode flowRight, EJoinKind joinKind,
  1791. const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& rightKeyColumns,
  1792. const TArrayRef<const ui32>& leftRenames, const TArrayRef<const ui32>& rightRenames, TType* returnType, EAnyJoinSettings anyJoinSettings ) {
  1793. MKQL_ENSURE(!leftKeyColumns.empty(), "At least one key column must be specified");
  1794. if (flowRight) {
  1795. MKQL_ENSURE(!rightKeyColumns.empty(), "At least one key column must be specified");
  1796. }
  1797. TRuntimeNode::TList leftKeyColumnsNodes, rightKeyColumnsNodes, leftRenamesNodes, rightRenamesNodes;
  1798. leftKeyColumnsNodes.reserve(leftKeyColumns.size());
  1799. std::transform(leftKeyColumns.cbegin(), leftKeyColumns.cend(), std::back_inserter(leftKeyColumnsNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  1800. rightKeyColumnsNodes.reserve(rightKeyColumns.size());
  1801. std::transform(rightKeyColumns.cbegin(), rightKeyColumns.cend(), std::back_inserter(rightKeyColumnsNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  1802. leftRenamesNodes.reserve(leftRenames.size());
  1803. std::transform(leftRenames.cbegin(), leftRenames.cend(), std::back_inserter(leftRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  1804. rightRenamesNodes.reserve(rightRenames.size());
  1805. std::transform(rightRenames.cbegin(), rightRenames.cend(), std::back_inserter(rightRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  1806. TCallableBuilder callableBuilder(Env, funcName, returnType);
  1807. callableBuilder.Add(flowLeft);
  1808. if (flowRight) {
  1809. callableBuilder.Add(flowRight);
  1810. }
  1811. callableBuilder.Add(NewDataLiteral((ui32)joinKind));
  1812. callableBuilder.Add(NewTuple(leftKeyColumnsNodes));
  1813. callableBuilder.Add(NewTuple(rightKeyColumnsNodes));
  1814. callableBuilder.Add(NewTuple(leftRenamesNodes));
  1815. callableBuilder.Add(NewTuple(rightRenamesNodes));
  1816. callableBuilder.Add(NewDataLiteral((ui32)anyJoinSettings));
  1817. return TRuntimeNode(callableBuilder.Build(), false);
  1818. }
  1819. TRuntimeNode TProgramBuilder::GraceJoin(TRuntimeNode flowLeft, TRuntimeNode flowRight, EJoinKind joinKind,
  1820. const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& rightKeyColumns,
  1821. const TArrayRef<const ui32>& leftRenames, const TArrayRef<const ui32>& rightRenames, TType* returnType, EAnyJoinSettings anyJoinSettings ) {
  1822. return GraceJoinCommon(__func__, flowLeft, flowRight, joinKind, leftKeyColumns, rightKeyColumns, leftRenames, rightRenames, returnType, anyJoinSettings);
  1823. }
  1824. TRuntimeNode TProgramBuilder::GraceSelfJoin(TRuntimeNode flowLeft, EJoinKind joinKind, const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& rightKeyColumns,
  1825. const TArrayRef<const ui32>& leftRenames, const TArrayRef<const ui32>& rightRenames, TType* returnType, EAnyJoinSettings anyJoinSettings ) {
  1826. if constexpr (RuntimeVersion < 40U) {
  1827. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1828. }
  1829. return GraceJoinCommon(__func__, flowLeft, {}, joinKind, leftKeyColumns, rightKeyColumns, leftRenames, rightRenames, returnType, anyJoinSettings);
  1830. }
  1831. TRuntimeNode TProgramBuilder::ToSortedDict(TRuntimeNode list, bool all, const TUnaryLambda& keySelector,
  1832. const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1833. return ToDict(list, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1834. }
  1835. TRuntimeNode TProgramBuilder::ToHashedDict(TRuntimeNode list, bool all, const TUnaryLambda& keySelector,
  1836. const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1837. return ToDict(list, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1838. }
  1839. TRuntimeNode TProgramBuilder::SqueezeToSortedDict(TRuntimeNode stream, bool all, const TUnaryLambda& keySelector,
  1840. const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1841. return SqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1842. }
  1843. TRuntimeNode TProgramBuilder::SqueezeToHashedDict(TRuntimeNode stream, bool all, const TUnaryLambda& keySelector,
  1844. const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1845. return SqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1846. }
  1847. TRuntimeNode TProgramBuilder::NarrowSqueezeToSortedDict(TRuntimeNode stream, bool all, const TNarrowLambda& keySelector,
  1848. const TNarrowLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1849. return NarrowSqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1850. }
  1851. TRuntimeNode TProgramBuilder::NarrowSqueezeToHashedDict(TRuntimeNode stream, bool all, const TNarrowLambda& keySelector,
  1852. const TNarrowLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1853. return NarrowSqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1854. }
  1855. TRuntimeNode TProgramBuilder::SqueezeToList(TRuntimeNode flow, TRuntimeNode limit) {
  1856. if constexpr (RuntimeVersion < 25U) {
  1857. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1858. }
  1859. const auto itemType = AS_TYPE(TFlowType, flow.GetStaticType())->GetItemType();
  1860. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewListType(itemType)));
  1861. callableBuilder.Add(flow);
  1862. callableBuilder.Add(limit);
  1863. return TRuntimeNode(callableBuilder.Build(), false);
  1864. }
  1865. TRuntimeNode TProgramBuilder::Append(TRuntimeNode list, TRuntimeNode item) {
  1866. auto listType = list.GetStaticType();
  1867. AS_TYPE(TListType, listType);
  1868. const auto& listDetailedType = static_cast<const TListType&>(*listType);
  1869. auto itemType = item.GetStaticType();
  1870. MKQL_ENSURE(itemType->IsSameType(*listDetailedType.GetItemType()), "Types of list and item are different");
  1871. TCallableBuilder callableBuilder(Env, __func__, listType);
  1872. callableBuilder.Add(list);
  1873. callableBuilder.Add(item);
  1874. return TRuntimeNode(callableBuilder.Build(), false);
  1875. }
  1876. TRuntimeNode TProgramBuilder::Prepend(TRuntimeNode item, TRuntimeNode list) {
  1877. auto listType = list.GetStaticType();
  1878. AS_TYPE(TListType, listType);
  1879. const auto& listDetailedType = static_cast<const TListType&>(*listType);
  1880. auto itemType = item.GetStaticType();
  1881. MKQL_ENSURE(itemType->IsSameType(*listDetailedType.GetItemType()), "Types of list and item are different");
  1882. TCallableBuilder callableBuilder(Env, __func__, listType);
  1883. callableBuilder.Add(item);
  1884. callableBuilder.Add(list);
  1885. return TRuntimeNode(callableBuilder.Build(), false);
  1886. }
  1887. TRuntimeNode TProgramBuilder::BuildExtend(const std::string_view& callableName, const TArrayRef<const TRuntimeNode>& lists) {
  1888. MKQL_ENSURE(lists.size() > 0, "Expected at least 1 list or flow");
  1889. if (lists.size() == 1) {
  1890. return lists.front();
  1891. }
  1892. auto listType = lists.front().GetStaticType();
  1893. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected either flow, list or stream");
  1894. for (ui32 i = 1; i < lists.size(); ++i) {
  1895. auto listType2 = lists[i].GetStaticType();
  1896. MKQL_ENSURE(listType->IsSameType(*listType2), "Types of flows are different, left: " <<
  1897. PrintNode(listType, true) << ", right: " <<
  1898. PrintNode(listType2, true));
  1899. }
  1900. TCallableBuilder callableBuilder(Env, callableName, listType);
  1901. for (auto list : lists) {
  1902. callableBuilder.Add(list);
  1903. }
  1904. return TRuntimeNode(callableBuilder.Build(), false);
  1905. }
  1906. TRuntimeNode TProgramBuilder::Extend(const TArrayRef<const TRuntimeNode>& lists) {
  1907. return BuildExtend(__func__, lists);
  1908. }
  1909. TRuntimeNode TProgramBuilder::OrderedExtend(const TArrayRef<const TRuntimeNode>& lists) {
  1910. return BuildExtend(__func__, lists);
  1911. }
  1912. template<>
  1913. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::String>(const NUdf::TStringRef& data) const {
  1914. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<const char*>::Id, Env), true);
  1915. }
  1916. template<>
  1917. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Utf8>(const NUdf::TStringRef& data) const {
  1918. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TUtf8>::Id, Env), true);
  1919. }
  1920. template<>
  1921. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Yson>(const NUdf::TStringRef& data) const {
  1922. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TYson>::Id, Env), true);
  1923. }
  1924. template<>
  1925. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Json>(const NUdf::TStringRef& data) const {
  1926. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TJson>::Id, Env), true);
  1927. }
  1928. template<>
  1929. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::JsonDocument>(const NUdf::TStringRef& data) const {
  1930. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TJsonDocument>::Id, Env), true);
  1931. }
  1932. template<>
  1933. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Uuid>(const NUdf::TStringRef& data) const {
  1934. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TUuid>::Id, Env), true);
  1935. }
  1936. template<>
  1937. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Date>(const NUdf::TStringRef& data) const {
  1938. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDate>::Id, Env), true);
  1939. }
  1940. template<>
  1941. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Datetime>(const NUdf::TStringRef& data) const {
  1942. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDatetime>::Id, Env), true);
  1943. }
  1944. template<>
  1945. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Timestamp>(const NUdf::TStringRef& data) const {
  1946. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TTimestamp>::Id, Env), true);
  1947. }
  1948. template<>
  1949. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Interval>(const NUdf::TStringRef& data) const {
  1950. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TInterval>::Id, Env), true);
  1951. }
  1952. template<>
  1953. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::DyNumber>(const NUdf::TStringRef& data) const {
  1954. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDyNumber>::Id, Env), true);
  1955. }
  1956. TRuntimeNode TProgramBuilder::NewDecimalLiteral(NYql::NDecimal::TInt128 data, ui8 precision, ui8 scale) const {
  1957. return TRuntimeNode(TDataLiteral::Create(NUdf::TUnboxedValuePod(data), TDataDecimalType::Create(precision, scale, Env), Env), true);
  1958. }
  1959. template<>
  1960. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Date32>(const NUdf::TStringRef& data) const {
  1961. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDate32>::Id, Env), true);
  1962. }
  1963. template<>
  1964. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Datetime64>(const NUdf::TStringRef& data) const {
  1965. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDatetime64>::Id, Env), true);
  1966. }
  1967. template<>
  1968. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Timestamp64>(const NUdf::TStringRef& data) const {
  1969. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TTimestamp64>::Id, Env), true);
  1970. }
  1971. template<>
  1972. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Interval64>(const NUdf::TStringRef& data) const {
  1973. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TInterval64>::Id, Env), true);
  1974. }
  1975. TRuntimeNode TProgramBuilder::NewOptional(TRuntimeNode data) {
  1976. auto type = TOptionalType::Create(data.GetStaticType(), Env);
  1977. return TRuntimeNode(TOptionalLiteral::Create(data, type, Env), true);
  1978. }
  1979. TRuntimeNode TProgramBuilder::NewOptional(TType* optionalType, TRuntimeNode data) {
  1980. auto type = AS_TYPE(TOptionalType, optionalType);
  1981. return TRuntimeNode(TOptionalLiteral::Create(data, type, Env), true);
  1982. }
  1983. TRuntimeNode TProgramBuilder::NewVoid() {
  1984. return TRuntimeNode(Env.GetVoidLazy(), true);
  1985. }
  1986. TRuntimeNode TProgramBuilder::NewEmptyListOfVoid() {
  1987. return TRuntimeNode(Env.GetListOfVoidLazy(), true);
  1988. }
  1989. TRuntimeNode TProgramBuilder::NewEmptyOptional(TType* optionalOrPgType) {
  1990. MKQL_ENSURE(optionalOrPgType->IsOptional() || optionalOrPgType->IsPg(), "Expected optional or pg type");
  1991. if (optionalOrPgType->IsOptional()) {
  1992. return TRuntimeNode(TOptionalLiteral::Create(static_cast<TOptionalType*>(optionalOrPgType), Env), true);
  1993. }
  1994. return PgCast(NewNull(), optionalOrPgType);
  1995. }
  1996. TRuntimeNode TProgramBuilder::NewEmptyOptionalDataLiteral(NUdf::TDataTypeId schemeType) {
  1997. return TRuntimeNode(BuildEmptyOptionalDataLiteral(schemeType, Env), true);
  1998. }
  1999. TRuntimeNode TProgramBuilder::NewEmptyStruct() {
  2000. return TRuntimeNode(Env.GetEmptyStructLazy(), true);
  2001. }
  2002. TRuntimeNode TProgramBuilder::NewStruct(const TArrayRef<const std::pair<std::string_view, TRuntimeNode>>& members) {
  2003. if (members.empty()) {
  2004. return NewEmptyStruct();
  2005. }
  2006. TStructLiteralBuilder builder(Env);
  2007. for (auto x : members) {
  2008. builder.Add(x.first, x.second);
  2009. }
  2010. return TRuntimeNode(builder.Build(), true);
  2011. }
  2012. TRuntimeNode TProgramBuilder::NewStruct(TType* structType, const TArrayRef<const std::pair<std::string_view, TRuntimeNode>>& members) {
  2013. const auto detailedStructType = AS_TYPE(TStructType, structType);
  2014. MKQL_ENSURE(members.size() == detailedStructType->GetMembersCount(), "Mismatch count of members");
  2015. if (members.empty()) {
  2016. return NewEmptyStruct();
  2017. }
  2018. std::vector<TRuntimeNode> values(detailedStructType->GetMembersCount());
  2019. for (ui32 i = 0; i < detailedStructType->GetMembersCount(); ++i) {
  2020. const auto& name = members[i].first;
  2021. ui32 index = detailedStructType->GetMemberIndex(name);
  2022. MKQL_ENSURE(!values[index], "Duplicate of member: " << name);
  2023. values[index] = members[i].second;
  2024. }
  2025. return TRuntimeNode(TStructLiteral::Create(values.size(), values.data(), detailedStructType, Env), true);
  2026. }
  2027. TRuntimeNode TProgramBuilder::NewEmptyList() {
  2028. return TRuntimeNode(Env.GetEmptyListLazy(), true);
  2029. }
  2030. TRuntimeNode TProgramBuilder::NewEmptyList(TType* itemType) {
  2031. TListLiteralBuilder builder(Env, itemType);
  2032. return TRuntimeNode(builder.Build(), true);
  2033. }
  2034. TRuntimeNode TProgramBuilder::NewList(TType* itemType, const TArrayRef<const TRuntimeNode>& items) {
  2035. TListLiteralBuilder builder(Env, itemType);
  2036. for (auto item : items) {
  2037. builder.Add(item);
  2038. }
  2039. return TRuntimeNode(builder.Build(), true);
  2040. }
  2041. TRuntimeNode TProgramBuilder::NewEmptyDict() {
  2042. return TRuntimeNode(Env.GetEmptyDictLazy(), true);
  2043. }
  2044. TRuntimeNode TProgramBuilder::NewDict(TType* dictType, const TArrayRef<const std::pair<TRuntimeNode, TRuntimeNode>>& items) {
  2045. MKQL_ENSURE(dictType->IsDict(), "Expected dict type");
  2046. return TRuntimeNode(TDictLiteral::Create(items.size(), items.data(), static_cast<TDictType*>(dictType), Env), true);
  2047. }
  2048. TRuntimeNode TProgramBuilder::NewEmptyTuple() {
  2049. return TRuntimeNode(Env.GetEmptyTupleLazy(), true);
  2050. }
  2051. TRuntimeNode TProgramBuilder::NewTuple(TType* tupleType, const TArrayRef<const TRuntimeNode>& elements) {
  2052. MKQL_ENSURE(tupleType->IsTuple(), "Expected tuple type");
  2053. return TRuntimeNode(TTupleLiteral::Create(elements.size(), elements.data(), static_cast<TTupleType*>(tupleType), Env), true);
  2054. }
  2055. TRuntimeNode TProgramBuilder::NewTuple(const TArrayRef<const TRuntimeNode>& elements) {
  2056. std::vector<TType*> types;
  2057. types.reserve(elements.size());
  2058. for (auto elem : elements) {
  2059. types.push_back(elem.GetStaticType());
  2060. }
  2061. return NewTuple(NewTupleType(types), elements);
  2062. }
  2063. TRuntimeNode TProgramBuilder::NewVariant(TRuntimeNode item, ui32 index, TType* variantType) {
  2064. const auto type = AS_TYPE(TVariantType, variantType);
  2065. MKQL_ENSURE(type->GetUnderlyingType()->IsTuple(), "Expected tuple as underlying type");
  2066. return TRuntimeNode(TVariantLiteral::Create(item, index, type, Env), true);
  2067. }
  2068. TRuntimeNode TProgramBuilder::NewVariant(TRuntimeNode item, const std::string_view& member, TType* variantType) {
  2069. const auto type = AS_TYPE(TVariantType, variantType);
  2070. MKQL_ENSURE(type->GetUnderlyingType()->IsStruct(), "Expected struct as underlying type");
  2071. ui32 index = AS_TYPE(TStructType, type->GetUnderlyingType())->GetMemberIndex(member);
  2072. return TRuntimeNode(TVariantLiteral::Create(item, index, type, Env), true);
  2073. }
  2074. TRuntimeNode TProgramBuilder::Coalesce(TRuntimeNode data, TRuntimeNode defaultData) {
  2075. bool isOptional = false;
  2076. const auto dataType = UnpackOptional(data, isOptional);
  2077. if (!isOptional && !data.GetStaticType()->IsPg()) {
  2078. MKQL_ENSURE(data.GetStaticType()->IsSameType(*defaultData.GetStaticType()), "Mismatch operand types");
  2079. return data;
  2080. }
  2081. if (!dataType->IsSameType(*defaultData.GetStaticType())) {
  2082. bool isOptionalDefault;
  2083. const auto defaultDataType = UnpackOptional(defaultData, isOptionalDefault);
  2084. MKQL_ENSURE(dataType->IsSameType(*defaultDataType), "Mismatch operand types");
  2085. }
  2086. TCallableBuilder callableBuilder(Env, __func__, defaultData.GetStaticType());
  2087. callableBuilder.Add(data);
  2088. callableBuilder.Add(defaultData);
  2089. return TRuntimeNode(callableBuilder.Build(), false);
  2090. }
  2091. TRuntimeNode TProgramBuilder::Unwrap(TRuntimeNode optional, TRuntimeNode message, const std::string_view& file, ui32 row, ui32 column) {
  2092. bool isOptional;
  2093. auto underlyingType = UnpackOptional(optional, isOptional);
  2094. MKQL_ENSURE(isOptional, "Expected optional");
  2095. const auto& messageType = message.GetStaticType();
  2096. MKQL_ENSURE(messageType->IsData(), "Expected data");
  2097. const auto& messageTypeData = static_cast<const TDataType&>(*messageType);
  2098. MKQL_ENSURE(messageTypeData.GetSchemeType() == NUdf::TDataType<char*>::Id || messageTypeData.GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected string or utf8.");
  2099. TCallableBuilder callableBuilder(Env, __func__, underlyingType);
  2100. callableBuilder.Add(optional);
  2101. callableBuilder.Add(message);
  2102. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  2103. callableBuilder.Add(NewDataLiteral(row));
  2104. callableBuilder.Add(NewDataLiteral(column));
  2105. return TRuntimeNode(callableBuilder.Build(), false);
  2106. }
  2107. TRuntimeNode TProgramBuilder::Increment(TRuntimeNode data) {
  2108. const std::array<TRuntimeNode, 1> args = {{ data }};
  2109. bool isOptional;
  2110. const auto type = UnpackOptionalData(data, isOptional);
  2111. if (type->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id)
  2112. return Invoke(__func__, data.GetStaticType(), args);
  2113. return Invoke(TString("Inc_") += ::ToString(static_cast<TDataDecimalType*>(type)->GetParams().first), data.GetStaticType(), args);
  2114. }
  2115. TRuntimeNode TProgramBuilder::Decrement(TRuntimeNode data) {
  2116. const std::array<TRuntimeNode, 1> args = {{ data }};
  2117. bool isOptional;
  2118. const auto type = UnpackOptionalData(data, isOptional);
  2119. if (type->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id)
  2120. return Invoke(__func__, data.GetStaticType(), args);
  2121. return Invoke(TString("Dec_") += ::ToString(static_cast<TDataDecimalType*>(type)->GetParams().first), data.GetStaticType(), args);
  2122. }
  2123. TRuntimeNode TProgramBuilder::Abs(TRuntimeNode data) {
  2124. const std::array<TRuntimeNode, 1> args = {{ data }};
  2125. return Invoke(__func__, data.GetStaticType(), args);
  2126. }
  2127. TRuntimeNode TProgramBuilder::Plus(TRuntimeNode data) {
  2128. const std::array<TRuntimeNode, 1> args = {{ data }};
  2129. return Invoke(__func__, data.GetStaticType(), args);
  2130. }
  2131. TRuntimeNode TProgramBuilder::Minus(TRuntimeNode data) {
  2132. const std::array<TRuntimeNode, 1> args = {{ data }};
  2133. return Invoke(__func__, data.GetStaticType(), args);
  2134. }
  2135. TRuntimeNode TProgramBuilder::Add(TRuntimeNode data1, TRuntimeNode data2) {
  2136. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2137. bool isOptionalLeft;
  2138. const auto leftType = UnpackOptionalData(data1, isOptionalLeft);
  2139. if (leftType->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id)
  2140. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2141. const auto decimalType = static_cast<TDataDecimalType*>(leftType);
  2142. bool isOptionalRight;
  2143. const auto rightType = static_cast<TDataDecimalType*>(UnpackOptionalData(data2, isOptionalRight));
  2144. MKQL_ENSURE(rightType->IsSameType(*decimalType), "Operands type mismatch");
  2145. const auto resultType = isOptionalLeft || isOptionalRight ? NewOptionalType(decimalType) : decimalType;
  2146. return Invoke(TString("Add_") += ::ToString(decimalType->GetParams().first), resultType, args);
  2147. }
  2148. TRuntimeNode TProgramBuilder::Sub(TRuntimeNode data1, TRuntimeNode data2) {
  2149. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2150. bool isOptionalLeft;
  2151. const auto leftType = UnpackOptionalData(data1, isOptionalLeft);
  2152. if (leftType->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id)
  2153. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2154. const auto decimalType = static_cast<TDataDecimalType*>(leftType);
  2155. bool isOptionalRight;
  2156. const auto rightType = static_cast<TDataDecimalType*>(UnpackOptionalData(data2, isOptionalRight));
  2157. MKQL_ENSURE(rightType->IsSameType(*decimalType), "Operands type mismatch");
  2158. const auto resultType = isOptionalLeft || isOptionalRight ? NewOptionalType(decimalType) : decimalType;
  2159. return Invoke(TString("Sub_") += ::ToString(decimalType->GetParams().first), resultType, args);
  2160. }
  2161. TRuntimeNode TProgramBuilder::Mul(TRuntimeNode data1, TRuntimeNode data2) {
  2162. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2163. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2164. }
  2165. TRuntimeNode TProgramBuilder::Div(TRuntimeNode data1, TRuntimeNode data2) {
  2166. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2167. auto resultType = BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType());
  2168. if (resultType->IsData() && !(NUdf::GetDataTypeInfo(*static_cast<TDataType*>(resultType)->GetDataSlot()).Features & (NUdf::EDataTypeFeatures::FloatType | NUdf::EDataTypeFeatures::DecimalType))) {
  2169. resultType = NewOptionalType(resultType);
  2170. }
  2171. return Invoke(__func__, resultType, args);
  2172. }
  2173. TRuntimeNode TProgramBuilder::DecimalDiv(TRuntimeNode data1, TRuntimeNode data2) {
  2174. bool isOptionalLeft, isOptionalRight;
  2175. const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(data1, isOptionalLeft));
  2176. const auto rightType = UnpackOptionalData(data2, isOptionalRight);
  2177. if (rightType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id)
  2178. MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch");
  2179. else
  2180. MKQL_ENSURE(NUdf::GetDataTypeInfo(*rightType->GetDataSlot()).Features & NUdf::IntegralType, "Operands type mismatch");
  2181. const auto returnType = isOptionalLeft || isOptionalRight ? NewOptionalType(leftType) : leftType;
  2182. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2183. callableBuilder.Add(data1);
  2184. callableBuilder.Add(data2);
  2185. return TRuntimeNode(callableBuilder.Build(), false);
  2186. }
  2187. TRuntimeNode TProgramBuilder::DecimalMod(TRuntimeNode data1, TRuntimeNode data2) {
  2188. bool isOptionalLeft, isOptionalRight;
  2189. const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(data1, isOptionalLeft));
  2190. const auto rightType = UnpackOptionalData(data2, isOptionalRight);
  2191. if (rightType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id)
  2192. MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch");
  2193. else
  2194. MKQL_ENSURE(NUdf::GetDataTypeInfo(*rightType->GetDataSlot()).Features & NUdf::IntegralType, "Operands type mismatch");
  2195. const auto returnType = isOptionalLeft || isOptionalRight ? NewOptionalType(leftType) : leftType;
  2196. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2197. callableBuilder.Add(data1);
  2198. callableBuilder.Add(data2);
  2199. return TRuntimeNode(callableBuilder.Build(), false);
  2200. }
  2201. TRuntimeNode TProgramBuilder::DecimalMul(TRuntimeNode data1, TRuntimeNode data2) {
  2202. bool isOptionalLeft, isOptionalRight;
  2203. const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(data1, isOptionalLeft));
  2204. const auto rightType = UnpackOptionalData(data2, isOptionalRight);
  2205. if (rightType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id)
  2206. MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch");
  2207. else
  2208. MKQL_ENSURE(NUdf::GetDataTypeInfo(*rightType->GetDataSlot()).Features & NUdf::IntegralType, "Operands type mismatch");
  2209. const auto returnType = isOptionalLeft || isOptionalRight ? NewOptionalType(leftType) : leftType;
  2210. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2211. callableBuilder.Add(data1);
  2212. callableBuilder.Add(data2);
  2213. return TRuntimeNode(callableBuilder.Build(), false);
  2214. }
  2215. TRuntimeNode TProgramBuilder::AllOf(TRuntimeNode list, const TUnaryLambda& predicate) {
  2216. return Not(NotAllOf(list, predicate));
  2217. }
  2218. TRuntimeNode TProgramBuilder::NotAllOf(TRuntimeNode list, const TUnaryLambda& predicate) {
  2219. return Exists(ToOptional(SkipWhile(list, predicate)));
  2220. }
  2221. TRuntimeNode TProgramBuilder::BitNot(TRuntimeNode data) {
  2222. const std::array<TRuntimeNode, 1> args = {{ data }};
  2223. return Invoke(__func__, data.GetStaticType(), args);
  2224. }
  2225. TRuntimeNode TProgramBuilder::CountBits(TRuntimeNode data) {
  2226. const std::array<TRuntimeNode, 1> args = {{ data }};
  2227. return Invoke(__func__, data.GetStaticType(), args);
  2228. }
  2229. TRuntimeNode TProgramBuilder::BitAnd(TRuntimeNode data1, TRuntimeNode data2) {
  2230. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2231. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2232. }
  2233. TRuntimeNode TProgramBuilder::BitOr(TRuntimeNode data1, TRuntimeNode data2) {
  2234. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2235. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2236. }
  2237. TRuntimeNode TProgramBuilder::BitXor(TRuntimeNode data1, TRuntimeNode data2) {
  2238. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2239. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2240. }
  2241. TRuntimeNode TProgramBuilder::ShiftLeft(TRuntimeNode arg, TRuntimeNode bits) {
  2242. const std::array<TRuntimeNode, 2> args = {{ arg, bits }};
  2243. return Invoke(__func__, arg.GetStaticType(), args);
  2244. }
  2245. TRuntimeNode TProgramBuilder::RotLeft(TRuntimeNode arg, TRuntimeNode bits) {
  2246. const std::array<TRuntimeNode, 2> args = {{ arg, bits }};
  2247. return Invoke(__func__, arg.GetStaticType(), args);
  2248. }
  2249. TRuntimeNode TProgramBuilder::ShiftRight(TRuntimeNode arg, TRuntimeNode bits) {
  2250. const std::array<TRuntimeNode, 2> args = {{ arg, bits }};
  2251. return Invoke(__func__, arg.GetStaticType(), args);
  2252. }
  2253. TRuntimeNode TProgramBuilder::RotRight(TRuntimeNode arg, TRuntimeNode bits) {
  2254. const std::array<TRuntimeNode, 2> args = {{ arg, bits }};
  2255. return Invoke(__func__, arg.GetStaticType(), args);
  2256. }
  2257. TRuntimeNode TProgramBuilder::Mod(TRuntimeNode data1, TRuntimeNode data2) {
  2258. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2259. auto resultType = BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType());
  2260. if (resultType->IsData() && !(NUdf::GetDataTypeInfo(*static_cast<TDataType*>(resultType)->GetDataSlot()).Features & (NUdf::EDataTypeFeatures::FloatType | NUdf::EDataTypeFeatures::DecimalType))) {
  2261. resultType = NewOptionalType(resultType);
  2262. }
  2263. return Invoke(__func__, resultType, args);
  2264. }
  2265. TRuntimeNode TProgramBuilder::BuildMinMax(const std::string_view& callableName, const TRuntimeNode* data, size_t size) {
  2266. switch (size) {
  2267. case 0U: return NewNull();
  2268. case 1U: return *data;
  2269. case 2U: return InvokeBinary(callableName, ChooseCommonType(data[0U].GetStaticType(), data[1U].GetStaticType()), data[0U], data[1U]);
  2270. default: break;
  2271. }
  2272. const auto half = size >> 1U;
  2273. const std::array<TRuntimeNode, 2U> args = {{ BuildMinMax(callableName, data, half), BuildMinMax(callableName, data + half, size - half) }};
  2274. return BuildMinMax(callableName, args.data(), args.size());
  2275. }
  2276. TRuntimeNode TProgramBuilder::BuildWideSkipTakeBlocks(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count) {
  2277. ValidateBlockFlowType(flow.GetStaticType());
  2278. MKQL_ENSURE(count.GetStaticType()->IsData(), "Expected data");
  2279. MKQL_ENSURE(static_cast<const TDataType&>(*count.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  2280. TCallableBuilder callableBuilder(Env, callableName, flow.GetStaticType());
  2281. callableBuilder.Add(flow);
  2282. callableBuilder.Add(count);
  2283. return TRuntimeNode(callableBuilder.Build(), false);
  2284. }
  2285. TRuntimeNode TProgramBuilder::BuildBlockLogical(const std::string_view& callableName, TRuntimeNode first, TRuntimeNode second) {
  2286. auto firstType = AS_TYPE(TBlockType, first.GetStaticType());
  2287. auto secondType = AS_TYPE(TBlockType, second.GetStaticType());
  2288. bool isOpt1, isOpt2;
  2289. MKQL_ENSURE(UnpackOptionalData(firstType->GetItemType(), isOpt1)->GetSchemeType() == NUdf::TDataType<bool>::Id, "Requires boolean args.");
  2290. MKQL_ENSURE(UnpackOptionalData(secondType->GetItemType(), isOpt2)->GetSchemeType() == NUdf::TDataType<bool>::Id, "Requires boolean args.");
  2291. const auto itemType = NewDataType(NUdf::TDataType<bool>::Id, isOpt1 || isOpt2);
  2292. auto outputType = NewBlockType(itemType, GetResultShape({firstType, secondType}));
  2293. TCallableBuilder callableBuilder(Env, callableName, outputType);
  2294. callableBuilder.Add(first);
  2295. callableBuilder.Add(second);
  2296. return TRuntimeNode(callableBuilder.Build(), false);
  2297. }
  2298. TRuntimeNode TProgramBuilder::BuildBlockDecimalBinary(const std::string_view& callableName, TRuntimeNode first, TRuntimeNode second) {
  2299. auto firstType = AS_TYPE(TBlockType, first.GetStaticType());
  2300. auto secondType = AS_TYPE(TBlockType, second.GetStaticType());
  2301. bool isOpt1, isOpt2;
  2302. auto* leftDataType = UnpackOptionalData(firstType->GetItemType(), isOpt1);
  2303. UnpackOptionalData(secondType->GetItemType(), isOpt2);
  2304. MKQL_ENSURE(leftDataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id, "Requires decimal args.");
  2305. const auto& lParams = static_cast<TDataDecimalType*>(leftDataType)->GetParams();
  2306. auto [precision, scale] = lParams;
  2307. TType* outputType = TDataDecimalType::Create(precision, scale, Env);
  2308. if (isOpt1 || isOpt2) {
  2309. outputType = TOptionalType::Create(outputType, Env);
  2310. }
  2311. outputType = NewBlockType(outputType, TBlockType::EShape::Many);
  2312. TCallableBuilder callableBuilder(Env, callableName, outputType);
  2313. callableBuilder.Add(first);
  2314. callableBuilder.Add(second);
  2315. return TRuntimeNode(callableBuilder.Build(), false);
  2316. }
  2317. TRuntimeNode TProgramBuilder::Min(const TArrayRef<const TRuntimeNode>& args) {
  2318. return BuildMinMax(__func__, args.data(), args.size());
  2319. }
  2320. TRuntimeNode TProgramBuilder::Max(const TArrayRef<const TRuntimeNode>& args) {
  2321. return BuildMinMax(__func__, args.data(), args.size());
  2322. }
  2323. TRuntimeNode TProgramBuilder::Min(TRuntimeNode data1, TRuntimeNode data2) {
  2324. const std::array<TRuntimeNode, 2U> args = {{ data1, data2 }};
  2325. return Min(args);
  2326. }
  2327. TRuntimeNode TProgramBuilder::Max(TRuntimeNode data1, TRuntimeNode data2) {
  2328. const std::array<TRuntimeNode, 2U> args = {{ data1, data2 }};
  2329. return Max(args);
  2330. }
  2331. TRuntimeNode TProgramBuilder::Equals(TRuntimeNode data1, TRuntimeNode data2) {
  2332. return DataCompare(__func__, data1, data2);
  2333. }
  2334. TRuntimeNode TProgramBuilder::NotEquals(TRuntimeNode data1, TRuntimeNode data2) {
  2335. return DataCompare(__func__, data1, data2);
  2336. }
  2337. TRuntimeNode TProgramBuilder::Less(TRuntimeNode data1, TRuntimeNode data2) {
  2338. return DataCompare(__func__, data1, data2);
  2339. }
  2340. TRuntimeNode TProgramBuilder::LessOrEqual(TRuntimeNode data1, TRuntimeNode data2) {
  2341. return DataCompare(__func__, data1, data2);
  2342. }
  2343. TRuntimeNode TProgramBuilder::Greater(TRuntimeNode data1, TRuntimeNode data2) {
  2344. return DataCompare(__func__, data1, data2);
  2345. }
  2346. TRuntimeNode TProgramBuilder::GreaterOrEqual(TRuntimeNode data1, TRuntimeNode data2) {
  2347. return DataCompare(__func__, data1, data2);
  2348. }
  2349. TRuntimeNode TProgramBuilder::InvokeBinary(const std::string_view& callableName, TType* type, TRuntimeNode data1, TRuntimeNode data2) {
  2350. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2351. return Invoke(callableName, type, args);
  2352. }
  2353. TRuntimeNode TProgramBuilder::AggrCompare(const std::string_view& callableName, TRuntimeNode data1, TRuntimeNode data2) {
  2354. return InvokeBinary(callableName, NewDataType(NUdf::TDataType<bool>::Id), data1, data2);
  2355. }
  2356. TRuntimeNode TProgramBuilder::DataCompare(const std::string_view& callableName, TRuntimeNode left, TRuntimeNode right) {
  2357. bool isOptionalLeft, isOptionalRight;
  2358. const auto leftType = UnpackOptionalData(left, isOptionalLeft);
  2359. const auto rightType = UnpackOptionalData(right, isOptionalRight);
  2360. const auto lId = leftType->GetSchemeType();
  2361. const auto rId = rightType->GetSchemeType();
  2362. if (lId == NUdf::TDataType<NUdf::TDecimal>::Id && rId == NUdf::TDataType<NUdf::TDecimal>::Id) {
  2363. const auto& lDec = static_cast<TDataDecimalType*>(leftType)->GetParams();
  2364. const auto& rDec = static_cast<TDataDecimalType*>(rightType)->GetParams();
  2365. if (lDec.second < rDec.second) {
  2366. left = ToDecimal(left, std::min<ui8>(lDec.first + rDec.second - lDec.second, NYql::NDecimal::MaxPrecision), rDec.second);
  2367. } else if (lDec.second > rDec.second) {
  2368. right = ToDecimal(right, std::min<ui8>(rDec.first + lDec.second - rDec.second, NYql::NDecimal::MaxPrecision), lDec.second);
  2369. }
  2370. } else if (lId == NUdf::TDataType<NUdf::TDecimal>::Id && NUdf::GetDataTypeInfo(NUdf::GetDataSlot(rId)).Features & NUdf::EDataTypeFeatures::IntegralType) {
  2371. const auto scale = static_cast<TDataDecimalType*>(leftType)->GetParams().second;
  2372. right = ToDecimal(right, std::min<ui8>(NYql::NDecimal::MaxPrecision, NUdf::GetDataTypeInfo(NUdf::GetDataSlot(rId)).DecimalDigits + scale), scale);
  2373. } else if (rId == NUdf::TDataType<NUdf::TDecimal>::Id && NUdf::GetDataTypeInfo(NUdf::GetDataSlot(lId)).Features & NUdf::EDataTypeFeatures::IntegralType) {
  2374. const auto scale = static_cast<TDataDecimalType*>(rightType)->GetParams().second;
  2375. left = ToDecimal(left, std::min<ui8>(NYql::NDecimal::MaxPrecision, NUdf::GetDataTypeInfo(NUdf::GetDataSlot(lId)).DecimalDigits + scale), scale);
  2376. }
  2377. const std::array<TRuntimeNode, 2> args = {{ left, right }};
  2378. const auto resultType = isOptionalLeft || isOptionalRight ? NewOptionalType(NewDataType(NUdf::TDataType<bool>::Id)) : NewDataType(NUdf::TDataType<bool>::Id);
  2379. return Invoke(callableName, resultType, args);
  2380. }
  2381. TRuntimeNode TProgramBuilder::BuildRangeLogical(const std::string_view& callableName, const TArrayRef<const TRuntimeNode>& lists) {
  2382. MKQL_ENSURE(!lists.empty(), "Expecting at least one argument");
  2383. for (auto& list : lists) {
  2384. MKQL_ENSURE(list.GetStaticType()->IsList(), "Expecting lists");
  2385. MKQL_ENSURE(list.GetStaticType()->IsSameType(*lists.front().GetStaticType()), "Expecting arguments of same type");
  2386. }
  2387. TCallableBuilder callableBuilder(Env, callableName, lists.front().GetStaticType());
  2388. for (auto& list : lists) {
  2389. callableBuilder.Add(list);
  2390. }
  2391. return TRuntimeNode(callableBuilder.Build(), false);
  2392. }
  2393. TRuntimeNode TProgramBuilder::AggrEquals(TRuntimeNode data1, TRuntimeNode data2) {
  2394. return AggrCompare(__func__, data1, data2);
  2395. }
  2396. TRuntimeNode TProgramBuilder::AggrNotEquals(TRuntimeNode data1, TRuntimeNode data2) {
  2397. return AggrCompare(__func__, data1, data2);
  2398. }
  2399. TRuntimeNode TProgramBuilder::AggrLess(TRuntimeNode data1, TRuntimeNode data2) {
  2400. return AggrCompare(__func__, data1, data2);
  2401. }
  2402. TRuntimeNode TProgramBuilder::AggrLessOrEqual(TRuntimeNode data1, TRuntimeNode data2) {
  2403. return AggrCompare(__func__, data1, data2);
  2404. }
  2405. TRuntimeNode TProgramBuilder::AggrGreater(TRuntimeNode data1, TRuntimeNode data2) {
  2406. return AggrCompare(__func__, data1, data2);
  2407. }
  2408. TRuntimeNode TProgramBuilder::AggrGreaterOrEqual(TRuntimeNode data1, TRuntimeNode data2) {
  2409. return AggrCompare(__func__, data1, data2);
  2410. }
  2411. TRuntimeNode TProgramBuilder::If(TRuntimeNode condition, TRuntimeNode thenBranch, TRuntimeNode elseBranch) {
  2412. bool condOpt, thenOpt, elseOpt;
  2413. const auto conditionType = UnpackOptionalData(condition, condOpt);
  2414. MKQL_ENSURE(conditionType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool");
  2415. const auto thenUnpacked = UnpackOptional(thenBranch, thenOpt);
  2416. const auto elseUnpacked = UnpackOptional(elseBranch, elseOpt);
  2417. MKQL_ENSURE(thenUnpacked->IsSameType(*elseUnpacked), "Different return types in branches.");
  2418. const bool isOptional = condOpt || thenOpt || elseOpt;
  2419. TCallableBuilder callableBuilder(Env, __func__, isOptional ? NewOptionalType(thenUnpacked) : thenUnpacked);
  2420. callableBuilder.Add(condition);
  2421. callableBuilder.Add(thenBranch);
  2422. callableBuilder.Add(elseBranch);
  2423. return TRuntimeNode(callableBuilder.Build(), false);
  2424. }
  2425. TRuntimeNode TProgramBuilder::If(const TArrayRef<const TRuntimeNode>& args) {
  2426. MKQL_ENSURE(args.size() % 2U, "Expected odd arguments.");
  2427. MKQL_ENSURE(args.size() >= 3U, "Expected at least three arguments.");
  2428. return If(args.front(), args[1U], 3U == args.size() ? args.back() : If(args.last(args.size() - 2U)));
  2429. }
  2430. TRuntimeNode TProgramBuilder::If(TRuntimeNode condition, TRuntimeNode thenBranch, TRuntimeNode elseBranch, TType* resultType) {
  2431. bool condOpt;
  2432. const auto conditionType = UnpackOptionalData(condition, condOpt);
  2433. MKQL_ENSURE(conditionType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool");
  2434. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2435. callableBuilder.Add(condition);
  2436. callableBuilder.Add(thenBranch);
  2437. callableBuilder.Add(elseBranch);
  2438. return TRuntimeNode(callableBuilder.Build(), false);
  2439. }
  2440. TRuntimeNode TProgramBuilder::Ensure(TRuntimeNode value, TRuntimeNode predicate, TRuntimeNode message, const std::string_view& file, ui32 row, ui32 column) {
  2441. bool isOptional;
  2442. const auto unpackedType = UnpackOptionalData(predicate, isOptional);
  2443. MKQL_ENSURE(unpackedType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool");
  2444. const auto& messageType = message.GetStaticType();
  2445. MKQL_ENSURE(messageType->IsData(), "Expected data");
  2446. const auto& messageTypeData = static_cast<const TDataType&>(*messageType);
  2447. MKQL_ENSURE(messageTypeData.GetSchemeType() == NUdf::TDataType<char*>::Id || messageTypeData.GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected string or utf8.");
  2448. TCallableBuilder callableBuilder(Env, __func__, value.GetStaticType());
  2449. callableBuilder.Add(value);
  2450. callableBuilder.Add(predicate);
  2451. callableBuilder.Add(message);
  2452. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  2453. callableBuilder.Add(NewDataLiteral(row));
  2454. callableBuilder.Add(NewDataLiteral(column));
  2455. return TRuntimeNode(callableBuilder.Build(), false);
  2456. }
  2457. TRuntimeNode TProgramBuilder::SourceOf(TType* returnType) {
  2458. MKQL_ENSURE(returnType->IsFlow() || returnType->IsStream(), "Expected flow or stream.");
  2459. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2460. return TRuntimeNode(callableBuilder.Build(), false);
  2461. }
  2462. TRuntimeNode TProgramBuilder::Source() {
  2463. if constexpr (RuntimeVersion < 18U) {
  2464. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  2465. }
  2466. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType({})));
  2467. return TRuntimeNode(callableBuilder.Build(), false);
  2468. }
  2469. TRuntimeNode TProgramBuilder::IfPresent(TRuntimeNode optional, const TUnaryLambda& thenBranch, TRuntimeNode elseBranch) {
  2470. bool isOptional;
  2471. const auto unpackedType = UnpackOptional(optional, isOptional);
  2472. if (!isOptional) {
  2473. return thenBranch(optional);
  2474. }
  2475. const auto itemArg = Arg(unpackedType);
  2476. const auto then = thenBranch(itemArg);
  2477. bool thenOpt, elseOpt;
  2478. const auto thenUnpacked = UnpackOptional(then, thenOpt);
  2479. const auto elseUnpacked = UnpackOptional(elseBranch, elseOpt);
  2480. MKQL_ENSURE(thenUnpacked->IsSameType(*elseUnpacked), "Different return types in branches.");
  2481. TCallableBuilder callableBuilder(Env, __func__, (thenOpt || elseOpt) ? NewOptionalType(thenUnpacked) : thenUnpacked);
  2482. callableBuilder.Add(optional);
  2483. callableBuilder.Add(itemArg);
  2484. callableBuilder.Add(then);
  2485. callableBuilder.Add(elseBranch);
  2486. return TRuntimeNode(callableBuilder.Build(), false);
  2487. }
  2488. TRuntimeNode TProgramBuilder::IfPresent(TRuntimeNode::TList optionals, const TNarrowLambda& thenBranch, TRuntimeNode elseBranch) {
  2489. switch (optionals.size()) {
  2490. case 0U:
  2491. return thenBranch({});
  2492. case 1U:
  2493. return IfPresent(optionals.front(), [&](TRuntimeNode unwrap){ return thenBranch({unwrap}); }, elseBranch);
  2494. default:
  2495. break;
  2496. }
  2497. const auto first = optionals.front();
  2498. optionals.erase(optionals.cbegin());
  2499. return IfPresent(first,
  2500. [&](TRuntimeNode head) {
  2501. return IfPresent(optionals,
  2502. [&](TRuntimeNode::TList tail) {
  2503. tail.insert(tail.cbegin(), head);
  2504. return thenBranch(tail);
  2505. },
  2506. elseBranch
  2507. );
  2508. },
  2509. elseBranch
  2510. );
  2511. }
  2512. TRuntimeNode TProgramBuilder::Not(TRuntimeNode data) {
  2513. return UnaryDataFunction(data, __func__, TDataFunctionFlags::CommonOptionalResult | TDataFunctionFlags::RequiresBooleanArgs | TDataFunctionFlags::AllowOptionalArgs);
  2514. }
  2515. TRuntimeNode TProgramBuilder::BuildBinaryLogical(const std::string_view& callableName, TRuntimeNode data1, TRuntimeNode data2) {
  2516. bool isOpt1, isOpt2;
  2517. MKQL_ENSURE(UnpackOptionalData(data1, isOpt1)->GetSchemeType() == NUdf::TDataType<bool>::Id, "Requires boolean args.");
  2518. MKQL_ENSURE(UnpackOptionalData(data2, isOpt2)->GetSchemeType() == NUdf::TDataType<bool>::Id, "Requires boolean args.");
  2519. const auto resultType = NewDataType(NUdf::TDataType<bool>::Id, isOpt1 || isOpt2);
  2520. TCallableBuilder callableBuilder(Env, callableName, resultType);
  2521. callableBuilder.Add(data1);
  2522. callableBuilder.Add(data2);
  2523. return TRuntimeNode(callableBuilder.Build(), false);
  2524. }
  2525. TRuntimeNode TProgramBuilder::BuildLogical(const std::string_view& callableName, const TArrayRef<const TRuntimeNode>& args) {
  2526. MKQL_ENSURE(!args.empty(), "Empty logical args.");
  2527. switch (args.size()) {
  2528. case 1U: return args.front();
  2529. case 2U: return BuildBinaryLogical(callableName, args.front(), args.back());
  2530. }
  2531. const auto half = (args.size() + 1U) >> 1U;
  2532. const TArrayRef<const TRuntimeNode> one(args.data(), half), two(args.data() + half, args.size() - half);
  2533. return BuildBinaryLogical(callableName, BuildLogical(callableName, one), BuildLogical(callableName, two));
  2534. }
  2535. TRuntimeNode TProgramBuilder::And(const TArrayRef<const TRuntimeNode>& args) {
  2536. return BuildLogical(__func__, args);
  2537. }
  2538. TRuntimeNode TProgramBuilder::Or(const TArrayRef<const TRuntimeNode>& args) {
  2539. return BuildLogical(__func__, args);
  2540. }
  2541. TRuntimeNode TProgramBuilder::Xor(const TArrayRef<const TRuntimeNode>& args) {
  2542. return BuildLogical(__func__, args);
  2543. }
  2544. TRuntimeNode TProgramBuilder::Exists(TRuntimeNode data) {
  2545. const auto& nodeType = data.GetStaticType();
  2546. if (nodeType->IsVoid()) {
  2547. return NewDataLiteral(false);
  2548. }
  2549. if (!nodeType->IsOptional() && !nodeType->IsPg()) {
  2550. return NewDataLiteral(true);
  2551. }
  2552. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<bool>::Id));
  2553. callableBuilder.Add(data);
  2554. return TRuntimeNode(callableBuilder.Build(), false);
  2555. }
  2556. TRuntimeNode TProgramBuilder::NewMTRand(TRuntimeNode seed) {
  2557. auto seedData = AS_TYPE(TDataType, seed);
  2558. MKQL_ENSURE(seedData->GetSchemeType() == NUdf::TDataType<ui64>::Id, "seed must be ui64");
  2559. TCallableBuilder callableBuilder(Env, __func__, NewResourceType(RandomMTResource), true);
  2560. callableBuilder.Add(seed);
  2561. return TRuntimeNode(callableBuilder.Build(), false);
  2562. }
  2563. TRuntimeNode TProgramBuilder::NextMTRand(TRuntimeNode rand) {
  2564. auto resType = AS_TYPE(TResourceType, rand);
  2565. MKQL_ENSURE(resType->GetTag() == RandomMTResource, "Expected MTRand resource");
  2566. const std::array<TType*, 2U> tupleTypes = {{ NewDataType(NUdf::TDataType<ui64>::Id), rand.GetStaticType() }};
  2567. auto returnType = NewTupleType(tupleTypes);
  2568. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2569. callableBuilder.Add(rand);
  2570. return TRuntimeNode(callableBuilder.Build(), false);
  2571. }
  2572. TRuntimeNode TProgramBuilder::AggrCountInit(TRuntimeNode value) {
  2573. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<ui64>::Id));
  2574. callableBuilder.Add(value);
  2575. return TRuntimeNode(callableBuilder.Build(), false);
  2576. }
  2577. TRuntimeNode TProgramBuilder::AggrCountUpdate(TRuntimeNode value, TRuntimeNode state) {
  2578. MKQL_ENSURE(AS_TYPE(TDataType, state)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64 type");
  2579. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<ui64>::Id));
  2580. callableBuilder.Add(value);
  2581. callableBuilder.Add(state);
  2582. return TRuntimeNode(callableBuilder.Build(), false);
  2583. }
  2584. TRuntimeNode TProgramBuilder::AggrMin(TRuntimeNode data1, TRuntimeNode data2) {
  2585. const auto type = data1.GetStaticType();
  2586. MKQL_ENSURE(type->IsSameType(*data2.GetStaticType()), "Must be same type.");
  2587. return InvokeBinary(__func__, type, data1, data2);
  2588. }
  2589. TRuntimeNode TProgramBuilder::AggrMax(TRuntimeNode data1, TRuntimeNode data2) {
  2590. const auto type = data1.GetStaticType();
  2591. MKQL_ENSURE(type->IsSameType(*data2.GetStaticType()), "Must be same type.");
  2592. return InvokeBinary(__func__, type, data1, data2);
  2593. }
  2594. TRuntimeNode TProgramBuilder::AggrAdd(TRuntimeNode data1, TRuntimeNode data2) {
  2595. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2596. bool isOptionalLeft;
  2597. const auto leftType = UnpackOptionalData(data1, isOptionalLeft);
  2598. if (leftType->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id)
  2599. return Invoke(__func__, data1.GetStaticType(), args);
  2600. const auto decimalType = static_cast<TDataDecimalType*>(leftType);
  2601. bool isOptionalRight;
  2602. const auto rightType = static_cast<TDataDecimalType*>(UnpackOptionalData(data2, isOptionalRight));
  2603. MKQL_ENSURE(rightType->IsSameType(*decimalType), "Operands type mismatch");
  2604. return Invoke(TString("AggrAdd_") += ::ToString(decimalType->GetParams().first), data1.GetStaticType(), args);
  2605. }
  2606. TRuntimeNode TProgramBuilder::QueueCreate(TRuntimeNode initCapacity, TRuntimeNode initSize, const TArrayRef<const TRuntimeNode>& dependentNodes, TType* returnType) {
  2607. auto resType = AS_TYPE(TResourceType, returnType);
  2608. const auto tag = resType->GetTag();
  2609. if (initCapacity.GetStaticType()->IsVoid()) {
  2610. MKQL_ENSURE(RuntimeVersion >= 13, "Unbounded queue is not supported in runtime version " << RuntimeVersion);
  2611. } else {
  2612. auto initCapacityType = AS_TYPE(TDataType, initCapacity);
  2613. MKQL_ENSURE(initCapacityType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "init capcity must be ui64");
  2614. }
  2615. auto initSizeType = AS_TYPE(TDataType, initSize);
  2616. MKQL_ENSURE(initSizeType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "init size must be ui64");
  2617. TCallableBuilder callableBuilder(Env, __func__, returnType, true);
  2618. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(tag));
  2619. callableBuilder.Add(initCapacity);
  2620. callableBuilder.Add(initSize);
  2621. for (auto node : dependentNodes) {
  2622. callableBuilder.Add(node);
  2623. }
  2624. return TRuntimeNode(callableBuilder.Build(), false);
  2625. }
  2626. TRuntimeNode TProgramBuilder::QueuePush(TRuntimeNode resource, TRuntimeNode value) {
  2627. auto resType = AS_TYPE(TResourceType, resource);
  2628. const auto tag = resType->GetTag();
  2629. MKQL_ENSURE(tag.StartsWith(ResourceQueuePrefix), "Expected Queue resource");
  2630. TCallableBuilder callableBuilder(Env, __func__, resource.GetStaticType());
  2631. callableBuilder.Add(resource);
  2632. callableBuilder.Add(value);
  2633. return TRuntimeNode(callableBuilder.Build(), false);
  2634. }
  2635. TRuntimeNode TProgramBuilder::QueuePop(TRuntimeNode resource) {
  2636. auto resType = AS_TYPE(TResourceType, resource);
  2637. const auto tag = resType->GetTag();
  2638. MKQL_ENSURE(tag.StartsWith(ResourceQueuePrefix), "Expected Queue resource");
  2639. TCallableBuilder callableBuilder(Env, __func__, resource.GetStaticType());
  2640. callableBuilder.Add(resource);
  2641. return TRuntimeNode(callableBuilder.Build(), false);
  2642. }
  2643. TRuntimeNode TProgramBuilder::QueuePeek(TRuntimeNode resource, TRuntimeNode index, const TArrayRef<const TRuntimeNode>& dependentNodes, TType* returnType) {
  2644. MKQL_ENSURE(returnType->IsOptional(), "Expected optional type as result of QueuePeek");
  2645. auto resType = AS_TYPE(TResourceType, resource);
  2646. auto indexType = AS_TYPE(TDataType, index);
  2647. MKQL_ENSURE(indexType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "index size must be ui64");
  2648. const auto tag = resType->GetTag();
  2649. MKQL_ENSURE(tag.StartsWith(ResourceQueuePrefix), "Expected Queue resource");
  2650. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2651. callableBuilder.Add(resource);
  2652. callableBuilder.Add(index);
  2653. for (auto node : dependentNodes) {
  2654. callableBuilder.Add(node);
  2655. }
  2656. return TRuntimeNode(callableBuilder.Build(), false);
  2657. }
  2658. TRuntimeNode TProgramBuilder::QueueRange(TRuntimeNode resource, TRuntimeNode begin, TRuntimeNode end, const TArrayRef<const TRuntimeNode>& dependentNodes, TType* returnType) {
  2659. MKQL_ENSURE(RuntimeVersion >= 14, "QueueRange is not supported in runtime version " << RuntimeVersion);
  2660. MKQL_ENSURE(returnType->IsList(), "Expected list type as result of QueueRange");
  2661. auto resType = AS_TYPE(TResourceType, resource);
  2662. auto beginType = AS_TYPE(TDataType, begin);
  2663. MKQL_ENSURE(beginType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "begin index must be ui64");
  2664. auto endType = AS_TYPE(TDataType, end);
  2665. MKQL_ENSURE(endType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "end index must be ui64");
  2666. const auto tag = resType->GetTag();
  2667. MKQL_ENSURE(tag.StartsWith(ResourceQueuePrefix), "Expected Queue resource");
  2668. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2669. callableBuilder.Add(resource);
  2670. callableBuilder.Add(begin);
  2671. callableBuilder.Add(end);
  2672. for (auto node : dependentNodes) {
  2673. callableBuilder.Add(node);
  2674. }
  2675. return TRuntimeNode(callableBuilder.Build(), false);
  2676. }
  2677. TRuntimeNode TProgramBuilder::PreserveStream(TRuntimeNode stream, TRuntimeNode queue, TRuntimeNode outpace) {
  2678. auto streamType = AS_TYPE(TStreamType, stream);
  2679. auto resType = AS_TYPE(TResourceType, queue);
  2680. auto outpaceType = AS_TYPE(TDataType, outpace);
  2681. MKQL_ENSURE(outpaceType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "PreserveStream: outpace size must be ui64");
  2682. const auto tag = resType->GetTag();
  2683. MKQL_ENSURE(tag.StartsWith(ResourceQueuePrefix), "PreserveStream: Expected Queue resource");
  2684. TCallableBuilder callableBuilder(Env, __func__, streamType);
  2685. callableBuilder.Add(stream);
  2686. callableBuilder.Add(queue);
  2687. callableBuilder.Add(outpace);
  2688. return TRuntimeNode(callableBuilder.Build(), false);
  2689. }
  2690. TRuntimeNode TProgramBuilder::Seq(const TArrayRef<const TRuntimeNode>& args, TType* returnType) {
  2691. MKQL_ENSURE(RuntimeVersion >= 15, "Seq is not supported in runtime version " << RuntimeVersion);
  2692. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2693. for (auto node : args) {
  2694. callableBuilder.Add(node);
  2695. }
  2696. return TRuntimeNode(callableBuilder.Build(), false);
  2697. }
  2698. TRuntimeNode TProgramBuilder::FromYsonSimpleType(TRuntimeNode input, NUdf::TDataTypeId schemeType) {
  2699. auto type = input.GetStaticType();
  2700. if (type->IsOptional()) {
  2701. type = static_cast<const TOptionalType&>(*type).GetItemType();
  2702. }
  2703. MKQL_ENSURE(type->IsData(), "Expected data type");
  2704. auto resDataType = NewDataType(schemeType);
  2705. auto resultType = NewOptionalType(resDataType);
  2706. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2707. callableBuilder.Add(input);
  2708. callableBuilder.Add(NewDataLiteral(static_cast<ui32>(schemeType)));
  2709. return TRuntimeNode(callableBuilder.Build(), false);
  2710. }
  2711. TRuntimeNode TProgramBuilder::TryWeakMemberFromDict(TRuntimeNode other, TRuntimeNode rest, NUdf::TDataTypeId schemeType, const std::string_view& memberName) {
  2712. auto resDataType = NewDataType(schemeType);
  2713. auto resultType = NewOptionalType(resDataType);
  2714. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2715. callableBuilder.Add(other);
  2716. callableBuilder.Add(rest);
  2717. callableBuilder.Add(NewDataLiteral(static_cast<ui32>(schemeType)));
  2718. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(memberName));
  2719. return TRuntimeNode(callableBuilder.Build(), false);
  2720. }
  2721. TRuntimeNode TProgramBuilder::TimezoneId(TRuntimeNode name) {
  2722. bool isOptional;
  2723. auto dataType = UnpackOptionalData(name, isOptional);
  2724. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expected string");
  2725. auto resultType = NewOptionalType(NewDataType(NUdf::EDataSlot::Uint16));
  2726. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2727. callableBuilder.Add(name);
  2728. return TRuntimeNode(callableBuilder.Build(), false);
  2729. }
  2730. TRuntimeNode TProgramBuilder::TimezoneName(TRuntimeNode id) {
  2731. bool isOptional;
  2732. auto dataType = UnpackOptionalData(id, isOptional);
  2733. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<ui16>::Id, "Expected ui32");
  2734. auto resultType = NewOptionalType(NewDataType(NUdf::EDataSlot::String));
  2735. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2736. callableBuilder.Add(id);
  2737. return TRuntimeNode(callableBuilder.Build(), false);
  2738. }
  2739. TRuntimeNode TProgramBuilder::AddTimezone(TRuntimeNode utc, TRuntimeNode id) {
  2740. bool isOptional1;
  2741. auto dataType1 = UnpackOptionalData(utc, isOptional1);
  2742. MKQL_ENSURE(NUdf::GetDataTypeInfo(*dataType1->GetDataSlot()).Features & NUdf::DateType, "Expected date type");
  2743. bool isOptional2;
  2744. auto dataType2 = UnpackOptionalData(id, isOptional2);
  2745. MKQL_ENSURE(dataType2->GetSchemeType() == NUdf::TDataType<ui16>::Id, "Expected ui16");
  2746. NUdf::EDataSlot tzType;
  2747. switch (*dataType1->GetDataSlot()) {
  2748. case NUdf::EDataSlot::Date: tzType = NUdf::EDataSlot::TzDate; break;
  2749. case NUdf::EDataSlot::Datetime: tzType = NUdf::EDataSlot::TzDatetime; break;
  2750. case NUdf::EDataSlot::Timestamp: tzType = NUdf::EDataSlot::TzTimestamp; break;
  2751. case NUdf::EDataSlot::Date32: tzType = NUdf::EDataSlot::TzDate32; break;
  2752. case NUdf::EDataSlot::Datetime64: tzType = NUdf::EDataSlot::TzDatetime64; break;
  2753. case NUdf::EDataSlot::Timestamp64: tzType = NUdf::EDataSlot::TzTimestamp64; break;
  2754. default:
  2755. ythrow yexception() << "Unknown date type: " << *dataType1->GetDataSlot();
  2756. }
  2757. auto resultType = NewOptionalType(NewDataType(tzType));
  2758. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2759. callableBuilder.Add(utc);
  2760. callableBuilder.Add(id);
  2761. return TRuntimeNode(callableBuilder.Build(), false);
  2762. }
  2763. TRuntimeNode TProgramBuilder::RemoveTimezone(TRuntimeNode local) {
  2764. bool isOptional1;
  2765. const auto dataType1 = UnpackOptionalData(local, isOptional1);
  2766. MKQL_ENSURE((NUdf::GetDataTypeInfo(*dataType1->GetDataSlot()).Features & NUdf::TzDateType), "Expected date with timezone type");
  2767. NUdf::EDataSlot type;
  2768. switch (*dataType1->GetDataSlot()) {
  2769. case NUdf::EDataSlot::TzDate: type = NUdf::EDataSlot::Date; break;
  2770. case NUdf::EDataSlot::TzDatetime: type = NUdf::EDataSlot::Datetime; break;
  2771. case NUdf::EDataSlot::TzTimestamp: type = NUdf::EDataSlot::Timestamp; break;
  2772. case NUdf::EDataSlot::TzDate32: type = NUdf::EDataSlot::Date32; break;
  2773. case NUdf::EDataSlot::TzDatetime64: type = NUdf::EDataSlot::Datetime64; break;
  2774. case NUdf::EDataSlot::TzTimestamp64: type = NUdf::EDataSlot::Timestamp64; break;
  2775. default:
  2776. ythrow yexception() << "Unknown date with timezone type: " << *dataType1->GetDataSlot();
  2777. }
  2778. return Convert(local, NewDataType(type, isOptional1));
  2779. }
  2780. TRuntimeNode TProgramBuilder::Nth(TRuntimeNode tuple, ui32 index) {
  2781. bool isOptional;
  2782. const auto type = AS_TYPE(TTupleType, UnpackOptional(tuple.GetStaticType(), isOptional));
  2783. MKQL_ENSURE(index < type->GetElementsCount(), "Index out of range: " << index <<
  2784. " is not less than " << type->GetElementsCount());
  2785. auto itemType = type->GetElementType(index);
  2786. if (isOptional && !itemType->IsOptional() && !itemType->IsNull() && !itemType->IsPg()) {
  2787. itemType = TOptionalType::Create(itemType, Env);
  2788. }
  2789. TCallableBuilder callableBuilder(Env, __func__, itemType);
  2790. callableBuilder.Add(tuple);
  2791. callableBuilder.Add(NewDataLiteral<ui32>(index));
  2792. return TRuntimeNode(callableBuilder.Build(), false);
  2793. }
  2794. TRuntimeNode TProgramBuilder::Element(TRuntimeNode tuple, ui32 index) {
  2795. return Nth(tuple, index);
  2796. }
  2797. TRuntimeNode TProgramBuilder::Guess(TRuntimeNode variant, ui32 tupleIndex) {
  2798. bool isOptional;
  2799. auto unpacked = UnpackOptional(variant, isOptional);
  2800. auto type = AS_TYPE(TVariantType, unpacked);
  2801. auto underlyingType = AS_TYPE(TTupleType, type->GetUnderlyingType());
  2802. MKQL_ENSURE(tupleIndex < underlyingType->GetElementsCount(), "Wrong tuple index");
  2803. auto resType = TOptionalType::Create(underlyingType->GetElementType(tupleIndex), Env);
  2804. TCallableBuilder callableBuilder(Env, __func__, resType);
  2805. callableBuilder.Add(variant);
  2806. callableBuilder.Add(NewDataLiteral<ui32>(tupleIndex));
  2807. return TRuntimeNode(callableBuilder.Build(), false);
  2808. }
  2809. TRuntimeNode TProgramBuilder::Guess(TRuntimeNode variant, const std::string_view& memberName) {
  2810. bool isOptional;
  2811. auto unpacked = UnpackOptional(variant, isOptional);
  2812. auto type = AS_TYPE(TVariantType, unpacked);
  2813. auto underlyingType = AS_TYPE(TStructType, type->GetUnderlyingType());
  2814. auto structIndex = underlyingType->GetMemberIndex(memberName);
  2815. auto resType = TOptionalType::Create(underlyingType->GetMemberType(structIndex), Env);
  2816. TCallableBuilder callableBuilder(Env, __func__, resType);
  2817. callableBuilder.Add(variant);
  2818. callableBuilder.Add(NewDataLiteral<ui32>(structIndex));
  2819. return TRuntimeNode(callableBuilder.Build(), false);
  2820. }
  2821. TRuntimeNode TProgramBuilder::Way(TRuntimeNode variant) {
  2822. bool isOptional;
  2823. auto unpacked = UnpackOptional(variant, isOptional);
  2824. auto type = AS_TYPE(TVariantType, unpacked);
  2825. auto underlyingType = type->GetUnderlyingType();
  2826. auto dataType = NewDataType(underlyingType->IsTuple() ? NUdf::EDataSlot::Uint32 : NUdf::EDataSlot::Utf8);
  2827. auto resType = isOptional ? TOptionalType::Create(dataType, Env) : dataType;
  2828. TCallableBuilder callableBuilder(Env, __func__, resType);
  2829. callableBuilder.Add(variant);
  2830. return TRuntimeNode(callableBuilder.Build(), false);
  2831. }
  2832. TRuntimeNode TProgramBuilder::VariantItem(TRuntimeNode variant) {
  2833. bool isOptional;
  2834. auto unpacked = UnpackOptional(variant, isOptional);
  2835. auto type = AS_TYPE(TVariantType, unpacked);
  2836. auto underlyingType = type->GetAlternativeType(0);
  2837. auto resType = isOptional ? TOptionalType::Create(underlyingType, Env) : underlyingType;
  2838. TCallableBuilder callableBuilder(Env, __func__, resType);
  2839. callableBuilder.Add(variant);
  2840. return TRuntimeNode(callableBuilder.Build(), false);
  2841. }
  2842. TRuntimeNode TProgramBuilder::DynamicVariant(TRuntimeNode item, TRuntimeNode index, TType* variantType) {
  2843. if constexpr (RuntimeVersion < 56U) {
  2844. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  2845. }
  2846. auto type = AS_TYPE(TVariantType, variantType);
  2847. auto expectedIndexSlot = type->GetUnderlyingType()->IsTuple() ? NUdf::EDataSlot::Uint32 : NUdf::EDataSlot::Utf8;
  2848. bool isOptional;
  2849. auto indexType = UnpackOptionalData(index.GetStaticType(), isOptional);
  2850. MKQL_ENSURE(indexType->GetDataSlot() == expectedIndexSlot, "Mismatch type of index");
  2851. auto resType = TOptionalType::Create(type, Env);
  2852. TCallableBuilder callableBuilder(Env, __func__, resType);
  2853. callableBuilder.Add(item);
  2854. callableBuilder.Add(index);
  2855. callableBuilder.Add(TRuntimeNode(variantType, true));
  2856. return TRuntimeNode(callableBuilder.Build(), false);
  2857. }
  2858. TRuntimeNode TProgramBuilder::VisitAll(TRuntimeNode variant, std::function<TRuntimeNode(ui32, TRuntimeNode)> handler) {
  2859. const auto type = AS_TYPE(TVariantType, variant);
  2860. std::vector<TRuntimeNode> items;
  2861. std::vector<TRuntimeNode> newItems;
  2862. for (ui32 i = 0; i < type->GetAlternativesCount(); ++i) {
  2863. const auto itemType = type->GetAlternativeType(i);
  2864. const auto itemArg = Arg(itemType);
  2865. const auto res = handler(i, itemArg);
  2866. items.emplace_back(itemArg);
  2867. newItems.emplace_back(res);
  2868. }
  2869. bool hasOptional;
  2870. const auto firstUnpacked = UnpackOptional(newItems.front(), hasOptional);
  2871. bool allOptional = hasOptional;
  2872. for (size_t i = 1U; i < newItems.size(); ++i) {
  2873. bool isOptional;
  2874. const auto unpacked = UnpackOptional(newItems[i].GetStaticType(), isOptional);
  2875. MKQL_ENSURE(unpacked->IsSameType(*firstUnpacked), "Different return types in branches.");
  2876. hasOptional = hasOptional || isOptional;
  2877. allOptional = allOptional && isOptional;
  2878. }
  2879. if (hasOptional && !allOptional) {
  2880. for (auto& item : newItems) {
  2881. if (!item.GetStaticType()->IsOptional()) {
  2882. item = NewOptional(item);
  2883. }
  2884. }
  2885. }
  2886. TCallableBuilder callableBuilder(Env, __func__, newItems.front().GetStaticType());
  2887. callableBuilder.Add(variant);
  2888. for (ui32 i = 0; i < type->GetAlternativesCount(); ++i) {
  2889. callableBuilder.Add(items[i]);
  2890. callableBuilder.Add(newItems[i]);
  2891. }
  2892. return TRuntimeNode(callableBuilder.Build(), false);
  2893. }
  2894. TRuntimeNode TProgramBuilder::UnaryDataFunction(TRuntimeNode data, const std::string_view& callableName, ui32 flags) {
  2895. bool isOptional;
  2896. auto type = UnpackOptionalData(data, isOptional);
  2897. if (!(flags & TDataFunctionFlags::AllowOptionalArgs)) {
  2898. MKQL_ENSURE(!isOptional, "Optional data is not allowed");
  2899. }
  2900. auto schemeType = type->GetSchemeType();
  2901. if (flags & TDataFunctionFlags::RequiresBooleanArgs) {
  2902. MKQL_ENSURE(schemeType == NUdf::TDataType<bool>::Id, "Boolean data is required");
  2903. } else if (flags & TDataFunctionFlags::RequiresStringArgs) {
  2904. MKQL_ENSURE(schemeType == NUdf::TDataType<char*>::Id, "String data is required");
  2905. }
  2906. if (!schemeType) {
  2907. MKQL_ENSURE((flags & TDataFunctionFlags::AllowNull) != 0, "Null is not allowed");
  2908. }
  2909. TType* resultType;
  2910. if (flags & TDataFunctionFlags::HasBooleanResult) {
  2911. resultType = TDataType::Create(NUdf::TDataType<bool>::Id, Env);
  2912. } else if (flags & TDataFunctionFlags::HasUi32Result) {
  2913. resultType = TDataType::Create(NUdf::TDataType<ui32>::Id, Env);
  2914. } else if (flags & TDataFunctionFlags::HasStringResult) {
  2915. resultType = TDataType::Create(NUdf::TDataType<char*>::Id, Env);
  2916. } else if (flags & TDataFunctionFlags::HasOptionalResult) {
  2917. resultType = TOptionalType::Create(type, Env);
  2918. } else {
  2919. resultType = type;
  2920. }
  2921. if ((flags & TDataFunctionFlags::CommonOptionalResult) && isOptional) {
  2922. resultType = TOptionalType::Create(resultType, Env);
  2923. }
  2924. TCallableBuilder callableBuilder(Env, callableName, resultType);
  2925. callableBuilder.Add(data);
  2926. return TRuntimeNode(callableBuilder.Build(), false);
  2927. }
  2928. TRuntimeNode TProgramBuilder::ToDict(TRuntimeNode list, bool multi, const TUnaryLambda& keySelector,
  2929. const TUnaryLambda& payloadSelector, std::string_view callableName, bool isCompact, ui64 itemsCountHint)
  2930. {
  2931. bool isOptional;
  2932. const auto type = UnpackOptional(list, isOptional);
  2933. MKQL_ENSURE(type->IsList(), "Expected list.");
  2934. if (isOptional) {
  2935. return Map(list, [&](TRuntimeNode unpacked) { return ToDict(unpacked, multi, keySelector, payloadSelector, callableName, isCompact, itemsCountHint); } );
  2936. }
  2937. const auto itemType = AS_TYPE(TListType, type)->GetItemType();
  2938. ThrowIfListOfVoid(itemType);
  2939. const auto itemArg = Arg(itemType);
  2940. const auto key = keySelector(itemArg);
  2941. const auto keyType = key.GetStaticType();
  2942. auto payload = payloadSelector(itemArg);
  2943. auto payloadType = payload.GetStaticType();
  2944. if (multi) {
  2945. payloadType = TListType::Create(payloadType, Env);
  2946. }
  2947. auto dictType = TDictType::Create(keyType, payloadType, Env);
  2948. TCallableBuilder callableBuilder(Env, callableName, dictType);
  2949. callableBuilder.Add(list);
  2950. callableBuilder.Add(itemArg);
  2951. callableBuilder.Add(key);
  2952. callableBuilder.Add(payload);
  2953. callableBuilder.Add(NewDataLiteral(multi));
  2954. callableBuilder.Add(NewDataLiteral(isCompact));
  2955. callableBuilder.Add(NewDataLiteral(itemsCountHint));
  2956. return TRuntimeNode(callableBuilder.Build(), false);
  2957. }
  2958. TRuntimeNode TProgramBuilder::SqueezeToDict(TRuntimeNode stream, bool multi, const TUnaryLambda& keySelector,
  2959. const TUnaryLambda& payloadSelector, std::string_view callableName, bool isCompact, ui64 itemsCountHint)
  2960. {
  2961. if constexpr (RuntimeVersion < 21U) {
  2962. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  2963. }
  2964. const auto type = stream.GetStaticType();
  2965. MKQL_ENSURE(type->IsStream() || type->IsFlow(), "Expected stream or flow.");
  2966. const auto itemType = type->IsFlow() ? AS_TYPE(TFlowType, type)->GetItemType() : AS_TYPE(TStreamType, type)->GetItemType();
  2967. ThrowIfListOfVoid(itemType);
  2968. const auto itemArg = Arg(itemType);
  2969. const auto key = keySelector(itemArg);
  2970. const auto keyType = key.GetStaticType();
  2971. auto payload = payloadSelector(itemArg);
  2972. auto payloadType = payload.GetStaticType();
  2973. if (multi) {
  2974. payloadType = TListType::Create(payloadType, Env);
  2975. }
  2976. auto dictType = TDictType::Create(keyType, payloadType, Env);
  2977. auto returnType = type->IsFlow()
  2978. ? (TType*) TFlowType::Create(dictType, Env)
  2979. : (TType*) TStreamType::Create(dictType, Env);
  2980. TCallableBuilder callableBuilder(Env, callableName, returnType);
  2981. callableBuilder.Add(stream);
  2982. callableBuilder.Add(itemArg);
  2983. callableBuilder.Add(key);
  2984. callableBuilder.Add(payload);
  2985. callableBuilder.Add(NewDataLiteral(multi));
  2986. callableBuilder.Add(NewDataLiteral(isCompact));
  2987. callableBuilder.Add(NewDataLiteral(itemsCountHint));
  2988. return TRuntimeNode(callableBuilder.Build(), false);
  2989. }
  2990. TRuntimeNode TProgramBuilder::NarrowSqueezeToDict(TRuntimeNode flow, bool multi, const TNarrowLambda& keySelector,
  2991. const TNarrowLambda& payloadSelector, std::string_view callableName, bool isCompact, ui64 itemsCountHint)
  2992. {
  2993. if constexpr (RuntimeVersion < 23U) {
  2994. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  2995. }
  2996. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  2997. TRuntimeNode::TList itemArgs;
  2998. itemArgs.reserve(wideComponents.size());
  2999. auto i = 0U;
  3000. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3001. const auto key = keySelector(itemArgs);
  3002. const auto keyType = key.GetStaticType();
  3003. auto payload = payloadSelector(itemArgs);
  3004. auto payloadType = payload.GetStaticType();
  3005. if (multi) {
  3006. payloadType = TListType::Create(payloadType, Env);
  3007. }
  3008. const auto dictType = TDictType::Create(keyType, payloadType, Env);
  3009. const auto returnType = TFlowType::Create(dictType, Env);
  3010. TCallableBuilder callableBuilder(Env, callableName, returnType);
  3011. callableBuilder.Add(flow);
  3012. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3013. callableBuilder.Add(key);
  3014. callableBuilder.Add(payload);
  3015. callableBuilder.Add(NewDataLiteral(multi));
  3016. callableBuilder.Add(NewDataLiteral(isCompact));
  3017. callableBuilder.Add(NewDataLiteral(itemsCountHint));
  3018. return TRuntimeNode(callableBuilder.Build(), false);
  3019. }
  3020. void TProgramBuilder::ThrowIfListOfVoid(TType* type) {
  3021. MKQL_ENSURE(!VoidWithEffects || !type->IsVoid(), "List of void is forbidden for current function");
  3022. }
  3023. TRuntimeNode TProgramBuilder::BuildFlatMap(const std::string_view& callableName, TRuntimeNode list, const TUnaryLambda& handler)
  3024. {
  3025. const auto listType = list.GetStaticType();
  3026. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsOptional() || listType->IsStream(), "Expected flow, list, stream or optional");
  3027. if (listType->IsOptional()) {
  3028. const auto itemArg = Arg(AS_TYPE(TOptionalType, listType)->GetItemType());
  3029. const auto newList = handler(itemArg);
  3030. const auto type = newList.GetStaticType();
  3031. MKQL_ENSURE(type->IsList() || type->IsOptional() || type->IsStream() || type->IsFlow(), "Expected flow, list, stream or optional");
  3032. return IfPresent(list, [&](TRuntimeNode item) {
  3033. return handler(item);
  3034. }, type->IsOptional() ? NewEmptyOptional(type) : type->IsList() ? NewEmptyList(AS_TYPE(TListType, type)->GetItemType()) : EmptyIterator(type));
  3035. }
  3036. const auto itemType = listType->IsFlow() ?
  3037. AS_TYPE(TFlowType, listType)->GetItemType():
  3038. listType->IsList() ?
  3039. AS_TYPE(TListType, listType)->GetItemType():
  3040. AS_TYPE(TStreamType, listType)->GetItemType();
  3041. ThrowIfListOfVoid(itemType);
  3042. const auto itemArg = Arg(itemType);
  3043. const auto newList = handler(itemArg);
  3044. const auto type = newList.GetStaticType();
  3045. TType* retItemType = nullptr;
  3046. if (type->IsOptional()) {
  3047. retItemType = AS_TYPE(TOptionalType, type)->GetItemType();
  3048. } else if (type->IsFlow()) {
  3049. retItemType = AS_TYPE(TFlowType, type)->GetItemType();
  3050. } else if (type->IsList()) {
  3051. retItemType = AS_TYPE(TListType, type)->GetItemType();
  3052. } else if (type->IsStream()) {
  3053. retItemType = AS_TYPE(TStreamType, type)->GetItemType();
  3054. } else {
  3055. THROW yexception() << "Expected flow, list or stream.";
  3056. }
  3057. const auto resultListType = listType->IsFlow() || type->IsFlow() ?
  3058. TFlowType::Create(retItemType, Env):
  3059. listType->IsList() ?
  3060. (TType*)TListType::Create(retItemType, Env):
  3061. (TType*)TStreamType::Create(retItemType, Env);
  3062. TCallableBuilder callableBuilder(Env, callableName, resultListType);
  3063. callableBuilder.Add(list);
  3064. callableBuilder.Add(itemArg);
  3065. callableBuilder.Add(newList);
  3066. return TRuntimeNode(callableBuilder.Build(), false);
  3067. }
  3068. TRuntimeNode TProgramBuilder::MultiMap(TRuntimeNode list, const TExpandLambda& handler)
  3069. {
  3070. if constexpr (RuntimeVersion < 16U) {
  3071. const auto single = [=](TRuntimeNode item) -> TRuntimeNode {
  3072. const auto newList = handler(item);
  3073. const auto retItemType = newList.front().GetStaticType();
  3074. MKQL_ENSURE(retItemType->IsSameType(*newList.back().GetStaticType()), "Must be same type.");
  3075. return NewList(retItemType, newList);
  3076. };
  3077. return OrderedFlatMap(list, single);
  3078. }
  3079. const auto listType = list.GetStaticType();
  3080. MKQL_ENSURE(listType->IsFlow() || listType->IsList(), "Expected flow, list, stream or optional");
  3081. const auto itemType = listType->IsFlow() ? AS_TYPE(TFlowType, listType)->GetItemType() : AS_TYPE(TListType, listType)->GetItemType();
  3082. const auto itemArg = Arg(itemType);
  3083. const auto newList = handler(itemArg);
  3084. MKQL_ENSURE(newList.size() > 1U, "Expected many items.");
  3085. const auto retItemType = newList.front().GetStaticType();
  3086. MKQL_ENSURE(retItemType->IsSameType(*newList.back().GetStaticType()), "Must be same type.");
  3087. const auto resultListType = listType->IsFlow() ?
  3088. (TType*)TFlowType::Create(retItemType, Env) : (TType*)TListType::Create(retItemType, Env);
  3089. TCallableBuilder callableBuilder(Env, __func__, resultListType);
  3090. callableBuilder.Add(list);
  3091. callableBuilder.Add(itemArg);
  3092. std::for_each(newList.cbegin(), newList.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3093. return TRuntimeNode(callableBuilder.Build(), false);
  3094. }
  3095. TRuntimeNode TProgramBuilder::NarrowMultiMap(TRuntimeNode flow, const TWideLambda& handler) {
  3096. if constexpr (RuntimeVersion < 18U) {
  3097. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3098. }
  3099. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3100. TRuntimeNode::TList itemArgs;
  3101. itemArgs.reserve(wideComponents.size());
  3102. auto i = 0U;
  3103. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3104. const auto newList = handler(itemArgs);
  3105. MKQL_ENSURE(newList.size() > 1U, "Expected many items.");
  3106. const auto retItemType = newList.front().GetStaticType();
  3107. MKQL_ENSURE(retItemType->IsSameType(*newList.back().GetStaticType()), "Must be same type.");
  3108. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(newList.front().GetStaticType()));
  3109. callableBuilder.Add(flow);
  3110. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3111. std::for_each(newList.cbegin(), newList.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3112. return TRuntimeNode(callableBuilder.Build(), false);
  3113. }
  3114. TRuntimeNode TProgramBuilder::ExpandMap(TRuntimeNode flow, const TExpandLambda& handler) {
  3115. if constexpr (RuntimeVersion < 18U) {
  3116. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3117. }
  3118. const auto itemType = AS_TYPE(TFlowType, flow.GetStaticType())->GetItemType();
  3119. const auto itemArg = Arg(itemType);
  3120. const auto newItems = handler(itemArg);
  3121. std::vector<TType*> tupleItems;
  3122. tupleItems.reserve(newItems.size());
  3123. std::transform(newItems.cbegin(), newItems.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  3124. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(tupleItems)));
  3125. callableBuilder.Add(flow);
  3126. callableBuilder.Add(itemArg);
  3127. std::for_each(newItems.cbegin(), newItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3128. return TRuntimeNode(callableBuilder.Build(), false);
  3129. }
  3130. TRuntimeNode TProgramBuilder::WideMap(TRuntimeNode flow, const TWideLambda& handler) {
  3131. if constexpr (RuntimeVersion < 18U) {
  3132. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3133. }
  3134. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3135. TRuntimeNode::TList itemArgs;
  3136. itemArgs.reserve(wideComponents.size());
  3137. auto i = 0U;
  3138. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3139. const auto newItems = handler(itemArgs);
  3140. std::vector<TType*> tupleItems;
  3141. tupleItems.reserve(newItems.size());
  3142. std::transform(newItems.cbegin(), newItems.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  3143. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(tupleItems)));
  3144. callableBuilder.Add(flow);
  3145. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3146. std::for_each(newItems.cbegin(), newItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3147. return TRuntimeNode(callableBuilder.Build(), false);
  3148. }
  3149. TRuntimeNode TProgramBuilder::WideChain1Map(TRuntimeNode flow, const TWideLambda& init, const TBinaryWideLambda& update) {
  3150. if constexpr (RuntimeVersion < 23U) {
  3151. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3152. }
  3153. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3154. TRuntimeNode::TList inputArgs;
  3155. inputArgs.reserve(wideComponents.size());
  3156. auto i = 0U;
  3157. std::generate_n(std::back_inserter(inputArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3158. const auto initItems = init(inputArgs);
  3159. std::vector<TType*> tupleItems;
  3160. tupleItems.reserve(initItems.size());
  3161. std::transform(initItems.cbegin(), initItems.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  3162. TRuntimeNode::TList outputArgs;
  3163. outputArgs.reserve(tupleItems.size());
  3164. std::transform(tupleItems.cbegin(), tupleItems.cend(), std::back_inserter(outputArgs), std::bind(&TProgramBuilder::Arg, this, std::placeholders::_1));
  3165. const auto updateItems = update(inputArgs, outputArgs);
  3166. MKQL_ENSURE(initItems.size() == updateItems.size(), "Expected same width.");
  3167. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(tupleItems)));
  3168. callableBuilder.Add(flow);
  3169. std::for_each(inputArgs.cbegin(), inputArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3170. std::for_each(initItems.cbegin(), initItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3171. std::for_each(outputArgs.cbegin(), outputArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3172. std::for_each(updateItems.cbegin(), updateItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3173. return TRuntimeNode(callableBuilder.Build(), false);
  3174. }
  3175. TRuntimeNode TProgramBuilder::NarrowMap(TRuntimeNode flow, const TNarrowLambda& handler) {
  3176. if constexpr (RuntimeVersion < 18U) {
  3177. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3178. }
  3179. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3180. TRuntimeNode::TList itemArgs;
  3181. itemArgs.reserve(wideComponents.size());
  3182. auto i = 0U;
  3183. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3184. const auto newItem = handler(itemArgs);
  3185. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(newItem.GetStaticType()));
  3186. callableBuilder.Add(flow);
  3187. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3188. callableBuilder.Add(newItem);
  3189. return TRuntimeNode(callableBuilder.Build(), false);
  3190. }
  3191. TRuntimeNode TProgramBuilder::NarrowFlatMap(TRuntimeNode flow, const TNarrowLambda& handler) {
  3192. if constexpr (RuntimeVersion < 18U) {
  3193. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3194. }
  3195. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3196. TRuntimeNode::TList itemArgs;
  3197. itemArgs.reserve(wideComponents.size());
  3198. auto i = 0U;
  3199. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3200. const auto newList = handler(itemArgs);
  3201. const auto type = newList.GetStaticType();
  3202. TType* retItemType = nullptr;
  3203. if (type->IsOptional()) {
  3204. retItemType = AS_TYPE(TOptionalType, type)->GetItemType();
  3205. } else if (type->IsFlow()) {
  3206. retItemType = AS_TYPE(TFlowType, type)->GetItemType();
  3207. } else if (type->IsList()) {
  3208. retItemType = AS_TYPE(TListType, type)->GetItemType();
  3209. } else if (type->IsStream()) {
  3210. retItemType = AS_TYPE(TStreamType, type)->GetItemType();
  3211. } else {
  3212. THROW yexception() << "Expected flow, list or stream.";
  3213. }
  3214. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(retItemType));
  3215. callableBuilder.Add(flow);
  3216. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3217. callableBuilder.Add(newList);
  3218. return TRuntimeNode(callableBuilder.Build(), false);
  3219. }
  3220. TRuntimeNode TProgramBuilder::BuildWideFilter(const std::string_view& callableName, TRuntimeNode flow, const TNarrowLambda& handler) {
  3221. if constexpr (RuntimeVersion < 18U) {
  3222. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3223. }
  3224. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3225. TRuntimeNode::TList itemArgs;
  3226. itemArgs.reserve(wideComponents.size());
  3227. auto i = 0U;
  3228. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3229. const auto predicate = handler(itemArgs);
  3230. TCallableBuilder callableBuilder(Env, callableName, flow.GetStaticType());
  3231. callableBuilder.Add(flow);
  3232. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3233. callableBuilder.Add(predicate);
  3234. return TRuntimeNode(callableBuilder.Build(), false);
  3235. }
  3236. TRuntimeNode TProgramBuilder::WideFilter(TRuntimeNode flow, const TNarrowLambda& handler) {
  3237. return BuildWideFilter(__func__, flow, handler);
  3238. }
  3239. TRuntimeNode TProgramBuilder::WideTakeWhile(TRuntimeNode flow, const TNarrowLambda& handler) {
  3240. return BuildWideFilter(__func__, flow, handler);
  3241. }
  3242. TRuntimeNode TProgramBuilder::WideSkipWhile(TRuntimeNode flow, const TNarrowLambda& handler) {
  3243. return BuildWideFilter(__func__, flow, handler);
  3244. }
  3245. TRuntimeNode TProgramBuilder::WideTakeWhileInclusive(TRuntimeNode flow, const TNarrowLambda& handler) {
  3246. return BuildWideFilter(__func__, flow, handler);
  3247. }
  3248. TRuntimeNode TProgramBuilder::WideSkipWhileInclusive(TRuntimeNode flow, const TNarrowLambda& handler) {
  3249. return BuildWideFilter(__func__, flow, handler);
  3250. }
  3251. TRuntimeNode TProgramBuilder::WideFilter(TRuntimeNode flow, TRuntimeNode limit, const TNarrowLambda& handler) {
  3252. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3253. TRuntimeNode::TList itemArgs;
  3254. itemArgs.reserve(wideComponents.size());
  3255. auto i = 0U;
  3256. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3257. const auto predicate = handler(itemArgs);
  3258. TCallableBuilder callableBuilder(Env, __func__, flow.GetStaticType());
  3259. callableBuilder.Add(flow);
  3260. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3261. callableBuilder.Add(predicate);
  3262. callableBuilder.Add(limit);
  3263. return TRuntimeNode(callableBuilder.Build(), false);
  3264. }
  3265. TRuntimeNode TProgramBuilder::BuildFilter(const std::string_view& callableName, TRuntimeNode list, const TUnaryLambda& handler, TType* resultType)
  3266. {
  3267. const auto listType = list.GetStaticType();
  3268. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream.");
  3269. const auto outputType = resultType ? resultType : listType;
  3270. const auto itemType = listType->IsFlow() ?
  3271. AS_TYPE(TFlowType, listType)->GetItemType():
  3272. listType->IsList() ?
  3273. AS_TYPE(TListType, listType)->GetItemType():
  3274. AS_TYPE(TStreamType, listType)->GetItemType();
  3275. ThrowIfListOfVoid(itemType);
  3276. const auto itemArg = Arg(itemType);
  3277. const auto predicate = handler(itemArg);
  3278. MKQL_ENSURE(predicate.GetStaticType()->IsData(), "Expected boolean data");
  3279. const auto& detailedPredicateType = static_cast<const TDataType&>(*predicate.GetStaticType());
  3280. MKQL_ENSURE(detailedPredicateType.GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected boolean data");
  3281. TCallableBuilder callableBuilder(Env, callableName, outputType);
  3282. callableBuilder.Add(list);
  3283. callableBuilder.Add(itemArg);
  3284. callableBuilder.Add(predicate);
  3285. return TRuntimeNode(callableBuilder.Build(), false);
  3286. }
  3287. TRuntimeNode TProgramBuilder::BuildFilter(const std::string_view& callableName, TRuntimeNode list, TRuntimeNode limit, const TUnaryLambda& handler, TType* resultType)
  3288. {
  3289. if constexpr (RuntimeVersion < 4U) {
  3290. return Take(BuildFilter(callableName, list, handler, resultType), limit);
  3291. }
  3292. const auto listType = list.GetStaticType();
  3293. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream.");
  3294. MKQL_ENSURE(limit.GetStaticType()->IsData(), "Expected data");
  3295. const auto outputType = resultType ? resultType : listType;
  3296. const auto itemType = listType->IsFlow() ?
  3297. AS_TYPE(TFlowType, listType)->GetItemType():
  3298. listType->IsList() ?
  3299. AS_TYPE(TListType, listType)->GetItemType():
  3300. AS_TYPE(TStreamType, listType)->GetItemType();
  3301. ThrowIfListOfVoid(itemType);
  3302. const auto itemArg = Arg(itemType);
  3303. const auto predicate = handler(itemArg);
  3304. MKQL_ENSURE(predicate.GetStaticType()->IsData(), "Expected boolean data");
  3305. const auto& detailedPredicateType = static_cast<const TDataType&>(*predicate.GetStaticType());
  3306. MKQL_ENSURE(detailedPredicateType.GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected boolean data");
  3307. TCallableBuilder callableBuilder(Env, callableName, outputType);
  3308. callableBuilder.Add(list);
  3309. callableBuilder.Add(limit);
  3310. callableBuilder.Add(itemArg);
  3311. callableBuilder.Add(predicate);
  3312. return TRuntimeNode(callableBuilder.Build(), false);
  3313. }
  3314. TRuntimeNode TProgramBuilder::Filter(TRuntimeNode list, const TUnaryLambda& handler, TType* resultType)
  3315. {
  3316. const auto type = list.GetStaticType();
  3317. if (type->IsOptional()) {
  3318. return
  3319. IfPresent(list,
  3320. [&](TRuntimeNode item) {
  3321. return If(handler(item), item, NewEmptyOptional(resultType), resultType);
  3322. },
  3323. NewEmptyOptional(resultType)
  3324. );
  3325. }
  3326. return BuildFilter(__func__, list, handler, resultType);
  3327. }
  3328. TRuntimeNode TProgramBuilder::BuildHeap(const std::string_view& callableName, TRuntimeNode list, const TBinaryLambda& comparator) {
  3329. const auto listType = list.GetStaticType();
  3330. MKQL_ENSURE(listType->IsList(), "Expected list.");
  3331. const auto itemType = AS_TYPE(TListType, listType)->GetItemType();
  3332. const auto leftArg = Arg(itemType);
  3333. const auto rightArg = Arg(itemType);
  3334. const auto predicate = comparator(leftArg, rightArg);
  3335. TCallableBuilder callableBuilder(Env, callableName, listType);
  3336. callableBuilder.Add(list);
  3337. callableBuilder.Add(leftArg);
  3338. callableBuilder.Add(rightArg);
  3339. callableBuilder.Add(predicate);
  3340. return TRuntimeNode(callableBuilder.Build(), false);
  3341. }
  3342. TRuntimeNode TProgramBuilder::BuildNth(const std::string_view& callableName, TRuntimeNode list, TRuntimeNode n, const TBinaryLambda& comparator) {
  3343. const auto listType = list.GetStaticType();
  3344. MKQL_ENSURE(listType->IsList(), "Expected list.");
  3345. const auto itemType = AS_TYPE(TListType, listType)->GetItemType();
  3346. MKQL_ENSURE(n.GetStaticType()->IsData(), "Expected data");
  3347. MKQL_ENSURE(static_cast<const TDataType&>(*n.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  3348. const auto leftArg = Arg(itemType);
  3349. const auto rightArg = Arg(itemType);
  3350. const auto predicate = comparator(leftArg, rightArg);
  3351. TCallableBuilder callableBuilder(Env, callableName, listType);
  3352. callableBuilder.Add(list);
  3353. callableBuilder.Add(n);
  3354. callableBuilder.Add(leftArg);
  3355. callableBuilder.Add(rightArg);
  3356. callableBuilder.Add(predicate);
  3357. return TRuntimeNode(callableBuilder.Build(), false);
  3358. }
  3359. TRuntimeNode TProgramBuilder::MakeHeap(TRuntimeNode list, const TBinaryLambda& comparator) {
  3360. return BuildHeap(__func__, list, std::move(comparator));
  3361. }
  3362. TRuntimeNode TProgramBuilder::PushHeap(TRuntimeNode list, const TBinaryLambda& comparator) {
  3363. return BuildHeap(__func__, list, std::move(comparator));
  3364. }
  3365. TRuntimeNode TProgramBuilder::PopHeap(TRuntimeNode list, const TBinaryLambda& comparator) {
  3366. return BuildHeap(__func__, list, std::move(comparator));
  3367. }
  3368. TRuntimeNode TProgramBuilder::SortHeap(TRuntimeNode list, const TBinaryLambda& comparator) {
  3369. return BuildHeap(__func__, list, std::move(comparator));
  3370. }
  3371. TRuntimeNode TProgramBuilder::StableSort(TRuntimeNode list, const TBinaryLambda& comparator) {
  3372. return BuildHeap(__func__, list, std::move(comparator));
  3373. }
  3374. TRuntimeNode TProgramBuilder::NthElement(TRuntimeNode list, TRuntimeNode n, const TBinaryLambda& comparator) {
  3375. return BuildNth(__func__, list, n, std::move(comparator));
  3376. }
  3377. TRuntimeNode TProgramBuilder::PartialSort(TRuntimeNode list, TRuntimeNode n, const TBinaryLambda& comparator) {
  3378. return BuildNth(__func__, list, n, std::move(comparator));
  3379. }
  3380. TRuntimeNode TProgramBuilder::BuildMap(const std::string_view& callableName, TRuntimeNode list, const TUnaryLambda& handler)
  3381. {
  3382. const auto listType = list.GetStaticType();
  3383. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream() || listType->IsOptional(), "Expected flow, list, stream or optional");
  3384. if (listType->IsOptional()) {
  3385. const auto itemArg = Arg(AS_TYPE(TOptionalType, listType)->GetItemType());
  3386. const auto newItem = handler(itemArg);
  3387. return IfPresent(list,
  3388. [&](TRuntimeNode item) { return NewOptional(handler(item)); },
  3389. NewEmptyOptional(NewOptionalType(newItem.GetStaticType()))
  3390. );
  3391. }
  3392. const auto itemType = listType->IsFlow() ?
  3393. AS_TYPE(TFlowType, listType)->GetItemType():
  3394. listType->IsList() ?
  3395. AS_TYPE(TListType, listType)->GetItemType():
  3396. AS_TYPE(TStreamType, listType)->GetItemType();
  3397. ThrowIfListOfVoid(itemType);
  3398. const auto itemArg = Arg(itemType);
  3399. const auto newItem = handler(itemArg);
  3400. const auto resultListType = listType->IsFlow() ?
  3401. (TType*)TFlowType::Create(newItem.GetStaticType(), Env):
  3402. listType->IsList() ?
  3403. (TType*)TListType::Create(newItem.GetStaticType(), Env):
  3404. (TType*)TStreamType::Create(newItem.GetStaticType(), Env);
  3405. TCallableBuilder callableBuilder(Env, callableName, resultListType);
  3406. callableBuilder.Add(list);
  3407. callableBuilder.Add(itemArg);
  3408. callableBuilder.Add(newItem);
  3409. return TRuntimeNode(callableBuilder.Build(), false);
  3410. }
  3411. TRuntimeNode TProgramBuilder::Invoke(const std::string_view& funcName, TType* resultType, const TArrayRef<const TRuntimeNode>& args) {
  3412. MKQL_ENSURE(args.size() >= 1U && args.size() <= 3U, "Expected from one to three arguments.");
  3413. std::array<TArgType, 4U> argTypes;
  3414. argTypes.front().first = UnpackOptionalData(resultType, argTypes.front().second)->GetSchemeType();
  3415. auto i = 0U;
  3416. for (const auto& arg : args) {
  3417. ++i;
  3418. argTypes[i].first = UnpackOptionalData(arg, argTypes[i].second)->GetSchemeType();
  3419. }
  3420. FunctionRegistry.GetBuiltins()->GetBuiltin(funcName, argTypes.data(), 1U + args.size());
  3421. TCallableBuilder callableBuilder(Env, __func__, resultType);
  3422. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(funcName));
  3423. for (const auto& arg : args) {
  3424. callableBuilder.Add(arg);
  3425. }
  3426. return TRuntimeNode(callableBuilder.Build(), false);
  3427. }
  3428. TRuntimeNode TProgramBuilder::Udf(
  3429. const std::string_view& funcName,
  3430. TRuntimeNode runConfig,
  3431. TType* userType,
  3432. const std::string_view& typeConfig
  3433. )
  3434. {
  3435. TRuntimeNode userTypeNode = userType ? TRuntimeNode(userType, true) : TRuntimeNode(Env.GetVoidLazy()->GetType(), true);
  3436. const ui32 flags = NUdf::IUdfModule::TFlags::TypesOnly;
  3437. if (!TypeInfoHelper) {
  3438. TypeInfoHelper = new TTypeInfoHelper();
  3439. }
  3440. TFunctionTypeInfo funcInfo;
  3441. TStatus status = FunctionRegistry.FindFunctionTypeInfo(
  3442. Env, TypeInfoHelper, nullptr, funcName, userType, typeConfig, flags, {}, nullptr, &funcInfo);
  3443. MKQL_ENSURE(status.IsOk(), status.GetError());
  3444. auto runConfigType = funcInfo.RunConfigType;
  3445. if (runConfig) {
  3446. bool typesMatch = runConfigType->IsSameType(*runConfig.GetStaticType());
  3447. MKQL_ENSURE(typesMatch, "RunConfig type mismatch");
  3448. } else {
  3449. MKQL_ENSURE(runConfigType->IsVoid() || runConfigType->IsOptional(), "RunConfig must be void or optional");
  3450. if (runConfigType->IsVoid()) {
  3451. runConfig = NewVoid();
  3452. } else {
  3453. runConfig = NewEmptyOptional(const_cast<TType*>(runConfigType));
  3454. }
  3455. }
  3456. auto funNameNode = NewDataLiteral<NUdf::EDataSlot::String>(funcName);
  3457. auto typeConfigNode = NewDataLiteral<NUdf::EDataSlot::String>(typeConfig);
  3458. TCallableBuilder callableBuilder(Env, __func__, funcInfo.FunctionType);
  3459. callableBuilder.Add(funNameNode);
  3460. callableBuilder.Add(userTypeNode);
  3461. callableBuilder.Add(typeConfigNode);
  3462. callableBuilder.Add(runConfig);
  3463. return TRuntimeNode(callableBuilder.Build(), false);
  3464. }
  3465. TRuntimeNode TProgramBuilder::TypedUdf(
  3466. const std::string_view& funcName,
  3467. TType* funcType,
  3468. TRuntimeNode runConfig,
  3469. TType* userType,
  3470. const std::string_view& typeConfig,
  3471. const std::string_view& file,
  3472. ui32 row,
  3473. ui32 column)
  3474. {
  3475. auto funNameNode = NewDataLiteral<NUdf::EDataSlot::String>(funcName);
  3476. auto typeConfigNode = NewDataLiteral<NUdf::EDataSlot::String>(typeConfig);
  3477. TRuntimeNode userTypeNode = userType ? TRuntimeNode(userType, true) : TRuntimeNode(Env.GetVoidLazy(), true);
  3478. TCallableBuilder callableBuilder(Env, "Udf", funcType);
  3479. callableBuilder.Add(funNameNode);
  3480. callableBuilder.Add(userTypeNode);
  3481. callableBuilder.Add(typeConfigNode);
  3482. callableBuilder.Add(runConfig);
  3483. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  3484. callableBuilder.Add(NewDataLiteral(row));
  3485. callableBuilder.Add(NewDataLiteral(column));
  3486. return TRuntimeNode(callableBuilder.Build(), false);
  3487. }
  3488. TRuntimeNode TProgramBuilder::ScriptUdf(
  3489. const std::string_view& moduleName,
  3490. const std::string_view& funcName,
  3491. TType* funcType,
  3492. TRuntimeNode script,
  3493. const std::string_view& file,
  3494. ui32 row,
  3495. ui32 column)
  3496. {
  3497. MKQL_ENSURE(funcType, "UDF callable type must not be empty");
  3498. MKQL_ENSURE(funcType->IsCallable(), "type must be callable");
  3499. auto scriptType = NKikimr::NMiniKQL::ScriptTypeFromStr(moduleName);
  3500. MKQL_ENSURE(scriptType != EScriptType::Unknown, "unknown script type '" << moduleName << "'");
  3501. EnsureScriptSpecificTypes(scriptType, static_cast<TCallableType*>(funcType), Env);
  3502. auto scriptTypeStr = IsCustomPython(scriptType) ? moduleName : ScriptTypeAsStr(CanonizeScriptType(scriptType));
  3503. TStringBuilder name;
  3504. name.reserve(scriptTypeStr.size() + funcName.size() + 1);
  3505. name << scriptTypeStr << '.' << funcName;
  3506. auto funcNameNode = NewDataLiteral<NUdf::EDataSlot::String>(name);
  3507. TRuntimeNode userTypeNode(funcType, true);
  3508. auto typeConfigNode = NewDataLiteral<NUdf::EDataSlot::String>("");
  3509. TCallableBuilder callableBuilder(Env, __func__, funcType);
  3510. callableBuilder.Add(funcNameNode);
  3511. callableBuilder.Add(userTypeNode);
  3512. callableBuilder.Add(typeConfigNode);
  3513. callableBuilder.Add(script);
  3514. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  3515. callableBuilder.Add(NewDataLiteral(row));
  3516. callableBuilder.Add(NewDataLiteral(column));
  3517. return TRuntimeNode(callableBuilder.Build(), false);
  3518. }
  3519. TRuntimeNode TProgramBuilder::Apply(TRuntimeNode callableNode, const TArrayRef<const TRuntimeNode>& args,
  3520. const std::string_view& file, ui32 row, ui32 column, ui32 dependentCount) {
  3521. MKQL_ENSURE(dependentCount <= args.size(), "Too many dependent nodes");
  3522. ui32 usedArgs = args.size() - dependentCount;
  3523. MKQL_ENSURE(!callableNode.IsImmediate() && callableNode.GetNode()->GetType()->IsCallable(),
  3524. "Expected callable");
  3525. auto callable = static_cast<TCallable*>(callableNode.GetNode());
  3526. TType* returnType = callable->GetType()->GetReturnType();
  3527. MKQL_ENSURE(returnType->IsCallable(), "Expected callable as return type");
  3528. auto callableType = static_cast<TCallableType*>(returnType);
  3529. MKQL_ENSURE(usedArgs <= callableType->GetArgumentsCount(), "Too many arguments");
  3530. MKQL_ENSURE(usedArgs >= callableType->GetArgumentsCount() - callableType->GetOptionalArgumentsCount(), "Too few arguments");
  3531. for (ui32 i = 0; i < usedArgs; i++) {
  3532. TType* argType = callableType->GetArgumentType(i);
  3533. TRuntimeNode arg = args[i];
  3534. MKQL_ENSURE(arg.GetStaticType()->IsConvertableTo(*argType),
  3535. "Argument type mismatch for argument " << i << ": runtime " << argType->GetKindAsStr()
  3536. << " with static " << arg.GetStaticType()->GetKindAsStr());
  3537. }
  3538. TCallableBuilder callableBuilder(Env, RuntimeVersion >= 8 ? "Apply2" : "Apply", callableType->GetReturnType());
  3539. callableBuilder.Add(callableNode);
  3540. callableBuilder.Add(NewDataLiteral<ui32>(dependentCount));
  3541. if constexpr (RuntimeVersion >= 8) {
  3542. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  3543. callableBuilder.Add(NewDataLiteral(row));
  3544. callableBuilder.Add(NewDataLiteral(column));
  3545. }
  3546. for (const auto& arg: args) {
  3547. callableBuilder.Add(arg);
  3548. }
  3549. return TRuntimeNode(callableBuilder.Build(), false);
  3550. }
  3551. TRuntimeNode TProgramBuilder::Apply(
  3552. TRuntimeNode callableNode,
  3553. const TArrayRef<const TRuntimeNode>& args,
  3554. ui32 dependentCount) {
  3555. return Apply(callableNode, args, {}, 0, 0, dependentCount);
  3556. }
  3557. TRuntimeNode TProgramBuilder::Callable(TType* callableType, const TArrayLambda& handler) {
  3558. auto castedCallableType = AS_TYPE(TCallableType, callableType);
  3559. std::vector<TRuntimeNode> args;
  3560. args.reserve(castedCallableType->GetArgumentsCount());
  3561. for (ui32 i = 0; i < castedCallableType->GetArgumentsCount(); ++i) {
  3562. args.push_back(Arg(castedCallableType->GetArgumentType(i)));
  3563. }
  3564. auto res = handler(args);
  3565. TCallableBuilder callableBuilder(Env, __func__, callableType);
  3566. for (ui32 i = 0; i < castedCallableType->GetArgumentsCount(); ++i) {
  3567. callableBuilder.Add(args[i]);
  3568. }
  3569. callableBuilder.Add(res);
  3570. return TRuntimeNode(callableBuilder.Build(), false);
  3571. }
  3572. TRuntimeNode TProgramBuilder::NewNull() {
  3573. if (!UseNullType || RuntimeVersion < 11) {
  3574. TCallableBuilder callableBuilder(Env, "Null", NewOptionalType(Env.GetVoidLazy()->GetType()));
  3575. return TRuntimeNode(callableBuilder.Build(), false);
  3576. } else {
  3577. return TRuntimeNode(Env.GetNullLazy(), true);
  3578. }
  3579. }
  3580. TRuntimeNode TProgramBuilder::Concat(TRuntimeNode data1, TRuntimeNode data2) {
  3581. bool isOpt1, isOpt2;
  3582. const auto type1 = UnpackOptionalData(data1, isOpt1)->GetSchemeType();
  3583. const auto type2 = UnpackOptionalData(data2, isOpt2)->GetSchemeType();
  3584. const auto resultType = NewDataType(type1 == type2 ? type1 : NUdf::TDataType<char*>::Id);
  3585. return InvokeBinary(__func__, isOpt1 || isOpt2 ? NewOptionalType(resultType) : resultType, data1, data2);
  3586. }
  3587. TRuntimeNode TProgramBuilder::AggrConcat(TRuntimeNode data1, TRuntimeNode data2) {
  3588. MKQL_ENSURE(data1.GetStaticType()->IsSameType(*data2.GetStaticType()), "Operands type mismatch.");
  3589. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  3590. return Invoke(__func__, data1.GetStaticType(), args);
  3591. }
  3592. TRuntimeNode TProgramBuilder::Substring(TRuntimeNode data, TRuntimeNode start, TRuntimeNode count) {
  3593. const std::array<TRuntimeNode, 3U> args = {{ data, start, count }};
  3594. return Invoke(__func__, data.GetStaticType(), args);
  3595. }
  3596. TRuntimeNode TProgramBuilder::Find(TRuntimeNode haystack, TRuntimeNode needle, TRuntimeNode pos) {
  3597. const std::array<TRuntimeNode, 3U> args = {{ haystack, needle, pos }};
  3598. return Invoke(__func__, NewOptionalType(NewDataType(NUdf::TDataType<ui32>::Id)), args);
  3599. }
  3600. TRuntimeNode TProgramBuilder::RFind(TRuntimeNode haystack, TRuntimeNode needle, TRuntimeNode pos) {
  3601. const std::array<TRuntimeNode, 3U> args = {{ haystack, needle, pos }};
  3602. return Invoke(__func__, NewOptionalType(NewDataType(NUdf::TDataType<ui32>::Id)), args);
  3603. }
  3604. TRuntimeNode TProgramBuilder::StartsWith(TRuntimeNode string, TRuntimeNode prefix) {
  3605. if constexpr (RuntimeVersion < 19U) {
  3606. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3607. }
  3608. return DataCompare(__func__, string, prefix);
  3609. }
  3610. TRuntimeNode TProgramBuilder::EndsWith(TRuntimeNode string, TRuntimeNode suffix) {
  3611. if constexpr (RuntimeVersion < 19U) {
  3612. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3613. }
  3614. return DataCompare(__func__, string, suffix);
  3615. }
  3616. TRuntimeNode TProgramBuilder::StringContains(TRuntimeNode string, TRuntimeNode pattern) {
  3617. bool isOpt1, isOpt2;
  3618. TDataType* type1 = UnpackOptionalData(string, isOpt1);
  3619. TDataType* type2 = UnpackOptionalData(pattern, isOpt2);
  3620. MKQL_ENSURE(type1->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id ||
  3621. type1->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expecting string as first argument");
  3622. MKQL_ENSURE(type2->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id ||
  3623. type2->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expecting string as second argument");
  3624. if constexpr (RuntimeVersion < 32U) {
  3625. auto stringCasted = (type1->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id) ? ToString(string) : string;
  3626. auto patternCasted = (type2->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id) ? ToString(pattern) : pattern;
  3627. auto found = Exists(Find(stringCasted, patternCasted, NewDataLiteral(ui32(0))));
  3628. if (!isOpt1 && !isOpt2) {
  3629. return found;
  3630. }
  3631. TVector<TRuntimeNode> predicates;
  3632. if (isOpt1) {
  3633. predicates.push_back(Exists(string));
  3634. }
  3635. if (isOpt2) {
  3636. predicates.push_back(Exists(pattern));
  3637. }
  3638. TRuntimeNode argsNotNull = (predicates.size() == 1) ? predicates.front() : And(predicates);
  3639. return If(argsNotNull, NewOptional(found), NewEmptyOptionalDataLiteral(NUdf::TDataType<bool>::Id));
  3640. }
  3641. return DataCompare(__func__, string, pattern);
  3642. }
  3643. TRuntimeNode TProgramBuilder::ByteAt(TRuntimeNode data, TRuntimeNode index) {
  3644. const std::array<TRuntimeNode, 2U> args = {{ data, index }};
  3645. return Invoke(__func__, NewOptionalType(NewDataType(NUdf::TDataType<ui8>::Id)), args);
  3646. }
  3647. TRuntimeNode TProgramBuilder::Size(TRuntimeNode data) {
  3648. return UnaryDataFunction(data, __func__, TDataFunctionFlags::HasUi32Result | TDataFunctionFlags::AllowNull | TDataFunctionFlags::AllowOptionalArgs | TDataFunctionFlags::CommonOptionalResult);
  3649. }
  3650. template <bool Utf8>
  3651. TRuntimeNode TProgramBuilder::ToString(TRuntimeNode data) {
  3652. bool isOptional;
  3653. UnpackOptionalData(data, isOptional);
  3654. const auto resultType = NewDataType(Utf8 ? NUdf::EDataSlot::Utf8 : NUdf::EDataSlot::String, isOptional);
  3655. TCallableBuilder callableBuilder(Env, __func__, resultType);
  3656. callableBuilder.Add(data);
  3657. return TRuntimeNode(callableBuilder.Build(), false);
  3658. }
  3659. TRuntimeNode TProgramBuilder::FromString(TRuntimeNode data, TType* type) {
  3660. bool isOptional;
  3661. const auto sourceType = UnpackOptionalData(data, isOptional);
  3662. const auto targetType = UnpackOptionalData(type, isOptional);
  3663. MKQL_ENSURE(sourceType->GetSchemeType() == NUdf::TDataType<char*>::Id || sourceType->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected String");
  3664. MKQL_ENSURE(targetType->GetSchemeType() != 0, "Null is not allowed");
  3665. TCallableBuilder callableBuilder(Env, __func__, type);
  3666. callableBuilder.Add(data);
  3667. callableBuilder.Add(NewDataLiteral(static_cast<ui32>(targetType->GetSchemeType())));
  3668. if (targetType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3669. const auto& params = static_cast<const TDataDecimalType*>(targetType)->GetParams();
  3670. callableBuilder.Add(NewDataLiteral(params.first));
  3671. callableBuilder.Add(NewDataLiteral(params.second));
  3672. }
  3673. return TRuntimeNode(callableBuilder.Build(), false);
  3674. }
  3675. TRuntimeNode TProgramBuilder::StrictFromString(TRuntimeNode data, TType* type) {
  3676. bool isOptional;
  3677. const auto sourceType = UnpackOptionalData(data, isOptional);
  3678. const auto targetType = UnpackOptionalData(type, isOptional);
  3679. MKQL_ENSURE(sourceType->GetSchemeType() == NUdf::TDataType<char*>::Id || sourceType->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected String");
  3680. MKQL_ENSURE(targetType->GetSchemeType() != 0, "Null is not allowed");
  3681. TCallableBuilder callableBuilder(Env, __func__, type);
  3682. callableBuilder.Add(data);
  3683. callableBuilder.Add(NewDataLiteral(static_cast<ui32>(targetType->GetSchemeType())));
  3684. if (targetType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3685. const auto& params = static_cast<const TDataDecimalType*>(targetType)->GetParams();
  3686. callableBuilder.Add(NewDataLiteral(params.first));
  3687. callableBuilder.Add(NewDataLiteral(params.second));
  3688. }
  3689. return TRuntimeNode(callableBuilder.Build(), false);
  3690. }
  3691. TRuntimeNode TProgramBuilder::ToBytes(TRuntimeNode data) {
  3692. return UnaryDataFunction(data, __func__, TDataFunctionFlags::HasStringResult | TDataFunctionFlags::AllowOptionalArgs | TDataFunctionFlags::CommonOptionalResult);
  3693. }
  3694. TRuntimeNode TProgramBuilder::FromBytes(TRuntimeNode data, TType* targetType) {
  3695. auto type = data.GetStaticType();
  3696. bool isOptional;
  3697. auto dataType = UnpackOptionalData(type, isOptional);
  3698. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expected String");
  3699. auto resultType = NewOptionalType(targetType);
  3700. TCallableBuilder callableBuilder(Env, __func__, resultType);
  3701. callableBuilder.Add(data);
  3702. auto targetDataType = AS_TYPE(TDataType, targetType);
  3703. callableBuilder.Add(NewDataLiteral(static_cast<ui32>(targetDataType->GetSchemeType())));
  3704. if (targetDataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3705. const auto& params = static_cast<const TDataDecimalType*>(targetType)->GetParams();
  3706. callableBuilder.Add(NewDataLiteral(params.first));
  3707. callableBuilder.Add(NewDataLiteral(params.second));
  3708. }
  3709. return TRuntimeNode(callableBuilder.Build(), false);
  3710. }
  3711. TRuntimeNode TProgramBuilder::InversePresortString(TRuntimeNode data) {
  3712. const std::array<TRuntimeNode, 1U> args = {{ data }};
  3713. return Invoke(__func__, NewDataType(NUdf::TDataType<char*>::Id), args);
  3714. }
  3715. TRuntimeNode TProgramBuilder::InverseString(TRuntimeNode data) {
  3716. const std::array<TRuntimeNode, 1U> args = {{ data }};
  3717. return Invoke(__func__, NewDataType(NUdf::TDataType<char*>::Id), args);
  3718. }
  3719. TRuntimeNode TProgramBuilder::Random(const TArrayRef<const TRuntimeNode>& dependentNodes) {
  3720. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<double>::Id));
  3721. for (auto& x : dependentNodes) {
  3722. callableBuilder.Add(x);
  3723. }
  3724. return TRuntimeNode(callableBuilder.Build(), false);
  3725. }
  3726. TRuntimeNode TProgramBuilder::RandomNumber(const TArrayRef<const TRuntimeNode>& dependentNodes) {
  3727. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<ui64>::Id));
  3728. for (auto& x : dependentNodes) {
  3729. callableBuilder.Add(x);
  3730. }
  3731. return TRuntimeNode(callableBuilder.Build(), false);
  3732. }
  3733. TRuntimeNode TProgramBuilder::RandomUuid(const TArrayRef<const TRuntimeNode>& dependentNodes) {
  3734. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<NUdf::TUuid>::Id));
  3735. for (auto& x : dependentNodes) {
  3736. callableBuilder.Add(x);
  3737. }
  3738. return TRuntimeNode(callableBuilder.Build(), false);
  3739. }
  3740. TRuntimeNode TProgramBuilder::Now(const TArrayRef<const TRuntimeNode>& args) {
  3741. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<ui64>::Id));
  3742. for (const auto& x : args) {
  3743. callableBuilder.Add(x);
  3744. }
  3745. return TRuntimeNode(callableBuilder.Build(), false);
  3746. }
  3747. TRuntimeNode TProgramBuilder::CurrentUtcDate(const TArrayRef<const TRuntimeNode>& args) {
  3748. return Cast(CurrentUtcTimestamp(args), NewDataType(NUdf::TDataType<NUdf::TDate>::Id));
  3749. }
  3750. TRuntimeNode TProgramBuilder::CurrentUtcDatetime(const TArrayRef<const TRuntimeNode>& args) {
  3751. return Cast(CurrentUtcTimestamp(args), NewDataType(NUdf::TDataType<NUdf::TDatetime>::Id));
  3752. }
  3753. TRuntimeNode TProgramBuilder::CurrentUtcTimestamp(const TArrayRef<const TRuntimeNode>& args) {
  3754. return Coalesce(ToIntegral(Now(args), NewDataType(NUdf::TDataType<NUdf::TTimestamp>::Id, true)),
  3755. TRuntimeNode(BuildDataLiteral(NUdf::TUnboxedValuePod(ui64(NUdf::MAX_TIMESTAMP - 1ULL)), NUdf::TDataType<NUdf::TTimestamp>::Id, Env), true));
  3756. }
  3757. TRuntimeNode TProgramBuilder::Pickle(TRuntimeNode data) {
  3758. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::EDataSlot::String));
  3759. callableBuilder.Add(data);
  3760. return TRuntimeNode(callableBuilder.Build(), false);
  3761. }
  3762. TRuntimeNode TProgramBuilder::StablePickle(TRuntimeNode data) {
  3763. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::EDataSlot::String));
  3764. callableBuilder.Add(data);
  3765. return TRuntimeNode(callableBuilder.Build(), false);
  3766. }
  3767. TRuntimeNode TProgramBuilder::Unpickle(TType* type, TRuntimeNode serialized) {
  3768. MKQL_ENSURE(AS_TYPE(TDataType, serialized)->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expected String");
  3769. TCallableBuilder callableBuilder(Env, __func__, type);
  3770. callableBuilder.Add(TRuntimeNode(type, true));
  3771. callableBuilder.Add(serialized);
  3772. return TRuntimeNode(callableBuilder.Build(), false);
  3773. }
  3774. TRuntimeNode TProgramBuilder::Ascending(TRuntimeNode data) {
  3775. auto dataType = NewDataType(NUdf::EDataSlot::String);
  3776. TCallableBuilder callableBuilder(Env, __func__, dataType);
  3777. callableBuilder.Add(data);
  3778. return TRuntimeNode(callableBuilder.Build(), false);
  3779. }
  3780. TRuntimeNode TProgramBuilder::Descending(TRuntimeNode data) {
  3781. auto dataType = NewDataType(NUdf::EDataSlot::String);
  3782. TCallableBuilder callableBuilder(Env, __func__, dataType);
  3783. callableBuilder.Add(data);
  3784. return TRuntimeNode(callableBuilder.Build(), false);
  3785. }
  3786. TRuntimeNode TProgramBuilder::Convert(TRuntimeNode data, TType* type) {
  3787. if (data.GetStaticType()->IsSameType(*type)) {
  3788. return data;
  3789. }
  3790. bool isOptional;
  3791. const auto dataType = UnpackOptionalData(data, isOptional);
  3792. const std::array<TRuntimeNode, 1> args = {{ data }};
  3793. if (dataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3794. const auto targetSchemeType = UnpackOptionalData(type, isOptional)->GetSchemeType();
  3795. TStringStream str;
  3796. str << "To" << NUdf::GetDataTypeInfo(NUdf::GetDataSlot(targetSchemeType)).Name
  3797. << '_' << ::ToString(static_cast<const TDataDecimalType*>(dataType)->GetParams().second);
  3798. return Invoke(str.Str().c_str(), type, args);
  3799. }
  3800. return Invoke(__func__, type, args);
  3801. }
  3802. TRuntimeNode TProgramBuilder::ToDecimal(TRuntimeNode data, ui8 precision, ui8 scale) {
  3803. bool isOptional;
  3804. auto dataType = UnpackOptionalData(data, isOptional);
  3805. TType* decimal = TDataDecimalType::Create(precision, scale, Env);
  3806. if (isOptional)
  3807. decimal = TOptionalType::Create(decimal, Env);
  3808. const std::array<TRuntimeNode, 1> args = {{ data }};
  3809. if (dataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3810. const auto& params = static_cast<const TDataDecimalType*>(dataType)->GetParams();
  3811. if (precision - scale < params.first - params.second && scale != params.second) {
  3812. return ToDecimal(ToDecimal(data, precision - scale + params.second, params.second), precision, scale);
  3813. } else if (params.second < scale) {
  3814. return Invoke("ScaleUp_" + ::ToString(scale - params.second), decimal, args);
  3815. } else if (params.second > scale) {
  3816. TRuntimeNode scaled = Invoke("ScaleDown_" + ::ToString(params.second - scale), decimal, args);
  3817. return Invoke("CheckBounds_" + ::ToString(precision), decimal, {{ scaled }});
  3818. } else if (precision < params.first) {
  3819. return Invoke("CheckBounds_" + ::ToString(precision), decimal, args);
  3820. } else if (precision > params.first) {
  3821. return Invoke("Plus", decimal, args);
  3822. } else {
  3823. return data;
  3824. }
  3825. } else {
  3826. const auto digits = NUdf::GetDataTypeInfo(*dataType->GetDataSlot()).DecimalDigits;
  3827. MKQL_ENSURE(digits, "Can't cast into Decimal.");
  3828. if (digits <= precision && !scale)
  3829. return Invoke(__func__, decimal, args);
  3830. else
  3831. return ToDecimal(ToDecimal(data, digits, 0), precision, scale);
  3832. }
  3833. }
  3834. TRuntimeNode TProgramBuilder::ToIntegral(TRuntimeNode data, TType* type) {
  3835. bool isOptional;
  3836. auto dataType = UnpackOptionalData(data, isOptional);
  3837. if (dataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3838. const auto& params = static_cast<const TDataDecimalType*>(dataType)->GetParams();
  3839. if (params.second)
  3840. return ToIntegral(ToDecimal(data, params.first - params.second, 0), type);
  3841. }
  3842. const std::array<TRuntimeNode, 1> args = {{ data }};
  3843. return Invoke(__func__, type, args);
  3844. }
  3845. TRuntimeNode TProgramBuilder::ListIf(TRuntimeNode predicate, TRuntimeNode item) {
  3846. return If(predicate, NewList(item.GetStaticType(), {item}), NewEmptyList(item.GetStaticType()));
  3847. }
  3848. TRuntimeNode TProgramBuilder::AsList(TRuntimeNode item) {
  3849. TListLiteralBuilder builder(Env, item.GetStaticType());
  3850. builder.Add(item);
  3851. return TRuntimeNode(builder.Build(), true);
  3852. }
  3853. TRuntimeNode TProgramBuilder::AsList(const TArrayRef<const TRuntimeNode>& items) {
  3854. MKQL_ENSURE(!items.empty(), "required not empty list of items");
  3855. TListLiteralBuilder builder(Env, items[0].GetStaticType());
  3856. for (auto item : items) {
  3857. builder.Add(item);
  3858. }
  3859. return TRuntimeNode(builder.Build(), true);
  3860. }
  3861. TRuntimeNode TProgramBuilder::MapJoinCore(TRuntimeNode flow, TRuntimeNode dict, EJoinKind joinKind,
  3862. const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& leftRenames,
  3863. const TArrayRef<const ui32>& rightRenames, TType* returnType) {
  3864. MKQL_ENSURE(joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left || joinKind == EJoinKind::LeftSemi || joinKind == EJoinKind::LeftOnly, "Unsupported join kind");
  3865. MKQL_ENSURE(!leftKeyColumns.empty(), "At least one key column must be specified");
  3866. MKQL_ENSURE(leftRenames.size() % 2U == 0U, "Expected even count");
  3867. MKQL_ENSURE(rightRenames.size() % 2U == 0U, "Expected even count");
  3868. TRuntimeNode::TList leftKeyColumnsNodes, leftRenamesNodes, rightRenamesNodes;
  3869. leftKeyColumnsNodes.reserve(leftKeyColumns.size());
  3870. std::transform(leftKeyColumns.cbegin(), leftKeyColumns.cend(), std::back_inserter(leftKeyColumnsNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  3871. leftRenamesNodes.reserve(leftRenames.size());
  3872. std::transform(leftRenames.cbegin(), leftRenames.cend(), std::back_inserter(leftRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  3873. rightRenamesNodes.reserve(rightRenames.size());
  3874. std::transform(rightRenames.cbegin(), rightRenames.cend(), std::back_inserter(rightRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  3875. TCallableBuilder callableBuilder(Env, __func__, returnType);
  3876. callableBuilder.Add(flow);
  3877. callableBuilder.Add(dict);
  3878. callableBuilder.Add(NewDataLiteral((ui32)joinKind));
  3879. callableBuilder.Add(NewTuple(leftKeyColumnsNodes));
  3880. callableBuilder.Add(NewTuple(leftRenamesNodes));
  3881. callableBuilder.Add(NewTuple(rightRenamesNodes));
  3882. return TRuntimeNode(callableBuilder.Build(), false);
  3883. }
  3884. TRuntimeNode TProgramBuilder::CommonJoinCore(TRuntimeNode flow, EJoinKind joinKind,
  3885. const TArrayRef<const ui32>& leftColumns, const TArrayRef<const ui32>& rightColumns,
  3886. const TArrayRef<const ui32>& requiredColumns, const TArrayRef<const ui32>& keyColumns,
  3887. ui64 memLimit, std::optional<ui32> sortedTableOrder,
  3888. EAnyJoinSettings anyJoinSettings, const ui32 tableIndexField, TType* returnType) {
  3889. if constexpr (RuntimeVersion < 17U) {
  3890. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3891. }
  3892. MKQL_ENSURE(leftColumns.size() % 2U == 0U, "Expected even count");
  3893. MKQL_ENSURE(rightColumns.size() % 2U == 0U, "Expected even count");
  3894. TRuntimeNode::TList leftInputColumnsNodes, rightInputColumnsNodes, requiredColumnsNodes,
  3895. leftOutputColumnsNodes, rightOutputColumnsNodes, keyColumnsNodes;
  3896. bool s = false;
  3897. for (const auto idx : leftColumns) {
  3898. ((s = !s) ? leftInputColumnsNodes : leftOutputColumnsNodes).emplace_back(NewDataLiteral(idx));
  3899. }
  3900. for (const auto idx : rightColumns) {
  3901. ((s = !s) ? rightInputColumnsNodes : rightOutputColumnsNodes).emplace_back(NewDataLiteral(idx));
  3902. }
  3903. const std::unordered_set<ui32> requiredIndices(requiredColumns.cbegin(), requiredColumns.cend());
  3904. MKQL_ENSURE(requiredIndices.size() == requiredColumns.size(), "Duplication of requred columns.");
  3905. requiredColumnsNodes.reserve(requiredColumns.size());
  3906. std::transform(requiredColumns.cbegin(), requiredColumns.cend(), std::back_inserter(requiredColumnsNodes),
  3907. std::bind(&TProgramBuilder::NewDataLiteral<ui32>, this, std::placeholders::_1));
  3908. const std::unordered_set<ui32> keyIndices(keyColumns.cbegin(), keyColumns.cend());
  3909. MKQL_ENSURE(keyIndices.size() == keyColumns.size(), "Duplication of key columns.");
  3910. keyColumnsNodes.reserve(keyColumns.size());
  3911. std::transform(keyColumns.cbegin(), keyColumns.cend(), std::back_inserter(keyColumnsNodes),
  3912. std::bind(&TProgramBuilder::NewDataLiteral<ui32>, this, std::placeholders::_1));
  3913. TCallableBuilder callableBuilder(Env, __func__, returnType);
  3914. callableBuilder.Add(flow);
  3915. callableBuilder.Add(NewDataLiteral((ui32)joinKind));
  3916. callableBuilder.Add(NewTuple(leftInputColumnsNodes));
  3917. callableBuilder.Add(NewTuple(rightInputColumnsNodes));
  3918. callableBuilder.Add(NewTuple(requiredColumnsNodes));
  3919. callableBuilder.Add(NewTuple(leftOutputColumnsNodes));
  3920. callableBuilder.Add(NewTuple(rightOutputColumnsNodes));
  3921. callableBuilder.Add(NewTuple(keyColumnsNodes));
  3922. callableBuilder.Add(NewDataLiteral(memLimit));
  3923. callableBuilder.Add(sortedTableOrder ? NewDataLiteral(*sortedTableOrder) : NewVoid());
  3924. callableBuilder.Add(NewDataLiteral((ui32)anyJoinSettings));
  3925. callableBuilder.Add(NewDataLiteral(tableIndexField));
  3926. return TRuntimeNode(callableBuilder.Build(), false);
  3927. }
  3928. TRuntimeNode TProgramBuilder::WideCombiner(TRuntimeNode flow, i64 memLimit, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) {
  3929. if constexpr (RuntimeVersion < 18U) {
  3930. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3931. }
  3932. if (memLimit < 0) {
  3933. if constexpr (RuntimeVersion < 46U) {
  3934. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__ << " with limit " << memLimit;
  3935. }
  3936. }
  3937. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3938. TRuntimeNode::TList itemArgs;
  3939. itemArgs.reserve(wideComponents.size());
  3940. auto i = 0U;
  3941. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3942. const auto keys = extractor(itemArgs);
  3943. TRuntimeNode::TList keyArgs;
  3944. keyArgs.reserve(keys.size());
  3945. std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } );
  3946. const auto first = init(keyArgs, itemArgs);
  3947. TRuntimeNode::TList stateArgs;
  3948. stateArgs.reserve(first.size());
  3949. std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } );
  3950. const auto next = update(keyArgs, itemArgs, stateArgs);
  3951. MKQL_ENSURE(next.size() == first.size(), "Mismatch init and update state size.");
  3952. TRuntimeNode::TList finishKeyArgs;
  3953. finishKeyArgs.reserve(keys.size());
  3954. std::transform(keys.cbegin(), keys.cend(), std::back_inserter(finishKeyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } );
  3955. TRuntimeNode::TList finishStateArgs;
  3956. finishStateArgs.reserve(next.size());
  3957. std::transform(next.cbegin(), next.cend(), std::back_inserter(finishStateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } );
  3958. const auto output = finish(finishKeyArgs, finishStateArgs);
  3959. std::vector<TType*> tupleItems;
  3960. tupleItems.reserve(output.size());
  3961. std::transform(output.cbegin(), output.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  3962. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(tupleItems)));
  3963. callableBuilder.Add(flow);
  3964. if constexpr (RuntimeVersion < 46U)
  3965. callableBuilder.Add(NewDataLiteral(ui64(memLimit)));
  3966. else
  3967. callableBuilder.Add(NewDataLiteral(memLimit));
  3968. callableBuilder.Add(NewDataLiteral(ui32(keyArgs.size())));
  3969. callableBuilder.Add(NewDataLiteral(ui32(stateArgs.size())));
  3970. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3971. std::for_each(keys.cbegin(), keys.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3972. std::for_each(keyArgs.cbegin(), keyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3973. std::for_each(first.cbegin(), first.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3974. std::for_each(stateArgs.cbegin(), stateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3975. std::for_each(next.cbegin(), next.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3976. std::for_each(finishKeyArgs.cbegin(), finishKeyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3977. std::for_each(finishStateArgs.cbegin(), finishStateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3978. std::for_each(output.cbegin(), output.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3979. return TRuntimeNode(callableBuilder.Build(), false);
  3980. }
  3981. TRuntimeNode TProgramBuilder::WideLastCombinerCommon(const TStringBuf& funcName, TRuntimeNode flow, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) {
  3982. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3983. TRuntimeNode::TList itemArgs;
  3984. itemArgs.reserve(wideComponents.size());
  3985. auto i = 0U;
  3986. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3987. const auto keys = extractor(itemArgs);
  3988. TRuntimeNode::TList keyArgs;
  3989. keyArgs.reserve(keys.size());
  3990. std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } );
  3991. const auto first = init(keyArgs, itemArgs);
  3992. TRuntimeNode::TList stateArgs;
  3993. stateArgs.reserve(first.size());
  3994. std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } );
  3995. const auto next = update(keyArgs, itemArgs, stateArgs);
  3996. MKQL_ENSURE(next.size() == first.size(), "Mismatch init and update state size.");
  3997. TRuntimeNode::TList finishKeyArgs;
  3998. finishKeyArgs.reserve(keys.size());
  3999. std::transform(keys.cbegin(), keys.cend(), std::back_inserter(finishKeyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } );
  4000. TRuntimeNode::TList finishStateArgs;
  4001. finishStateArgs.reserve(next.size());
  4002. std::transform(next.cbegin(), next.cend(), std::back_inserter(finishStateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } );
  4003. const auto output = finish(finishKeyArgs, finishStateArgs);
  4004. std::vector<TType*> tupleItems;
  4005. tupleItems.reserve(output.size());
  4006. std::transform(output.cbegin(), output.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  4007. TCallableBuilder callableBuilder(Env, funcName, NewFlowType(NewMultiType(tupleItems)));
  4008. callableBuilder.Add(flow);
  4009. callableBuilder.Add(NewDataLiteral(ui32(keyArgs.size())));
  4010. callableBuilder.Add(NewDataLiteral(ui32(stateArgs.size())));
  4011. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4012. std::for_each(keys.cbegin(), keys.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4013. std::for_each(keyArgs.cbegin(), keyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4014. std::for_each(first.cbegin(), first.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4015. std::for_each(stateArgs.cbegin(), stateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4016. std::for_each(next.cbegin(), next.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4017. std::for_each(finishKeyArgs.cbegin(), finishKeyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4018. std::for_each(finishStateArgs.cbegin(), finishStateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4019. std::for_each(output.cbegin(), output.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4020. return TRuntimeNode(callableBuilder.Build(), false);
  4021. }
  4022. TRuntimeNode TProgramBuilder::WideLastCombiner(TRuntimeNode flow, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) {
  4023. if constexpr (RuntimeVersion < 29U) {
  4024. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4025. }
  4026. return WideLastCombinerCommon(__func__, flow, extractor, init, update, finish);
  4027. }
  4028. TRuntimeNode TProgramBuilder::WideLastCombinerWithSpilling(TRuntimeNode flow, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) {
  4029. if constexpr (RuntimeVersion < 49U) {
  4030. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4031. }
  4032. return WideLastCombinerCommon(__func__, flow, extractor, init, update, finish);
  4033. }
  4034. TRuntimeNode TProgramBuilder::WideCondense1(TRuntimeNode flow, const TWideLambda& init, const TWideSwitchLambda& switcher, const TBinaryWideLambda& update, bool useCtx) {
  4035. if constexpr (RuntimeVersion < 18U) {
  4036. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4037. }
  4038. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  4039. TRuntimeNode::TList itemArgs;
  4040. itemArgs.reserve(wideComponents.size());
  4041. auto i = 0U;
  4042. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  4043. const auto first = init(itemArgs);
  4044. TRuntimeNode::TList stateArgs;
  4045. stateArgs.reserve(first.size());
  4046. std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } );
  4047. const auto chop = switcher(itemArgs, stateArgs);
  4048. const auto next = update(itemArgs, stateArgs);
  4049. MKQL_ENSURE(next.size() == first.size(), "Mismatch init and update state size.");
  4050. std::vector<TType*> tupleItems;
  4051. tupleItems.reserve(next.size());
  4052. std::transform(next.cbegin(), next.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  4053. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(tupleItems)));
  4054. callableBuilder.Add(flow);
  4055. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4056. std::for_each(first.cbegin(), first.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4057. std::for_each(stateArgs.cbegin(), stateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4058. callableBuilder.Add(chop);
  4059. std::for_each(next.cbegin(), next.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4060. if (useCtx) {
  4061. MKQL_ENSURE(RuntimeVersion >= 30U, "Too old runtime version");
  4062. callableBuilder.Add(NewDataLiteral<bool>(useCtx));
  4063. }
  4064. return TRuntimeNode(callableBuilder.Build(), false);
  4065. }
  4066. TRuntimeNode TProgramBuilder::CombineCore(TRuntimeNode stream,
  4067. const TUnaryLambda& keyExtractor,
  4068. const TBinaryLambda& init,
  4069. const TTernaryLambda& update,
  4070. const TBinaryLambda& finish,
  4071. ui64 memLimit)
  4072. {
  4073. if constexpr (RuntimeVersion < 3U) {
  4074. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4075. }
  4076. const bool isStream = stream.GetStaticType()->IsStream();
  4077. const auto itemType = isStream ? AS_TYPE(TStreamType, stream)->GetItemType() : AS_TYPE(TFlowType, stream)->GetItemType();
  4078. const auto itemArg = Arg(itemType);
  4079. const auto key = keyExtractor(itemArg);
  4080. const auto keyType = key.GetStaticType();
  4081. const auto keyArg = Arg(keyType);
  4082. const auto stateInit = init(keyArg, itemArg);
  4083. const auto stateType = stateInit.GetStaticType();
  4084. const auto stateArg = Arg(stateType);
  4085. const auto stateUpdate = update(keyArg, itemArg, stateArg);
  4086. const auto finishItem = finish(keyArg, stateArg);
  4087. const auto finishType = finishItem.GetStaticType();
  4088. MKQL_ENSURE(finishType->IsList() || finishType->IsStream() || finishType->IsOptional(), "Expected list, stream or optional");
  4089. TType* retItemType = nullptr;
  4090. if (finishType->IsOptional()) {
  4091. retItemType = AS_TYPE(TOptionalType, finishType)->GetItemType();
  4092. } else if (finishType->IsList()) {
  4093. retItemType = AS_TYPE(TListType, finishType)->GetItemType();
  4094. } else if (finishType->IsStream()) {
  4095. retItemType = AS_TYPE(TStreamType, finishType)->GetItemType();
  4096. }
  4097. const auto resultStreamType = isStream ? NewStreamType(retItemType) : NewFlowType(retItemType);
  4098. TCallableBuilder callableBuilder(Env, __func__, resultStreamType);
  4099. callableBuilder.Add(stream);
  4100. callableBuilder.Add(itemArg);
  4101. callableBuilder.Add(key);
  4102. callableBuilder.Add(keyArg);
  4103. callableBuilder.Add(stateInit);
  4104. callableBuilder.Add(stateArg);
  4105. callableBuilder.Add(stateUpdate);
  4106. callableBuilder.Add(finishItem);
  4107. callableBuilder.Add(NewDataLiteral(memLimit));
  4108. return TRuntimeNode(callableBuilder.Build(), false);
  4109. }
  4110. TRuntimeNode TProgramBuilder::GroupingCore(TRuntimeNode stream,
  4111. const TBinaryLambda& groupSwitch,
  4112. const TUnaryLambda& keyExtractor,
  4113. const TUnaryLambda& handler)
  4114. {
  4115. if (handler && RuntimeVersion < 20U) {
  4116. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__ << " with handler";
  4117. }
  4118. auto itemType = AS_TYPE(TStreamType, stream)->GetItemType();
  4119. TRuntimeNode keyExtractorItemArg = Arg(itemType);
  4120. TRuntimeNode keyExtractorResult = keyExtractor(keyExtractorItemArg);
  4121. TRuntimeNode groupSwitchKeyArg = Arg(keyExtractorResult.GetStaticType());
  4122. TRuntimeNode groupSwitchItemArg = Arg(itemType);
  4123. TRuntimeNode groupSwitchResult = groupSwitch(groupSwitchKeyArg, groupSwitchItemArg);
  4124. MKQL_ENSURE(AS_TYPE(TDataType, groupSwitchResult)->GetSchemeType() == NUdf::TDataType<bool>::Id,
  4125. "Expected bool type");
  4126. TRuntimeNode handlerItemArg;
  4127. TRuntimeNode handlerResult;
  4128. if (handler) {
  4129. handlerItemArg = Arg(itemType);
  4130. handlerResult = handler(handlerItemArg);
  4131. itemType = handlerResult.GetStaticType();
  4132. }
  4133. const std::array<TType*, 2U> tupleItems = {{ keyExtractorResult.GetStaticType(), NewStreamType(itemType) }};
  4134. const auto finishType = NewStreamType(NewTupleType(tupleItems));
  4135. TCallableBuilder callableBuilder(Env, __func__, finishType);
  4136. callableBuilder.Add(stream);
  4137. callableBuilder.Add(keyExtractorResult);
  4138. callableBuilder.Add(groupSwitchResult);
  4139. callableBuilder.Add(keyExtractorItemArg);
  4140. callableBuilder.Add(groupSwitchKeyArg);
  4141. callableBuilder.Add(groupSwitchItemArg);
  4142. if (handler) {
  4143. callableBuilder.Add(handlerResult);
  4144. callableBuilder.Add(handlerItemArg);
  4145. }
  4146. return TRuntimeNode(callableBuilder.Build(), false);
  4147. }
  4148. TRuntimeNode TProgramBuilder::Chopper(TRuntimeNode flow, const TUnaryLambda& keyExtractor, const TBinaryLambda& groupSwitch, const TBinaryLambda& groupHandler) {
  4149. const auto flowType = flow.GetStaticType();
  4150. MKQL_ENSURE(flowType->IsFlow() || flowType->IsStream(), "Expected flow or stream.");
  4151. if constexpr (RuntimeVersion < 9U) {
  4152. return FlatMap(GroupingCore(flow, groupSwitch, keyExtractor),
  4153. [&](TRuntimeNode item) -> TRuntimeNode { return groupHandler(Nth(item, 0U), Nth(item, 1U)); }
  4154. );
  4155. }
  4156. const bool isStream = flowType->IsStream();
  4157. const auto itemType = isStream ? AS_TYPE(TStreamType, flow)->GetItemType() : AS_TYPE(TFlowType, flow)->GetItemType();
  4158. const auto itemArg = Arg(itemType);
  4159. const auto keyExtractorResult = keyExtractor(itemArg);
  4160. const auto keyArg = Arg(keyExtractorResult.GetStaticType());
  4161. const auto groupSwitchResult = groupSwitch(keyArg, itemArg);
  4162. const auto input = Arg(flowType);
  4163. const auto output = groupHandler(keyArg, input);
  4164. TCallableBuilder callableBuilder(Env, __func__, output.GetStaticType());
  4165. callableBuilder.Add(flow);
  4166. callableBuilder.Add(itemArg);
  4167. callableBuilder.Add(keyExtractorResult);
  4168. callableBuilder.Add(keyArg);
  4169. callableBuilder.Add(groupSwitchResult);
  4170. callableBuilder.Add(input);
  4171. callableBuilder.Add(output);
  4172. return TRuntimeNode(callableBuilder.Build(), false);
  4173. }
  4174. TRuntimeNode TProgramBuilder::WideChopper(TRuntimeNode flow, const TWideLambda& extractor, const TWideSwitchLambda& groupSwitch,
  4175. const std::function<TRuntimeNode (TRuntimeNode::TList, TRuntimeNode)>& groupHandler
  4176. ) {
  4177. if constexpr (RuntimeVersion < 18U) {
  4178. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4179. }
  4180. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  4181. TRuntimeNode::TList itemArgs, keyArgs;
  4182. itemArgs.reserve(wideComponents.size());
  4183. auto i = 0U;
  4184. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  4185. const auto keys = extractor(itemArgs);
  4186. keyArgs.reserve(keys.size());
  4187. std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } );
  4188. const auto groupSwitchResult = groupSwitch(keyArgs, itemArgs);
  4189. const auto input = WideFlowArg(flow.GetStaticType());
  4190. const auto output = groupHandler(keyArgs, input);
  4191. TCallableBuilder callableBuilder(Env, __func__, output.GetStaticType());
  4192. callableBuilder.Add(flow);
  4193. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4194. std::for_each(keys.cbegin(), keys.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4195. std::for_each(keyArgs.cbegin(), keyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4196. callableBuilder.Add(groupSwitchResult);
  4197. callableBuilder.Add(input);
  4198. callableBuilder.Add(output);
  4199. return TRuntimeNode(callableBuilder.Build(), false);
  4200. }
  4201. TRuntimeNode TProgramBuilder::HoppingCore(TRuntimeNode list,
  4202. const TUnaryLambda& timeExtractor,
  4203. const TUnaryLambda& init,
  4204. const TBinaryLambda& update,
  4205. const TUnaryLambda& save,
  4206. const TUnaryLambda& load,
  4207. const TBinaryLambda& merge,
  4208. const TBinaryLambda& finish,
  4209. TRuntimeNode hop, TRuntimeNode interval, TRuntimeNode delay)
  4210. {
  4211. auto streamType = AS_TYPE(TStreamType, list);
  4212. auto itemType = AS_TYPE(TStructType, streamType->GetItemType());
  4213. auto timestampType = TOptionalType::Create(TDataType::Create(NUdf::TDataType<NUdf::TTimestamp>::Id, Env), Env);
  4214. TRuntimeNode itemArg = Arg(itemType);
  4215. auto outTime = timeExtractor(itemArg);
  4216. auto outStateInit = init(itemArg);
  4217. auto stateType = outStateInit.GetStaticType();
  4218. TRuntimeNode stateArg = Arg(stateType);
  4219. auto outStateUpdate = update(itemArg, stateArg);
  4220. auto hasSaveLoad = (bool)save;
  4221. TRuntimeNode saveArg, outSave, loadArg, outLoad;
  4222. if (hasSaveLoad) {
  4223. saveArg = Arg(stateType);
  4224. outSave = save(saveArg);
  4225. loadArg = Arg(outSave.GetStaticType());
  4226. outLoad = load(loadArg);
  4227. MKQL_ENSURE(outLoad.GetStaticType()->IsSameType(*stateType), "Loaded type is changed by the load handler");
  4228. } else {
  4229. saveArg = outSave = loadArg = outLoad = NewVoid();
  4230. }
  4231. TRuntimeNode state2Arg = Arg(stateType);
  4232. TRuntimeNode timeArg = Arg(timestampType);
  4233. auto outStateMerge = merge(stateArg, state2Arg);
  4234. auto outItemFinish = finish(stateArg, timeArg);
  4235. auto finishType = outItemFinish.GetStaticType();
  4236. MKQL_ENSURE(finishType->IsStruct(), "Expected struct type as finish lambda output");
  4237. auto resultType = TStreamType::Create(outItemFinish.GetStaticType(), Env);
  4238. TCallableBuilder callableBuilder(Env, __func__, resultType);
  4239. callableBuilder.Add(list);
  4240. callableBuilder.Add(itemArg);
  4241. callableBuilder.Add(stateArg);
  4242. callableBuilder.Add(state2Arg);
  4243. callableBuilder.Add(timeArg);
  4244. callableBuilder.Add(saveArg);
  4245. callableBuilder.Add(loadArg);
  4246. callableBuilder.Add(outTime);
  4247. callableBuilder.Add(outStateInit);
  4248. callableBuilder.Add(outStateUpdate);
  4249. callableBuilder.Add(outSave);
  4250. callableBuilder.Add(outLoad);
  4251. callableBuilder.Add(outStateMerge);
  4252. callableBuilder.Add(outItemFinish);
  4253. callableBuilder.Add(hop);
  4254. callableBuilder.Add(interval);
  4255. callableBuilder.Add(delay);
  4256. return TRuntimeNode(callableBuilder.Build(), false);
  4257. }
  4258. TRuntimeNode TProgramBuilder::MultiHoppingCore(TRuntimeNode list,
  4259. const TUnaryLambda& keyExtractor,
  4260. const TUnaryLambda& timeExtractor,
  4261. const TUnaryLambda& init,
  4262. const TBinaryLambda& update,
  4263. const TUnaryLambda& save,
  4264. const TUnaryLambda& load,
  4265. const TBinaryLambda& merge,
  4266. const TTernaryLambda& finish,
  4267. TRuntimeNode hop, TRuntimeNode interval, TRuntimeNode delay,
  4268. TRuntimeNode dataWatermarks, TRuntimeNode watermarksMode)
  4269. {
  4270. if constexpr (RuntimeVersion < 22U) {
  4271. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4272. }
  4273. auto streamType = AS_TYPE(TStreamType, list);
  4274. auto itemType = AS_TYPE(TStructType, streamType->GetItemType());
  4275. auto timestampType = TOptionalType::Create(TDataType::Create(NUdf::TDataType<NUdf::TTimestamp>::Id, Env), Env);
  4276. TRuntimeNode itemArg = Arg(itemType);
  4277. auto keyExtract = keyExtractor(itemArg);
  4278. auto keyType = keyExtract.GetStaticType();
  4279. TRuntimeNode keyArg = Arg(keyType);
  4280. auto outTime = timeExtractor(itemArg);
  4281. auto outStateInit = init(itemArg);
  4282. auto stateType = outStateInit.GetStaticType();
  4283. TRuntimeNode stateArg = Arg(stateType);
  4284. auto outStateUpdate = update(itemArg, stateArg);
  4285. auto hasSaveLoad = (bool)save;
  4286. TRuntimeNode saveArg, outSave, loadArg, outLoad;
  4287. if (hasSaveLoad) {
  4288. saveArg = Arg(stateType);
  4289. outSave = save(saveArg);
  4290. loadArg = Arg(outSave.GetStaticType());
  4291. outLoad = load(loadArg);
  4292. MKQL_ENSURE(outLoad.GetStaticType()->IsSameType(*stateType), "Loaded type is changed by the load handler");
  4293. } else {
  4294. saveArg = outSave = loadArg = outLoad = NewVoid();
  4295. }
  4296. TRuntimeNode state2Arg = Arg(stateType);
  4297. TRuntimeNode timeArg = Arg(timestampType);
  4298. auto outStateMerge = merge(stateArg, state2Arg);
  4299. auto outItemFinish = finish(keyArg, stateArg, timeArg);
  4300. auto finishType = outItemFinish.GetStaticType();
  4301. MKQL_ENSURE(finishType->IsStruct(), "Expected struct type as finish lambda output");
  4302. auto resultType = TStreamType::Create(outItemFinish.GetStaticType(), Env);
  4303. TCallableBuilder callableBuilder(Env, __func__, resultType);
  4304. callableBuilder.Add(list);
  4305. callableBuilder.Add(itemArg);
  4306. callableBuilder.Add(keyArg);
  4307. callableBuilder.Add(stateArg);
  4308. callableBuilder.Add(state2Arg);
  4309. callableBuilder.Add(timeArg);
  4310. callableBuilder.Add(saveArg);
  4311. callableBuilder.Add(loadArg);
  4312. callableBuilder.Add(keyExtract);
  4313. callableBuilder.Add(outTime);
  4314. callableBuilder.Add(outStateInit);
  4315. callableBuilder.Add(outStateUpdate);
  4316. callableBuilder.Add(outSave);
  4317. callableBuilder.Add(outLoad);
  4318. callableBuilder.Add(outStateMerge);
  4319. callableBuilder.Add(outItemFinish);
  4320. callableBuilder.Add(hop);
  4321. callableBuilder.Add(interval);
  4322. callableBuilder.Add(delay);
  4323. callableBuilder.Add(dataWatermarks);
  4324. callableBuilder.Add(watermarksMode);
  4325. return TRuntimeNode(callableBuilder.Build(), false);
  4326. }
  4327. TRuntimeNode TProgramBuilder::Default(TType* type) {
  4328. bool isOptional;
  4329. const auto targetType = UnpackOptionalData(type, isOptional);
  4330. if (isOptional) {
  4331. return NewOptional(Default(targetType));
  4332. }
  4333. const auto scheme = targetType->GetSchemeType();
  4334. const auto value = scheme == NUdf::TDataType<NUdf::TUuid>::Id ?
  4335. Env.NewStringValue("\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"sv) :
  4336. scheme == NUdf::TDataType<NUdf::TDyNumber>::Id ? NUdf::TUnboxedValuePod::Embedded("\1") : NUdf::TUnboxedValuePod::Zero();
  4337. return TRuntimeNode(TDataLiteral::Create(value, targetType, Env), true);
  4338. }
  4339. TRuntimeNode TProgramBuilder::Cast(TRuntimeNode arg, TType* type) {
  4340. if (arg.GetStaticType()->IsSameType(*type)) {
  4341. return arg;
  4342. }
  4343. bool isOptional;
  4344. const auto targetType = UnpackOptionalData(type, isOptional);
  4345. const auto sourceType = UnpackOptionalData(arg, isOptional);
  4346. const auto sId = sourceType->GetSchemeType();
  4347. const auto tId = targetType->GetSchemeType();
  4348. if (sId == NUdf::TDataType<char*>::Id) {
  4349. if (tId != NUdf::TDataType<char*>::Id) {
  4350. return FromString(arg, type);
  4351. } else {
  4352. return arg;
  4353. }
  4354. }
  4355. if (sId == NUdf::TDataType<NUdf::TUtf8>::Id) {
  4356. if (tId != NUdf::TDataType<char*>::Id) {
  4357. return FromString(arg, type);
  4358. } else {
  4359. return ToString(arg);
  4360. }
  4361. }
  4362. if (tId == NUdf::TDataType<char*>::Id) {
  4363. return ToString(arg);
  4364. }
  4365. if (tId == NUdf::TDataType<NUdf::TUtf8>::Id) {
  4366. return ToString<true>(arg);
  4367. }
  4368. if (tId == NUdf::TDataType<NUdf::TDecimal>::Id) {
  4369. const auto& params = static_cast<const TDataDecimalType*>(targetType)->GetParams();
  4370. return ToDecimal(arg, params.first, params.second);
  4371. }
  4372. const auto options = NKikimr::NUdf::GetCastResult(*sourceType->GetDataSlot(), *targetType->GetDataSlot());
  4373. MKQL_ENSURE((*options & NKikimr::NUdf::ECastOptions::Undefined) ||
  4374. !(*options & NKikimr::NUdf::ECastOptions::Impossible),
  4375. "Impossible to cast " << *static_cast<TType*>(sourceType) << " into " << *static_cast<TType*>(targetType));
  4376. const bool useToIntegral = (*options & NKikimr::NUdf::ECastOptions::Undefined) ||
  4377. (*options & NKikimr::NUdf::ECastOptions::MayFail);
  4378. return useToIntegral ? ToIntegral(arg, type) : Convert(arg, type);
  4379. }
  4380. TRuntimeNode TProgramBuilder::RangeCreate(TRuntimeNode list) {
  4381. MKQL_ENSURE(list.GetStaticType()->IsList(), "Expecting list");
  4382. auto itemType = static_cast<TListType*>(list.GetStaticType())->GetItemType();
  4383. MKQL_ENSURE(itemType->IsTuple(), "Expecting list of tuples");
  4384. auto tupleType = static_cast<TTupleType*>(itemType);
  4385. MKQL_ENSURE(tupleType->GetElementsCount() == 2,
  4386. "Expecting list ot 2-element tuples, got: " << tupleType->GetElementsCount() << " elements");
  4387. MKQL_ENSURE(tupleType->GetElementType(0)->IsSameType(*tupleType->GetElementType(1)),
  4388. "Expecting list ot 2-element tuples of same type");
  4389. MKQL_ENSURE(tupleType->GetElementType(0)->IsTuple(),
  4390. "Expecting range boundary to be tuple");
  4391. auto boundaryType = static_cast<TTupleType*>(tupleType->GetElementType(0));
  4392. MKQL_ENSURE(boundaryType->GetElementsCount() >= 2,
  4393. "Range boundary should have at least 2 components, got: " << boundaryType->GetElementsCount());
  4394. auto lastComp = boundaryType->GetElementType(boundaryType->GetElementsCount() - 1);
  4395. std::vector<TType*> outputComponents;
  4396. for (ui32 i = 0; i < boundaryType->GetElementsCount() - 1; ++i) {
  4397. outputComponents.push_back(lastComp);
  4398. outputComponents.push_back(boundaryType->GetElementType(i));
  4399. }
  4400. outputComponents.push_back(lastComp);
  4401. auto outputBoundary = TTupleType::Create(outputComponents.size(), &outputComponents.front(), Env);
  4402. std::vector<TType*> outputRangeComps(2, outputBoundary);
  4403. auto outputRange = TTupleType::Create(outputRangeComps.size(), &outputRangeComps.front(), Env);
  4404. TCallableBuilder callableBuilder(Env, __func__, TListType::Create(outputRange, Env));
  4405. callableBuilder.Add(list);
  4406. return TRuntimeNode(callableBuilder.Build(), false);
  4407. }
  4408. TRuntimeNode TProgramBuilder::RangeUnion(const TArrayRef<const TRuntimeNode>& lists) {
  4409. return BuildRangeLogical(__func__, lists);
  4410. }
  4411. TRuntimeNode TProgramBuilder::RangeIntersect(const TArrayRef<const TRuntimeNode>& lists) {
  4412. return BuildRangeLogical(__func__, lists);
  4413. }
  4414. TRuntimeNode TProgramBuilder::RangeMultiply(const TArrayRef<const TRuntimeNode>& args) {
  4415. MKQL_ENSURE(args.size() >= 2, "Expecting at least two arguments");
  4416. bool unlimited = false;
  4417. if (args.front().GetStaticType()->IsVoid()) {
  4418. unlimited = true;
  4419. } else {
  4420. MKQL_ENSURE(args.front().GetStaticType()->IsData() &&
  4421. static_cast<TDataType*>(args.front().GetStaticType())->GetSchemeType() == NUdf::TDataType<ui64>::Id,
  4422. "Expected ui64 as first argument");
  4423. }
  4424. std::vector<TType*> outputComponents;
  4425. for (size_t i = 1; i < args.size(); ++i) {
  4426. const auto& list = args[i];
  4427. MKQL_ENSURE(list.GetStaticType()->IsList(), "Expecting list");
  4428. auto listItemType = static_cast<TListType*>(list.GetStaticType())->GetItemType();
  4429. MKQL_ENSURE(listItemType->IsTuple(), "Expecting list of tuples");
  4430. auto rangeType = static_cast<TTupleType*>(listItemType);
  4431. MKQL_ENSURE(rangeType->GetElementsCount() == 2, "Expecting list of 2-element tuples");
  4432. MKQL_ENSURE(rangeType->GetElementType(0)->IsTuple(), "Range boundary should be tuple");
  4433. auto boundaryType = static_cast<TTupleType*>(rangeType->GetElementType(0));
  4434. ui32 elementsCount = boundaryType->GetElementsCount();
  4435. MKQL_ENSURE(elementsCount >= 3 && elementsCount % 2 == 1, "Range boundary should have odd number components (at least 3)");
  4436. for (size_t j = 0; j < elementsCount - 1; ++j) {
  4437. outputComponents.push_back(boundaryType->GetElementType(j));
  4438. }
  4439. }
  4440. outputComponents.push_back(TDataType::Create(NUdf::TDataType<i32>::Id, Env));
  4441. auto outputBoundary = TTupleType::Create(outputComponents.size(), &outputComponents.front(), Env);
  4442. std::vector<TType*> outputRangeComps(2, outputBoundary);
  4443. auto outputRange = TTupleType::Create(outputRangeComps.size(), &outputRangeComps.front(), Env);
  4444. TCallableBuilder callableBuilder(Env, __func__, TListType::Create(outputRange, Env));
  4445. if (unlimited) {
  4446. callableBuilder.Add(NewDataLiteral<ui64>(std::numeric_limits<ui64>::max()));
  4447. } else {
  4448. callableBuilder.Add(args[0]);
  4449. }
  4450. for (size_t i = 1; i < args.size(); ++i) {
  4451. callableBuilder.Add(args[i]);
  4452. }
  4453. return TRuntimeNode(callableBuilder.Build(), false);
  4454. }
  4455. TRuntimeNode TProgramBuilder::RangeFinalize(TRuntimeNode list) {
  4456. MKQL_ENSURE(list.GetStaticType()->IsList(), "Expecting list");
  4457. auto listItemType = static_cast<TListType*>(list.GetStaticType())->GetItemType();
  4458. MKQL_ENSURE(listItemType->IsTuple(), "Expecting list of tuples");
  4459. auto rangeType = static_cast<TTupleType*>(listItemType);
  4460. MKQL_ENSURE(rangeType->GetElementsCount() == 2, "Expecting list of 2-element tuples");
  4461. MKQL_ENSURE(rangeType->GetElementType(0)->IsTuple(), "Range boundary should be tuple");
  4462. auto boundaryType = static_cast<TTupleType*>(rangeType->GetElementType(0));
  4463. ui32 elementsCount = boundaryType->GetElementsCount();
  4464. MKQL_ENSURE(elementsCount >= 3 && elementsCount % 2 == 1, "Range boundary should have odd number components (at least 3)");
  4465. std::vector<TType*> outputComponents;
  4466. for (ui32 i = 0; i < elementsCount; ++i) {
  4467. if (i % 2 == 1 || i + 1 == elementsCount) {
  4468. outputComponents.push_back(boundaryType->GetElementType(i));
  4469. }
  4470. }
  4471. auto outputBoundary = TTupleType::Create(outputComponents.size(), &outputComponents.front(), Env);
  4472. std::vector<TType*> outputRangeComps(2, outputBoundary);
  4473. auto outputRange = TTupleType::Create(outputRangeComps.size(), &outputRangeComps.front(), Env);
  4474. TCallableBuilder callableBuilder(Env, __func__, TListType::Create(outputRange, Env));
  4475. callableBuilder.Add(list);
  4476. return TRuntimeNode(callableBuilder.Build(), false);
  4477. }
  4478. TRuntimeNode TProgramBuilder::Round(const std::string_view& callableName, TRuntimeNode source, TType* targetType) {
  4479. const auto sourceType = source.GetStaticType();
  4480. MKQL_ENSURE(sourceType->IsData(), "Expecting first arg to be of Data type");
  4481. MKQL_ENSURE(targetType->IsData(), "Expecting second arg to be Data type");
  4482. const auto ss = *static_cast<TDataType*>(sourceType)->GetDataSlot();
  4483. const auto ts = *static_cast<TDataType*>(targetType)->GetDataSlot();
  4484. const auto options = NKikimr::NUdf::GetCastResult(ss, ts);
  4485. MKQL_ENSURE(!(*options & NKikimr::NUdf::ECastOptions::Impossible),
  4486. "Impossible to cast " << *sourceType << " into " << *targetType);
  4487. MKQL_ENSURE(*options & (NKikimr::NUdf::ECastOptions::MayFail |
  4488. NKikimr::NUdf::ECastOptions::MayLoseData |
  4489. NKikimr::NUdf::ECastOptions::AnywayLoseData),
  4490. "Rounding from " << *sourceType << " to " << *targetType << " is trivial");
  4491. TCallableBuilder callableBuilder(Env, callableName, TOptionalType::Create(targetType, Env));
  4492. callableBuilder.Add(source);
  4493. return TRuntimeNode(callableBuilder.Build(), false);
  4494. }
  4495. TRuntimeNode TProgramBuilder::NextValue(TRuntimeNode value) {
  4496. const auto valueType = value.GetStaticType();
  4497. MKQL_ENSURE(valueType->IsData(), "Expecting argument of Data type");
  4498. const auto slot = *static_cast<TDataType*>(valueType)->GetDataSlot();
  4499. MKQL_ENSURE(slot == NUdf::EDataSlot::String || slot == NUdf::EDataSlot::Utf8,
  4500. "Unsupported type: " << *valueType);
  4501. TCallableBuilder callableBuilder(Env, __func__, TOptionalType::Create(valueType, Env));
  4502. callableBuilder.Add(value);
  4503. return TRuntimeNode(callableBuilder.Build(), false);
  4504. }
  4505. TRuntimeNode TProgramBuilder::Nop(TRuntimeNode value, TType* returnType) {
  4506. if constexpr (RuntimeVersion < 35U) {
  4507. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4508. }
  4509. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4510. callableBuilder.Add(value);
  4511. return TRuntimeNode(callableBuilder.Build(), false);
  4512. }
  4513. bool TProgramBuilder::IsNull(TRuntimeNode arg) {
  4514. return arg.GetStaticType()->IsSameType(*NewNull().GetStaticType()); // TODO ->IsNull();
  4515. }
  4516. TRuntimeNode TProgramBuilder::Replicate(TRuntimeNode item, TRuntimeNode count, const std::string_view& file, ui32 row, ui32 column) {
  4517. MKQL_ENSURE(count.GetStaticType()->IsData(), "Expected data");
  4518. MKQL_ENSURE(static_cast<const TDataType&>(*count.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  4519. const auto listType = TListType::Create(item.GetStaticType(), Env);
  4520. TCallableBuilder callableBuilder(Env, __func__, listType);
  4521. callableBuilder.Add(item);
  4522. callableBuilder.Add(count);
  4523. if constexpr (RuntimeVersion >= 2) {
  4524. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  4525. callableBuilder.Add(NewDataLiteral(row));
  4526. callableBuilder.Add(NewDataLiteral(column));
  4527. }
  4528. return TRuntimeNode(callableBuilder.Build(), false);
  4529. }
  4530. TRuntimeNode TProgramBuilder::PgConst(TPgType* pgType, const std::string_view& value, TRuntimeNode typeMod) {
  4531. if constexpr (RuntimeVersion < 30U) {
  4532. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4533. }
  4534. TCallableBuilder callableBuilder(Env, __func__, pgType);
  4535. callableBuilder.Add(NewDataLiteral(pgType->GetTypeId()));
  4536. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(value));
  4537. if (typeMod) {
  4538. callableBuilder.Add(typeMod);
  4539. }
  4540. return TRuntimeNode(callableBuilder.Build(), false);
  4541. }
  4542. TRuntimeNode TProgramBuilder::PgResolvedCall(bool useContext, const std::string_view& name,
  4543. ui32 id, const TArrayRef<const TRuntimeNode>& args,
  4544. TType* returnType, bool rangeFunction) {
  4545. if constexpr (RuntimeVersion < 45U) {
  4546. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4547. }
  4548. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4549. callableBuilder.Add(NewDataLiteral(useContext));
  4550. callableBuilder.Add(NewDataLiteral(rangeFunction));
  4551. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(name));
  4552. callableBuilder.Add(NewDataLiteral(id));
  4553. for (const auto& arg : args) {
  4554. callableBuilder.Add(arg);
  4555. }
  4556. return TRuntimeNode(callableBuilder.Build(), false);
  4557. }
  4558. TRuntimeNode TProgramBuilder::BlockPgResolvedCall(const std::string_view& name, ui32 id,
  4559. const TArrayRef<const TRuntimeNode>& args, TType* returnType) {
  4560. if constexpr (RuntimeVersion < 30U) {
  4561. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4562. }
  4563. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4564. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(name));
  4565. callableBuilder.Add(NewDataLiteral(id));
  4566. for (const auto& arg : args) {
  4567. callableBuilder.Add(arg);
  4568. }
  4569. return TRuntimeNode(callableBuilder.Build(), false);
  4570. }
  4571. TRuntimeNode TProgramBuilder::PgArray(const TArrayRef<const TRuntimeNode>& args, TType* returnType) {
  4572. if constexpr (RuntimeVersion < 30U) {
  4573. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4574. }
  4575. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4576. for (const auto& arg : args) {
  4577. callableBuilder.Add(arg);
  4578. }
  4579. return TRuntimeNode(callableBuilder.Build(), false);
  4580. }
  4581. TRuntimeNode TProgramBuilder::PgTableContent(
  4582. const std::string_view& cluster,
  4583. const std::string_view& table,
  4584. TType* returnType) {
  4585. if constexpr (RuntimeVersion < 47U) {
  4586. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4587. }
  4588. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4589. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(cluster));
  4590. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(table));
  4591. return TRuntimeNode(callableBuilder.Build(), false);
  4592. }
  4593. TRuntimeNode TProgramBuilder::PgToRecord(TRuntimeNode input, const TArrayRef<std::pair<std::string_view, std::string_view>>& members) {
  4594. if constexpr (RuntimeVersion < 48U) {
  4595. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4596. }
  4597. MKQL_ENSURE(input.GetStaticType()->IsStruct(), "Expected struct");
  4598. auto structType = AS_TYPE(TStructType, input.GetStaticType());
  4599. for (ui32 i = 0; i < structType->GetMembersCount(); ++i) {
  4600. auto itemType = structType->GetMemberType(i);
  4601. MKQL_ENSURE(itemType->IsNull() || itemType->IsPg(), "Expected null or pg");
  4602. }
  4603. auto returnType = NewPgType(NYql::NPg::LookupType("record").TypeId);
  4604. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4605. callableBuilder.Add(input);
  4606. TVector<TRuntimeNode> names;
  4607. for (const auto& x : members) {
  4608. names.push_back(NewDataLiteral<NUdf::EDataSlot::String>(x.first));
  4609. names.push_back(NewDataLiteral<NUdf::EDataSlot::String>(x.second));
  4610. }
  4611. callableBuilder.Add(NewTuple(names));
  4612. return TRuntimeNode(callableBuilder.Build(), false);
  4613. }
  4614. TRuntimeNode TProgramBuilder::PgCast(TRuntimeNode input, TType* returnType, TRuntimeNode typeMod) {
  4615. if constexpr (RuntimeVersion < 30U) {
  4616. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4617. }
  4618. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4619. callableBuilder.Add(input);
  4620. if (typeMod) {
  4621. callableBuilder.Add(typeMod);
  4622. }
  4623. return TRuntimeNode(callableBuilder.Build(), false);
  4624. }
  4625. TRuntimeNode TProgramBuilder::FromPg(TRuntimeNode input, TType* returnType) {
  4626. if constexpr (RuntimeVersion < 30U) {
  4627. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4628. }
  4629. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4630. callableBuilder.Add(input);
  4631. return TRuntimeNode(callableBuilder.Build(), false);
  4632. }
  4633. TRuntimeNode TProgramBuilder::ToPg(TRuntimeNode input, TType* returnType) {
  4634. if constexpr (RuntimeVersion < 30U) {
  4635. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4636. }
  4637. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4638. callableBuilder.Add(input);
  4639. return TRuntimeNode(callableBuilder.Build(), false);
  4640. }
  4641. TRuntimeNode TProgramBuilder::PgClone(TRuntimeNode input, const TArrayRef<const TRuntimeNode>& dependentNodes) {
  4642. if constexpr (RuntimeVersion < 38U) {
  4643. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4644. }
  4645. TCallableBuilder callableBuilder(Env, __func__, input.GetStaticType());
  4646. callableBuilder.Add(input);
  4647. for (const auto& node : dependentNodes) {
  4648. callableBuilder.Add(node);
  4649. }
  4650. return TRuntimeNode(callableBuilder.Build(), false);
  4651. }
  4652. TRuntimeNode TProgramBuilder::WithContext(TRuntimeNode input, const std::string_view& contextType) {
  4653. if constexpr (RuntimeVersion < 30U) {
  4654. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4655. }
  4656. TCallableBuilder callableBuilder(Env, __func__, input.GetStaticType());
  4657. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(contextType));
  4658. callableBuilder.Add(input);
  4659. return TRuntimeNode(callableBuilder.Build(), false);
  4660. }
  4661. TRuntimeNode TProgramBuilder::PgInternal0(TType* returnType) {
  4662. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4663. return TRuntimeNode(callableBuilder.Build(), false);
  4664. }
  4665. TRuntimeNode TProgramBuilder::BlockIf(TRuntimeNode condition, TRuntimeNode thenBranch, TRuntimeNode elseBranch) {
  4666. const auto conditionType = AS_TYPE(TBlockType, condition.GetStaticType());
  4667. MKQL_ENSURE(AS_TYPE(TDataType, conditionType->GetItemType())->GetSchemeType() == NUdf::TDataType<bool>::Id,
  4668. "Expected bool as first argument");
  4669. const auto thenType = AS_TYPE(TBlockType, thenBranch.GetStaticType());
  4670. const auto elseType = AS_TYPE(TBlockType, elseBranch.GetStaticType());
  4671. MKQL_ENSURE(thenType->GetItemType()->IsSameType(*elseType->GetItemType()), "Different return types in branches.");
  4672. auto returnType = NewBlockType(thenType->GetItemType(), GetResultShape({conditionType, thenType, elseType}));
  4673. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4674. callableBuilder.Add(condition);
  4675. callableBuilder.Add(thenBranch);
  4676. callableBuilder.Add(elseBranch);
  4677. return TRuntimeNode(callableBuilder.Build(), false);
  4678. }
  4679. TRuntimeNode TProgramBuilder::BlockJust(TRuntimeNode data) {
  4680. const auto initialType = AS_TYPE(TBlockType, data.GetStaticType());
  4681. auto returnType = NewBlockType(NewOptionalType(initialType->GetItemType()), initialType->GetShape());
  4682. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4683. callableBuilder.Add(data);
  4684. return TRuntimeNode(callableBuilder.Build(), false);
  4685. }
  4686. TRuntimeNode TProgramBuilder::BlockFunc(const std::string_view& funcName, TType* returnType, const TArrayRef<const TRuntimeNode>& args) {
  4687. for (const auto& arg : args) {
  4688. MKQL_ENSURE(arg.GetStaticType()->IsBlock(), "Expected Block type");
  4689. }
  4690. TCallableBuilder builder(Env, __func__, returnType);
  4691. builder.Add(NewDataLiteral<NUdf::EDataSlot::String>(funcName));
  4692. for (const auto& arg : args) {
  4693. builder.Add(arg);
  4694. }
  4695. return TRuntimeNode(builder.Build(), false);
  4696. }
  4697. TRuntimeNode TProgramBuilder::BuildBlockCombineAll(const std::string_view& callableName, TRuntimeNode input, std::optional<ui32> filterColumn,
  4698. const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4699. const auto inputType = input.GetStaticType();
  4700. MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type");
  4701. MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type");
  4702. TCallableBuilder builder(Env, callableName, returnType);
  4703. builder.Add(input);
  4704. if (!filterColumn) {
  4705. builder.Add(NewEmptyOptionalDataLiteral(NUdf::TDataType<ui32>::Id));
  4706. } else {
  4707. builder.Add(NewOptional(NewDataLiteral<ui32>(*filterColumn)));
  4708. }
  4709. TVector<TRuntimeNode> aggsNodes;
  4710. for (const auto& agg : aggs) {
  4711. TVector<TRuntimeNode> params;
  4712. params.push_back(NewDataLiteral<NUdf::EDataSlot::String>(agg.Name));
  4713. for (const auto& col : agg.ArgsColumns) {
  4714. params.push_back(NewDataLiteral<ui32>(col));
  4715. }
  4716. aggsNodes.push_back(NewTuple(params));
  4717. }
  4718. builder.Add(NewTuple(aggsNodes));
  4719. return TRuntimeNode(builder.Build(), false);
  4720. }
  4721. TRuntimeNode TProgramBuilder::BlockCombineAll(TRuntimeNode stream, std::optional<ui32> filterColumn,
  4722. const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4723. if constexpr (RuntimeVersion < 31U) {
  4724. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4725. }
  4726. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type");
  4727. MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type");
  4728. if constexpr (RuntimeVersion < 52U) {
  4729. const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType());
  4730. return FromFlow(BuildBlockCombineAll(__func__, ToFlow(stream), filterColumn, aggs, flowReturnType));
  4731. } else {
  4732. return BuildBlockCombineAll(__func__, stream, filterColumn, aggs, returnType);
  4733. }
  4734. }
  4735. TRuntimeNode TProgramBuilder::BuildBlockCombineHashed(const std::string_view& callableName, TRuntimeNode input, std::optional<ui32> filterColumn,
  4736. const TArrayRef<ui32>& keys, const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4737. const auto inputType = input.GetStaticType();
  4738. MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type");
  4739. MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type");
  4740. TCallableBuilder builder(Env, callableName, returnType);
  4741. builder.Add(input);
  4742. if (!filterColumn) {
  4743. builder.Add(NewEmptyOptionalDataLiteral(NUdf::TDataType<ui32>::Id));
  4744. } else {
  4745. builder.Add(NewOptional(NewDataLiteral<ui32>(*filterColumn)));
  4746. }
  4747. TVector<TRuntimeNode> keyNodes;
  4748. for (const auto& key : keys) {
  4749. keyNodes.push_back(NewDataLiteral<ui32>(key));
  4750. }
  4751. builder.Add(NewTuple(keyNodes));
  4752. TVector<TRuntimeNode> aggsNodes;
  4753. for (const auto& agg : aggs) {
  4754. TVector<TRuntimeNode> params;
  4755. params.push_back(NewDataLiteral<NUdf::EDataSlot::String>(agg.Name));
  4756. for (const auto& col : agg.ArgsColumns) {
  4757. params.push_back(NewDataLiteral<ui32>(col));
  4758. }
  4759. aggsNodes.push_back(NewTuple(params));
  4760. }
  4761. builder.Add(NewTuple(aggsNodes));
  4762. return TRuntimeNode(builder.Build(), false);
  4763. }
  4764. TRuntimeNode TProgramBuilder::BlockCombineHashed(TRuntimeNode stream, std::optional<ui32> filterColumn, const TArrayRef<ui32>& keys,
  4765. const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4766. if constexpr (RuntimeVersion < 31U) {
  4767. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4768. }
  4769. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type");
  4770. MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type");
  4771. if constexpr (RuntimeVersion < 52U) {
  4772. const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType());
  4773. return FromFlow(BuildBlockCombineHashed(__func__, ToFlow(stream), filterColumn, keys, aggs, flowReturnType));
  4774. } else {
  4775. return BuildBlockCombineHashed(__func__, stream, filterColumn, keys, aggs, returnType);
  4776. }
  4777. }
  4778. TRuntimeNode TProgramBuilder::BuildBlockMergeFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, const TArrayRef<ui32>& keys,
  4779. const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4780. const auto inputType = input.GetStaticType();
  4781. MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type");
  4782. MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type");
  4783. TCallableBuilder builder(Env, callableName, returnType);
  4784. builder.Add(input);
  4785. TVector<TRuntimeNode> keyNodes;
  4786. for (const auto& key : keys) {
  4787. keyNodes.push_back(NewDataLiteral<ui32>(key));
  4788. }
  4789. builder.Add(NewTuple(keyNodes));
  4790. TVector<TRuntimeNode> aggsNodes;
  4791. for (const auto& agg : aggs) {
  4792. TVector<TRuntimeNode> params;
  4793. params.push_back(NewDataLiteral<NUdf::EDataSlot::String>(agg.Name));
  4794. for (const auto& col : agg.ArgsColumns) {
  4795. params.push_back(NewDataLiteral<ui32>(col));
  4796. }
  4797. aggsNodes.push_back(NewTuple(params));
  4798. }
  4799. builder.Add(NewTuple(aggsNodes));
  4800. return TRuntimeNode(builder.Build(), false);
  4801. }
  4802. TRuntimeNode TProgramBuilder::BlockMergeFinalizeHashed(TRuntimeNode stream, const TArrayRef<ui32>& keys,
  4803. const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4804. if constexpr (RuntimeVersion < 31U) {
  4805. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4806. }
  4807. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type");
  4808. MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type");
  4809. if constexpr (RuntimeVersion < 52U) {
  4810. const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType());
  4811. return FromFlow(BuildBlockMergeFinalizeHashed(__func__, ToFlow(stream), keys, aggs, flowReturnType));
  4812. } else {
  4813. return BuildBlockMergeFinalizeHashed(__func__, stream, keys, aggs, returnType);
  4814. }
  4815. }
  4816. TRuntimeNode TProgramBuilder::BuildBlockMergeManyFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, const TArrayRef<ui32>& keys,
  4817. const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType) {
  4818. const auto inputType = input.GetStaticType();
  4819. MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type");
  4820. MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type");
  4821. TCallableBuilder builder(Env, callableName, returnType);
  4822. builder.Add(input);
  4823. TVector<TRuntimeNode> keyNodes;
  4824. for (const auto& key : keys) {
  4825. keyNodes.push_back(NewDataLiteral<ui32>(key));
  4826. }
  4827. builder.Add(NewTuple(keyNodes));
  4828. TVector<TRuntimeNode> aggsNodes;
  4829. for (const auto& agg : aggs) {
  4830. TVector<TRuntimeNode> params;
  4831. params.push_back(NewDataLiteral<NUdf::EDataSlot::String>(agg.Name));
  4832. for (const auto& col : agg.ArgsColumns) {
  4833. params.push_back(NewDataLiteral<ui32>(col));
  4834. }
  4835. aggsNodes.push_back(NewTuple(params));
  4836. }
  4837. builder.Add(NewTuple(aggsNodes));
  4838. builder.Add(NewDataLiteral<ui32>(streamIndex));
  4839. TVector<TRuntimeNode> streamsNodes;
  4840. for (const auto& s : streams) {
  4841. TVector<TRuntimeNode> streamNodes;
  4842. for (const auto& i : s) {
  4843. streamNodes.push_back(NewDataLiteral<ui32>(i));
  4844. }
  4845. streamsNodes.push_back(NewTuple(streamNodes));
  4846. }
  4847. builder.Add(NewTuple(streamsNodes));
  4848. return TRuntimeNode(builder.Build(), false);
  4849. }
  4850. TRuntimeNode TProgramBuilder::BlockMergeManyFinalizeHashed(TRuntimeNode stream, const TArrayRef<ui32>& keys,
  4851. const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType) {
  4852. if constexpr (RuntimeVersion < 31U) {
  4853. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4854. }
  4855. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type");
  4856. MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type");
  4857. if constexpr (RuntimeVersion < 52U) {
  4858. const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType());
  4859. return FromFlow(BuildBlockMergeManyFinalizeHashed(__func__, ToFlow(stream), keys, aggs, streamIndex, streams, flowReturnType));
  4860. } else {
  4861. return BuildBlockMergeManyFinalizeHashed(__func__, stream, keys, aggs, streamIndex, streams, returnType);
  4862. }
  4863. }
  4864. TRuntimeNode TProgramBuilder::ScalarApply(const TArrayRef<const TRuntimeNode>& args, const TArrayLambda& handler) {
  4865. if constexpr (RuntimeVersion < 39U) {
  4866. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4867. }
  4868. MKQL_ENSURE(!args.empty(), "Required at least one argument");
  4869. TVector<TRuntimeNode> lambdaArgs;
  4870. bool scalarOnly = true;
  4871. std::shared_ptr<arrow::DataType> arrowType;
  4872. for (const auto& arg : args) {
  4873. auto blockType = AS_TYPE(TBlockType, arg.GetStaticType());
  4874. scalarOnly = scalarOnly && blockType->GetShape() == TBlockType::EShape::Scalar;
  4875. MKQL_ENSURE(ConvertArrowType(blockType->GetItemType(), arrowType), "Unsupported arrow type");
  4876. lambdaArgs.emplace_back(Arg(blockType->GetItemType()));
  4877. }
  4878. auto ret = handler(lambdaArgs);
  4879. MKQL_ENSURE(ConvertArrowType(ret.GetStaticType(), arrowType), "Unsupported arrow type");
  4880. auto returnType = NewBlockType(ret.GetStaticType(), scalarOnly ? TBlockType::EShape::Scalar : TBlockType::EShape::Many);
  4881. TCallableBuilder builder(Env, __func__, returnType);
  4882. for (const auto& arg : args) {
  4883. builder.Add(arg);
  4884. }
  4885. for (const auto& arg : lambdaArgs) {
  4886. builder.Add(arg);
  4887. }
  4888. builder.Add(ret);
  4889. return TRuntimeNode(builder.Build(), false);
  4890. }
  4891. TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode leftStream, TRuntimeNode rightStream, EJoinKind joinKind,
  4892. const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& leftKeyDrops,
  4893. const TArrayRef<const ui32>& rightKeyColumns, const TArrayRef<const ui32>& rightKeyDrops, bool rightAny, TType* returnType
  4894. ) {
  4895. if constexpr (RuntimeVersion < 53U) {
  4896. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4897. }
  4898. if (RuntimeVersion < 57U && joinKind == EJoinKind::Cross) {
  4899. THROW yexception() << __func__ << " does not support cross join in runtime version (" << RuntimeVersion << ")";
  4900. }
  4901. MKQL_ENSURE(joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left ||
  4902. joinKind == EJoinKind::LeftSemi || joinKind == EJoinKind::LeftOnly || joinKind == EJoinKind::Cross,
  4903. "Unsupported join kind");
  4904. MKQL_ENSURE(leftKeyColumns.size() == rightKeyColumns.size(), "Key column count mismatch");
  4905. if (joinKind == EJoinKind::Cross) {
  4906. MKQL_ENSURE(leftKeyColumns.empty(), "Specifying key columns is not allowed for cross join");
  4907. } else {
  4908. MKQL_ENSURE(!leftKeyColumns.empty(), "At least one key column must be specified");
  4909. }
  4910. ValidateBlockStreamType(leftStream.GetStaticType());
  4911. ValidateBlockStreamType(rightStream.GetStaticType());
  4912. ValidateBlockStreamType(returnType);
  4913. TRuntimeNode::TList leftKeyColumnsNodes;
  4914. leftKeyColumnsNodes.reserve(leftKeyColumns.size());
  4915. std::transform(leftKeyColumns.cbegin(), leftKeyColumns.cend(),
  4916. std::back_inserter(leftKeyColumnsNodes), [this](const ui32 idx) {
  4917. return NewDataLiteral(idx);
  4918. });
  4919. TRuntimeNode::TList leftKeyDropsNodes;
  4920. leftKeyDropsNodes.reserve(leftKeyDrops.size());
  4921. std::transform(leftKeyDrops.cbegin(), leftKeyDrops.cend(),
  4922. std::back_inserter(leftKeyDropsNodes), [this](const ui32 idx) {
  4923. return NewDataLiteral(idx);
  4924. });
  4925. TRuntimeNode::TList rightKeyColumnsNodes;
  4926. rightKeyColumnsNodes.reserve(rightKeyColumns.size());
  4927. std::transform(rightKeyColumns.cbegin(), rightKeyColumns.cend(),
  4928. std::back_inserter(rightKeyColumnsNodes), [this](const ui32 idx) {
  4929. return NewDataLiteral(idx);
  4930. });
  4931. TRuntimeNode::TList rightKeyDropsNodes;
  4932. rightKeyDropsNodes.reserve(leftKeyDrops.size());
  4933. std::transform(rightKeyDrops.cbegin(), rightKeyDrops.cend(),
  4934. std::back_inserter(rightKeyDropsNodes), [this](const ui32 idx) {
  4935. return NewDataLiteral(idx);
  4936. });
  4937. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4938. callableBuilder.Add(leftStream);
  4939. callableBuilder.Add(rightStream);
  4940. callableBuilder.Add(NewDataLiteral((ui32)joinKind));
  4941. callableBuilder.Add(NewTuple(leftKeyColumnsNodes));
  4942. callableBuilder.Add(NewTuple(leftKeyDropsNodes));
  4943. callableBuilder.Add(NewTuple(rightKeyColumnsNodes));
  4944. callableBuilder.Add(NewTuple(rightKeyDropsNodes));
  4945. callableBuilder.Add(NewDataLiteral((bool)rightAny));
  4946. return TRuntimeNode(callableBuilder.Build(), false);
  4947. }
  4948. namespace {
  4949. using namespace NYql::NMatchRecognize;
  4950. TRuntimeNode PatternToRuntimeNode(const TRowPattern& pattern, const TProgramBuilder& programBuilder) {
  4951. const auto& env = programBuilder.GetTypeEnvironment();
  4952. TTupleLiteralBuilder patternBuilder(env);
  4953. for (const auto& term: pattern) {
  4954. TTupleLiteralBuilder termBuilder(env);
  4955. for (const auto& factor: term) {
  4956. TTupleLiteralBuilder factorBuilder(env);
  4957. factorBuilder.Add(std::visit(TOverloaded {
  4958. [&](const TString& s) {
  4959. return programBuilder.NewDataLiteral<NUdf::EDataSlot::String>(s);
  4960. },
  4961. [&](const TRowPattern& pattern) {
  4962. return PatternToRuntimeNode(pattern, programBuilder);
  4963. },
  4964. }, factor.Primary));
  4965. factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMin));
  4966. factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMax));
  4967. factorBuilder.Add(programBuilder.NewDataLiteral(factor.Greedy));
  4968. factorBuilder.Add(programBuilder.NewDataLiteral(factor.Output));
  4969. factorBuilder.Add(programBuilder.NewDataLiteral(factor.Unused));
  4970. termBuilder.Add({factorBuilder.Build(), true});
  4971. }
  4972. patternBuilder.Add({termBuilder.Build(), true});
  4973. }
  4974. return {patternBuilder.Build(), true};
  4975. };
  4976. } //namespace
  4977. TRuntimeNode TProgramBuilder::MatchRecognizeCore(
  4978. TRuntimeNode inputStream,
  4979. const TUnaryLambda& getPartitionKeySelectorNode,
  4980. const TArrayRef<TStringBuf>& partitionColumnNames,
  4981. const TVector<TStringBuf>& measureColumnNames,
  4982. const TVector<TBinaryLambda>& getMeasures,
  4983. const NYql::NMatchRecognize::TRowPattern& pattern,
  4984. const TVector<TStringBuf>& defineVarNames,
  4985. const TVector<TTernaryLambda>& getDefines,
  4986. bool streamingMode,
  4987. const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo,
  4988. NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch
  4989. ) {
  4990. MKQL_ENSURE(RuntimeVersion >= 42, "MatchRecognize is not supported in runtime version " << RuntimeVersion);
  4991. const auto inputRowType = AS_TYPE(TStructType, AS_TYPE(TFlowType, inputStream.GetStaticType())->GetItemType());
  4992. const auto inputRowArg = Arg(inputRowType);
  4993. const auto partitionKeySelectorNode = getPartitionKeySelectorNode(inputRowArg);
  4994. const auto partitionColumnTypes = AS_TYPE(TTupleType, partitionKeySelectorNode.GetStaticType())->GetElements();
  4995. const auto rangeList = NewListType(NewStructType({
  4996. {"From", NewDataType(NUdf::EDataSlot::Uint64)},
  4997. {"To", NewDataType(NUdf::EDataSlot::Uint64)}
  4998. }));
  4999. TStructTypeBuilder matchedVarsTypeBuilder(Env);
  5000. for (const auto& var: GetPatternVars(pattern)) {
  5001. matchedVarsTypeBuilder.Add(var, rangeList);
  5002. }
  5003. const auto matchedVarsType = matchedVarsTypeBuilder.Build();
  5004. TRuntimeNode matchedVarsArg = Arg(matchedVarsType);
  5005. //---These vars may be empty in case of no measures
  5006. TRuntimeNode measureInputDataArg;
  5007. std::vector<TRuntimeNode> specialColumnIndexesInMeasureInputDataRow;
  5008. TVector<TRuntimeNode> measures;
  5009. //---
  5010. if (getMeasures.empty()) {
  5011. measureInputDataArg = Arg(Env.GetTypeOfVoidLazy());
  5012. } else {
  5013. measures.reserve(getMeasures.size());
  5014. specialColumnIndexesInMeasureInputDataRow.resize(static_cast<size_t>(NYql::NMatchRecognize::EMeasureInputDataSpecialColumns::Last));
  5015. TStructTypeBuilder measureInputDataRowTypeBuilder(Env);
  5016. for (ui32 i = 0; i < inputRowType->GetMembersCount(); ++i) {
  5017. measureInputDataRowTypeBuilder.Add(inputRowType->GetMemberName(i), inputRowType->GetMemberType(i));
  5018. }
  5019. measureInputDataRowTypeBuilder.Add(
  5020. MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::Classifier),
  5021. NewDataType(NUdf::EDataSlot::Utf8)
  5022. );
  5023. measureInputDataRowTypeBuilder.Add(
  5024. MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::MatchNumber),
  5025. NewDataType(NUdf::EDataSlot::Uint64)
  5026. );
  5027. const auto measureInputDataRowType = measureInputDataRowTypeBuilder.Build();
  5028. for (ui32 i = 0; i < measureInputDataRowType->GetMembersCount(); ++i) {
  5029. //assume a few, if grows, it's better to use a lookup table here
  5030. static_assert(static_cast<size_t>(EMeasureInputDataSpecialColumns::Last) < 5);
  5031. for (size_t j = 0; j != static_cast<size_t>(EMeasureInputDataSpecialColumns::Last); ++j) {
  5032. if (measureInputDataRowType->GetMemberName(i) ==
  5033. NYql::NMatchRecognize::MeasureInputDataSpecialColumnName(static_cast<EMeasureInputDataSpecialColumns>(j)))
  5034. specialColumnIndexesInMeasureInputDataRow[j] = NewDataLiteral(i);
  5035. }
  5036. }
  5037. measureInputDataArg = Arg(NewListType(measureInputDataRowType));
  5038. for (size_t i = 0; i != getMeasures.size(); ++i) {
  5039. measures.push_back(getMeasures[i](measureInputDataArg, matchedVarsArg));
  5040. }
  5041. }
  5042. TStructTypeBuilder outputRowTypeBuilder(Env);
  5043. THashMap<TStringBuf, size_t> partitionColumnLookup;
  5044. THashMap<TStringBuf, size_t> measureColumnLookup;
  5045. THashMap<TStringBuf, size_t> otherColumnLookup;
  5046. for (size_t i = 0; i < measureColumnNames.size(); ++i) {
  5047. const auto name = measureColumnNames[i];
  5048. measureColumnLookup.emplace(name, i);
  5049. outputRowTypeBuilder.Add(name, measures[i].GetStaticType());
  5050. }
  5051. switch (rowsPerMatch) {
  5052. case NYql::NMatchRecognize::ERowsPerMatch::OneRow:
  5053. for (size_t i = 0; i < partitionColumnNames.size(); ++i) {
  5054. const auto name = partitionColumnNames[i];
  5055. partitionColumnLookup.emplace(name, i);
  5056. outputRowTypeBuilder.Add(name, partitionColumnTypes[i]);
  5057. }
  5058. break;
  5059. case NYql::NMatchRecognize::ERowsPerMatch::AllRows:
  5060. for (size_t i = 0; i < inputRowType->GetMembersCount(); ++i) {
  5061. const auto name = inputRowType->GetMemberName(i);
  5062. otherColumnLookup.emplace(name, i);
  5063. outputRowTypeBuilder.Add(name, inputRowType->GetMemberType(i));
  5064. }
  5065. break;
  5066. }
  5067. auto outputRowType = outputRowTypeBuilder.Build();
  5068. std::vector<TRuntimeNode> partitionColumnIndexes(partitionColumnLookup.size());
  5069. std::vector<TRuntimeNode> measureColumnIndexes(measureColumnLookup.size());
  5070. TVector<TRuntimeNode> outputColumnOrder(NDetail::TReserveTag{outputRowType->GetMembersCount()});
  5071. for (ui32 i = 0; i < outputRowType->GetMembersCount(); ++i) {
  5072. const auto name = outputRowType->GetMemberName(i);
  5073. if (auto iter = partitionColumnLookup.find(name);
  5074. iter != partitionColumnLookup.end()) {
  5075. partitionColumnIndexes[iter->second] = NewDataLiteral(i);
  5076. outputColumnOrder.push_back(NewStruct({
  5077. std::pair{"Index", NewDataLiteral(iter->second)},
  5078. std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::PartitionKey))},
  5079. }));
  5080. } else if (auto iter = measureColumnLookup.find(name);
  5081. iter != measureColumnLookup.end()) {
  5082. measureColumnIndexes[iter->second] = NewDataLiteral(i);
  5083. outputColumnOrder.push_back(NewStruct({
  5084. std::pair{"Index", NewDataLiteral(iter->second)},
  5085. std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Measure))},
  5086. }));
  5087. } else if (auto iter = otherColumnLookup.find(name);
  5088. iter != otherColumnLookup.end()) {
  5089. outputColumnOrder.push_back(NewStruct({
  5090. std::pair{"Index", NewDataLiteral(iter->second)},
  5091. std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Other))},
  5092. }));
  5093. }
  5094. }
  5095. const auto outputType = NewFlowType(outputRowType);
  5096. THashMap<TStringBuf, size_t> patternVarLookup;
  5097. for (ui32 i = 0; i < matchedVarsType->GetMembersCount(); ++i) {
  5098. patternVarLookup[matchedVarsType->GetMemberName(i)] = i;
  5099. }
  5100. THashMap<TStringBuf, size_t> defineLookup;
  5101. for (size_t i = 0; i < defineVarNames.size(); ++i) {
  5102. const auto name = defineVarNames[i];
  5103. defineLookup[name] = i;
  5104. }
  5105. TVector<TRuntimeNode> defineNames(patternVarLookup.size());
  5106. TVector<TRuntimeNode> defineNodes(patternVarLookup.size());
  5107. const auto inputDataArg = Arg(NewListType(inputRowType));
  5108. const auto currentRowIndexArg = Arg(NewDataType(NUdf::EDataSlot::Uint64));
  5109. for (const auto& [v, i]: patternVarLookup) {
  5110. defineNames[i] = NewDataLiteral<NUdf::EDataSlot::String>(v);
  5111. if (auto iter = defineLookup.find(v);
  5112. iter != defineLookup.end()) {
  5113. defineNodes[i] = getDefines[iter->second](inputDataArg, matchedVarsArg, currentRowIndexArg);
  5114. } else if ("$" == v || "^" == v) {
  5115. //DO nothing, //will be handled in a specific way
  5116. } else { // a var without a predicate matches any row
  5117. defineNodes[i] = NewDataLiteral(true);
  5118. }
  5119. }
  5120. TCallableBuilder callableBuilder(GetTypeEnvironment(), "MatchRecognizeCore", outputType);
  5121. const auto indexType = NewDataType(NUdf::EDataSlot::Uint32);
  5122. const auto outputColumnEntryType = NewStructType({
  5123. {"Index", NewDataType(NUdf::EDataSlot::Uint64)},
  5124. {"SourceType", NewDataType(NUdf::EDataSlot::Int32)},
  5125. });
  5126. callableBuilder.Add(inputStream);
  5127. callableBuilder.Add(inputRowArg);
  5128. callableBuilder.Add(partitionKeySelectorNode);
  5129. callableBuilder.Add(NewList(indexType, partitionColumnIndexes));
  5130. callableBuilder.Add(measureInputDataArg);
  5131. callableBuilder.Add(NewList(indexType, specialColumnIndexesInMeasureInputDataRow));
  5132. callableBuilder.Add(NewDataLiteral(inputRowType->GetMembersCount()));
  5133. callableBuilder.Add(matchedVarsArg);
  5134. callableBuilder.Add(NewList(indexType, measureColumnIndexes));
  5135. for (const auto& m: measures) {
  5136. callableBuilder.Add(m);
  5137. }
  5138. callableBuilder.Add(PatternToRuntimeNode(pattern, *this));
  5139. callableBuilder.Add(currentRowIndexArg);
  5140. callableBuilder.Add(inputDataArg);
  5141. callableBuilder.Add(NewList(NewDataType(NUdf::EDataSlot::String), defineNames));
  5142. for (const auto& d: defineNodes) {
  5143. callableBuilder.Add(d);
  5144. }
  5145. callableBuilder.Add(NewDataLiteral(streamingMode));
  5146. if constexpr (RuntimeVersion >= 52U) {
  5147. callableBuilder.Add(NewDataLiteral(static_cast<i32>(skipTo.To)));
  5148. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(skipTo.Var));
  5149. }
  5150. if constexpr (RuntimeVersion >= 54U) {
  5151. callableBuilder.Add(NewDataLiteral(static_cast<i32>(rowsPerMatch)));
  5152. callableBuilder.Add(NewList(outputColumnEntryType, outputColumnOrder));
  5153. }
  5154. return TRuntimeNode(callableBuilder.Build(), false);
  5155. }
  5156. TRuntimeNode TProgramBuilder::TimeOrderRecover(
  5157. TRuntimeNode inputStream,
  5158. const TUnaryLambda& getTimeExtractor,
  5159. TRuntimeNode delay,
  5160. TRuntimeNode ahead,
  5161. TRuntimeNode rowLimit
  5162. )
  5163. {
  5164. MKQL_ENSURE(RuntimeVersion >= 44, "TimeOrderRecover is not supported in runtime version " << RuntimeVersion);
  5165. auto& inputRowType = *static_cast<TStructType*>(AS_TYPE(TStructType, AS_TYPE(TFlowType, inputStream.GetStaticType())->GetItemType()));
  5166. const auto inputRowArg = Arg(&inputRowType);
  5167. TStructTypeBuilder outputRowTypeBuilder(Env);
  5168. outputRowTypeBuilder.Reserve(inputRowType.GetMembersCount() + 1);
  5169. const ui32 inputRowColumnCount = inputRowType.GetMembersCount();
  5170. for (ui32 i = 0; i != inputRowColumnCount; ++i) {
  5171. outputRowTypeBuilder.Add(inputRowType.GetMemberName(i), inputRowType.GetMemberType(i));
  5172. }
  5173. using NYql::NTimeOrderRecover::OUT_OF_ORDER_MARKER;
  5174. outputRowTypeBuilder.Add(OUT_OF_ORDER_MARKER, TDataType::Create(NUdf::TDataType<bool>::Id, Env));
  5175. const auto outputRowType = outputRowTypeBuilder.Build();
  5176. const auto outOfOrderColumnIndex = outputRowType->GetMemberIndex(OUT_OF_ORDER_MARKER);
  5177. TCallableBuilder callableBuilder(GetTypeEnvironment(), "TimeOrderRecover", TFlowType::Create(outputRowType, Env));
  5178. callableBuilder.Add(inputStream);
  5179. callableBuilder.Add(inputRowArg);
  5180. callableBuilder.Add(getTimeExtractor(inputRowArg));
  5181. callableBuilder.Add(NewDataLiteral(inputRowColumnCount));
  5182. callableBuilder.Add(NewDataLiteral(outOfOrderColumnIndex));
  5183. callableBuilder.Add(delay),
  5184. callableBuilder.Add(ahead),
  5185. callableBuilder.Add(rowLimit);
  5186. return TRuntimeNode(callableBuilder.Build(), false);
  5187. }
  5188. bool CanExportType(TType* type, const TTypeEnvironment& env) {
  5189. if (type->GetKind() == TType::EKind::Type) {
  5190. return false; // Type of Type
  5191. }
  5192. TExploringNodeVisitor explorer;
  5193. explorer.Walk(type, env);
  5194. bool canExport = true;
  5195. for (auto& node : explorer.GetNodes()) {
  5196. switch (static_cast<TType*>(node)->GetKind()) {
  5197. case TType::EKind::Void:
  5198. node->SetCookie(1);
  5199. break;
  5200. case TType::EKind::Data:
  5201. node->SetCookie(1);
  5202. break;
  5203. case TType::EKind::Pg:
  5204. node->SetCookie(1);
  5205. break;
  5206. case TType::EKind::Optional: {
  5207. auto optionalType = static_cast<TOptionalType*>(node);
  5208. if (!optionalType->GetItemType()->GetCookie()) {
  5209. canExport = false;
  5210. } else {
  5211. node->SetCookie(1);
  5212. }
  5213. break;
  5214. }
  5215. case TType::EKind::List: {
  5216. auto listType = static_cast<TListType*>(node);
  5217. if (!listType->GetItemType()->GetCookie()) {
  5218. canExport = false;
  5219. } else {
  5220. node->SetCookie(1);
  5221. }
  5222. break;
  5223. }
  5224. case TType::EKind::Struct: {
  5225. auto structType = static_cast<TStructType*>(node);
  5226. for (ui32 index = 0; index < structType->GetMembersCount(); ++index) {
  5227. if (!structType->GetMemberType(index)->GetCookie()) {
  5228. canExport = false;
  5229. break;
  5230. }
  5231. }
  5232. if (canExport) {
  5233. node->SetCookie(1);
  5234. }
  5235. break;
  5236. }
  5237. case TType::EKind::Tuple: {
  5238. auto tupleType = static_cast<TTupleType*>(node);
  5239. for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) {
  5240. if (!tupleType->GetElementType(index)->GetCookie()) {
  5241. canExport = false;
  5242. break;
  5243. }
  5244. }
  5245. if (canExport) {
  5246. node->SetCookie(1);
  5247. }
  5248. break;
  5249. }
  5250. case TType::EKind::Dict: {
  5251. auto dictType = static_cast<TDictType*>(node);
  5252. if (!dictType->GetKeyType()->GetCookie() || !dictType->GetPayloadType()->GetCookie()) {
  5253. canExport = false;
  5254. } else {
  5255. node->SetCookie(1);
  5256. }
  5257. break;
  5258. }
  5259. case TType::EKind::Variant: {
  5260. auto variantType = static_cast<TVariantType*>(node);
  5261. TType* innerType = variantType->GetUnderlyingType();
  5262. if (innerType->IsStruct()) {
  5263. auto structType = static_cast<TStructType*>(innerType);
  5264. for (ui32 index = 0; index < structType->GetMembersCount(); ++index) {
  5265. if (!structType->GetMemberType(index)->GetCookie()) {
  5266. canExport = false;
  5267. break;
  5268. }
  5269. }
  5270. }
  5271. if (innerType->IsTuple()) {
  5272. auto tupleType = static_cast<TTupleType*>(innerType);
  5273. for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) {
  5274. if (!tupleType->GetElementType(index)->GetCookie()) {
  5275. canExport = false;
  5276. break;
  5277. }
  5278. }
  5279. }
  5280. if (canExport) {
  5281. node->SetCookie(1);
  5282. }
  5283. break;
  5284. }
  5285. case TType::EKind::Type:
  5286. break;
  5287. default:
  5288. canExport = false;
  5289. }
  5290. if (!canExport) {
  5291. break;
  5292. }
  5293. }
  5294. for (auto& node : explorer.GetNodes()) {
  5295. node->SetCookie(0);
  5296. }
  5297. return canExport;
  5298. }
  5299. }
  5300. }