mkql_program_builder.cpp 276 KB


  1. #include "mkql_program_builder.h"
  2. #include "mkql_node_visitor.h"
  3. #include "mkql_node_cast.h"
  4. #include "mkql_runtime_version.h"
  5. #include "yql/essentials/minikql/mkql_node_printer.h"
  6. #include "yql/essentials/minikql/mkql_function_registry.h"
  7. #include "yql/essentials/minikql/mkql_utils.h"
  8. #include "yql/essentials/minikql/mkql_type_builder.h"
  9. #include "yql/essentials/core/sql_types/match_recognize.h"
  10. #include "yql/essentials/core/sql_types/time_order_recover.h"
  11. #include <yql/essentials/parser/pg_catalog/catalog.h>
  12. #include <util/generic/overloaded.h>
  13. #include <util/string/cast.h>
  14. #include <util/string/printf.h>
  15. #include <array>
  16. using namespace std::string_view_literals;
  17. namespace NKikimr {
  18. namespace NMiniKQL {
  19. namespace {
  20. struct TDataFunctionFlags {
  21. enum {
  22. HasBooleanResult = 0x01,
  23. RequiresBooleanArgs = 0x02,
  24. HasOptionalResult = 0x04,
  25. AllowOptionalArgs = 0x08,
  26. HasUi32Result = 0x10,
  27. RequiresCompare = 0x20,
  28. HasStringResult = 0x40,
  29. RequiresStringArgs = 0x80,
  30. RequiresHash = 0x100,
  31. RequiresEquals = 0x200,
  32. AllowNull = 0x400,
  33. CommonOptionalResult = 0x800,
  34. SupportsTuple = 0x1000,
  35. SameOptionalArgs = 0x2000,
  36. Default = 0x00
  37. };
  38. };
  39. #define MKQL_BAD_TYPE_VISIT(NodeType, ScriptName) \
  40. void Visit(NodeType& node) override { \
  41. Y_UNUSED(node); \
  42. MKQL_ENSURE(false, "Can't convert " #NodeType " to " ScriptName " object"); \
  43. }
  44. class TPythonTypeChecker : public TExploringNodeVisitor {
  45. using TExploringNodeVisitor::Visit;
  46. MKQL_BAD_TYPE_VISIT(TAnyType, "Python");
  47. };
  48. class TLuaTypeChecker : public TExploringNodeVisitor {
  49. using TExploringNodeVisitor::Visit;
  50. MKQL_BAD_TYPE_VISIT(TVoidType, "Lua");
  51. MKQL_BAD_TYPE_VISIT(TAnyType, "Lua");
  52. MKQL_BAD_TYPE_VISIT(TVariantType, "Lua");
  53. };
  54. class TJavascriptTypeChecker : public TExploringNodeVisitor {
  55. using TExploringNodeVisitor::Visit;
  56. MKQL_BAD_TYPE_VISIT(TAnyType, "Javascript");
  57. };
  58. #undef MKQL_BAD_TYPE_VISIT
  59. void EnsureScriptSpecificTypes(
  60. EScriptType scriptType,
  61. TCallableType* funcType,
  62. const TTypeEnvironment& env)
  63. {
  64. switch (scriptType) {
  65. case EScriptType::Lua:
  66. return TLuaTypeChecker().Walk(funcType, env);
  67. case EScriptType::Python:
  68. case EScriptType::Python2:
  69. case EScriptType::Python3:
  70. case EScriptType::ArcPython:
  71. case EScriptType::ArcPython2:
  72. case EScriptType::ArcPython3:
  73. case EScriptType::CustomPython:
  74. case EScriptType::CustomPython2:
  75. case EScriptType::CustomPython3:
  76. case EScriptType::SystemPython2:
  77. case EScriptType::SystemPython3:
  78. case EScriptType::SystemPython3_8:
  79. case EScriptType::SystemPython3_9:
  80. case EScriptType::SystemPython3_10:
  81. case EScriptType::SystemPython3_11:
  82. case EScriptType::SystemPython3_12:
  83. case EScriptType::SystemPython3_13:
  84. return TPythonTypeChecker().Walk(funcType, env);
  85. case EScriptType::Javascript:
  86. return TJavascriptTypeChecker().Walk(funcType, env);
  87. default:
  88. MKQL_ENSURE(false, "Unknown script type " << static_cast<ui32>(scriptType));
  89. }
  90. }
  91. ui32 GetNumericSchemeTypeLevel(NUdf::TDataTypeId typeId) {
  92. switch (typeId) {
  93. case NUdf::TDataType<ui8>::Id:
  94. return 0;
  95. case NUdf::TDataType<i8>::Id:
  96. return 1;
  97. case NUdf::TDataType<ui16>::Id:
  98. return 2;
  99. case NUdf::TDataType<i16>::Id:
  100. return 3;
  101. case NUdf::TDataType<ui32>::Id:
  102. return 4;
  103. case NUdf::TDataType<i32>::Id:
  104. return 5;
  105. case NUdf::TDataType<ui64>::Id:
  106. return 6;
  107. case NUdf::TDataType<i64>::Id:
  108. return 7;
  109. case NUdf::TDataType<float>::Id:
  110. return 8;
  111. case NUdf::TDataType<double>::Id:
  112. return 9;
  113. default:
  114. ythrow yexception() << "Unknown numeric type: " << typeId;
  115. }
  116. }
  117. NUdf::TDataTypeId GetNumericSchemeTypeByLevel(ui32 level) {
  118. switch (level) {
  119. case 0:
  120. return NUdf::TDataType<ui8>::Id;
  121. case 1:
  122. return NUdf::TDataType<i8>::Id;
  123. case 2:
  124. return NUdf::TDataType<ui16>::Id;
  125. case 3:
  126. return NUdf::TDataType<i16>::Id;
  127. case 4:
  128. return NUdf::TDataType<ui32>::Id;
  129. case 5:
  130. return NUdf::TDataType<i32>::Id;
  131. case 6:
  132. return NUdf::TDataType<ui64>::Id;
  133. case 7:
  134. return NUdf::TDataType<i64>::Id;
  135. case 8:
  136. return NUdf::TDataType<float>::Id;
  137. case 9:
  138. return NUdf::TDataType<double>::Id;
  139. default:
  140. ythrow yexception() << "Unknown numeric level: " << level;
  141. }
  142. }
  143. NUdf::TDataTypeId MakeNumericDataSuperType(NUdf::TDataTypeId typeId1, NUdf::TDataTypeId typeId2) {
  144. return typeId1 == typeId2 ? typeId1 :
  145. GetNumericSchemeTypeByLevel(std::max(GetNumericSchemeTypeLevel(typeId1), GetNumericSchemeTypeLevel(typeId2)));
  146. }
  147. template<bool IsFilter>
  148. bool CollectOptionalElements(const TType* type, std::vector<std::string_view>& test, std::vector<std::pair<std::string_view, TType*>>& output) {
  149. const auto structType = AS_TYPE(TStructType, type);
  150. test.reserve(structType->GetMembersCount());
  151. output.reserve(structType->GetMembersCount());
  152. bool multiOptional = false;
  153. for (ui32 i = 0; i < structType->GetMembersCount(); ++i) {
  154. output.emplace_back(structType->GetMemberName(i), structType->GetMemberType(i));
  155. auto& memberType = output.back().second;
  156. if (memberType->IsOptional()) {
  157. test.emplace_back(output.back().first);
  158. if constexpr (IsFilter) {
  159. memberType = AS_TYPE(TOptionalType, memberType)->GetItemType();
  160. multiOptional = multiOptional || memberType->IsOptional();
  161. }
  162. }
  163. }
  164. return multiOptional;
  165. }
  166. template<bool IsFilter>
  167. bool CollectOptionalElements(const TType* type, std::vector<ui32>& test, std::vector<TType*>& output) {
  168. const auto typleType = AS_TYPE(TTupleType, type);
  169. test.reserve(typleType->GetElementsCount());
  170. output.reserve(typleType->GetElementsCount());
  171. bool multiOptional = false;
  172. for (ui32 i = 0; i < typleType->GetElementsCount(); ++i) {
  173. output.emplace_back(typleType->GetElementType(i));
  174. auto& elementType = output.back();
  175. if (elementType->IsOptional()) {
  176. test.emplace_back(i);
  177. if constexpr (IsFilter) {
  178. elementType = AS_TYPE(TOptionalType, elementType)->GetItemType();
  179. multiOptional = multiOptional || elementType->IsOptional();
  180. }
  181. }
  182. }
  183. return multiOptional;
  184. }
  185. bool ReduceOptionalElements(const TType* type, const TArrayRef<const std::string_view>& test, std::vector<std::pair<std::string_view, TType*>>& output) {
  186. const auto structType = AS_TYPE(TStructType, type);
  187. output.reserve(structType->GetMembersCount());
  188. for (ui32 i = 0U; i < structType->GetMembersCount(); ++i) {
  189. output.emplace_back(structType->GetMemberName(i), structType->GetMemberType(i));
  190. }
  191. bool multiOptional = false;
  192. for (const auto& member : test) {
  193. auto& memberType = output[structType->GetMemberIndex(member)].second;
  194. MKQL_ENSURE(memberType->IsOptional(), "Required optional column type");
  195. memberType = AS_TYPE(TOptionalType, memberType)->GetItemType();
  196. multiOptional = multiOptional || memberType->IsOptional();
  197. }
  198. return multiOptional;
  199. }
  200. bool ReduceOptionalElements(const TType* type, const TArrayRef<const ui32>& test, std::vector<TType*>& output) {
  201. const auto typleType = AS_TYPE(TTupleType, type);
  202. output.reserve(typleType->GetElementsCount());
  203. for (ui32 i = 0U; i < typleType->GetElementsCount(); ++i) {
  204. output.emplace_back(typleType->GetElementType(i));
  205. }
  206. bool multiOptional = false;
  207. for (const auto& member : test) {
  208. auto& memberType = output[member];
  209. MKQL_ENSURE(memberType->IsOptional(), "Required optional column type");
  210. memberType = AS_TYPE(TOptionalType, memberType)->GetItemType();
  211. multiOptional = multiOptional || memberType->IsOptional();
  212. }
  213. return multiOptional;
  214. }
  215. static std::vector<TType*> ValidateBlockItems(const TArrayRef<TType* const>& wideComponents, bool unwrap) {
  216. MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column");
  217. std::vector<TType*> items;
  218. items.reserve(wideComponents.size());
  219. // XXX: Declare these variables outside the loop body to use for the last
  220. // item (i.e. block length column) in the assertions below.
  221. bool isScalar;
  222. TType* itemType;
  223. for (const auto& wideComponent : wideComponents) {
  224. auto blockType = AS_TYPE(TBlockType, wideComponent);
  225. isScalar = blockType->GetShape() == TBlockType::EShape::Scalar;
  226. itemType = blockType->GetItemType();
  227. items.push_back(unwrap ? itemType : blockType);
  228. }
  229. MKQL_ENSURE(isScalar, "Last column should be scalar");
  230. MKQL_ENSURE(AS_TYPE(TDataType, itemType)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64");
  231. return items;
  232. }
  233. } // namespace
  234. std::string_view ScriptTypeAsStr(EScriptType type) {
  235. switch (type) {
  236. #define MKQL_SCRIPT_TYPE_CASE(name, value, ...) \
  237. case EScriptType::name: return std::string_view(#name);
  238. MKQL_SCRIPT_TYPES(MKQL_SCRIPT_TYPE_CASE)
  239. #undef MKQL_SCRIPT_TYPE_CASE
  240. } // switch
  241. return std::string_view("Unknown");
  242. }
  243. EScriptType ScriptTypeFromStr(std::string_view str) {
  244. TString lowerStr = TString(str);
  245. lowerStr.to_lower();
  246. #define MKQL_SCRIPT_TYPE_FROM_STR(name, value, lowerName, allowSuffix) \
  247. if ((allowSuffix && lowerStr.StartsWith(#lowerName)) || lowerStr == #lowerName) return EScriptType::name;
  248. MKQL_SCRIPT_TYPES(MKQL_SCRIPT_TYPE_FROM_STR)
  249. #undef MKQL_SCRIPT_TYPE_FROM_STR
  250. return EScriptType::Unknown;
  251. }
  252. bool IsCustomPython(EScriptType type) {
  253. return type == EScriptType::CustomPython ||
  254. type == EScriptType::CustomPython2 ||
  255. type == EScriptType::CustomPython3;
  256. }
  257. bool IsSystemPython(EScriptType type) {
  258. return type == EScriptType::SystemPython2
  259. || type == EScriptType::SystemPython3
  260. || type == EScriptType::SystemPython3_8
  261. || type == EScriptType::SystemPython3_9
  262. || type == EScriptType::SystemPython3_10
  263. || type == EScriptType::SystemPython3_11
  264. || type == EScriptType::SystemPython3_12
  265. || type == EScriptType::SystemPython3_13
  266. || type == EScriptType::Python
  267. || type == EScriptType::Python2;
  268. }
  269. EScriptType CanonizeScriptType(EScriptType type) {
  270. if (type == EScriptType::Python) {
  271. return EScriptType::Python2;
  272. }
  273. if (type == EScriptType::ArcPython) {
  274. return EScriptType::ArcPython2;
  275. }
  276. return type;
  277. }
  278. void EnsureDataOrOptionalOfData(TRuntimeNode node) {
  279. MKQL_ENSURE(node.GetStaticType()->IsData() ||
  280. node.GetStaticType()->IsOptional() && AS_TYPE(TOptionalType, node.GetStaticType())
  281. ->GetItemType()->IsData(), "Expected data or optional of data");
  282. }
  283. std::vector<TType*> ValidateBlockStreamType(const TType* streamType, bool unwrap) {
  284. const auto wideComponents = GetWideComponents(AS_TYPE(TStreamType, streamType));
  285. return ValidateBlockItems(wideComponents, unwrap);
  286. }
  287. std::vector<TType*> ValidateBlockFlowType(const TType* flowType, bool unwrap) {
  288. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flowType));
  289. return ValidateBlockItems(wideComponents, unwrap);
  290. }
  291. TProgramBuilder::TProgramBuilder(const TTypeEnvironment& env, const IFunctionRegistry& functionRegistry, bool voidWithEffects)
  292. : TTypeBuilder(env)
  293. , FunctionRegistry(functionRegistry)
  294. , VoidWithEffects(voidWithEffects)
  295. {}
  296. const TTypeEnvironment& TProgramBuilder::GetTypeEnvironment() const {
  297. return Env;
  298. }
  299. const IFunctionRegistry& TProgramBuilder::GetFunctionRegistry() const {
  300. return FunctionRegistry;
  301. }
  302. TType* TProgramBuilder::ChooseCommonType(TType* type1, TType* type2) {
  303. bool isOptional1, isOptional2;
  304. const auto data1 = UnpackOptionalData(type1, isOptional1);
  305. const auto data2 = UnpackOptionalData(type2, isOptional2);
  306. if (data1->IsSameType(*data2)) {
  307. return isOptional1 ? type1 : type2;
  308. }
  309. MKQL_ENSURE(!
  310. ((NUdf::GetDataTypeInfo(*data1->GetDataSlot()).Features | NUdf::GetDataTypeInfo(*data2->GetDataSlot()).Features) & (NUdf::EDataTypeFeatures::DateType | NUdf::EDataTypeFeatures::TzDateType)),
  311. "Not same date types: " << *type1 << " and " << *type2
  312. );
  313. const auto data = NewDataType(MakeNumericDataSuperType(data1->GetSchemeType(), data2->GetSchemeType()));
  314. return isOptional1 || isOptional2 ? NewOptionalType(data) : data;
  315. }
  316. TType* TProgramBuilder::BuildArithmeticCommonType(TType* type1, TType* type2) {
  317. bool isOptional1, isOptional2;
  318. const auto data1 = UnpackOptionalData(type1, isOptional1);
  319. const auto data2 = UnpackOptionalData(type2, isOptional2);
  320. const auto features1 = NUdf::GetDataTypeInfo(*data1->GetDataSlot()).Features;
  321. const auto features2 = NUdf::GetDataTypeInfo(*data2->GetDataSlot()).Features;
  322. const bool isOptional = isOptional1 || isOptional2;
  323. if (features1 & features2 & NUdf::EDataTypeFeatures::TimeIntervalType) {
  324. return NewOptionalType(features1 & NUdf::EDataTypeFeatures::BigDateType ? data1 : data2);
  325. } else if (features1 & NUdf::EDataTypeFeatures::TimeIntervalType) {
  326. return NewOptionalType(features2 & NUdf::EDataTypeFeatures::IntegralType ? data1 : data2);
  327. } else if (features2 & NUdf::EDataTypeFeatures::TimeIntervalType) {
  328. return NewOptionalType(features1 & NUdf::EDataTypeFeatures::IntegralType ? data2 : data1);
  329. } else if (
  330. features1 & (NUdf::EDataTypeFeatures::DateType | NUdf::EDataTypeFeatures::TzDateType) &&
  331. features2 & (NUdf::EDataTypeFeatures::DateType | NUdf::EDataTypeFeatures::TzDateType)
  332. ) {
  333. const auto used = ((features1 | features2) & NUdf::EDataTypeFeatures::BigDateType)
  334. ? NewDataType(NUdf::EDataSlot::Interval64)
  335. : NewDataType(NUdf::EDataSlot::Interval);
  336. return isOptional ? NewOptionalType(used) : used;
  337. } else if (data1->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  338. MKQL_ENSURE(data1->IsSameType(*data2), "Must be same type.");
  339. return isOptional ? NewOptionalType(data1) : data2;
  340. }
  341. const auto data = NewDataType(MakeNumericDataSuperType(data1->GetSchemeType(), data2->GetSchemeType()));
  342. return isOptional ? NewOptionalType(data) : data;
  343. }
  344. TRuntimeNode TProgramBuilder::Arg(TType* type) const {
  345. TCallableBuilder builder(Env, __func__, type, true);
  346. return TRuntimeNode(builder.Build(), false);
  347. }
  348. TRuntimeNode TProgramBuilder::WideFlowArg(TType* type) const {
  349. if constexpr (RuntimeVersion < 18U) {
  350. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  351. }
  352. TCallableBuilder builder(Env, __func__, type, true);
  353. return TRuntimeNode(builder.Build(), false);
  354. }
  355. TRuntimeNode TProgramBuilder::Member(TRuntimeNode structObj, const std::string_view& memberName) {
  356. bool isOptional;
  357. const auto type = AS_TYPE(TStructType, UnpackOptional(structObj.GetStaticType(), isOptional));
  358. const auto memberIndex = type->GetMemberIndex(memberName);
  359. auto memberType = type->GetMemberType(memberIndex);
  360. if (isOptional && !memberType->IsOptional() && !memberType->IsNull() && !memberType->IsPg()) {
  361. memberType = NewOptionalType(memberType);
  362. }
  363. TCallableBuilder callableBuilder(Env, __func__, memberType);
  364. callableBuilder.Add(structObj);
  365. callableBuilder.Add(NewDataLiteral<ui32>(memberIndex));
  366. return TRuntimeNode(callableBuilder.Build(), false);
  367. }
  368. TRuntimeNode TProgramBuilder::Element(TRuntimeNode structObj, const std::string_view& memberName) {
  369. return Member(structObj, memberName);
  370. }
  371. TRuntimeNode TProgramBuilder::AddMember(TRuntimeNode structObj, const std::string_view& memberName, TRuntimeNode memberValue) {
  372. auto oldType = structObj.GetStaticType();
  373. MKQL_ENSURE(oldType->IsStruct(), "Expected struct");
  374. const auto& oldTypeDetailed = static_cast<const TStructType&>(*oldType);
  375. TStructTypeBuilder newTypeBuilder(Env);
  376. newTypeBuilder.Reserve(oldTypeDetailed.GetMembersCount() + 1);
  377. for (ui32 i = 0, e = oldTypeDetailed.GetMembersCount(); i < e; ++i) {
  378. newTypeBuilder.Add(oldTypeDetailed.GetMemberName(i), oldTypeDetailed.GetMemberType(i));
  379. }
  380. newTypeBuilder.Add(memberName, memberValue.GetStaticType());
  381. auto newType = newTypeBuilder.Build();
  382. for (ui32 i = 0, e = newType->GetMembersCount(); i < e; ++i) {
  383. if (newType->GetMemberName(i) == memberName) {
  384. // insert at position i in the struct
  385. TCallableBuilder callableBuilder(Env, __func__, newType);
  386. callableBuilder.Add(structObj);
  387. callableBuilder.Add(memberValue);
  388. callableBuilder.Add(NewDataLiteral<ui32>(i));
  389. return TRuntimeNode(callableBuilder.Build(), false);
  390. }
  391. }
  392. Y_ABORT();
  393. }
  394. TRuntimeNode TProgramBuilder::RemoveMember(TRuntimeNode structObj, const std::string_view& memberName, bool forced) {
  395. auto oldType = structObj.GetStaticType();
  396. MKQL_ENSURE(oldType->IsStruct(), "Expected struct");
  397. const auto& oldTypeDetailed = static_cast<const TStructType&>(*oldType);
  398. MKQL_ENSURE(oldTypeDetailed.GetMembersCount() > 0, "Expected non-empty struct");
  399. TStructTypeBuilder newTypeBuilder(Env);
  400. newTypeBuilder.Reserve(oldTypeDetailed.GetMembersCount() - 1);
  401. std::optional<ui32> memberIndex;
  402. for (ui32 i = 0, e = oldTypeDetailed.GetMembersCount(); i < e; ++i) {
  403. if (oldTypeDetailed.GetMemberName(i) != memberName) {
  404. newTypeBuilder.Add(oldTypeDetailed.GetMemberName(i), oldTypeDetailed.GetMemberType(i));
  405. }
  406. else {
  407. memberIndex = i;
  408. }
  409. }
  410. if (!memberIndex && forced) {
  411. return structObj;
  412. }
  413. MKQL_ENSURE(memberIndex, "Unknown member name: " << memberName);
  414. // remove at position i in the struct
  415. auto newType = newTypeBuilder.Build();
  416. TCallableBuilder callableBuilder(Env, __func__, newType);
  417. callableBuilder.Add(structObj);
  418. callableBuilder.Add(NewDataLiteral<ui32>(*memberIndex));
  419. return TRuntimeNode(callableBuilder.Build(), false);
  420. }
  421. TRuntimeNode TProgramBuilder::Zip(const TArrayRef<const TRuntimeNode>& lists) {
  422. if (lists.empty()) {
  423. return NewEmptyList(Env.GetEmptyTupleLazy()->GetGenericType());
  424. }
  425. std::vector<TType*> tupleTypes;
  426. tupleTypes.reserve(lists.size());
  427. for (auto& list : lists) {
  428. if (list.GetStaticType()->IsEmptyList()) {
  429. tupleTypes.push_back(Env.GetTypeOfVoidLazy());
  430. continue;
  431. }
  432. AS_TYPE(TListType, list.GetStaticType());
  433. auto itemType = static_cast<const TListType&>(*list.GetStaticType()).GetItemType();
  434. tupleTypes.push_back(itemType);
  435. }
  436. auto returnType = TListType::Create(TTupleType::Create(tupleTypes.size(), tupleTypes.data(), Env), Env);
  437. TCallableBuilder callableBuilder(Env, __func__, returnType);
  438. for (auto& list : lists) {
  439. callableBuilder.Add(list);
  440. }
  441. return TRuntimeNode(callableBuilder.Build(), false);
  442. }
  443. TRuntimeNode TProgramBuilder::ZipAll(const TArrayRef<const TRuntimeNode>& lists) {
  444. if (lists.empty()) {
  445. return NewEmptyList(Env.GetEmptyTupleLazy()->GetGenericType());
  446. }
  447. std::vector<TType*> tupleTypes;
  448. tupleTypes.reserve(lists.size());
  449. for (auto& list : lists) {
  450. if (list.GetStaticType()->IsEmptyList()) {
  451. tupleTypes.push_back(TOptionalType::Create(Env.GetTypeOfVoidLazy(), Env));
  452. continue;
  453. }
  454. AS_TYPE(TListType, list.GetStaticType());
  455. auto itemType = static_cast<const TListType&>(*list.GetStaticType()).GetItemType();
  456. tupleTypes.push_back(TOptionalType::Create(itemType, Env));
  457. }
  458. auto returnType = TListType::Create(TTupleType::Create(tupleTypes.size(), tupleTypes.data(), Env), Env);
  459. TCallableBuilder callableBuilder(Env, __func__, returnType);
  460. for (auto& list : lists) {
  461. callableBuilder.Add(list);
  462. }
  463. return TRuntimeNode(callableBuilder.Build(), false);
  464. }
  465. TRuntimeNode TProgramBuilder::Enumerate(TRuntimeNode list, TRuntimeNode start, TRuntimeNode step) {
  466. const auto itemType = AS_TYPE(TListType, list.GetStaticType())->GetItemType();
  467. ThrowIfListOfVoid(itemType);
  468. MKQL_ENSURE(AS_TYPE(TDataType, start)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64 as start");
  469. MKQL_ENSURE(AS_TYPE(TDataType, step)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64 as step");
  470. const std::array<TType*, 2U> tupleTypes = {{ NewDataType(NUdf::EDataSlot::Uint64), itemType }};
  471. const auto returnType = NewListType(NewTupleType(tupleTypes));
  472. TCallableBuilder callableBuilder(Env, __func__, returnType);
  473. callableBuilder.Add(list);
  474. callableBuilder.Add(start);
  475. callableBuilder.Add(step);
  476. return TRuntimeNode(callableBuilder.Build(), false);
  477. }
  478. TRuntimeNode TProgramBuilder::Enumerate(TRuntimeNode list) {
  479. return TProgramBuilder::Enumerate(list, NewDataLiteral<ui64>(0), NewDataLiteral<ui64>(1));
  480. }
  481. TRuntimeNode TProgramBuilder::Fold(TRuntimeNode list, TRuntimeNode state, const TBinaryLambda& handler) {
  482. const auto itemType = AS_TYPE(TListType, list.GetStaticType())->GetItemType();
  483. ThrowIfListOfVoid(itemType);
  484. const auto stateNodeArg = Arg(state.GetStaticType());
  485. const auto itemArg = Arg(itemType);
  486. const auto newState = handler(itemArg, stateNodeArg);
  487. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*state.GetStaticType()), "State type is changed by the handler");
  488. TCallableBuilder callableBuilder(Env, __func__, state.GetStaticType());
  489. callableBuilder.Add(list);
  490. callableBuilder.Add(state);
  491. callableBuilder.Add(itemArg);
  492. callableBuilder.Add(stateNodeArg);
  493. callableBuilder.Add(newState);
  494. return TRuntimeNode(callableBuilder.Build(), false);
  495. }
  496. TRuntimeNode TProgramBuilder::Fold1(TRuntimeNode list, const TUnaryLambda& init, const TBinaryLambda& handler) {
  497. const auto itemType = AS_TYPE(TListType, list.GetStaticType())->GetItemType();
  498. ThrowIfListOfVoid(itemType);
  499. const auto itemArg = Arg(itemType);
  500. const auto initState = init(itemArg);
  501. const auto stateNodeArg = Arg(initState.GetStaticType());
  502. const auto newState = handler(itemArg, stateNodeArg);
  503. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*initState.GetStaticType()), "State type is changed by the handler");
  504. TCallableBuilder callableBuilder(Env, __func__, NewOptionalType(newState.GetStaticType()));
  505. callableBuilder.Add(list);
  506. callableBuilder.Add(itemArg);
  507. callableBuilder.Add(initState);
  508. callableBuilder.Add(stateNodeArg);
  509. callableBuilder.Add(newState);
  510. return TRuntimeNode(callableBuilder.Build(), false);
  511. }
  512. TRuntimeNode TProgramBuilder::Reduce(TRuntimeNode list, TRuntimeNode state1,
  513. const TBinaryLambda& handler1,
  514. const TUnaryLambda& handler2,
  515. TRuntimeNode state3,
  516. const TBinaryLambda& handler3) {
  517. const auto listType = list.GetStaticType();
  518. MKQL_ENSURE(listType->IsList() || listType->IsStream(), "Expected list or stream");
  519. const auto itemType = listType->IsList()?
  520. static_cast<const TListType&>(*listType).GetItemType():
  521. static_cast<const TStreamType&>(*listType).GetItemType();
  522. ThrowIfListOfVoid(itemType);
  523. const auto state1NodeArg = Arg(state1.GetStaticType());
  524. const auto state3NodeArg = Arg(state3.GetStaticType());
  525. const auto itemArg = Arg(itemType);
  526. const auto newState1 = handler1(itemArg, state1NodeArg);
  527. MKQL_ENSURE(newState1.GetStaticType()->IsSameType(*state1.GetStaticType()), "State 1 type is changed by the handler");
  528. const auto newState2 = handler2(state1NodeArg);
  529. TRuntimeNode itemState2Arg = Arg(newState2.GetStaticType());
  530. const auto newState3 = handler3(itemState2Arg, state3NodeArg);
  531. MKQL_ENSURE(newState3.GetStaticType()->IsSameType(*state3.GetStaticType()), "State 3 type is changed by the handler");
  532. TCallableBuilder callableBuilder(Env, __func__, newState3.GetStaticType());
  533. callableBuilder.Add(list);
  534. callableBuilder.Add(state1);
  535. callableBuilder.Add(state3);
  536. callableBuilder.Add(itemArg);
  537. callableBuilder.Add(state1NodeArg);
  538. callableBuilder.Add(newState1);
  539. callableBuilder.Add(newState2);
  540. callableBuilder.Add(itemState2Arg);
  541. callableBuilder.Add(state3NodeArg);
  542. callableBuilder.Add(newState3);
  543. return TRuntimeNode(callableBuilder.Build(), false);
  544. }
  545. TRuntimeNode TProgramBuilder::Condense(TRuntimeNode flow, TRuntimeNode state,
  546. const TBinaryLambda& switcher,
  547. const TBinaryLambda& handler, bool useCtx) {
  548. const auto flowType = flow.GetStaticType();
  549. if (flowType->IsList()) {
  550. // TODO: Native implementation for list.
  551. return Collect(Condense(ToFlow(flow), state, switcher, handler));
  552. }
  553. MKQL_ENSURE(flowType->IsFlow() || flowType->IsStream(), "Expected flow or stream.");
  554. const auto itemType = flowType->IsFlow() ?
  555. static_cast<const TFlowType&>(*flowType).GetItemType():
  556. static_cast<const TStreamType&>(*flowType).GetItemType();
  557. const auto itemArg = Arg(itemType);
  558. const auto stateArg = Arg(state.GetStaticType());
  559. const auto outSwitch = switcher(itemArg, stateArg);
  560. const auto newState = handler(itemArg, stateArg);
  561. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*state.GetStaticType()), "State type is changed by the handler");
  562. TCallableBuilder callableBuilder(Env, __func__, flowType->IsFlow() ? NewFlowType(state.GetStaticType()) : NewStreamType(state.GetStaticType()));
  563. callableBuilder.Add(flow);
  564. callableBuilder.Add(state);
  565. callableBuilder.Add(itemArg);
  566. callableBuilder.Add(stateArg);
  567. callableBuilder.Add(outSwitch);
  568. callableBuilder.Add(newState);
  569. if (useCtx) {
  570. MKQL_ENSURE(RuntimeVersion >= 30U, "Too old runtime version");
  571. callableBuilder.Add(NewDataLiteral<bool>(useCtx));
  572. }
  573. return TRuntimeNode(callableBuilder.Build(), false);
  574. }
  575. TRuntimeNode TProgramBuilder::Condense1(TRuntimeNode flow, const TUnaryLambda& init,
  576. const TBinaryLambda& switcher,
  577. const TBinaryLambda& handler, bool useCtx) {
  578. const auto flowType = flow.GetStaticType();
  579. if (flowType->IsList()) {
  580. // TODO: Native implementation for list.
  581. return Collect(Condense1(ToFlow(flow), init, switcher, handler));
  582. }
  583. MKQL_ENSURE(flowType->IsFlow() || flowType->IsStream(), "Expected flow or stream.");
  584. const auto itemType = flowType->IsFlow() ?
  585. static_cast<const TFlowType&>(*flowType).GetItemType():
  586. static_cast<const TStreamType&>(*flowType).GetItemType();
  587. const auto itemArg = Arg(itemType);
  588. const auto initState = init(itemArg);
  589. const auto stateArg = Arg(initState.GetStaticType());
  590. const auto outSwitch = switcher(itemArg, stateArg);
  591. const auto newState = handler(itemArg, stateArg);
  592. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*initState.GetStaticType()), "State type is changed by the handler");
  593. TCallableBuilder callableBuilder(Env, __func__, flowType->IsFlow() ? NewFlowType(newState.GetStaticType()) : NewStreamType(newState.GetStaticType()));
  594. callableBuilder.Add(flow);
  595. callableBuilder.Add(itemArg);
  596. callableBuilder.Add(initState);
  597. callableBuilder.Add(stateArg);
  598. callableBuilder.Add(outSwitch);
  599. callableBuilder.Add(newState);
  600. if (useCtx) {
  601. MKQL_ENSURE(RuntimeVersion >= 30U, "Too old runtime version");
  602. callableBuilder.Add(NewDataLiteral<bool>(useCtx));
  603. }
  604. return TRuntimeNode(callableBuilder.Build(), false);
  605. }
  606. TRuntimeNode TProgramBuilder::Squeeze(TRuntimeNode stream, TRuntimeNode state,
  607. const TBinaryLambda& handler,
  608. const TUnaryLambda& save,
  609. const TUnaryLambda& load) {
  610. const auto streamType = stream.GetStaticType();
  611. MKQL_ENSURE(streamType->IsStream(), "Expected stream");
  612. const auto& streamDetailedType = static_cast<const TStreamType&>(*streamType);
  613. const auto itemType = streamDetailedType.GetItemType();
  614. ThrowIfListOfVoid(itemType);
  615. const auto stateNodeArg = Arg(state.GetStaticType());
  616. const auto itemArg = Arg(itemType);
  617. const auto newState = handler(itemArg, stateNodeArg);
  618. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*state.GetStaticType()), "State type is changed by the handler");
  619. TRuntimeNode saveArg, outSave, loadArg, outLoad;
  620. if (save && load) {
  621. outSave = save(saveArg = Arg(state.GetStaticType()));
  622. outLoad = load(loadArg = Arg(outSave.GetStaticType()));
  623. MKQL_ENSURE(outLoad.GetStaticType()->IsSameType(*state.GetStaticType()), "Loaded type is changed by the load handler");
  624. } else {
  625. saveArg = outSave = loadArg = outLoad = NewVoid();
  626. }
  627. TCallableBuilder callableBuilder(Env, __func__, TStreamType::Create(state.GetStaticType(), Env));
  628. callableBuilder.Add(stream);
  629. callableBuilder.Add(state);
  630. callableBuilder.Add(itemArg);
  631. callableBuilder.Add(stateNodeArg);
  632. callableBuilder.Add(newState);
  633. callableBuilder.Add(saveArg);
  634. callableBuilder.Add(outSave);
  635. callableBuilder.Add(loadArg);
  636. callableBuilder.Add(outLoad);
  637. return TRuntimeNode(callableBuilder.Build(), false);
  638. }
  639. TRuntimeNode TProgramBuilder::Squeeze1(TRuntimeNode stream, const TUnaryLambda& init,
  640. const TBinaryLambda& handler,
  641. const TUnaryLambda& save,
  642. const TUnaryLambda& load) {
  643. const auto streamType = stream.GetStaticType();
  644. MKQL_ENSURE(streamType->IsStream(), "Expected stream");
  645. const auto& streamDetailedType = static_cast<const TStreamType&>(*streamType);
  646. const auto itemType = streamDetailedType.GetItemType();
  647. ThrowIfListOfVoid(itemType);
  648. const auto itemArg = Arg(itemType);
  649. const auto initState = init(itemArg);
  650. const auto stateNodeArg = Arg(initState.GetStaticType());
  651. const auto newState = handler(itemArg, stateNodeArg);
  652. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*initState.GetStaticType()), "State type is changed by the handler");
  653. TRuntimeNode saveArg, outSave, loadArg, outLoad;
  654. if (save && load) {
  655. outSave = save(saveArg = Arg(initState.GetStaticType()));
  656. outLoad = load(loadArg = Arg(outSave.GetStaticType()));
  657. MKQL_ENSURE(outLoad.GetStaticType()->IsSameType(*initState.GetStaticType()), "Loaded type is changed by the load handler");
  658. } else {
  659. saveArg = outSave = loadArg = outLoad = NewVoid();
  660. }
  661. TCallableBuilder callableBuilder(Env, __func__, NewStreamType(newState.GetStaticType()));
  662. callableBuilder.Add(stream);
  663. callableBuilder.Add(itemArg);
  664. callableBuilder.Add(initState);
  665. callableBuilder.Add(stateNodeArg);
  666. callableBuilder.Add(newState);
  667. callableBuilder.Add(saveArg);
  668. callableBuilder.Add(outSave);
  669. callableBuilder.Add(loadArg);
  670. callableBuilder.Add(outLoad);
  671. return TRuntimeNode(callableBuilder.Build(), false);
  672. }
  673. TRuntimeNode TProgramBuilder::Discard(TRuntimeNode stream) {
  674. const auto streamType = stream.GetStaticType();
  675. MKQL_ENSURE(streamType->IsStream() || streamType->IsFlow(), "Expected stream or flow.");
  676. TCallableBuilder callableBuilder(Env, __func__, streamType);
  677. callableBuilder.Add(stream);
  678. return TRuntimeNode(callableBuilder.Build(), false);
  679. }
  680. TRuntimeNode TProgramBuilder::Map(TRuntimeNode list, const TUnaryLambda& handler) {
  681. return BuildMap(__func__, list, handler);
  682. }
  683. TRuntimeNode TProgramBuilder::OrderedMap(TRuntimeNode list, const TUnaryLambda& handler) {
  684. return BuildMap(__func__, list, handler);
  685. }
  686. TRuntimeNode TProgramBuilder::MapNext(TRuntimeNode list, const TBinaryLambda& handler) {
  687. const auto listType = list.GetStaticType();
  688. MKQL_ENSURE(listType->IsStream() || listType->IsFlow(), "Expected stream or flow");
  689. const auto itemType = listType->IsFlow() ?
  690. AS_TYPE(TFlowType, listType)->GetItemType():
  691. AS_TYPE(TStreamType, listType)->GetItemType();
  692. ThrowIfListOfVoid(itemType);
  693. TType* nextItemType = TOptionalType::Create(itemType, Env);
  694. const auto itemArg = Arg(itemType);
  695. const auto nextItemArg = Arg(nextItemType);
  696. const auto newItem = handler(itemArg, nextItemArg);
  697. const auto resultListType = listType->IsFlow() ?
  698. (TType*)TFlowType::Create(newItem.GetStaticType(), Env):
  699. (TType*)TStreamType::Create(newItem.GetStaticType(), Env);
  700. TCallableBuilder callableBuilder(Env, __func__, resultListType);
  701. callableBuilder.Add(list);
  702. callableBuilder.Add(itemArg);
  703. callableBuilder.Add(nextItemArg);
  704. callableBuilder.Add(newItem);
  705. return TRuntimeNode(callableBuilder.Build(), false);
  706. }
  707. template <bool Ordered>
  708. TRuntimeNode TProgramBuilder::BuildExtract(TRuntimeNode list, const std::string_view& name) {
  709. const auto listType = list.GetStaticType();
  710. MKQL_ENSURE(listType->IsList() || listType->IsOptional(), "Expected list or optional.");
  711. const auto itemType = listType->IsList() ?
  712. AS_TYPE(TListType, listType)->GetItemType():
  713. AS_TYPE(TOptionalType, listType)->GetItemType();
  714. const auto lambda = [&](TRuntimeNode item) {
  715. return itemType->IsStruct() ? Member(item, name) : Nth(item, ::FromString<ui32>(name));
  716. };
  717. return Ordered ? OrderedMap(list, lambda) : Map(list, lambda);
  718. }
  719. TRuntimeNode TProgramBuilder::Extract(TRuntimeNode list, const std::string_view& name) {
  720. return BuildExtract<false>(list, name);
  721. }
  722. TRuntimeNode TProgramBuilder::OrderedExtract(TRuntimeNode list, const std::string_view& name) {
  723. return BuildExtract<true>(list, name);
  724. }
  725. TRuntimeNode TProgramBuilder::ChainMap(TRuntimeNode list, TRuntimeNode state, const TBinaryLambda& handler) {
  726. return ChainMap(list, state, [&](TRuntimeNode item, TRuntimeNode state) -> TRuntimeNodePair {
  727. const auto result = handler(item, state);
  728. return {result, result};
  729. });
  730. }
  731. TRuntimeNode TProgramBuilder::ChainMap(TRuntimeNode list, TRuntimeNode state, const TBinarySplitLambda& handler) {
  732. const auto listType = list.GetStaticType();
  733. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream");
  734. const auto itemType = listType->IsFlow() ?
  735. AS_TYPE(TFlowType, listType)->GetItemType():
  736. listType->IsList() ?
  737. AS_TYPE(TListType, listType)->GetItemType():
  738. AS_TYPE(TStreamType, listType)->GetItemType();
  739. ThrowIfListOfVoid(itemType);
  740. const auto stateNodeArg = Arg(state.GetStaticType());
  741. const auto itemArg = Arg(itemType);
  742. const auto newItemAndState = handler(itemArg, stateNodeArg);
  743. MKQL_ENSURE(std::get<1U>(newItemAndState).GetStaticType()->IsSameType(*state.GetStaticType()), "State type is changed by the handler");
  744. const auto resultItemType = std::get<0U>(newItemAndState).GetStaticType();
  745. TType* resultListType = nullptr;
  746. if (listType->IsFlow()) {
  747. resultListType = TFlowType::Create(resultItemType, Env);
  748. } else if (listType->IsList()) {
  749. resultListType = TListType::Create(resultItemType, Env);
  750. } else if (listType->IsStream()) {
  751. resultListType = TStreamType::Create(resultItemType, Env);
  752. }
  753. TCallableBuilder callableBuilder(Env, __func__, resultListType);
  754. callableBuilder.Add(list);
  755. callableBuilder.Add(state);
  756. callableBuilder.Add(itemArg);
  757. callableBuilder.Add(stateNodeArg);
  758. callableBuilder.Add(std::get<0U>(newItemAndState));
  759. callableBuilder.Add(std::get<1U>(newItemAndState));
  760. return TRuntimeNode(callableBuilder.Build(), false);
  761. }
  762. TRuntimeNode TProgramBuilder::Chain1Map(TRuntimeNode list, const TUnaryLambda& init, const TBinaryLambda& handler) {
  763. return Chain1Map(list,
  764. [&](TRuntimeNode item) -> TRuntimeNodePair {
  765. const auto result = init(item);
  766. return {result, result};
  767. },
  768. [&](TRuntimeNode item, TRuntimeNode state) -> TRuntimeNodePair {
  769. const auto result = handler(item, state);
  770. return {result, result};
  771. }
  772. );
  773. }
  774. TRuntimeNode TProgramBuilder::Chain1Map(TRuntimeNode list, const TUnarySplitLambda& init, const TBinarySplitLambda& handler) {
  775. const auto listType = list.GetStaticType();
  776. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream");
  777. const auto itemType = listType->IsFlow() ?
  778. AS_TYPE(TFlowType, listType)->GetItemType():
  779. listType->IsList() ?
  780. AS_TYPE(TListType, listType)->GetItemType():
  781. AS_TYPE(TStreamType, listType)->GetItemType();
  782. ThrowIfListOfVoid(itemType);
  783. const auto itemArg = Arg(itemType);
  784. const auto initItemAndState = init(itemArg);
  785. const auto resultItemType = std::get<0U>(initItemAndState).GetStaticType();
  786. const auto stateType = std::get<1U>(initItemAndState).GetStaticType();;
  787. TType* resultListType = nullptr;
  788. if (listType->IsFlow()) {
  789. resultListType = TFlowType::Create(resultItemType, Env);
  790. } else if (listType->IsList()) {
  791. resultListType = TListType::Create(resultItemType, Env);
  792. } else if (listType->IsStream()) {
  793. resultListType = TStreamType::Create(resultItemType, Env);
  794. }
  795. const auto stateArg = Arg(stateType);
  796. const auto updateItemAndState = handler(itemArg, stateArg);
  797. MKQL_ENSURE(std::get<0U>(updateItemAndState).GetStaticType()->IsSameType(*resultItemType), "Item type is changed by the handler");
  798. MKQL_ENSURE(std::get<1U>(updateItemAndState).GetStaticType()->IsSameType(*stateType), "State type is changed by the handler");
  799. TCallableBuilder callableBuilder(Env, __func__, resultListType);
  800. callableBuilder.Add(list);
  801. callableBuilder.Add(itemArg);
  802. callableBuilder.Add(std::get<0U>(initItemAndState));
  803. callableBuilder.Add(std::get<1U>(initItemAndState));
  804. callableBuilder.Add(stateArg);
  805. callableBuilder.Add(std::get<0U>(updateItemAndState));
  806. callableBuilder.Add(std::get<1U>(updateItemAndState));
  807. return TRuntimeNode(callableBuilder.Build(), false);
  808. }
  809. TRuntimeNode TProgramBuilder::ToList(TRuntimeNode optional) {
  810. const auto optionalType = optional.GetStaticType();
  811. MKQL_ENSURE(optionalType->IsOptional(), "Expected optional");
  812. const auto& optionalDetailedType = static_cast<const TOptionalType&>(*optionalType);
  813. const auto itemType = optionalDetailedType.GetItemType();
  814. return IfPresent(optional, [&](TRuntimeNode item) { return AsList(item); }, NewEmptyList(itemType));
  815. }
  816. TRuntimeNode TProgramBuilder::Iterable(TZeroLambda lambda) {
  817. if constexpr (RuntimeVersion < 19U) {
  818. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  819. }
  820. const auto itemArg = Arg(NewNull().GetStaticType());
  821. auto lambdaRes = lambda();
  822. const auto resultType = NewListType(AS_TYPE(TStreamType, lambdaRes.GetStaticType())->GetItemType());
  823. TCallableBuilder callableBuilder(Env, __func__, resultType);
  824. callableBuilder.Add(lambdaRes);
  825. callableBuilder.Add(itemArg);
  826. return TRuntimeNode(callableBuilder.Build(), false);
  827. }
  828. TRuntimeNode TProgramBuilder::ToOptional(TRuntimeNode list) {
  829. return Head(list);
  830. }
  831. TRuntimeNode TProgramBuilder::Head(TRuntimeNode list) {
  832. const auto resultType = NewOptionalType(AS_TYPE(TListType, list.GetStaticType())->GetItemType());
  833. TCallableBuilder callableBuilder(Env, __func__, resultType);
  834. callableBuilder.Add(list);
  835. return TRuntimeNode(callableBuilder.Build(), false);
  836. }
  837. TRuntimeNode TProgramBuilder::Last(TRuntimeNode list) {
  838. const auto resultType = NewOptionalType(AS_TYPE(TListType, list.GetStaticType())->GetItemType());
  839. TCallableBuilder callableBuilder(Env, __func__, resultType);
  840. callableBuilder.Add(list);
  841. return TRuntimeNode(callableBuilder.Build(), false);
  842. }
  843. TRuntimeNode TProgramBuilder::Nanvl(TRuntimeNode data, TRuntimeNode dataIfNaN) {
  844. const std::array<TRuntimeNode, 2> args = {{ data, dataIfNaN }};
  845. return Invoke(__func__, BuildArithmeticCommonType(data.GetStaticType(), dataIfNaN.GetStaticType()), args);
  846. }
  847. TRuntimeNode TProgramBuilder::FlatMap(TRuntimeNode list, const TUnaryLambda& handler)
  848. {
  849. return BuildFlatMap(__func__, list, handler);
  850. }
  851. TRuntimeNode TProgramBuilder::OrderedFlatMap(TRuntimeNode list, const TUnaryLambda& handler)
  852. {
  853. return BuildFlatMap(__func__, list, handler);
  854. }
  855. TRuntimeNode TProgramBuilder::Filter(TRuntimeNode list, const TUnaryLambda& handler)
  856. {
  857. return BuildFilter(__func__, list, handler);
  858. }
  859. TRuntimeNode TProgramBuilder::Filter(TRuntimeNode list, TRuntimeNode limit, const TUnaryLambda& handler)
  860. {
  861. return BuildFilter(__func__, list, limit, handler);
  862. }
  863. TRuntimeNode TProgramBuilder::OrderedFilter(TRuntimeNode list, const TUnaryLambda& handler)
  864. {
  865. return BuildFilter(__func__, list, handler);
  866. }
  867. TRuntimeNode TProgramBuilder::OrderedFilter(TRuntimeNode list, TRuntimeNode limit, const TUnaryLambda& handler)
  868. {
  869. return BuildFilter(__func__, list, limit, handler);
  870. }
  871. TRuntimeNode TProgramBuilder::TakeWhile(TRuntimeNode list, const TUnaryLambda& handler)
  872. {
  873. return BuildFilter(__func__, list, handler);
  874. }
  875. TRuntimeNode TProgramBuilder::SkipWhile(TRuntimeNode list, const TUnaryLambda& handler)
  876. {
  877. return BuildFilter(__func__, list, handler);
  878. }
  879. TRuntimeNode TProgramBuilder::TakeWhileInclusive(TRuntimeNode list, const TUnaryLambda& handler)
  880. {
  881. return BuildFilter(__func__, list, handler);
  882. }
  883. TRuntimeNode TProgramBuilder::SkipWhileInclusive(TRuntimeNode list, const TUnaryLambda& handler)
  884. {
  885. return BuildFilter(__func__, list, handler);
  886. }
  887. TRuntimeNode TProgramBuilder::BuildListSort(const std::string_view& callableName, TRuntimeNode list, TRuntimeNode ascending,
  888. const TUnaryLambda& keyExtractor)
  889. {
  890. const auto listType = list.GetStaticType();
  891. MKQL_ENSURE(listType->IsList(), "Expected list.");
  892. const auto itemType = static_cast<const TListType&>(*listType).GetItemType();
  893. ThrowIfListOfVoid(itemType);
  894. const auto ascendingType = ascending.GetStaticType();
  895. const auto itemArg = Arg(itemType);
  896. auto key = keyExtractor(itemArg);
  897. if (ascendingType->IsTuple()) {
  898. const auto ascendingTuple = AS_TYPE(TTupleType, ascendingType);
  899. if (ascendingTuple->GetElementsCount() == 0) {
  900. return list;
  901. }
  902. if (ascendingTuple->GetElementsCount() == 1) {
  903. ascending = Nth(ascending, 0);
  904. key = Nth(key, 0);
  905. }
  906. }
  907. TCallableBuilder callableBuilder(Env, callableName, listType);
  908. callableBuilder.Add(list);
  909. callableBuilder.Add(itemArg);
  910. callableBuilder.Add(key);
  911. callableBuilder.Add(ascending);
  912. return TRuntimeNode(callableBuilder.Build(), false);
  913. }
  914. TRuntimeNode TProgramBuilder::BuildListNth(const std::string_view& callableName, TRuntimeNode list, TRuntimeNode n, TRuntimeNode ascending,
  915. const TUnaryLambda& keyExtractor)
  916. {
  917. const auto listType = list.GetStaticType();
  918. MKQL_ENSURE(listType->IsList(), "Expected list.");
  919. const auto itemType = static_cast<const TListType&>(*listType).GetItemType();
  920. ThrowIfListOfVoid(itemType);
  921. MKQL_ENSURE(n.GetStaticType()->IsData(), "Expected data");
  922. MKQL_ENSURE(static_cast<const TDataType&>(*n.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  923. const auto ascendingType = ascending.GetStaticType();
  924. const auto itemArg = Arg(itemType);
  925. auto key = keyExtractor(itemArg);
  926. if (ascendingType->IsTuple()) {
  927. const auto ascendingTuple = AS_TYPE(TTupleType, ascendingType);
  928. if (ascendingTuple->GetElementsCount() == 0) {
  929. return Take(list, n);
  930. }
  931. if (ascendingTuple->GetElementsCount() == 1) {
  932. ascending = Nth(ascending, 0);
  933. key = Nth(key, 0);
  934. }
  935. }
  936. TCallableBuilder callableBuilder(Env, callableName, listType);
  937. callableBuilder.Add(list);
  938. callableBuilder.Add(n);
  939. callableBuilder.Add(itemArg);
  940. callableBuilder.Add(key);
  941. callableBuilder.Add(ascending);
  942. return TRuntimeNode(callableBuilder.Build(), false);
  943. }
  944. TRuntimeNode TProgramBuilder::BuildSort(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode ascending,
  945. const TUnaryLambda& keyExtractor)
  946. {
  947. if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
  948. const bool newVersion = RuntimeVersion >= 25U && flowType->IsFlow();
  949. const auto condense = newVersion ?
  950. SqueezeToList(Map(flow, [&](TRuntimeNode item) { return Pickle(item); }), NewEmptyOptionalDataLiteral(NUdf::TDataType<ui64>::Id)) :
  951. Condense1(flow,
  952. [this](TRuntimeNode item) { return AsList(item); },
  953. [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
  954. [this](TRuntimeNode item, TRuntimeNode state) { return Append(state, item); }
  955. );
  956. const auto finalKeyExtractor = newVersion ? [&](TRuntimeNode item) {
  957. auto itemType = AS_TYPE(TFlowType, flowType)->GetItemType();
  958. return keyExtractor(Unpickle(itemType, item));
  959. } : keyExtractor;
  960. return FlatMap(condense, [&](TRuntimeNode list) {
  961. auto stealed = RuntimeVersion >= 27U ? Steal(list) : list;
  962. auto sorted = BuildSort(RuntimeVersion >= 26U ? "UnstableSort" : callableName, stealed, ascending, finalKeyExtractor);
  963. return newVersion ? Map(LazyList(sorted), [&](TRuntimeNode item) {
  964. auto itemType = AS_TYPE(TFlowType, flowType)->GetItemType();
  965. return Unpickle(itemType, item);
  966. }) : sorted;
  967. });
  968. }
  969. return BuildListSort(callableName, flow, ascending, keyExtractor);
  970. }
  971. TRuntimeNode TProgramBuilder::BuildNth(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode n, TRuntimeNode ascending,
  972. const TUnaryLambda& keyExtractor)
  973. {
  974. if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
  975. return FlatMap(Condense1(flow,
  976. [this](TRuntimeNode item) { return AsList(item); },
  977. [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
  978. [this](TRuntimeNode item, TRuntimeNode state) { return Append(state, item); }
  979. ),
  980. [&](TRuntimeNode list) { return BuildNth(callableName, list, n, ascending, keyExtractor); }
  981. );
  982. }
  983. return BuildListNth(callableName, flow, n, ascending, keyExtractor);
  984. }
  985. TRuntimeNode TProgramBuilder::BuildTake(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count) {
  986. const auto listType = flow.GetStaticType();
  987. TType* itemType = nullptr;
  988. if (listType->IsFlow()) {
  989. itemType = AS_TYPE(TFlowType, listType)->GetItemType();
  990. } else if (listType->IsList()) {
  991. itemType = AS_TYPE(TListType, listType)->GetItemType();
  992. } else if (listType->IsStream()) {
  993. itemType = AS_TYPE(TStreamType, listType)->GetItemType();
  994. }
  995. MKQL_ENSURE(itemType, "Expected flow, list or stream.");
  996. ThrowIfListOfVoid(itemType);
  997. MKQL_ENSURE(count.GetStaticType()->IsData(), "Expected data");
  998. MKQL_ENSURE(static_cast<const TDataType&>(*count.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  999. TCallableBuilder callableBuilder(Env, callableName, listType);
  1000. callableBuilder.Add(flow);
  1001. callableBuilder.Add(count);
  1002. return TRuntimeNode(callableBuilder.Build(), false);
  1003. }
  1004. template<bool IsFilter, bool OnStruct>
  1005. TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list) {
  1006. const auto listType = list.GetStaticType();
  1007. TType* itemType;
  1008. if (listType->IsFlow()) {
  1009. itemType = AS_TYPE(TFlowType, listType)->GetItemType();
  1010. } else if (listType->IsList()) {
  1011. itemType = AS_TYPE(TListType, listType)->GetItemType();
  1012. } else if (listType->IsStream()) {
  1013. itemType = AS_TYPE(TStreamType, listType)->GetItemType();
  1014. } else if (listType->IsOptional()) {
  1015. itemType = AS_TYPE(TOptionalType, listType)->GetItemType();
  1016. } else {
  1017. THROW yexception() << "Expected flow or list or stream or optional of " << (OnStruct ? "struct." : "tuple.");
  1018. }
  1019. std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TType*>>, std::vector<TType*>> filteredItems;
  1020. std::vector<std::conditional_t<OnStruct, std::string_view, ui32>> members;
  1021. const bool multiOptional = CollectOptionalElements<IsFilter>(itemType, members, filteredItems);
  1022. const auto predicate = [=](TRuntimeNode item) {
  1023. std::vector<TRuntimeNode> checkMembers;
  1024. checkMembers.reserve(members.size());
  1025. std::transform(members.cbegin(), members.cend(), std::back_inserter(checkMembers),
  1026. [=](const auto& i){ return Exists(Element(item, i)); });
  1027. return And(checkMembers);
  1028. };
  1029. auto resultType = listType;
  1030. if constexpr (IsFilter) {
  1031. if (const auto filteredItemType = NewArrayType(filteredItems); multiOptional) {
  1032. return BuildFilterNulls<OnStruct>(list, members, filteredItems);
  1033. } else {
  1034. resultType = listType->IsFlow() ?
  1035. NewFlowType(filteredItemType):
  1036. listType->IsList() ?
  1037. NewListType(filteredItemType):
  1038. listType->IsStream() ? NewStreamType(filteredItemType) : NewOptionalType(filteredItemType);
  1039. }
  1040. }
  1041. return Filter(list, predicate, resultType);
  1042. }
  1043. template<bool IsFilter, bool OnStruct>
  1044. TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list, const TArrayRef<std::conditional_t<OnStruct, const std::string_view, const ui32>>& members) {
  1045. if (members.empty()) {
  1046. return list;
  1047. }
  1048. const auto listType = list.GetStaticType();
  1049. TType* itemType;
  1050. if (listType->IsFlow()) {
  1051. itemType = AS_TYPE(TFlowType, listType)->GetItemType();
  1052. } else if (listType->IsList()) {
  1053. itemType = AS_TYPE(TListType, listType)->GetItemType();
  1054. } else if (listType->IsStream()) {
  1055. itemType = AS_TYPE(TStreamType, listType)->GetItemType();
  1056. } else if (listType->IsOptional()) {
  1057. itemType = AS_TYPE(TOptionalType, listType)->GetItemType();
  1058. } else {
  1059. THROW yexception() << "Expected flow or list or stream or optional of struct.";
  1060. }
  1061. const auto predicate = [=](TRuntimeNode item) {
  1062. TRuntimeNode::TList checkMembers;
  1063. checkMembers.reserve(members.size());
  1064. std::transform(members.cbegin(), members.cend(), std::back_inserter(checkMembers),
  1065. [=](const auto& i){ return Exists(Element(item, i)); });
  1066. return And(checkMembers);
  1067. };
  1068. auto resultType = listType;
  1069. if constexpr (IsFilter) {
  1070. if (std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TType*>>, std::vector<TType*>> filteredItems;
  1071. ReduceOptionalElements(itemType, members, filteredItems)) {
  1072. return BuildFilterNulls<OnStruct>(list, members, filteredItems);
  1073. } else {
  1074. const auto filteredItemType = NewArrayType(filteredItems);
  1075. resultType = listType->IsFlow() ?
  1076. NewFlowType(filteredItemType):
  1077. listType->IsList() ?
  1078. NewListType(filteredItemType):
  1079. listType->IsStream() ? NewStreamType(filteredItemType) : NewOptionalType(filteredItemType);
  1080. }
  1081. }
  1082. return Filter(list, predicate, resultType);
  1083. }
  1084. template<bool OnStruct>
  1085. TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list, const TArrayRef<std::conditional_t<OnStruct, const std::string_view, const ui32>>& members,
  1086. const std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TType*>>, std::vector<TType*>>& filteredItems) {
  1087. return FlatMap(list, [&](TRuntimeNode item) {
  1088. TRuntimeNode::TList checkMembers;
  1089. checkMembers.reserve(members.size());
  1090. std::transform(members.cbegin(), members.cend(), std::back_inserter(checkMembers),
  1091. [=](const auto& i){ return Element(item, i); });
  1092. return IfPresent(checkMembers, [&](TRuntimeNode::TList items) {
  1093. std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TRuntimeNode>>, TRuntimeNode::TList> row;
  1094. row.reserve(filteredItems.size());
  1095. auto j = 0U;
  1096. if constexpr (OnStruct) {
  1097. std::transform(filteredItems.cbegin(), filteredItems.cend(), std::back_inserter(row),
  1098. [&](const std::pair<std::string_view, TType*>& i) {
  1099. const auto& member = i.first;
  1100. const bool passtrought = members.cend() == std::find(members.cbegin(), members.cend(), member);
  1101. return std::make_pair(member, passtrought ? Element(item, member) : items[j++]);
  1102. }
  1103. );
  1104. return NewOptional(NewStruct(row));
  1105. } else {
  1106. auto i = 0U;
  1107. std::generate_n(std::back_inserter(row), filteredItems.size(),
  1108. [&]() {
  1109. const auto index = i++;
  1110. const bool passtrought = members.cend() == std::find(members.cbegin(), members.cend(), index);
  1111. return passtrought ? Element(item, index) : items[j++];
  1112. }
  1113. );
  1114. return NewOptional(NewTuple(row));
  1115. }
  1116. }, NewEmptyOptional(NewOptionalType(NewArrayType(filteredItems))));
  1117. });
  1118. }
  1119. TRuntimeNode TProgramBuilder::SkipNullMembers(TRuntimeNode list) {
  1120. return BuildFilterNulls<false, true>(list);
  1121. }
  1122. TRuntimeNode TProgramBuilder::FilterNullMembers(TRuntimeNode list) {
  1123. return BuildFilterNulls<true, true>(list);
  1124. }
  1125. TRuntimeNode TProgramBuilder::SkipNullMembers(TRuntimeNode list, const TArrayRef<const std::string_view>& members) {
  1126. return BuildFilterNulls<false, true>(list, members);
  1127. }
  1128. TRuntimeNode TProgramBuilder::FilterNullMembers(TRuntimeNode list, const TArrayRef<const std::string_view>& members) {
  1129. return BuildFilterNulls<true, true>(list, members);
  1130. }
  1131. TRuntimeNode TProgramBuilder::FilterNullElements(TRuntimeNode list) {
  1132. return BuildFilterNulls<true, false>(list);
  1133. }
  1134. TRuntimeNode TProgramBuilder::SkipNullElements(TRuntimeNode list) {
  1135. return BuildFilterNulls<false, false>(list);
  1136. }
  1137. TRuntimeNode TProgramBuilder::FilterNullElements(TRuntimeNode list, const TArrayRef<const ui32>& elements) {
  1138. return BuildFilterNulls<true, false>(list, elements);
  1139. }
  1140. TRuntimeNode TProgramBuilder::SkipNullElements(TRuntimeNode list, const TArrayRef<const ui32>& elements) {
  1141. return BuildFilterNulls<false, false>(list, elements);
  1142. }
  1143. template <typename ResultType>
  1144. TRuntimeNode TProgramBuilder::BuildContainerProperty(const std::string_view& callableName, TRuntimeNode listOrDict) {
  1145. const auto type = listOrDict.GetStaticType();
  1146. MKQL_ENSURE(type->IsList() || type->IsDict() || type->IsEmptyList() || type->IsEmptyDict(), "Expected list or dict.");
  1147. if (type->IsList()) {
  1148. const auto itemType = AS_TYPE(TListType, type)->GetItemType();
  1149. ThrowIfListOfVoid(itemType);
  1150. }
  1151. TCallableBuilder callableBuilder(Env, callableName, NewDataType(NUdf::TDataType<ResultType>::Id));
  1152. callableBuilder.Add(listOrDict);
  1153. return TRuntimeNode(callableBuilder.Build(), false);
  1154. }
  1155. TRuntimeNode TProgramBuilder::Length(TRuntimeNode listOrDict) {
  1156. return BuildContainerProperty<ui64>(__func__, listOrDict);
  1157. }
  1158. TRuntimeNode TProgramBuilder::Iterator(TRuntimeNode list, const TArrayRef<const TRuntimeNode>& dependentNodes) {
  1159. const auto streamType = NewStreamType(AS_TYPE(TListType, list.GetStaticType())->GetItemType());
  1160. TCallableBuilder callableBuilder(Env, __func__, streamType);
  1161. callableBuilder.Add(list);
  1162. for (auto node : dependentNodes) {
  1163. callableBuilder.Add(node);
  1164. }
  1165. return TRuntimeNode(callableBuilder.Build(), false);
  1166. }
  1167. TRuntimeNode TProgramBuilder::EmptyIterator(TType* streamType) {
  1168. MKQL_ENSURE(streamType->IsStream() || streamType->IsFlow(), "Expected stream or flow.");
  1169. if (RuntimeVersion < 7U && streamType->IsFlow()) {
  1170. return ToFlow(EmptyIterator(NewStreamType(AS_TYPE(TFlowType, streamType)->GetItemType())));
  1171. }
  1172. TCallableBuilder callableBuilder(Env, __func__, streamType);
  1173. return TRuntimeNode(callableBuilder.Build(), false);
  1174. }
  1175. TRuntimeNode TProgramBuilder::Collect(TRuntimeNode flow) {
  1176. const auto seqType = flow.GetStaticType();
  1177. TType* itemType = nullptr;
  1178. if (seqType->IsFlow()) {
  1179. itemType = AS_TYPE(TFlowType, seqType)->GetItemType();
  1180. } else if (seqType->IsList()) {
  1181. itemType = AS_TYPE(TListType, seqType)->GetItemType();
  1182. } else if (seqType->IsStream()) {
  1183. itemType = AS_TYPE(TStreamType, seqType)->GetItemType();
  1184. } else {
  1185. THROW yexception() << "Expected flow, list or stream.";
  1186. }
  1187. TCallableBuilder callableBuilder(Env, __func__, NewListType(itemType));
  1188. callableBuilder.Add(flow);
  1189. return TRuntimeNode(callableBuilder.Build(), false);
  1190. }
  1191. TRuntimeNode TProgramBuilder::LazyList(TRuntimeNode list) {
  1192. const auto type = list.GetStaticType();
  1193. bool isOptional;
  1194. const auto listType = UnpackOptional(type, isOptional);
  1195. MKQL_ENSURE(listType->IsList(), "Expected list");
  1196. TCallableBuilder callableBuilder(Env, __func__, type);
  1197. callableBuilder.Add(list);
  1198. return TRuntimeNode(callableBuilder.Build(), false);
  1199. }
  1200. TRuntimeNode TProgramBuilder::ForwardList(TRuntimeNode stream) {
  1201. const auto type = stream.GetStaticType();
  1202. MKQL_ENSURE(type->IsStream() || type->IsFlow(), "Expected flow or stream.");
  1203. if constexpr (RuntimeVersion < 10U) {
  1204. if (type->IsFlow()) {
  1205. return ForwardList(FromFlow(stream));
  1206. }
  1207. }
  1208. TCallableBuilder callableBuilder(Env, __func__, NewListType(type->IsFlow() ? AS_TYPE(TFlowType, stream)->GetItemType() : AS_TYPE(TStreamType, stream)->GetItemType()));
  1209. callableBuilder.Add(stream);
  1210. return TRuntimeNode(callableBuilder.Build(), false);
  1211. }
  1212. TRuntimeNode TProgramBuilder::ToFlow(TRuntimeNode stream) {
  1213. const auto type = stream.GetStaticType();
  1214. MKQL_ENSURE(type->IsStream() || type->IsList() || type->IsOptional(), "Expected stream, list or optional.");
  1215. const auto itemType = type->IsStream() ? AS_TYPE(TStreamType, stream)->GetItemType() :
  1216. type->IsList() ? AS_TYPE(TListType, stream)->GetItemType() : AS_TYPE(TOptionalType, stream)->GetItemType();
  1217. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(itemType));
  1218. callableBuilder.Add(stream);
  1219. return TRuntimeNode(callableBuilder.Build(), false);
  1220. }
  1221. TRuntimeNode TProgramBuilder::FromFlow(TRuntimeNode flow) {
  1222. MKQL_ENSURE(flow.GetStaticType()->IsFlow(), "Expected flow.");
  1223. TCallableBuilder callableBuilder(Env, __func__, NewStreamType(AS_TYPE(TFlowType, flow)->GetItemType()));
  1224. callableBuilder.Add(flow);
  1225. return TRuntimeNode(callableBuilder.Build(), false);
  1226. }
  1227. TRuntimeNode TProgramBuilder::Steal(TRuntimeNode input) {
  1228. if constexpr (RuntimeVersion < 27U) {
  1229. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1230. }
  1231. TCallableBuilder callableBuilder(Env, __func__, input.GetStaticType(), true);
  1232. callableBuilder.Add(input);
  1233. return TRuntimeNode(callableBuilder.Build(), false);
  1234. }
  1235. TRuntimeNode TProgramBuilder::ToBlocks(TRuntimeNode flow) {
  1236. auto* flowType = AS_TYPE(TFlowType, flow.GetStaticType());
  1237. auto* blockType = NewBlockType(flowType->GetItemType(), TBlockType::EShape::Many);
  1238. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(blockType));
  1239. callableBuilder.Add(flow);
  1240. return TRuntimeNode(callableBuilder.Build(), false);
  1241. }
  1242. TType* TProgramBuilder::BuildWideBlockType(const TArrayRef<TType* const>& wideComponents) {
  1243. std::vector<TType*> blockItems;
  1244. blockItems.reserve(wideComponents.size());
  1245. for (size_t i = 0; i < wideComponents.size(); i++) {
  1246. blockItems.push_back(NewBlockType(wideComponents[i], TBlockType::EShape::Many));
  1247. }
  1248. blockItems.push_back(NewBlockType(NewDataType(NUdf::TDataType<ui64>::Id), TBlockType::EShape::Scalar));
  1249. return NewMultiType(blockItems);
  1250. }
  1251. TRuntimeNode TProgramBuilder::WideToBlocks(TRuntimeNode stream) {
  1252. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected WideStream as input type");
  1253. if constexpr (RuntimeVersion < 58U) {
  1254. // Preserve the old behaviour for ABI compatibility.
  1255. // Emit (FromFlow (WideToBlocks (ToFlow (<stream>)))) to
  1256. // process the flow in favor to the given stream following
  1257. // the older MKQL ABI.
  1258. // FIXME: Drop the branch below, when the time comes.
  1259. const auto inputFlow = ToFlow(stream);
  1260. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, inputFlow.GetStaticType()));
  1261. TType* outputMultiType = BuildWideBlockType(wideComponents);
  1262. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(outputMultiType));
  1263. callableBuilder.Add(inputFlow);
  1264. const auto outputFlow = TRuntimeNode(callableBuilder.Build(), false);
  1265. return FromFlow(outputFlow);
  1266. }
  1267. const auto wideComponents = GetWideComponents(AS_TYPE(TStreamType, stream.GetStaticType()));
  1268. TType* outputMultiType = BuildWideBlockType(wideComponents);
  1269. TCallableBuilder callableBuilder(Env, __func__, NewStreamType(outputMultiType));
  1270. callableBuilder.Add(stream);
  1271. return TRuntimeNode(callableBuilder.Build(), false);
  1272. }
  1273. TRuntimeNode TProgramBuilder::FromBlocks(TRuntimeNode flow) {
  1274. auto* flowType = AS_TYPE(TFlowType, flow.GetStaticType());
  1275. auto* blockType = AS_TYPE(TBlockType, flowType->GetItemType());
  1276. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(blockType->GetItemType()));
  1277. callableBuilder.Add(flow);
  1278. return TRuntimeNode(callableBuilder.Build(), false);
  1279. }
  1280. TRuntimeNode TProgramBuilder::WideFromBlocks(TRuntimeNode stream) {
  1281. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected WideStream as input type");
  1282. if constexpr (RuntimeVersion < 55U) {
  1283. // Preserve the old behaviour for ABI compatibility.
  1284. // Emit (FromFlow (WideFromBlocks (ToFlow (<stream>)))) to
  1285. // process the flow in favor to the given stream following
  1286. // the older MKQL ABI.
  1287. // FIXME: Drop the branch below, when the time comes.
  1288. const auto inputFlow = ToFlow(stream);
  1289. auto outputItems = ValidateBlockFlowType(inputFlow.GetStaticType());
  1290. outputItems.pop_back();
  1291. TType* outputMultiType = NewMultiType(outputItems);
  1292. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(outputMultiType));
  1293. callableBuilder.Add(inputFlow);
  1294. const auto outputFlow = TRuntimeNode(callableBuilder.Build(), false);
  1295. return FromFlow(outputFlow);
  1296. }
  1297. auto outputItems = ValidateBlockStreamType(stream.GetStaticType());
  1298. outputItems.pop_back();
  1299. TType* outputMultiType = NewMultiType(outputItems);
  1300. TCallableBuilder callableBuilder(Env, __func__, NewStreamType(outputMultiType));
  1301. callableBuilder.Add(stream);
  1302. return TRuntimeNode(callableBuilder.Build(), false);
  1303. }
  1304. TRuntimeNode TProgramBuilder::WideSkipBlocks(TRuntimeNode flow, TRuntimeNode count) {
  1305. return BuildWideSkipTakeBlocks(__func__, flow, count);
  1306. }
  1307. TRuntimeNode TProgramBuilder::WideTakeBlocks(TRuntimeNode flow, TRuntimeNode count) {
  1308. return BuildWideSkipTakeBlocks(__func__, flow, count);
  1309. }
  1310. TRuntimeNode TProgramBuilder::WideTopBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
  1311. return BuildWideTopOrSort(__func__, flow, count, keys);
  1312. }
  1313. TRuntimeNode TProgramBuilder::WideTopSortBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
  1314. return BuildWideTopOrSort(__func__, flow, count, keys);
  1315. }
  1316. TRuntimeNode TProgramBuilder::WideSortBlocks(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
  1317. return BuildWideTopOrSort(__func__, flow, Nothing(), keys);
  1318. }
  1319. TRuntimeNode TProgramBuilder::AsScalar(TRuntimeNode value) {
  1320. TCallableBuilder callableBuilder(Env, __func__, NewBlockType(value.GetStaticType(), TBlockType::EShape::Scalar));
  1321. callableBuilder.Add(value);
  1322. return TRuntimeNode(callableBuilder.Build(), false);
  1323. }
  1324. TRuntimeNode TProgramBuilder::ReplicateScalar(TRuntimeNode value, TRuntimeNode count) {
  1325. if constexpr (RuntimeVersion < 43U) {
  1326. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1327. }
  1328. auto valueType = AS_TYPE(TBlockType, value.GetStaticType());
  1329. auto countType = AS_TYPE(TBlockType, count.GetStaticType());
  1330. MKQL_ENSURE(valueType->GetShape() == TBlockType::EShape::Scalar, "Expecting scalar as first arguemnt");
  1331. MKQL_ENSURE(countType->GetShape() == TBlockType::EShape::Scalar, "Expecting scalar as second arguemnt");
  1332. MKQL_ENSURE(countType->GetItemType()->IsData(), "Expected scalar data as second argument");
  1333. MKQL_ENSURE(AS_TYPE(TDataType, countType->GetItemType())->GetSchemeType() ==
  1334. NUdf::TDataType<ui64>::Id, "Expected scalar ui64 as second argument");
  1335. auto outputType = NewBlockType(valueType->GetItemType(), TBlockType::EShape::Many);
  1336. TCallableBuilder callableBuilder(Env, __func__, outputType);
  1337. callableBuilder.Add(value);
  1338. callableBuilder.Add(count);
  1339. return TRuntimeNode(callableBuilder.Build(), false);
  1340. }
  1341. TRuntimeNode TProgramBuilder::BlockCompress(TRuntimeNode flow, ui32 bitmapIndex) {
  1342. auto blockItemTypes = ValidateBlockFlowType(flow.GetStaticType());
  1343. MKQL_ENSURE(blockItemTypes.size() >= 2, "Expected at least two input columns");
  1344. MKQL_ENSURE(bitmapIndex < blockItemTypes.size() - 1, "Invalid bitmap index");
  1345. MKQL_ENSURE(AS_TYPE(TDataType, blockItemTypes[bitmapIndex])->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected Bool as bitmap column type");
  1346. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  1347. MKQL_ENSURE(wideComponents.size() == blockItemTypes.size(), "Unexpected tuple size");
  1348. std::vector<TType*> flowItems;
  1349. for (size_t i = 0; i < wideComponents.size(); ++i) {
  1350. if (i == bitmapIndex) {
  1351. continue;
  1352. }
  1353. flowItems.push_back(wideComponents[i]);
  1354. }
  1355. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(flowItems)));
  1356. callableBuilder.Add(flow);
  1357. callableBuilder.Add(NewDataLiteral<ui32>(bitmapIndex));
  1358. return TRuntimeNode(callableBuilder.Build(), false);
  1359. }
  1360. TRuntimeNode TProgramBuilder::BlockExpandChunked(TRuntimeNode comp) {
  1361. if (comp.GetStaticType()->IsStream()) {
  1362. ValidateBlockStreamType(comp.GetStaticType());
  1363. } else {
  1364. ValidateBlockFlowType(comp.GetStaticType());
  1365. }
  1366. TCallableBuilder callableBuilder(Env, __func__, comp.GetStaticType());
  1367. callableBuilder.Add(comp);
  1368. return TRuntimeNode(callableBuilder.Build(), false);
  1369. }
  1370. TRuntimeNode TProgramBuilder::BlockCoalesce(TRuntimeNode first, TRuntimeNode second) {
  1371. auto firstType = AS_TYPE(TBlockType, first.GetStaticType());
  1372. auto secondType = AS_TYPE(TBlockType, second.GetStaticType());
  1373. auto firstItemType = firstType->GetItemType();
  1374. auto secondItemType = secondType->GetItemType();
  1375. MKQL_ENSURE(firstItemType->IsOptional() || firstItemType->IsPg(), "Expecting Optional or Pg type as first argument");
  1376. if (!firstItemType->IsSameType(*secondItemType)) {
  1377. bool firstOptional;
  1378. firstItemType = UnpackOptional(firstItemType, firstOptional);
  1379. MKQL_ENSURE(firstItemType->IsSameType(*secondItemType), "Uncompatible arguemnt types");
  1380. }
  1381. auto outputType = NewBlockType(secondType->GetItemType(), GetResultShape({firstType, secondType}));
  1382. TCallableBuilder callableBuilder(Env, __func__, outputType);
  1383. callableBuilder.Add(first);
  1384. callableBuilder.Add(second);
  1385. return TRuntimeNode(callableBuilder.Build(), false);
  1386. }
  1387. TRuntimeNode TProgramBuilder::BlockExists(TRuntimeNode data) {
  1388. auto dataType = AS_TYPE(TBlockType, data.GetStaticType());
  1389. auto outputType = NewBlockType(NewDataType(NUdf::TDataType<bool>::Id), dataType->GetShape());
  1390. TCallableBuilder callableBuilder(Env, __func__, outputType);
  1391. callableBuilder.Add(data);
  1392. return TRuntimeNode(callableBuilder.Build(), false);
  1393. }
  1394. TRuntimeNode TProgramBuilder::BlockMember(TRuntimeNode structObj, const std::string_view& memberName) {
  1395. auto blockType = AS_TYPE(TBlockType, structObj.GetStaticType());
  1396. bool isOptional;
  1397. const auto type = AS_TYPE(TStructType, UnpackOptional(blockType->GetItemType(), isOptional));
  1398. const auto memberIndex = type->GetMemberIndex(memberName);
  1399. auto memberType = type->GetMemberType(memberIndex);
  1400. if (isOptional && !memberType->IsOptional() && !memberType->IsNull() && !memberType->IsPg()) {
  1401. memberType = NewOptionalType(memberType);
  1402. }
  1403. auto returnType = NewBlockType(memberType, blockType->GetShape());
  1404. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1405. callableBuilder.Add(structObj);
  1406. callableBuilder.Add(NewDataLiteral<ui32>(memberIndex));
  1407. return TRuntimeNode(callableBuilder.Build(), false);
  1408. }
  1409. TRuntimeNode TProgramBuilder::BlockNth(TRuntimeNode tuple, ui32 index) {
  1410. auto blockType = AS_TYPE(TBlockType, tuple.GetStaticType());
  1411. bool isOptional;
  1412. const auto type = AS_TYPE(TTupleType, UnpackOptional(blockType->GetItemType(), isOptional));
  1413. MKQL_ENSURE(index < type->GetElementsCount(), "Index out of range: " << index <<
  1414. " is not less than " << type->GetElementsCount());
  1415. auto itemType = type->GetElementType(index);
  1416. if (isOptional && !itemType->IsOptional() && !itemType->IsNull() && !itemType->IsPg()) {
  1417. itemType = TOptionalType::Create(itemType, Env);
  1418. }
  1419. auto returnType = NewBlockType(itemType, blockType->GetShape());
  1420. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1421. callableBuilder.Add(tuple);
  1422. callableBuilder.Add(NewDataLiteral<ui32>(index));
  1423. return TRuntimeNode(callableBuilder.Build(), false);
  1424. }
  1425. TRuntimeNode TProgramBuilder::BlockAsStruct(const TArrayRef<std::pair<std::string_view, TRuntimeNode>>& args) {
  1426. MKQL_ENSURE(!args.empty(), "Expected at least one argument");
  1427. TBlockType::EShape resultShape = TBlockType::EShape::Scalar;
  1428. TVector<std::pair<std::string_view, TType*>> members;
  1429. for (const auto& x : args) {
  1430. auto blockType = AS_TYPE(TBlockType, x.second.GetStaticType());
  1431. members.emplace_back(x.first, blockType->GetItemType());
  1432. if (blockType->GetShape() == TBlockType::EShape::Many) {
  1433. resultShape = TBlockType::EShape::Many;
  1434. }
  1435. }
  1436. auto returnType = NewBlockType(NewStructType(members), resultShape);
  1437. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1438. for (const auto& x : args) {
  1439. callableBuilder.Add(x.second);
  1440. }
  1441. return TRuntimeNode(callableBuilder.Build(), false);
  1442. }
  1443. TRuntimeNode TProgramBuilder::BlockAsTuple(const TArrayRef<const TRuntimeNode>& args) {
  1444. MKQL_ENSURE(!args.empty(), "Expected at least one argument");
  1445. TBlockType::EShape resultShape = TBlockType::EShape::Scalar;
  1446. TVector<TType*> types;
  1447. for (const auto& x : args) {
  1448. auto blockType = AS_TYPE(TBlockType, x.GetStaticType());
  1449. types.push_back(blockType->GetItemType());
  1450. if (blockType->GetShape() == TBlockType::EShape::Many) {
  1451. resultShape = TBlockType::EShape::Many;
  1452. }
  1453. }
  1454. auto tupleType = NewTupleType(types);
  1455. auto returnType = NewBlockType(tupleType, resultShape);
  1456. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1457. for (const auto& x : args) {
  1458. callableBuilder.Add(x);
  1459. }
  1460. return TRuntimeNode(callableBuilder.Build(), false);
  1461. }
  1462. TRuntimeNode TProgramBuilder::BlockToPg(TRuntimeNode input, TType* returnType) {
  1463. if constexpr (RuntimeVersion < 37U) {
  1464. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1465. }
  1466. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1467. callableBuilder.Add(input);
  1468. return TRuntimeNode(callableBuilder.Build(), false);
  1469. }
  1470. TRuntimeNode TProgramBuilder::BlockFromPg(TRuntimeNode input, TType* returnType) {
  1471. if constexpr (RuntimeVersion < 37U) {
  1472. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1473. }
  1474. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1475. callableBuilder.Add(input);
  1476. return TRuntimeNode(callableBuilder.Build(), false);
  1477. }
  1478. TRuntimeNode TProgramBuilder::BlockNot(TRuntimeNode data) {
  1479. auto dataType = AS_TYPE(TBlockType, data.GetStaticType());
  1480. bool isOpt;
  1481. MKQL_ENSURE(UnpackOptionalData(dataType->GetItemType(), isOpt)->GetSchemeType() == NUdf::TDataType<bool>::Id, "Requires boolean args.");
  1482. TCallableBuilder callableBuilder(Env, __func__, data.GetStaticType());
  1483. callableBuilder.Add(data);
  1484. return TRuntimeNode(callableBuilder.Build(), false);
  1485. }
  1486. TRuntimeNode TProgramBuilder::BlockAnd(TRuntimeNode first, TRuntimeNode second) {
  1487. return BuildBlockLogical(__func__, first, second);
  1488. }
  1489. TRuntimeNode TProgramBuilder::BlockOr(TRuntimeNode first, TRuntimeNode second) {
  1490. return BuildBlockLogical(__func__, first, second);
  1491. }
  1492. TRuntimeNode TProgramBuilder::BlockXor(TRuntimeNode first, TRuntimeNode second) {
  1493. return BuildBlockLogical(__func__, first, second);
  1494. }
  1495. TRuntimeNode TProgramBuilder::BlockDecimalDiv(TRuntimeNode first, TRuntimeNode second) {
  1496. return BuildBlockDecimalBinary(__func__, first, second);
  1497. }
  1498. TRuntimeNode TProgramBuilder::BlockDecimalMod(TRuntimeNode first, TRuntimeNode second) {
  1499. return BuildBlockDecimalBinary(__func__, first, second);
  1500. }
  1501. TRuntimeNode TProgramBuilder::BlockDecimalMul(TRuntimeNode first, TRuntimeNode second) {
  1502. return BuildBlockDecimalBinary(__func__, first, second);
  1503. }
  1504. TRuntimeNode TProgramBuilder::ListFromRange(TRuntimeNode start, TRuntimeNode end, TRuntimeNode step) {
  1505. MKQL_ENSURE(start.GetStaticType()->IsData(), "Expected data");
  1506. MKQL_ENSURE(end.GetStaticType()->IsSameType(*start.GetStaticType()), "Mismatch type");
  1507. if constexpr (RuntimeVersion < 24U) {
  1508. MKQL_ENSURE(IsNumericType(AS_TYPE(TDataType, start)->GetSchemeType()), "Expected numeric");
  1509. } else {
  1510. MKQL_ENSURE(IsNumericType(AS_TYPE(TDataType, start)->GetSchemeType()) ||
  1511. IsDateType(AS_TYPE(TDataType, start)->GetSchemeType()) ||
  1512. IsTzDateType(AS_TYPE(TDataType, start)->GetSchemeType()) ||
  1513. IsIntervalType(AS_TYPE(TDataType, start)->GetSchemeType()),
  1514. "Expected numeric, date or tzdate");
  1515. if (IsNumericType(AS_TYPE(TDataType, start)->GetSchemeType())) {
  1516. MKQL_ENSURE(IsNumericType(AS_TYPE(TDataType, step)->GetSchemeType()), "Expected numeric");
  1517. } else {
  1518. MKQL_ENSURE(IsIntervalType(AS_TYPE(TDataType, step)->GetSchemeType()), "Expected interval");
  1519. }
  1520. }
  1521. TCallableBuilder callableBuilder(Env, __func__, TListType::Create(start.GetStaticType(), Env));
  1522. callableBuilder.Add(start);
  1523. callableBuilder.Add(end);
  1524. callableBuilder.Add(step);
  1525. return TRuntimeNode(callableBuilder.Build(), false);
  1526. }
  1527. TRuntimeNode TProgramBuilder::Switch(TRuntimeNode stream,
  1528. const TArrayRef<const TSwitchInput>& handlerInputs,
  1529. std::function<TRuntimeNode(ui32 index, TRuntimeNode item)> handler,
  1530. ui64 memoryLimitBytes, TType* returnType) {
  1531. MKQL_ENSURE(stream.GetStaticType()->IsStream() || stream.GetStaticType()->IsFlow(), "Expected stream or flow.");
  1532. std::vector<TRuntimeNode> argNodes(handlerInputs.size());
  1533. std::vector<TRuntimeNode> outputNodes(handlerInputs.size());
  1534. for (ui32 i = 0; i < handlerInputs.size(); ++i) {
  1535. TRuntimeNode arg = Arg(handlerInputs[i].InputType);
  1536. argNodes[i] = arg;
  1537. outputNodes[i] = handler(i, arg);
  1538. }
  1539. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1540. callableBuilder.Add(stream);
  1541. callableBuilder.Add(NewDataLiteral<ui64>(memoryLimitBytes));
  1542. for (ui32 i = 0; i < handlerInputs.size(); ++i) {
  1543. std::vector<TRuntimeNode> tupleElems;
  1544. for (auto index : handlerInputs[i].Indicies) {
  1545. tupleElems.push_back(NewDataLiteral<ui32>(index));
  1546. }
  1547. auto indiciesTuple = NewTuple(tupleElems);
  1548. callableBuilder.Add(indiciesTuple);
  1549. callableBuilder.Add(argNodes[i]);
  1550. callableBuilder.Add(outputNodes[i]);
  1551. if (!handlerInputs[i].ResultVariantOffset) {
  1552. callableBuilder.Add(NewVoid());
  1553. } else {
  1554. callableBuilder.Add(NewDataLiteral<ui32>(*handlerInputs[i].ResultVariantOffset));
  1555. }
  1556. }
  1557. return TRuntimeNode(callableBuilder.Build(), false);
  1558. }
  1559. TRuntimeNode TProgramBuilder::HasItems(TRuntimeNode listOrDict) {
  1560. return BuildContainerProperty<bool>(__func__, listOrDict);
  1561. }
  1562. TRuntimeNode TProgramBuilder::Reverse(TRuntimeNode list) {
  1563. bool isOptional = false;
  1564. const auto listType = UnpackOptional(list, isOptional);
  1565. if (isOptional) {
  1566. return Map(list, [&](TRuntimeNode unpacked) { return Reverse(unpacked); } );
  1567. }
  1568. const auto listDetailedType = AS_TYPE(TListType, listType);
  1569. const auto itemType = listDetailedType->GetItemType();
  1570. ThrowIfListOfVoid(itemType);
  1571. TCallableBuilder callableBuilder(Env, __func__, listType);
  1572. callableBuilder.Add(list);
  1573. return TRuntimeNode(callableBuilder.Build(), false);
  1574. }
  1575. TRuntimeNode TProgramBuilder::Skip(TRuntimeNode list, TRuntimeNode count) {
  1576. return BuildTake(__func__, list, count);
  1577. }
  1578. TRuntimeNode TProgramBuilder::Take(TRuntimeNode list, TRuntimeNode count) {
  1579. return BuildTake(__func__, list, count);
  1580. }
  1581. TRuntimeNode TProgramBuilder::Sort(TRuntimeNode list, TRuntimeNode ascending, const TUnaryLambda& keyExtractor)
  1582. {
  1583. return BuildSort(__func__, list, ascending, keyExtractor);
  1584. }
  1585. TRuntimeNode TProgramBuilder::WideTop(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys)
  1586. {
  1587. return BuildWideTopOrSort(__func__, flow, count, keys);
  1588. }
  1589. TRuntimeNode TProgramBuilder::WideTopSort(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys)
  1590. {
  1591. return BuildWideTopOrSort(__func__, flow, count, keys);
  1592. }
  1593. TRuntimeNode TProgramBuilder::WideSort(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys)
  1594. {
  1595. return BuildWideTopOrSort(__func__, flow, Nothing(), keys);
  1596. }
  1597. TRuntimeNode TProgramBuilder::BuildWideTopOrSort(const std::string_view& callableName, TRuntimeNode flow, TMaybe<TRuntimeNode> count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
  1598. if (count) {
  1599. if constexpr (RuntimeVersion < 33U) {
  1600. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << callableName;
  1601. }
  1602. } else {
  1603. if constexpr (RuntimeVersion < 34U) {
  1604. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << callableName;
  1605. }
  1606. }
  1607. const auto width = GetWideComponentsCount(AS_TYPE(TFlowType, flow.GetStaticType()));
  1608. MKQL_ENSURE(!keys.empty() && keys.size() <= width, "Unexpected keys count: " << keys.size());
  1609. TCallableBuilder callableBuilder(Env, callableName, flow.GetStaticType());
  1610. callableBuilder.Add(flow);
  1611. if (count) {
  1612. callableBuilder.Add(*count);
  1613. }
  1614. std::for_each(keys.cbegin(), keys.cend(), [&](const std::pair<ui32, TRuntimeNode>& key) {
  1615. MKQL_ENSURE(key.first < width, "Key index too large: " << key.first);
  1616. callableBuilder.Add(NewDataLiteral(key.first));
  1617. callableBuilder.Add(key.second);
  1618. });
  1619. return TRuntimeNode(callableBuilder.Build(), false);
  1620. }
  1621. TRuntimeNode TProgramBuilder::Top(TRuntimeNode flow, TRuntimeNode count, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) {
  1622. if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
  1623. const TUnaryLambda getKey = [&](TRuntimeNode item) { return Nth(item, 0U); };
  1624. const TUnaryLambda getItem = [&](TRuntimeNode item) { return Nth(item, 1U); };
  1625. const TUnaryLambda cacheKeyExtractor = [&](TRuntimeNode item) {
  1626. return NewTuple({keyExtractor(item), item});
  1627. };
  1628. return FlatMap(Condense1(Map(flow, cacheKeyExtractor),
  1629. [&](TRuntimeNode item) { return AsList(item); },
  1630. [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
  1631. [&](TRuntimeNode item, TRuntimeNode state) {
  1632. return KeepTop(count, state, item, ascending, getKey);
  1633. }
  1634. ),
  1635. [&](TRuntimeNode list) { return Map(Top(list, count, ascending, getKey), getItem); }
  1636. );
  1637. }
  1638. return BuildListNth(__func__, flow, count, ascending, keyExtractor);
  1639. }
  1640. TRuntimeNode TProgramBuilder::TopSort(TRuntimeNode flow, TRuntimeNode count, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) {
  1641. if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
  1642. const TUnaryLambda getKey = [&](TRuntimeNode item) { return Nth(item, 0U); };
  1643. const TUnaryLambda getItem = [&](TRuntimeNode item) { return Nth(item, 1U); };
  1644. const TUnaryLambda cacheKeyExtractor = [&](TRuntimeNode item) {
  1645. return NewTuple({keyExtractor(item), item});
  1646. };
  1647. return FlatMap(Condense1(Map(flow, cacheKeyExtractor),
  1648. [&](TRuntimeNode item) { return AsList(item); },
  1649. [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
  1650. [&](TRuntimeNode item, TRuntimeNode state) {
  1651. return KeepTop(count, state, item, ascending, getKey);
  1652. }
  1653. ),
  1654. [&](TRuntimeNode list) { return Map(TopSort(list, count, ascending, getKey), getItem); }
  1655. );
  1656. }
  1657. if constexpr (RuntimeVersion >= 25U)
  1658. return BuildListNth(__func__, flow, count, ascending, keyExtractor);
  1659. else
  1660. return BuildListSort("Sort", BuildListNth("Top", flow, count, ascending, keyExtractor), ascending, keyExtractor);
  1661. }
  1662. TRuntimeNode TProgramBuilder::KeepTop(TRuntimeNode count, TRuntimeNode list, TRuntimeNode item, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) {
  1663. const auto listType = list.GetStaticType();
  1664. MKQL_ENSURE(listType->IsList(), "Expected list.");
  1665. const auto itemType = static_cast<const TListType&>(*listType).GetItemType();
  1666. ThrowIfListOfVoid(itemType);
  1667. MKQL_ENSURE(count.GetStaticType()->IsData(), "Expected data");
  1668. MKQL_ENSURE(static_cast<const TDataType&>(*count.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  1669. MKQL_ENSURE(itemType->IsSameType(*item.GetStaticType()), "Types of list and item are different.");
  1670. const auto ascendingType = ascending.GetStaticType();
  1671. const auto itemArg = Arg(itemType);
  1672. auto key = keyExtractor(itemArg);
  1673. const auto hotkey = Arg(key.GetStaticType());
  1674. if (ascendingType->IsTuple()) {
  1675. const auto ascendingTuple = AS_TYPE(TTupleType, ascendingType);
  1676. if (ascendingTuple->GetElementsCount() == 0) {
  1677. return If(AggrLess(Length(list), count), Append(list, item), list);
  1678. }
  1679. if (ascendingTuple->GetElementsCount() == 1) {
  1680. ascending = Nth(ascending, 0);
  1681. key = Nth(key, 0);
  1682. }
  1683. }
  1684. TCallableBuilder callableBuilder(Env, __func__, listType);
  1685. callableBuilder.Add(count);
  1686. callableBuilder.Add(list);
  1687. callableBuilder.Add(item);
  1688. callableBuilder.Add(itemArg);
  1689. callableBuilder.Add(key);
  1690. callableBuilder.Add(ascending);
  1691. callableBuilder.Add(hotkey);
  1692. return TRuntimeNode(callableBuilder.Build(), false);
  1693. }
  1694. TRuntimeNode TProgramBuilder::Contains(TRuntimeNode dict, TRuntimeNode key) {
  1695. if constexpr (RuntimeVersion >= 25U)
  1696. if (!dict.GetStaticType()->IsDict())
  1697. return DataCompare(__func__, dict, key);
  1698. const auto keyType = AS_TYPE(TDictType, dict.GetStaticType())->GetKeyType();
  1699. MKQL_ENSURE(keyType->IsSameType(*key.GetStaticType()), "Key type mismatch. Requred: " << *keyType << ", but got: " << *key.GetStaticType());
  1700. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<bool>::Id));
  1701. callableBuilder.Add(dict);
  1702. callableBuilder.Add(key);
  1703. return TRuntimeNode(callableBuilder.Build(), false);
  1704. }
  1705. TRuntimeNode TProgramBuilder::Lookup(TRuntimeNode dict, TRuntimeNode key) {
  1706. const auto dictType = AS_TYPE(TDictType, dict.GetStaticType());
  1707. const auto keyType = dictType->GetKeyType();
  1708. MKQL_ENSURE(keyType->IsSameType(*key.GetStaticType()), "Key type mismatch. Requred: " << *keyType << ", but got: " << *key.GetStaticType());
  1709. TCallableBuilder callableBuilder(Env, __func__, NewOptionalType(dictType->GetPayloadType()));
  1710. callableBuilder.Add(dict);
  1711. callableBuilder.Add(key);
  1712. return TRuntimeNode(callableBuilder.Build(), false);
  1713. }
  1714. TRuntimeNode TProgramBuilder::DictItems(TRuntimeNode dict, EDictItems mode) {
  1715. const auto dictTypeChecked = AS_TYPE(TDictType, dict.GetStaticType());
  1716. TType* itemType;
  1717. switch (mode) {
  1718. case EDictItems::Both: {
  1719. const std::array<TType*, 2U> tupleTypes = {{ dictTypeChecked->GetKeyType(), dictTypeChecked->GetPayloadType() }};
  1720. itemType = NewTupleType(tupleTypes);
  1721. break;
  1722. }
  1723. case EDictItems::Keys: itemType = dictTypeChecked->GetKeyType(); break;
  1724. case EDictItems::Payloads: itemType = dictTypeChecked->GetPayloadType(); break;
  1725. }
  1726. TCallableBuilder callableBuilder(Env, __func__, NewListType(itemType));
  1727. callableBuilder.Add(dict);
  1728. callableBuilder.Add(NewDataLiteral((ui32)mode));
  1729. return TRuntimeNode(callableBuilder.Build(), false);
  1730. }
  1731. TRuntimeNode TProgramBuilder::DictItems(TRuntimeNode dict) {
  1732. if constexpr (RuntimeVersion < 6U) {
  1733. return DictItems(dict, EDictItems::Both);
  1734. }
  1735. const auto dictTypeChecked = AS_TYPE(TDictType, dict.GetStaticType());
  1736. const auto itemType = NewTupleType({ dictTypeChecked->GetKeyType(), dictTypeChecked->GetPayloadType() });
  1737. TCallableBuilder callableBuilder(Env, __func__, NewListType(itemType));
  1738. callableBuilder.Add(dict);
  1739. return TRuntimeNode(callableBuilder.Build(), false);
  1740. }
  1741. TRuntimeNode TProgramBuilder::DictKeys(TRuntimeNode dict) {
  1742. if constexpr (RuntimeVersion < 6U) {
  1743. return DictItems(dict, EDictItems::Keys);
  1744. }
  1745. const auto dictTypeChecked = AS_TYPE(TDictType, dict.GetStaticType());
  1746. TCallableBuilder callableBuilder(Env, __func__, NewListType(dictTypeChecked->GetKeyType()));
  1747. callableBuilder.Add(dict);
  1748. return TRuntimeNode(callableBuilder.Build(), false);
  1749. }
  1750. TRuntimeNode TProgramBuilder::DictPayloads(TRuntimeNode dict) {
  1751. if constexpr (RuntimeVersion < 6U) {
  1752. return DictItems(dict, EDictItems::Payloads);
  1753. }
  1754. const auto dictTypeChecked = AS_TYPE(TDictType, dict.GetStaticType());
  1755. TCallableBuilder callableBuilder(Env, __func__, NewListType(dictTypeChecked->GetPayloadType()));
  1756. callableBuilder.Add(dict);
  1757. return TRuntimeNode(callableBuilder.Build(), false);
  1758. }
  1759. TRuntimeNode TProgramBuilder::ToIndexDict(TRuntimeNode list) {
  1760. const auto itemType = AS_TYPE(TListType, list.GetStaticType())->GetItemType();
  1761. ThrowIfListOfVoid(itemType);
  1762. const auto keyType = NewDataType(NUdf::TDataType<ui64>::Id);
  1763. const auto dictType = NewDictType(keyType, itemType, false);
  1764. TCallableBuilder callableBuilder(Env, __func__, dictType);
  1765. callableBuilder.Add(list);
  1766. return TRuntimeNode(callableBuilder.Build(), false);
  1767. }
  1768. TRuntimeNode TProgramBuilder::JoinDict(TRuntimeNode dict1, bool isMulti1, TRuntimeNode dict2, bool isMulti2, EJoinKind joinKind) {
  1769. const auto dict1type = AS_TYPE(TDictType, dict1);
  1770. const auto dict2type = AS_TYPE(TDictType, dict2);
  1771. MKQL_ENSURE(dict1type->GetKeyType()->IsSameType(*dict2type->GetKeyType()), "Dict key types must be the same");
  1772. if (joinKind == EJoinKind::RightOnly || joinKind == EJoinKind::RightSemi)
  1773. MKQL_ENSURE(dict1type->GetPayloadType()->IsVoid(), "Void required for first dict payload.");
  1774. else if (isMulti1)
  1775. MKQL_ENSURE(dict1type->GetPayloadType()->IsList(), "List required for first dict payload.");
  1776. if (joinKind == EJoinKind::LeftOnly || joinKind == EJoinKind::LeftSemi)
  1777. MKQL_ENSURE(dict2type->GetPayloadType()->IsVoid(), "Void required for second dict payload.");
  1778. else if (isMulti2)
  1779. MKQL_ENSURE(dict2type->GetPayloadType()->IsList(), "List required for second dict payload.");
  1780. std::array<TType*, 2> tupleItems = {{ dict1type->GetPayloadType(), dict2type->GetPayloadType() }};
  1781. if (isMulti1 && tupleItems.front()->IsList())
  1782. tupleItems.front() = AS_TYPE(TListType, tupleItems.front())->GetItemType();
  1783. if (isMulti2 && tupleItems.back()->IsList())
  1784. tupleItems.back() = AS_TYPE(TListType, tupleItems.back())->GetItemType();
  1785. if (IsLeftOptional(joinKind))
  1786. tupleItems.front() = NewOptionalType(tupleItems.front());
  1787. if (IsRightOptional(joinKind))
  1788. tupleItems.back() = NewOptionalType(tupleItems.back());
  1789. TType* itemType;
  1790. if (joinKind == EJoinKind::LeftOnly || joinKind == EJoinKind::LeftSemi)
  1791. itemType = tupleItems.front();
  1792. else if (joinKind == EJoinKind::RightOnly || joinKind == EJoinKind::RightSemi)
  1793. itemType = tupleItems.back();
  1794. else
  1795. itemType = NewTupleType(tupleItems);
  1796. const auto returnType = NewListType(itemType);
  1797. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1798. callableBuilder.Add(dict1);
  1799. callableBuilder.Add(dict2);
  1800. callableBuilder.Add(NewDataLiteral(isMulti1));
  1801. callableBuilder.Add(NewDataLiteral(isMulti2));
  1802. callableBuilder.Add(NewDataLiteral(ui32(joinKind)));
  1803. return TRuntimeNode(callableBuilder.Build(), false);
  1804. }
  1805. TRuntimeNode TProgramBuilder::GraceJoinCommon(const TStringBuf& funcName, TRuntimeNode flowLeft, TRuntimeNode flowRight, EJoinKind joinKind,
  1806. const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& rightKeyColumns,
  1807. const TArrayRef<const ui32>& leftRenames, const TArrayRef<const ui32>& rightRenames, TType* returnType, EAnyJoinSettings anyJoinSettings ) {
  1808. MKQL_ENSURE(!leftKeyColumns.empty(), "At least one key column must be specified");
  1809. if (flowRight) {
  1810. MKQL_ENSURE(!rightKeyColumns.empty(), "At least one key column must be specified");
  1811. }
  1812. TRuntimeNode::TList leftKeyColumnsNodes, rightKeyColumnsNodes, leftRenamesNodes, rightRenamesNodes;
  1813. leftKeyColumnsNodes.reserve(leftKeyColumns.size());
  1814. std::transform(leftKeyColumns.cbegin(), leftKeyColumns.cend(), std::back_inserter(leftKeyColumnsNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  1815. rightKeyColumnsNodes.reserve(rightKeyColumns.size());
  1816. std::transform(rightKeyColumns.cbegin(), rightKeyColumns.cend(), std::back_inserter(rightKeyColumnsNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  1817. leftRenamesNodes.reserve(leftRenames.size());
  1818. std::transform(leftRenames.cbegin(), leftRenames.cend(), std::back_inserter(leftRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  1819. rightRenamesNodes.reserve(rightRenames.size());
  1820. std::transform(rightRenames.cbegin(), rightRenames.cend(), std::back_inserter(rightRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  1821. TCallableBuilder callableBuilder(Env, funcName, returnType);
  1822. callableBuilder.Add(flowLeft);
  1823. if (flowRight) {
  1824. callableBuilder.Add(flowRight);
  1825. }
  1826. callableBuilder.Add(NewDataLiteral((ui32)joinKind));
  1827. callableBuilder.Add(NewTuple(leftKeyColumnsNodes));
  1828. callableBuilder.Add(NewTuple(rightKeyColumnsNodes));
  1829. callableBuilder.Add(NewTuple(leftRenamesNodes));
  1830. callableBuilder.Add(NewTuple(rightRenamesNodes));
  1831. callableBuilder.Add(NewDataLiteral((ui32)anyJoinSettings));
  1832. return TRuntimeNode(callableBuilder.Build(), false);
  1833. }
  1834. TRuntimeNode TProgramBuilder::GraceJoin(TRuntimeNode flowLeft, TRuntimeNode flowRight, EJoinKind joinKind,
  1835. const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& rightKeyColumns,
  1836. const TArrayRef<const ui32>& leftRenames, const TArrayRef<const ui32>& rightRenames, TType* returnType, EAnyJoinSettings anyJoinSettings ) {
  1837. return GraceJoinCommon(__func__, flowLeft, flowRight, joinKind, leftKeyColumns, rightKeyColumns, leftRenames, rightRenames, returnType, anyJoinSettings);
  1838. }
  1839. TRuntimeNode TProgramBuilder::GraceSelfJoin(TRuntimeNode flowLeft, EJoinKind joinKind, const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& rightKeyColumns,
  1840. const TArrayRef<const ui32>& leftRenames, const TArrayRef<const ui32>& rightRenames, TType* returnType, EAnyJoinSettings anyJoinSettings ) {
  1841. if constexpr (RuntimeVersion < 40U) {
  1842. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1843. }
  1844. return GraceJoinCommon(__func__, flowLeft, {}, joinKind, leftKeyColumns, rightKeyColumns, leftRenames, rightRenames, returnType, anyJoinSettings);
  1845. }
  1846. TRuntimeNode TProgramBuilder::ToSortedDict(TRuntimeNode list, bool all, const TUnaryLambda& keySelector,
  1847. const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1848. return ToDict(list, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1849. }
  1850. TRuntimeNode TProgramBuilder::ToHashedDict(TRuntimeNode list, bool all, const TUnaryLambda& keySelector,
  1851. const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1852. return ToDict(list, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1853. }
  1854. TRuntimeNode TProgramBuilder::SqueezeToSortedDict(TRuntimeNode stream, bool all, const TUnaryLambda& keySelector,
  1855. const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1856. return SqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1857. }
  1858. TRuntimeNode TProgramBuilder::SqueezeToHashedDict(TRuntimeNode stream, bool all, const TUnaryLambda& keySelector,
  1859. const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1860. return SqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1861. }
  1862. TRuntimeNode TProgramBuilder::NarrowSqueezeToSortedDict(TRuntimeNode stream, bool all, const TNarrowLambda& keySelector,
  1863. const TNarrowLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1864. return NarrowSqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1865. }
  1866. TRuntimeNode TProgramBuilder::NarrowSqueezeToHashedDict(TRuntimeNode stream, bool all, const TNarrowLambda& keySelector,
  1867. const TNarrowLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1868. return NarrowSqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1869. }
  1870. TRuntimeNode TProgramBuilder::SqueezeToList(TRuntimeNode flow, TRuntimeNode limit) {
  1871. if constexpr (RuntimeVersion < 25U) {
  1872. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1873. }
  1874. const auto itemType = AS_TYPE(TFlowType, flow.GetStaticType())->GetItemType();
  1875. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewListType(itemType)));
  1876. callableBuilder.Add(flow);
  1877. callableBuilder.Add(limit);
  1878. return TRuntimeNode(callableBuilder.Build(), false);
  1879. }
  1880. TRuntimeNode TProgramBuilder::Append(TRuntimeNode list, TRuntimeNode item) {
  1881. auto listType = list.GetStaticType();
  1882. AS_TYPE(TListType, listType);
  1883. const auto& listDetailedType = static_cast<const TListType&>(*listType);
  1884. auto itemType = item.GetStaticType();
  1885. MKQL_ENSURE(itemType->IsSameType(*listDetailedType.GetItemType()), "Types of list and item are different");
  1886. TCallableBuilder callableBuilder(Env, __func__, listType);
  1887. callableBuilder.Add(list);
  1888. callableBuilder.Add(item);
  1889. return TRuntimeNode(callableBuilder.Build(), false);
  1890. }
  1891. TRuntimeNode TProgramBuilder::Prepend(TRuntimeNode item, TRuntimeNode list) {
  1892. auto listType = list.GetStaticType();
  1893. AS_TYPE(TListType, listType);
  1894. const auto& listDetailedType = static_cast<const TListType&>(*listType);
  1895. auto itemType = item.GetStaticType();
  1896. MKQL_ENSURE(itemType->IsSameType(*listDetailedType.GetItemType()), "Types of list and item are different");
  1897. TCallableBuilder callableBuilder(Env, __func__, listType);
  1898. callableBuilder.Add(item);
  1899. callableBuilder.Add(list);
  1900. return TRuntimeNode(callableBuilder.Build(), false);
  1901. }
  1902. TRuntimeNode TProgramBuilder::BuildExtend(const std::string_view& callableName, const TArrayRef<const TRuntimeNode>& lists) {
  1903. MKQL_ENSURE(lists.size() > 0, "Expected at least 1 list or flow");
  1904. if (lists.size() == 1) {
  1905. return lists.front();
  1906. }
  1907. auto listType = lists.front().GetStaticType();
  1908. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected either flow, list or stream");
  1909. for (ui32 i = 1; i < lists.size(); ++i) {
  1910. auto listType2 = lists[i].GetStaticType();
  1911. MKQL_ENSURE(listType->IsSameType(*listType2), "Types of flows are different, left: " <<
  1912. PrintNode(listType, true) << ", right: " <<
  1913. PrintNode(listType2, true));
  1914. }
  1915. TCallableBuilder callableBuilder(Env, callableName, listType);
  1916. for (auto list : lists) {
  1917. callableBuilder.Add(list);
  1918. }
  1919. return TRuntimeNode(callableBuilder.Build(), false);
  1920. }
  1921. TRuntimeNode TProgramBuilder::Extend(const TArrayRef<const TRuntimeNode>& lists) {
  1922. return BuildExtend(__func__, lists);
  1923. }
  1924. TRuntimeNode TProgramBuilder::OrderedExtend(const TArrayRef<const TRuntimeNode>& lists) {
  1925. return BuildExtend(__func__, lists);
  1926. }
  1927. template<>
  1928. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::String>(const NUdf::TStringRef& data) const {
  1929. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<const char*>::Id, Env), true);
  1930. }
  1931. template<>
  1932. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Utf8>(const NUdf::TStringRef& data) const {
  1933. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TUtf8>::Id, Env), true);
  1934. }
  1935. template<>
  1936. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Yson>(const NUdf::TStringRef& data) const {
  1937. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TYson>::Id, Env), true);
  1938. }
  1939. template<>
  1940. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Json>(const NUdf::TStringRef& data) const {
  1941. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TJson>::Id, Env), true);
  1942. }
  1943. template<>
  1944. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::JsonDocument>(const NUdf::TStringRef& data) const {
  1945. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TJsonDocument>::Id, Env), true);
  1946. }
  1947. template<>
  1948. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Uuid>(const NUdf::TStringRef& data) const {
  1949. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TUuid>::Id, Env), true);
  1950. }
  1951. template<>
  1952. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Date>(const NUdf::TStringRef& data) const {
  1953. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDate>::Id, Env), true);
  1954. }
  1955. template<>
  1956. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Datetime>(const NUdf::TStringRef& data) const {
  1957. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDatetime>::Id, Env), true);
  1958. }
  1959. template<>
  1960. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Timestamp>(const NUdf::TStringRef& data) const {
  1961. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TTimestamp>::Id, Env), true);
  1962. }
  1963. template<>
  1964. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Interval>(const NUdf::TStringRef& data) const {
  1965. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TInterval>::Id, Env), true);
  1966. }
  1967. template<>
  1968. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::DyNumber>(const NUdf::TStringRef& data) const {
  1969. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDyNumber>::Id, Env), true);
  1970. }
  1971. TRuntimeNode TProgramBuilder::NewDecimalLiteral(NYql::NDecimal::TInt128 data, ui8 precision, ui8 scale) const {
  1972. return TRuntimeNode(TDataLiteral::Create(NUdf::TUnboxedValuePod(data), TDataDecimalType::Create(precision, scale, Env), Env), true);
  1973. }
  1974. template<>
  1975. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Date32>(const NUdf::TStringRef& data) const {
  1976. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDate32>::Id, Env), true);
  1977. }
  1978. template<>
  1979. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Datetime64>(const NUdf::TStringRef& data) const {
  1980. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDatetime64>::Id, Env), true);
  1981. }
  1982. template<>
  1983. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Timestamp64>(const NUdf::TStringRef& data) const {
  1984. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TTimestamp64>::Id, Env), true);
  1985. }
  1986. template<>
  1987. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Interval64>(const NUdf::TStringRef& data) const {
  1988. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TInterval64>::Id, Env), true);
  1989. }
  1990. TRuntimeNode TProgramBuilder::NewOptional(TRuntimeNode data) {
  1991. auto type = TOptionalType::Create(data.GetStaticType(), Env);
  1992. return TRuntimeNode(TOptionalLiteral::Create(data, type, Env), true);
  1993. }
  1994. TRuntimeNode TProgramBuilder::NewOptional(TType* optionalType, TRuntimeNode data) {
  1995. auto type = AS_TYPE(TOptionalType, optionalType);
  1996. return TRuntimeNode(TOptionalLiteral::Create(data, type, Env), true);
  1997. }
  1998. TRuntimeNode TProgramBuilder::NewVoid() {
  1999. return TRuntimeNode(Env.GetVoidLazy(), true);
  2000. }
  2001. TRuntimeNode TProgramBuilder::NewEmptyListOfVoid() {
  2002. return TRuntimeNode(Env.GetListOfVoidLazy(), true);
  2003. }
  2004. TRuntimeNode TProgramBuilder::NewEmptyOptional(TType* optionalOrPgType) {
  2005. MKQL_ENSURE(optionalOrPgType->IsOptional() || optionalOrPgType->IsPg(), "Expected optional or pg type");
  2006. if (optionalOrPgType->IsOptional()) {
  2007. return TRuntimeNode(TOptionalLiteral::Create(static_cast<TOptionalType*>(optionalOrPgType), Env), true);
  2008. }
  2009. return PgCast(NewNull(), optionalOrPgType);
  2010. }
  2011. TRuntimeNode TProgramBuilder::NewEmptyOptionalDataLiteral(NUdf::TDataTypeId schemeType) {
  2012. return TRuntimeNode(BuildEmptyOptionalDataLiteral(schemeType, Env), true);
  2013. }
  2014. TRuntimeNode TProgramBuilder::NewEmptyStruct() {
  2015. return TRuntimeNode(Env.GetEmptyStructLazy(), true);
  2016. }
  2017. TRuntimeNode TProgramBuilder::NewStruct(const TArrayRef<const std::pair<std::string_view, TRuntimeNode>>& members) {
  2018. if (members.empty()) {
  2019. return NewEmptyStruct();
  2020. }
  2021. TStructLiteralBuilder builder(Env);
  2022. for (auto x : members) {
  2023. builder.Add(x.first, x.second);
  2024. }
  2025. return TRuntimeNode(builder.Build(), true);
  2026. }
  2027. TRuntimeNode TProgramBuilder::NewStruct(TType* structType, const TArrayRef<const std::pair<std::string_view, TRuntimeNode>>& members) {
  2028. const auto detailedStructType = AS_TYPE(TStructType, structType);
  2029. MKQL_ENSURE(members.size() == detailedStructType->GetMembersCount(), "Mismatch count of members");
  2030. if (members.empty()) {
  2031. return NewEmptyStruct();
  2032. }
  2033. std::vector<TRuntimeNode> values(detailedStructType->GetMembersCount());
  2034. for (ui32 i = 0; i < detailedStructType->GetMembersCount(); ++i) {
  2035. const auto& name = members[i].first;
  2036. ui32 index = detailedStructType->GetMemberIndex(name);
  2037. MKQL_ENSURE(!values[index], "Duplicate of member: " << name);
  2038. values[index] = members[i].second;
  2039. }
  2040. return TRuntimeNode(TStructLiteral::Create(values.size(), values.data(), detailedStructType, Env), true);
  2041. }
  2042. TRuntimeNode TProgramBuilder::NewEmptyList() {
  2043. return TRuntimeNode(Env.GetEmptyListLazy(), true);
  2044. }
  2045. TRuntimeNode TProgramBuilder::NewEmptyList(TType* itemType) {
  2046. TListLiteralBuilder builder(Env, itemType);
  2047. return TRuntimeNode(builder.Build(), true);
  2048. }
  2049. TRuntimeNode TProgramBuilder::NewList(TType* itemType, const TArrayRef<const TRuntimeNode>& items) {
  2050. TListLiteralBuilder builder(Env, itemType);
  2051. for (auto item : items) {
  2052. builder.Add(item);
  2053. }
  2054. return TRuntimeNode(builder.Build(), true);
  2055. }
  2056. TRuntimeNode TProgramBuilder::NewEmptyDict() {
  2057. return TRuntimeNode(Env.GetEmptyDictLazy(), true);
  2058. }
  2059. TRuntimeNode TProgramBuilder::NewDict(TType* dictType, const TArrayRef<const std::pair<TRuntimeNode, TRuntimeNode>>& items) {
  2060. MKQL_ENSURE(dictType->IsDict(), "Expected dict type");
  2061. return TRuntimeNode(TDictLiteral::Create(items.size(), items.data(), static_cast<TDictType*>(dictType), Env), true);
  2062. }
  2063. TRuntimeNode TProgramBuilder::NewEmptyTuple() {
  2064. return TRuntimeNode(Env.GetEmptyTupleLazy(), true);
  2065. }
  2066. TRuntimeNode TProgramBuilder::NewTuple(TType* tupleType, const TArrayRef<const TRuntimeNode>& elements) {
  2067. MKQL_ENSURE(tupleType->IsTuple(), "Expected tuple type");
  2068. return TRuntimeNode(TTupleLiteral::Create(elements.size(), elements.data(), static_cast<TTupleType*>(tupleType), Env), true);
  2069. }
  2070. TRuntimeNode TProgramBuilder::NewTuple(const TArrayRef<const TRuntimeNode>& elements) {
  2071. std::vector<TType*> types;
  2072. types.reserve(elements.size());
  2073. for (auto elem : elements) {
  2074. types.push_back(elem.GetStaticType());
  2075. }
  2076. return NewTuple(NewTupleType(types), elements);
  2077. }
  2078. TRuntimeNode TProgramBuilder::NewVariant(TRuntimeNode item, ui32 index, TType* variantType) {
  2079. const auto type = AS_TYPE(TVariantType, variantType);
  2080. MKQL_ENSURE(type->GetUnderlyingType()->IsTuple(), "Expected tuple as underlying type");
  2081. return TRuntimeNode(TVariantLiteral::Create(item, index, type, Env), true);
  2082. }
  2083. TRuntimeNode TProgramBuilder::NewVariant(TRuntimeNode item, const std::string_view& member, TType* variantType) {
  2084. const auto type = AS_TYPE(TVariantType, variantType);
  2085. MKQL_ENSURE(type->GetUnderlyingType()->IsStruct(), "Expected struct as underlying type");
  2086. ui32 index = AS_TYPE(TStructType, type->GetUnderlyingType())->GetMemberIndex(member);
  2087. return TRuntimeNode(TVariantLiteral::Create(item, index, type, Env), true);
  2088. }
  2089. TRuntimeNode TProgramBuilder::Coalesce(TRuntimeNode data, TRuntimeNode defaultData) {
  2090. bool isOptional = false;
  2091. const auto dataType = UnpackOptional(data, isOptional);
  2092. if (!isOptional && !data.GetStaticType()->IsPg()) {
  2093. MKQL_ENSURE(data.GetStaticType()->IsSameType(*defaultData.GetStaticType()), "Mismatch operand types");
  2094. return data;
  2095. }
  2096. if (!dataType->IsSameType(*defaultData.GetStaticType())) {
  2097. bool isOptionalDefault;
  2098. const auto defaultDataType = UnpackOptional(defaultData, isOptionalDefault);
  2099. MKQL_ENSURE(dataType->IsSameType(*defaultDataType), "Mismatch operand types");
  2100. }
  2101. TCallableBuilder callableBuilder(Env, __func__, defaultData.GetStaticType());
  2102. callableBuilder.Add(data);
  2103. callableBuilder.Add(defaultData);
  2104. return TRuntimeNode(callableBuilder.Build(), false);
  2105. }
  2106. TRuntimeNode TProgramBuilder::Unwrap(TRuntimeNode optional, TRuntimeNode message, const std::string_view& file, ui32 row, ui32 column) {
  2107. bool isOptional;
  2108. auto underlyingType = UnpackOptional(optional, isOptional);
  2109. MKQL_ENSURE(isOptional, "Expected optional");
  2110. const auto& messageType = message.GetStaticType();
  2111. MKQL_ENSURE(messageType->IsData(), "Expected data");
  2112. const auto& messageTypeData = static_cast<const TDataType&>(*messageType);
  2113. MKQL_ENSURE(messageTypeData.GetSchemeType() == NUdf::TDataType<char*>::Id || messageTypeData.GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected string or utf8.");
  2114. TCallableBuilder callableBuilder(Env, __func__, underlyingType);
  2115. callableBuilder.Add(optional);
  2116. callableBuilder.Add(message);
  2117. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  2118. callableBuilder.Add(NewDataLiteral(row));
  2119. callableBuilder.Add(NewDataLiteral(column));
  2120. return TRuntimeNode(callableBuilder.Build(), false);
  2121. }
  2122. TRuntimeNode TProgramBuilder::Increment(TRuntimeNode data) {
  2123. const std::array<TRuntimeNode, 1> args = {{ data }};
  2124. bool isOptional;
  2125. const auto type = UnpackOptionalData(data, isOptional);
  2126. if (type->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id)
  2127. return Invoke(__func__, data.GetStaticType(), args);
  2128. return Invoke(TString("Inc_") += ::ToString(static_cast<TDataDecimalType*>(type)->GetParams().first), data.GetStaticType(), args);
  2129. }
  2130. TRuntimeNode TProgramBuilder::Decrement(TRuntimeNode data) {
  2131. const std::array<TRuntimeNode, 1> args = {{ data }};
  2132. bool isOptional;
  2133. const auto type = UnpackOptionalData(data, isOptional);
  2134. if (type->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id)
  2135. return Invoke(__func__, data.GetStaticType(), args);
  2136. return Invoke(TString("Dec_") += ::ToString(static_cast<TDataDecimalType*>(type)->GetParams().first), data.GetStaticType(), args);
  2137. }
  2138. TRuntimeNode TProgramBuilder::Abs(TRuntimeNode data) {
  2139. const std::array<TRuntimeNode, 1> args = {{ data }};
  2140. return Invoke(__func__, data.GetStaticType(), args);
  2141. }
  2142. TRuntimeNode TProgramBuilder::Plus(TRuntimeNode data) {
  2143. const std::array<TRuntimeNode, 1> args = {{ data }};
  2144. return Invoke(__func__, data.GetStaticType(), args);
  2145. }
  2146. TRuntimeNode TProgramBuilder::Minus(TRuntimeNode data) {
  2147. const std::array<TRuntimeNode, 1> args = {{ data }};
  2148. return Invoke(__func__, data.GetStaticType(), args);
  2149. }
  2150. TRuntimeNode TProgramBuilder::Add(TRuntimeNode data1, TRuntimeNode data2) {
  2151. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2152. bool isOptionalLeft;
  2153. const auto leftType = UnpackOptionalData(data1, isOptionalLeft);
  2154. if (leftType->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id)
  2155. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2156. const auto decimalType = static_cast<TDataDecimalType*>(leftType);
  2157. bool isOptionalRight;
  2158. const auto rightType = static_cast<TDataDecimalType*>(UnpackOptionalData(data2, isOptionalRight));
  2159. MKQL_ENSURE(rightType->IsSameType(*decimalType), "Operands type mismatch");
  2160. const auto resultType = isOptionalLeft || isOptionalRight ? NewOptionalType(decimalType) : decimalType;
  2161. return Invoke(TString("Add_") += ::ToString(decimalType->GetParams().first), resultType, args);
  2162. }
  2163. TRuntimeNode TProgramBuilder::Sub(TRuntimeNode data1, TRuntimeNode data2) {
  2164. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2165. bool isOptionalLeft;
  2166. const auto leftType = UnpackOptionalData(data1, isOptionalLeft);
  2167. if (leftType->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id)
  2168. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2169. const auto decimalType = static_cast<TDataDecimalType*>(leftType);
  2170. bool isOptionalRight;
  2171. const auto rightType = static_cast<TDataDecimalType*>(UnpackOptionalData(data2, isOptionalRight));
  2172. MKQL_ENSURE(rightType->IsSameType(*decimalType), "Operands type mismatch");
  2173. const auto resultType = isOptionalLeft || isOptionalRight ? NewOptionalType(decimalType) : decimalType;
  2174. return Invoke(TString("Sub_") += ::ToString(decimalType->GetParams().first), resultType, args);
  2175. }
  2176. TRuntimeNode TProgramBuilder::Mul(TRuntimeNode data1, TRuntimeNode data2) {
  2177. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2178. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2179. }
  2180. TRuntimeNode TProgramBuilder::Div(TRuntimeNode data1, TRuntimeNode data2) {
  2181. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2182. auto resultType = BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType());
  2183. if (resultType->IsData() && !(NUdf::GetDataTypeInfo(*static_cast<TDataType*>(resultType)->GetDataSlot()).Features & (NUdf::EDataTypeFeatures::FloatType | NUdf::EDataTypeFeatures::DecimalType))) {
  2184. resultType = NewOptionalType(resultType);
  2185. }
  2186. return Invoke(__func__, resultType, args);
  2187. }
  2188. TRuntimeNode TProgramBuilder::DecimalDiv(TRuntimeNode data1, TRuntimeNode data2) {
  2189. bool isOptionalLeft, isOptionalRight;
  2190. const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(data1, isOptionalLeft));
  2191. const auto rightType = UnpackOptionalData(data2, isOptionalRight);
  2192. if (rightType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id)
  2193. MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch");
  2194. else
  2195. MKQL_ENSURE(NUdf::GetDataTypeInfo(*rightType->GetDataSlot()).Features & NUdf::IntegralType, "Operands type mismatch");
  2196. const auto returnType = isOptionalLeft || isOptionalRight ? NewOptionalType(leftType) : leftType;
  2197. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2198. callableBuilder.Add(data1);
  2199. callableBuilder.Add(data2);
  2200. return TRuntimeNode(callableBuilder.Build(), false);
  2201. }
  2202. TRuntimeNode TProgramBuilder::DecimalMod(TRuntimeNode data1, TRuntimeNode data2) {
  2203. bool isOptionalLeft, isOptionalRight;
  2204. const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(data1, isOptionalLeft));
  2205. const auto rightType = UnpackOptionalData(data2, isOptionalRight);
  2206. if (rightType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id)
  2207. MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch");
  2208. else
  2209. MKQL_ENSURE(NUdf::GetDataTypeInfo(*rightType->GetDataSlot()).Features & NUdf::IntegralType, "Operands type mismatch");
  2210. const auto returnType = isOptionalLeft || isOptionalRight ? NewOptionalType(leftType) : leftType;
  2211. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2212. callableBuilder.Add(data1);
  2213. callableBuilder.Add(data2);
  2214. return TRuntimeNode(callableBuilder.Build(), false);
  2215. }
  2216. TRuntimeNode TProgramBuilder::DecimalMul(TRuntimeNode data1, TRuntimeNode data2) {
  2217. bool isOptionalLeft, isOptionalRight;
  2218. const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(data1, isOptionalLeft));
  2219. const auto rightType = UnpackOptionalData(data2, isOptionalRight);
  2220. if (rightType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id)
  2221. MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch");
  2222. else
  2223. MKQL_ENSURE(NUdf::GetDataTypeInfo(*rightType->GetDataSlot()).Features & NUdf::IntegralType, "Operands type mismatch");
  2224. const auto returnType = isOptionalLeft || isOptionalRight ? NewOptionalType(leftType) : leftType;
  2225. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2226. callableBuilder.Add(data1);
  2227. callableBuilder.Add(data2);
  2228. return TRuntimeNode(callableBuilder.Build(), false);
  2229. }
  2230. TRuntimeNode TProgramBuilder::AllOf(TRuntimeNode list, const TUnaryLambda& predicate) {
  2231. return Not(NotAllOf(list, predicate));
  2232. }
  2233. TRuntimeNode TProgramBuilder::NotAllOf(TRuntimeNode list, const TUnaryLambda& predicate) {
  2234. return Exists(ToOptional(SkipWhile(list, predicate)));
  2235. }
  2236. TRuntimeNode TProgramBuilder::BitNot(TRuntimeNode data) {
  2237. const std::array<TRuntimeNode, 1> args = {{ data }};
  2238. return Invoke(__func__, data.GetStaticType(), args);
  2239. }
  2240. TRuntimeNode TProgramBuilder::CountBits(TRuntimeNode data) {
  2241. const std::array<TRuntimeNode, 1> args = {{ data }};
  2242. return Invoke(__func__, data.GetStaticType(), args);
  2243. }
  2244. TRuntimeNode TProgramBuilder::BitAnd(TRuntimeNode data1, TRuntimeNode data2) {
  2245. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2246. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2247. }
  2248. TRuntimeNode TProgramBuilder::BitOr(TRuntimeNode data1, TRuntimeNode data2) {
  2249. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2250. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2251. }
  2252. TRuntimeNode TProgramBuilder::BitXor(TRuntimeNode data1, TRuntimeNode data2) {
  2253. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2254. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2255. }
  2256. TRuntimeNode TProgramBuilder::ShiftLeft(TRuntimeNode arg, TRuntimeNode bits) {
  2257. const std::array<TRuntimeNode, 2> args = {{ arg, bits }};
  2258. return Invoke(__func__, arg.GetStaticType(), args);
  2259. }
  2260. TRuntimeNode TProgramBuilder::RotLeft(TRuntimeNode arg, TRuntimeNode bits) {
  2261. const std::array<TRuntimeNode, 2> args = {{ arg, bits }};
  2262. return Invoke(__func__, arg.GetStaticType(), args);
  2263. }
  2264. TRuntimeNode TProgramBuilder::ShiftRight(TRuntimeNode arg, TRuntimeNode bits) {
  2265. const std::array<TRuntimeNode, 2> args = {{ arg, bits }};
  2266. return Invoke(__func__, arg.GetStaticType(), args);
  2267. }
  2268. TRuntimeNode TProgramBuilder::RotRight(TRuntimeNode arg, TRuntimeNode bits) {
  2269. const std::array<TRuntimeNode, 2> args = {{ arg, bits }};
  2270. return Invoke(__func__, arg.GetStaticType(), args);
  2271. }
  2272. TRuntimeNode TProgramBuilder::Mod(TRuntimeNode data1, TRuntimeNode data2) {
  2273. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2274. auto resultType = BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType());
  2275. if (resultType->IsData() && !(NUdf::GetDataTypeInfo(*static_cast<TDataType*>(resultType)->GetDataSlot()).Features & (NUdf::EDataTypeFeatures::FloatType | NUdf::EDataTypeFeatures::DecimalType))) {
  2276. resultType = NewOptionalType(resultType);
  2277. }
  2278. return Invoke(__func__, resultType, args);
  2279. }
  2280. TRuntimeNode TProgramBuilder::BuildMinMax(const std::string_view& callableName, const TRuntimeNode* data, size_t size) {
  2281. switch (size) {
  2282. case 0U: return NewNull();
  2283. case 1U: return *data;
  2284. case 2U: return InvokeBinary(callableName, ChooseCommonType(data[0U].GetStaticType(), data[1U].GetStaticType()), data[0U], data[1U]);
  2285. default: break;
  2286. }
  2287. const auto half = size >> 1U;
  2288. const std::array<TRuntimeNode, 2U> args = {{ BuildMinMax(callableName, data, half), BuildMinMax(callableName, data + half, size - half) }};
  2289. return BuildMinMax(callableName, args.data(), args.size());
  2290. }
  2291. TRuntimeNode TProgramBuilder::BuildWideSkipTakeBlocks(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count) {
  2292. ValidateBlockFlowType(flow.GetStaticType());
  2293. MKQL_ENSURE(count.GetStaticType()->IsData(), "Expected data");
  2294. MKQL_ENSURE(static_cast<const TDataType&>(*count.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  2295. TCallableBuilder callableBuilder(Env, callableName, flow.GetStaticType());
  2296. callableBuilder.Add(flow);
  2297. callableBuilder.Add(count);
  2298. return TRuntimeNode(callableBuilder.Build(), false);
  2299. }
  2300. TRuntimeNode TProgramBuilder::BuildBlockLogical(const std::string_view& callableName, TRuntimeNode first, TRuntimeNode second) {
  2301. auto firstType = AS_TYPE(TBlockType, first.GetStaticType());
  2302. auto secondType = AS_TYPE(TBlockType, second.GetStaticType());
  2303. bool isOpt1, isOpt2;
  2304. MKQL_ENSURE(UnpackOptionalData(firstType->GetItemType(), isOpt1)->GetSchemeType() == NUdf::TDataType<bool>::Id, "Requires boolean args.");
  2305. MKQL_ENSURE(UnpackOptionalData(secondType->GetItemType(), isOpt2)->GetSchemeType() == NUdf::TDataType<bool>::Id, "Requires boolean args.");
  2306. const auto itemType = NewDataType(NUdf::TDataType<bool>::Id, isOpt1 || isOpt2);
  2307. auto outputType = NewBlockType(itemType, GetResultShape({firstType, secondType}));
  2308. TCallableBuilder callableBuilder(Env, callableName, outputType);
  2309. callableBuilder.Add(first);
  2310. callableBuilder.Add(second);
  2311. return TRuntimeNode(callableBuilder.Build(), false);
  2312. }
  2313. TRuntimeNode TProgramBuilder::BuildBlockDecimalBinary(const std::string_view& callableName, TRuntimeNode first, TRuntimeNode second) {
  2314. auto firstType = AS_TYPE(TBlockType, first.GetStaticType());
  2315. auto secondType = AS_TYPE(TBlockType, second.GetStaticType());
  2316. bool isOpt1, isOpt2;
  2317. auto* leftDataType = UnpackOptionalData(firstType->GetItemType(), isOpt1);
  2318. UnpackOptionalData(secondType->GetItemType(), isOpt2);
  2319. MKQL_ENSURE(leftDataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id, "Requires decimal args.");
  2320. const auto& lParams = static_cast<TDataDecimalType*>(leftDataType)->GetParams();
  2321. auto [precision, scale] = lParams;
  2322. TType* outputType = TDataDecimalType::Create(precision, scale, Env);
  2323. if (isOpt1 || isOpt2) {
  2324. outputType = TOptionalType::Create(outputType, Env);
  2325. }
  2326. outputType = NewBlockType(outputType, TBlockType::EShape::Many);
  2327. TCallableBuilder callableBuilder(Env, callableName, outputType);
  2328. callableBuilder.Add(first);
  2329. callableBuilder.Add(second);
  2330. return TRuntimeNode(callableBuilder.Build(), false);
  2331. }
  2332. TRuntimeNode TProgramBuilder::Min(const TArrayRef<const TRuntimeNode>& args) {
  2333. return BuildMinMax(__func__, args.data(), args.size());
  2334. }
  2335. TRuntimeNode TProgramBuilder::Max(const TArrayRef<const TRuntimeNode>& args) {
  2336. return BuildMinMax(__func__, args.data(), args.size());
  2337. }
  2338. TRuntimeNode TProgramBuilder::Min(TRuntimeNode data1, TRuntimeNode data2) {
  2339. const std::array<TRuntimeNode, 2U> args = {{ data1, data2 }};
  2340. return Min(args);
  2341. }
  2342. TRuntimeNode TProgramBuilder::Max(TRuntimeNode data1, TRuntimeNode data2) {
  2343. const std::array<TRuntimeNode, 2U> args = {{ data1, data2 }};
  2344. return Max(args);
  2345. }
  2346. TRuntimeNode TProgramBuilder::Equals(TRuntimeNode data1, TRuntimeNode data2) {
  2347. return DataCompare(__func__, data1, data2);
  2348. }
  2349. TRuntimeNode TProgramBuilder::NotEquals(TRuntimeNode data1, TRuntimeNode data2) {
  2350. return DataCompare(__func__, data1, data2);
  2351. }
  2352. TRuntimeNode TProgramBuilder::Less(TRuntimeNode data1, TRuntimeNode data2) {
  2353. return DataCompare(__func__, data1, data2);
  2354. }
  2355. TRuntimeNode TProgramBuilder::LessOrEqual(TRuntimeNode data1, TRuntimeNode data2) {
  2356. return DataCompare(__func__, data1, data2);
  2357. }
  2358. TRuntimeNode TProgramBuilder::Greater(TRuntimeNode data1, TRuntimeNode data2) {
  2359. return DataCompare(__func__, data1, data2);
  2360. }
  2361. TRuntimeNode TProgramBuilder::GreaterOrEqual(TRuntimeNode data1, TRuntimeNode data2) {
  2362. return DataCompare(__func__, data1, data2);
  2363. }
  2364. TRuntimeNode TProgramBuilder::InvokeBinary(const std::string_view& callableName, TType* type, TRuntimeNode data1, TRuntimeNode data2) {
  2365. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2366. return Invoke(callableName, type, args);
  2367. }
  2368. TRuntimeNode TProgramBuilder::AggrCompare(const std::string_view& callableName, TRuntimeNode data1, TRuntimeNode data2) {
  2369. return InvokeBinary(callableName, NewDataType(NUdf::TDataType<bool>::Id), data1, data2);
  2370. }
  2371. TRuntimeNode TProgramBuilder::DataCompare(const std::string_view& callableName, TRuntimeNode left, TRuntimeNode right) {
  2372. bool isOptionalLeft, isOptionalRight;
  2373. const auto leftType = UnpackOptionalData(left, isOptionalLeft);
  2374. const auto rightType = UnpackOptionalData(right, isOptionalRight);
  2375. const auto lId = leftType->GetSchemeType();
  2376. const auto rId = rightType->GetSchemeType();
  2377. if (lId == NUdf::TDataType<NUdf::TDecimal>::Id && rId == NUdf::TDataType<NUdf::TDecimal>::Id) {
  2378. const auto& lDec = static_cast<TDataDecimalType*>(leftType)->GetParams();
  2379. const auto& rDec = static_cast<TDataDecimalType*>(rightType)->GetParams();
  2380. if (lDec.second < rDec.second) {
  2381. left = ToDecimal(left, std::min<ui8>(lDec.first + rDec.second - lDec.second, NYql::NDecimal::MaxPrecision), rDec.second);
  2382. } else if (lDec.second > rDec.second) {
  2383. right = ToDecimal(right, std::min<ui8>(rDec.first + lDec.second - rDec.second, NYql::NDecimal::MaxPrecision), lDec.second);
  2384. }
  2385. } else if (lId == NUdf::TDataType<NUdf::TDecimal>::Id && NUdf::GetDataTypeInfo(NUdf::GetDataSlot(rId)).Features & NUdf::EDataTypeFeatures::IntegralType) {
  2386. const auto scale = static_cast<TDataDecimalType*>(leftType)->GetParams().second;
  2387. right = ToDecimal(right, std::min<ui8>(NYql::NDecimal::MaxPrecision, NUdf::GetDataTypeInfo(NUdf::GetDataSlot(rId)).DecimalDigits + scale), scale);
  2388. } else if (rId == NUdf::TDataType<NUdf::TDecimal>::Id && NUdf::GetDataTypeInfo(NUdf::GetDataSlot(lId)).Features & NUdf::EDataTypeFeatures::IntegralType) {
  2389. const auto scale = static_cast<TDataDecimalType*>(rightType)->GetParams().second;
  2390. left = ToDecimal(left, std::min<ui8>(NYql::NDecimal::MaxPrecision, NUdf::GetDataTypeInfo(NUdf::GetDataSlot(lId)).DecimalDigits + scale), scale);
  2391. }
  2392. const std::array<TRuntimeNode, 2> args = {{ left, right }};
  2393. const auto resultType = isOptionalLeft || isOptionalRight ? NewOptionalType(NewDataType(NUdf::TDataType<bool>::Id)) : NewDataType(NUdf::TDataType<bool>::Id);
  2394. return Invoke(callableName, resultType, args);
  2395. }
  2396. TRuntimeNode TProgramBuilder::BuildRangeLogical(const std::string_view& callableName, const TArrayRef<const TRuntimeNode>& lists) {
  2397. MKQL_ENSURE(!lists.empty(), "Expecting at least one argument");
  2398. for (auto& list : lists) {
  2399. MKQL_ENSURE(list.GetStaticType()->IsList(), "Expecting lists");
  2400. MKQL_ENSURE(list.GetStaticType()->IsSameType(*lists.front().GetStaticType()), "Expecting arguments of same type");
  2401. }
  2402. TCallableBuilder callableBuilder(Env, callableName, lists.front().GetStaticType());
  2403. for (auto& list : lists) {
  2404. callableBuilder.Add(list);
  2405. }
  2406. return TRuntimeNode(callableBuilder.Build(), false);
  2407. }
  2408. TRuntimeNode TProgramBuilder::AggrEquals(TRuntimeNode data1, TRuntimeNode data2) {
  2409. return AggrCompare(__func__, data1, data2);
  2410. }
  2411. TRuntimeNode TProgramBuilder::AggrNotEquals(TRuntimeNode data1, TRuntimeNode data2) {
  2412. return AggrCompare(__func__, data1, data2);
  2413. }
  2414. TRuntimeNode TProgramBuilder::AggrLess(TRuntimeNode data1, TRuntimeNode data2) {
  2415. return AggrCompare(__func__, data1, data2);
  2416. }
  2417. TRuntimeNode TProgramBuilder::AggrLessOrEqual(TRuntimeNode data1, TRuntimeNode data2) {
  2418. return AggrCompare(__func__, data1, data2);
  2419. }
  2420. TRuntimeNode TProgramBuilder::AggrGreater(TRuntimeNode data1, TRuntimeNode data2) {
  2421. return AggrCompare(__func__, data1, data2);
  2422. }
  2423. TRuntimeNode TProgramBuilder::AggrGreaterOrEqual(TRuntimeNode data1, TRuntimeNode data2) {
  2424. return AggrCompare(__func__, data1, data2);
  2425. }
  2426. TRuntimeNode TProgramBuilder::If(TRuntimeNode condition, TRuntimeNode thenBranch, TRuntimeNode elseBranch) {
  2427. bool condOpt, thenOpt, elseOpt;
  2428. const auto conditionType = UnpackOptionalData(condition, condOpt);
  2429. MKQL_ENSURE(conditionType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool");
  2430. const auto thenUnpacked = UnpackOptional(thenBranch, thenOpt);
  2431. const auto elseUnpacked = UnpackOptional(elseBranch, elseOpt);
  2432. MKQL_ENSURE(thenUnpacked->IsSameType(*elseUnpacked), "Different return types in branches.");
  2433. const bool isOptional = condOpt || thenOpt || elseOpt;
  2434. TCallableBuilder callableBuilder(Env, __func__, isOptional ? NewOptionalType(thenUnpacked) : thenUnpacked);
  2435. callableBuilder.Add(condition);
  2436. callableBuilder.Add(thenBranch);
  2437. callableBuilder.Add(elseBranch);
  2438. return TRuntimeNode(callableBuilder.Build(), false);
  2439. }
  2440. TRuntimeNode TProgramBuilder::If(const TArrayRef<const TRuntimeNode>& args) {
  2441. MKQL_ENSURE(args.size() % 2U, "Expected odd arguments.");
  2442. MKQL_ENSURE(args.size() >= 3U, "Expected at least three arguments.");
  2443. return If(args.front(), args[1U], 3U == args.size() ? args.back() : If(args.last(args.size() - 2U)));
  2444. }
  2445. TRuntimeNode TProgramBuilder::If(TRuntimeNode condition, TRuntimeNode thenBranch, TRuntimeNode elseBranch, TType* resultType) {
  2446. bool condOpt;
  2447. const auto conditionType = UnpackOptionalData(condition, condOpt);
  2448. MKQL_ENSURE(conditionType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool");
  2449. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2450. callableBuilder.Add(condition);
  2451. callableBuilder.Add(thenBranch);
  2452. callableBuilder.Add(elseBranch);
  2453. return TRuntimeNode(callableBuilder.Build(), false);
  2454. }
  2455. TRuntimeNode TProgramBuilder::Ensure(TRuntimeNode value, TRuntimeNode predicate, TRuntimeNode message, const std::string_view& file, ui32 row, ui32 column) {
  2456. bool isOptional;
  2457. const auto unpackedType = UnpackOptionalData(predicate, isOptional);
  2458. MKQL_ENSURE(unpackedType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool");
  2459. const auto& messageType = message.GetStaticType();
  2460. MKQL_ENSURE(messageType->IsData(), "Expected data");
  2461. const auto& messageTypeData = static_cast<const TDataType&>(*messageType);
  2462. MKQL_ENSURE(messageTypeData.GetSchemeType() == NUdf::TDataType<char*>::Id || messageTypeData.GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected string or utf8.");
  2463. TCallableBuilder callableBuilder(Env, __func__, value.GetStaticType());
  2464. callableBuilder.Add(value);
  2465. callableBuilder.Add(predicate);
  2466. callableBuilder.Add(message);
  2467. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  2468. callableBuilder.Add(NewDataLiteral(row));
  2469. callableBuilder.Add(NewDataLiteral(column));
  2470. return TRuntimeNode(callableBuilder.Build(), false);
  2471. }
  2472. TRuntimeNode TProgramBuilder::SourceOf(TType* returnType) {
  2473. MKQL_ENSURE(returnType->IsFlow() || returnType->IsStream(), "Expected flow or stream.");
  2474. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2475. return TRuntimeNode(callableBuilder.Build(), false);
  2476. }
  2477. TRuntimeNode TProgramBuilder::Source() {
  2478. if constexpr (RuntimeVersion < 18U) {
  2479. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  2480. }
  2481. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType({})));
  2482. return TRuntimeNode(callableBuilder.Build(), false);
  2483. }
  2484. TRuntimeNode TProgramBuilder::IfPresent(TRuntimeNode optional, const TUnaryLambda& thenBranch, TRuntimeNode elseBranch) {
  2485. bool isOptional;
  2486. const auto unpackedType = UnpackOptional(optional, isOptional);
  2487. if (!isOptional) {
  2488. return thenBranch(optional);
  2489. }
  2490. const auto itemArg = Arg(unpackedType);
  2491. const auto then = thenBranch(itemArg);
  2492. bool thenOpt, elseOpt;
  2493. const auto thenUnpacked = UnpackOptional(then, thenOpt);
  2494. const auto elseUnpacked = UnpackOptional(elseBranch, elseOpt);
  2495. MKQL_ENSURE(thenUnpacked->IsSameType(*elseUnpacked), "Different return types in branches.");
  2496. TCallableBuilder callableBuilder(Env, __func__, (thenOpt || elseOpt) ? NewOptionalType(thenUnpacked) : thenUnpacked);
  2497. callableBuilder.Add(optional);
  2498. callableBuilder.Add(itemArg);
  2499. callableBuilder.Add(then);
  2500. callableBuilder.Add(elseBranch);
  2501. return TRuntimeNode(callableBuilder.Build(), false);
  2502. }
  2503. TRuntimeNode TProgramBuilder::IfPresent(TRuntimeNode::TList optionals, const TNarrowLambda& thenBranch, TRuntimeNode elseBranch) {
  2504. switch (optionals.size()) {
  2505. case 0U:
  2506. return thenBranch({});
  2507. case 1U:
  2508. return IfPresent(optionals.front(), [&](TRuntimeNode unwrap){ return thenBranch({unwrap}); }, elseBranch);
  2509. default:
  2510. break;
  2511. }
  2512. const auto first = optionals.front();
  2513. optionals.erase(optionals.cbegin());
  2514. return IfPresent(first,
  2515. [&](TRuntimeNode head) {
  2516. return IfPresent(optionals,
  2517. [&](TRuntimeNode::TList tail) {
  2518. tail.insert(tail.cbegin(), head);
  2519. return thenBranch(tail);
  2520. },
  2521. elseBranch
  2522. );
  2523. },
  2524. elseBranch
  2525. );
  2526. }
  2527. TRuntimeNode TProgramBuilder::Not(TRuntimeNode data) {
  2528. return UnaryDataFunction(data, __func__, TDataFunctionFlags::CommonOptionalResult | TDataFunctionFlags::RequiresBooleanArgs | TDataFunctionFlags::AllowOptionalArgs);
  2529. }
  2530. TRuntimeNode TProgramBuilder::BuildBinaryLogical(const std::string_view& callableName, TRuntimeNode data1, TRuntimeNode data2) {
  2531. bool isOpt1, isOpt2;
  2532. MKQL_ENSURE(UnpackOptionalData(data1, isOpt1)->GetSchemeType() == NUdf::TDataType<bool>::Id, "Requires boolean args.");
  2533. MKQL_ENSURE(UnpackOptionalData(data2, isOpt2)->GetSchemeType() == NUdf::TDataType<bool>::Id, "Requires boolean args.");
  2534. const auto resultType = NewDataType(NUdf::TDataType<bool>::Id, isOpt1 || isOpt2);
  2535. TCallableBuilder callableBuilder(Env, callableName, resultType);
  2536. callableBuilder.Add(data1);
  2537. callableBuilder.Add(data2);
  2538. return TRuntimeNode(callableBuilder.Build(), false);
  2539. }
  2540. TRuntimeNode TProgramBuilder::BuildLogical(const std::string_view& callableName, const TArrayRef<const TRuntimeNode>& args) {
  2541. MKQL_ENSURE(!args.empty(), "Empty logical args.");
  2542. switch (args.size()) {
  2543. case 1U: return args.front();
  2544. case 2U: return BuildBinaryLogical(callableName, args.front(), args.back());
  2545. }
  2546. const auto half = (args.size() + 1U) >> 1U;
  2547. const TArrayRef<const TRuntimeNode> one(args.data(), half), two(args.data() + half, args.size() - half);
  2548. return BuildBinaryLogical(callableName, BuildLogical(callableName, one), BuildLogical(callableName, two));
  2549. }
  2550. TRuntimeNode TProgramBuilder::And(const TArrayRef<const TRuntimeNode>& args) {
  2551. return BuildLogical(__func__, args);
  2552. }
  2553. TRuntimeNode TProgramBuilder::Or(const TArrayRef<const TRuntimeNode>& args) {
  2554. return BuildLogical(__func__, args);
  2555. }
  2556. TRuntimeNode TProgramBuilder::Xor(const TArrayRef<const TRuntimeNode>& args) {
  2557. return BuildLogical(__func__, args);
  2558. }
  2559. TRuntimeNode TProgramBuilder::Exists(TRuntimeNode data) {
  2560. const auto& nodeType = data.GetStaticType();
  2561. if (nodeType->IsVoid()) {
  2562. return NewDataLiteral(false);
  2563. }
  2564. if (!nodeType->IsOptional() && !nodeType->IsPg()) {
  2565. return NewDataLiteral(true);
  2566. }
  2567. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<bool>::Id));
  2568. callableBuilder.Add(data);
  2569. return TRuntimeNode(callableBuilder.Build(), false);
  2570. }
  2571. TRuntimeNode TProgramBuilder::NewMTRand(TRuntimeNode seed) {
  2572. auto seedData = AS_TYPE(TDataType, seed);
  2573. MKQL_ENSURE(seedData->GetSchemeType() == NUdf::TDataType<ui64>::Id, "seed must be ui64");
  2574. TCallableBuilder callableBuilder(Env, __func__, NewResourceType(RandomMTResource), true);
  2575. callableBuilder.Add(seed);
  2576. return TRuntimeNode(callableBuilder.Build(), false);
  2577. }
  2578. TRuntimeNode TProgramBuilder::NextMTRand(TRuntimeNode rand) {
  2579. auto resType = AS_TYPE(TResourceType, rand);
  2580. MKQL_ENSURE(resType->GetTag() == RandomMTResource, "Expected MTRand resource");
  2581. const std::array<TType*, 2U> tupleTypes = {{ NewDataType(NUdf::TDataType<ui64>::Id), rand.GetStaticType() }};
  2582. auto returnType = NewTupleType(tupleTypes);
  2583. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2584. callableBuilder.Add(rand);
  2585. return TRuntimeNode(callableBuilder.Build(), false);
  2586. }
  2587. TRuntimeNode TProgramBuilder::AggrCountInit(TRuntimeNode value) {
  2588. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<ui64>::Id));
  2589. callableBuilder.Add(value);
  2590. return TRuntimeNode(callableBuilder.Build(), false);
  2591. }
  2592. TRuntimeNode TProgramBuilder::AggrCountUpdate(TRuntimeNode value, TRuntimeNode state) {
  2593. MKQL_ENSURE(AS_TYPE(TDataType, state)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64 type");
  2594. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<ui64>::Id));
  2595. callableBuilder.Add(value);
  2596. callableBuilder.Add(state);
  2597. return TRuntimeNode(callableBuilder.Build(), false);
  2598. }
  2599. TRuntimeNode TProgramBuilder::AggrMin(TRuntimeNode data1, TRuntimeNode data2) {
  2600. const auto type = data1.GetStaticType();
  2601. MKQL_ENSURE(type->IsSameType(*data2.GetStaticType()), "Must be same type.");
  2602. return InvokeBinary(__func__, type, data1, data2);
  2603. }
  2604. TRuntimeNode TProgramBuilder::AggrMax(TRuntimeNode data1, TRuntimeNode data2) {
  2605. const auto type = data1.GetStaticType();
  2606. MKQL_ENSURE(type->IsSameType(*data2.GetStaticType()), "Must be same type.");
  2607. return InvokeBinary(__func__, type, data1, data2);
  2608. }
  2609. TRuntimeNode TProgramBuilder::AggrAdd(TRuntimeNode data1, TRuntimeNode data2) {
  2610. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2611. bool isOptionalLeft;
  2612. const auto leftType = UnpackOptionalData(data1, isOptionalLeft);
  2613. if (leftType->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id)
  2614. return Invoke(__func__, data1.GetStaticType(), args);
  2615. const auto decimalType = static_cast<TDataDecimalType*>(leftType);
  2616. bool isOptionalRight;
  2617. const auto rightType = static_cast<TDataDecimalType*>(UnpackOptionalData(data2, isOptionalRight));
  2618. MKQL_ENSURE(rightType->IsSameType(*decimalType), "Operands type mismatch");
  2619. return Invoke(TString("AggrAdd_") += ::ToString(decimalType->GetParams().first), data1.GetStaticType(), args);
  2620. }
  2621. TRuntimeNode TProgramBuilder::QueueCreate(TRuntimeNode initCapacity, TRuntimeNode initSize, const TArrayRef<const TRuntimeNode>& dependentNodes, TType* returnType) {
  2622. auto resType = AS_TYPE(TResourceType, returnType);
  2623. const auto tag = resType->GetTag();
  2624. if (initCapacity.GetStaticType()->IsVoid()) {
  2625. MKQL_ENSURE(RuntimeVersion >= 13, "Unbounded queue is not supported in runtime version " << RuntimeVersion);
  2626. } else {
  2627. auto initCapacityType = AS_TYPE(TDataType, initCapacity);
  2628. MKQL_ENSURE(initCapacityType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "init capcity must be ui64");
  2629. }
  2630. auto initSizeType = AS_TYPE(TDataType, initSize);
  2631. MKQL_ENSURE(initSizeType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "init size must be ui64");
  2632. TCallableBuilder callableBuilder(Env, __func__, returnType, true);
  2633. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(tag));
  2634. callableBuilder.Add(initCapacity);
  2635. callableBuilder.Add(initSize);
  2636. for (auto node : dependentNodes) {
  2637. callableBuilder.Add(node);
  2638. }
  2639. return TRuntimeNode(callableBuilder.Build(), false);
  2640. }
  2641. TRuntimeNode TProgramBuilder::QueuePush(TRuntimeNode resource, TRuntimeNode value) {
  2642. auto resType = AS_TYPE(TResourceType, resource);
  2643. const auto tag = resType->GetTag();
  2644. MKQL_ENSURE(tag.StartsWith(ResourceQueuePrefix), "Expected Queue resource");
  2645. TCallableBuilder callableBuilder(Env, __func__, resource.GetStaticType());
  2646. callableBuilder.Add(resource);
  2647. callableBuilder.Add(value);
  2648. return TRuntimeNode(callableBuilder.Build(), false);
  2649. }
  2650. TRuntimeNode TProgramBuilder::QueuePop(TRuntimeNode resource) {
  2651. auto resType = AS_TYPE(TResourceType, resource);
  2652. const auto tag = resType->GetTag();
  2653. MKQL_ENSURE(tag.StartsWith(ResourceQueuePrefix), "Expected Queue resource");
  2654. TCallableBuilder callableBuilder(Env, __func__, resource.GetStaticType());
  2655. callableBuilder.Add(resource);
  2656. return TRuntimeNode(callableBuilder.Build(), false);
  2657. }
  2658. TRuntimeNode TProgramBuilder::QueuePeek(TRuntimeNode resource, TRuntimeNode index, const TArrayRef<const TRuntimeNode>& dependentNodes, TType* returnType) {
  2659. MKQL_ENSURE(returnType->IsOptional(), "Expected optional type as result of QueuePeek");
  2660. auto resType = AS_TYPE(TResourceType, resource);
  2661. auto indexType = AS_TYPE(TDataType, index);
  2662. MKQL_ENSURE(indexType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "index size must be ui64");
  2663. const auto tag = resType->GetTag();
  2664. MKQL_ENSURE(tag.StartsWith(ResourceQueuePrefix), "Expected Queue resource");
  2665. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2666. callableBuilder.Add(resource);
  2667. callableBuilder.Add(index);
  2668. for (auto node : dependentNodes) {
  2669. callableBuilder.Add(node);
  2670. }
  2671. return TRuntimeNode(callableBuilder.Build(), false);
  2672. }
  2673. TRuntimeNode TProgramBuilder::QueueRange(TRuntimeNode resource, TRuntimeNode begin, TRuntimeNode end, const TArrayRef<const TRuntimeNode>& dependentNodes, TType* returnType) {
  2674. MKQL_ENSURE(RuntimeVersion >= 14, "QueueRange is not supported in runtime version " << RuntimeVersion);
  2675. MKQL_ENSURE(returnType->IsList(), "Expected list type as result of QueueRange");
  2676. auto resType = AS_TYPE(TResourceType, resource);
  2677. auto beginType = AS_TYPE(TDataType, begin);
  2678. MKQL_ENSURE(beginType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "begin index must be ui64");
  2679. auto endType = AS_TYPE(TDataType, end);
  2680. MKQL_ENSURE(endType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "end index must be ui64");
  2681. const auto tag = resType->GetTag();
  2682. MKQL_ENSURE(tag.StartsWith(ResourceQueuePrefix), "Expected Queue resource");
  2683. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2684. callableBuilder.Add(resource);
  2685. callableBuilder.Add(begin);
  2686. callableBuilder.Add(end);
  2687. for (auto node : dependentNodes) {
  2688. callableBuilder.Add(node);
  2689. }
  2690. return TRuntimeNode(callableBuilder.Build(), false);
  2691. }
  2692. TRuntimeNode TProgramBuilder::PreserveStream(TRuntimeNode stream, TRuntimeNode queue, TRuntimeNode outpace) {
  2693. auto streamType = AS_TYPE(TStreamType, stream);
  2694. auto resType = AS_TYPE(TResourceType, queue);
  2695. auto outpaceType = AS_TYPE(TDataType, outpace);
  2696. MKQL_ENSURE(outpaceType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "PreserveStream: outpace size must be ui64");
  2697. const auto tag = resType->GetTag();
  2698. MKQL_ENSURE(tag.StartsWith(ResourceQueuePrefix), "PreserveStream: Expected Queue resource");
  2699. TCallableBuilder callableBuilder(Env, __func__, streamType);
  2700. callableBuilder.Add(stream);
  2701. callableBuilder.Add(queue);
  2702. callableBuilder.Add(outpace);
  2703. return TRuntimeNode(callableBuilder.Build(), false);
  2704. }
  2705. TRuntimeNode TProgramBuilder::Seq(const TArrayRef<const TRuntimeNode>& args, TType* returnType) {
  2706. MKQL_ENSURE(RuntimeVersion >= 15, "Seq is not supported in runtime version " << RuntimeVersion);
  2707. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2708. for (auto node : args) {
  2709. callableBuilder.Add(node);
  2710. }
  2711. return TRuntimeNode(callableBuilder.Build(), false);
  2712. }
  2713. TRuntimeNode TProgramBuilder::FromYsonSimpleType(TRuntimeNode input, NUdf::TDataTypeId schemeType) {
  2714. auto type = input.GetStaticType();
  2715. if (type->IsOptional()) {
  2716. type = static_cast<const TOptionalType&>(*type).GetItemType();
  2717. }
  2718. MKQL_ENSURE(type->IsData(), "Expected data type");
  2719. auto resDataType = NewDataType(schemeType);
  2720. auto resultType = NewOptionalType(resDataType);
  2721. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2722. callableBuilder.Add(input);
  2723. callableBuilder.Add(NewDataLiteral(static_cast<ui32>(schemeType)));
  2724. return TRuntimeNode(callableBuilder.Build(), false);
  2725. }
  2726. TRuntimeNode TProgramBuilder::TryWeakMemberFromDict(TRuntimeNode other, TRuntimeNode rest, NUdf::TDataTypeId schemeType, const std::string_view& memberName) {
  2727. auto resDataType = NewDataType(schemeType);
  2728. auto resultType = NewOptionalType(resDataType);
  2729. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2730. callableBuilder.Add(other);
  2731. callableBuilder.Add(rest);
  2732. callableBuilder.Add(NewDataLiteral(static_cast<ui32>(schemeType)));
  2733. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(memberName));
  2734. return TRuntimeNode(callableBuilder.Build(), false);
  2735. }
  2736. TRuntimeNode TProgramBuilder::TimezoneId(TRuntimeNode name) {
  2737. bool isOptional;
  2738. auto dataType = UnpackOptionalData(name, isOptional);
  2739. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expected string");
  2740. auto resultType = NewOptionalType(NewDataType(NUdf::EDataSlot::Uint16));
  2741. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2742. callableBuilder.Add(name);
  2743. return TRuntimeNode(callableBuilder.Build(), false);
  2744. }
  2745. TRuntimeNode TProgramBuilder::TimezoneName(TRuntimeNode id) {
  2746. bool isOptional;
  2747. auto dataType = UnpackOptionalData(id, isOptional);
  2748. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<ui16>::Id, "Expected ui32");
  2749. auto resultType = NewOptionalType(NewDataType(NUdf::EDataSlot::String));
  2750. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2751. callableBuilder.Add(id);
  2752. return TRuntimeNode(callableBuilder.Build(), false);
  2753. }
  2754. TRuntimeNode TProgramBuilder::AddTimezone(TRuntimeNode utc, TRuntimeNode id) {
  2755. bool isOptional1;
  2756. auto dataType1 = UnpackOptionalData(utc, isOptional1);
  2757. MKQL_ENSURE(NUdf::GetDataTypeInfo(*dataType1->GetDataSlot()).Features & NUdf::DateType, "Expected date type");
  2758. bool isOptional2;
  2759. auto dataType2 = UnpackOptionalData(id, isOptional2);
  2760. MKQL_ENSURE(dataType2->GetSchemeType() == NUdf::TDataType<ui16>::Id, "Expected ui16");
  2761. NUdf::EDataSlot tzType;
  2762. switch (*dataType1->GetDataSlot()) {
  2763. case NUdf::EDataSlot::Date: tzType = NUdf::EDataSlot::TzDate; break;
  2764. case NUdf::EDataSlot::Datetime: tzType = NUdf::EDataSlot::TzDatetime; break;
  2765. case NUdf::EDataSlot::Timestamp: tzType = NUdf::EDataSlot::TzTimestamp; break;
  2766. case NUdf::EDataSlot::Date32: tzType = NUdf::EDataSlot::TzDate32; break;
  2767. case NUdf::EDataSlot::Datetime64: tzType = NUdf::EDataSlot::TzDatetime64; break;
  2768. case NUdf::EDataSlot::Timestamp64: tzType = NUdf::EDataSlot::TzTimestamp64; break;
  2769. default:
  2770. ythrow yexception() << "Unknown date type: " << *dataType1->GetDataSlot();
  2771. }
  2772. auto resultType = NewOptionalType(NewDataType(tzType));
  2773. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2774. callableBuilder.Add(utc);
  2775. callableBuilder.Add(id);
  2776. return TRuntimeNode(callableBuilder.Build(), false);
  2777. }
  2778. TRuntimeNode TProgramBuilder::RemoveTimezone(TRuntimeNode local) {
  2779. bool isOptional1;
  2780. const auto dataType1 = UnpackOptionalData(local, isOptional1);
  2781. MKQL_ENSURE((NUdf::GetDataTypeInfo(*dataType1->GetDataSlot()).Features & NUdf::TzDateType), "Expected date with timezone type");
  2782. NUdf::EDataSlot type;
  2783. switch (*dataType1->GetDataSlot()) {
  2784. case NUdf::EDataSlot::TzDate: type = NUdf::EDataSlot::Date; break;
  2785. case NUdf::EDataSlot::TzDatetime: type = NUdf::EDataSlot::Datetime; break;
  2786. case NUdf::EDataSlot::TzTimestamp: type = NUdf::EDataSlot::Timestamp; break;
  2787. case NUdf::EDataSlot::TzDate32: type = NUdf::EDataSlot::Date32; break;
  2788. case NUdf::EDataSlot::TzDatetime64: type = NUdf::EDataSlot::Datetime64; break;
  2789. case NUdf::EDataSlot::TzTimestamp64: type = NUdf::EDataSlot::Timestamp64; break;
  2790. default:
  2791. ythrow yexception() << "Unknown date with timezone type: " << *dataType1->GetDataSlot();
  2792. }
  2793. return Convert(local, NewDataType(type, isOptional1));
  2794. }
  2795. TRuntimeNode TProgramBuilder::Nth(TRuntimeNode tuple, ui32 index) {
  2796. bool isOptional;
  2797. const auto type = AS_TYPE(TTupleType, UnpackOptional(tuple.GetStaticType(), isOptional));
  2798. MKQL_ENSURE(index < type->GetElementsCount(), "Index out of range: " << index <<
  2799. " is not less than " << type->GetElementsCount());
  2800. auto itemType = type->GetElementType(index);
  2801. if (isOptional && !itemType->IsOptional() && !itemType->IsNull() && !itemType->IsPg()) {
  2802. itemType = TOptionalType::Create(itemType, Env);
  2803. }
  2804. TCallableBuilder callableBuilder(Env, __func__, itemType);
  2805. callableBuilder.Add(tuple);
  2806. callableBuilder.Add(NewDataLiteral<ui32>(index));
  2807. return TRuntimeNode(callableBuilder.Build(), false);
  2808. }
  2809. TRuntimeNode TProgramBuilder::Element(TRuntimeNode tuple, ui32 index) {
  2810. return Nth(tuple, index);
  2811. }
  2812. TRuntimeNode TProgramBuilder::Guess(TRuntimeNode variant, ui32 tupleIndex) {
  2813. bool isOptional;
  2814. auto unpacked = UnpackOptional(variant, isOptional);
  2815. auto type = AS_TYPE(TVariantType, unpacked);
  2816. auto underlyingType = AS_TYPE(TTupleType, type->GetUnderlyingType());
  2817. MKQL_ENSURE(tupleIndex < underlyingType->GetElementsCount(), "Wrong tuple index");
  2818. auto resType = TOptionalType::Create(underlyingType->GetElementType(tupleIndex), Env);
  2819. TCallableBuilder callableBuilder(Env, __func__, resType);
  2820. callableBuilder.Add(variant);
  2821. callableBuilder.Add(NewDataLiteral<ui32>(tupleIndex));
  2822. return TRuntimeNode(callableBuilder.Build(), false);
  2823. }
  2824. TRuntimeNode TProgramBuilder::Guess(TRuntimeNode variant, const std::string_view& memberName) {
  2825. bool isOptional;
  2826. auto unpacked = UnpackOptional(variant, isOptional);
  2827. auto type = AS_TYPE(TVariantType, unpacked);
  2828. auto underlyingType = AS_TYPE(TStructType, type->GetUnderlyingType());
  2829. auto structIndex = underlyingType->GetMemberIndex(memberName);
  2830. auto resType = TOptionalType::Create(underlyingType->GetMemberType(structIndex), Env);
  2831. TCallableBuilder callableBuilder(Env, __func__, resType);
  2832. callableBuilder.Add(variant);
  2833. callableBuilder.Add(NewDataLiteral<ui32>(structIndex));
  2834. return TRuntimeNode(callableBuilder.Build(), false);
  2835. }
  2836. TRuntimeNode TProgramBuilder::Way(TRuntimeNode variant) {
  2837. bool isOptional;
  2838. auto unpacked = UnpackOptional(variant, isOptional);
  2839. auto type = AS_TYPE(TVariantType, unpacked);
  2840. auto underlyingType = type->GetUnderlyingType();
  2841. auto dataType = NewDataType(underlyingType->IsTuple() ? NUdf::EDataSlot::Uint32 : NUdf::EDataSlot::Utf8);
  2842. auto resType = isOptional ? TOptionalType::Create(dataType, Env) : dataType;
  2843. TCallableBuilder callableBuilder(Env, __func__, resType);
  2844. callableBuilder.Add(variant);
  2845. return TRuntimeNode(callableBuilder.Build(), false);
  2846. }
  2847. TRuntimeNode TProgramBuilder::VariantItem(TRuntimeNode variant) {
  2848. bool isOptional;
  2849. auto unpacked = UnpackOptional(variant, isOptional);
  2850. auto type = AS_TYPE(TVariantType, unpacked);
  2851. auto underlyingType = type->GetAlternativeType(0);
  2852. auto resType = isOptional ? TOptionalType::Create(underlyingType, Env) : underlyingType;
  2853. TCallableBuilder callableBuilder(Env, __func__, resType);
  2854. callableBuilder.Add(variant);
  2855. return TRuntimeNode(callableBuilder.Build(), false);
  2856. }
  2857. TRuntimeNode TProgramBuilder::DynamicVariant(TRuntimeNode item, TRuntimeNode index, TType* variantType) {
  2858. if constexpr (RuntimeVersion < 56U) {
  2859. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  2860. }
  2861. auto type = AS_TYPE(TVariantType, variantType);
  2862. auto expectedIndexSlot = type->GetUnderlyingType()->IsTuple() ? NUdf::EDataSlot::Uint32 : NUdf::EDataSlot::Utf8;
  2863. bool isOptional;
  2864. auto indexType = UnpackOptionalData(index.GetStaticType(), isOptional);
  2865. MKQL_ENSURE(indexType->GetDataSlot() == expectedIndexSlot, "Mismatch type of index");
  2866. auto resType = TOptionalType::Create(type, Env);
  2867. TCallableBuilder callableBuilder(Env, __func__, resType);
  2868. callableBuilder.Add(item);
  2869. callableBuilder.Add(index);
  2870. callableBuilder.Add(TRuntimeNode(variantType, true));
  2871. return TRuntimeNode(callableBuilder.Build(), false);
  2872. }
  2873. TRuntimeNode TProgramBuilder::VisitAll(TRuntimeNode variant, std::function<TRuntimeNode(ui32, TRuntimeNode)> handler) {
  2874. const auto type = AS_TYPE(TVariantType, variant);
  2875. std::vector<TRuntimeNode> items;
  2876. std::vector<TRuntimeNode> newItems;
  2877. for (ui32 i = 0; i < type->GetAlternativesCount(); ++i) {
  2878. const auto itemType = type->GetAlternativeType(i);
  2879. const auto itemArg = Arg(itemType);
  2880. const auto res = handler(i, itemArg);
  2881. items.emplace_back(itemArg);
  2882. newItems.emplace_back(res);
  2883. }
  2884. bool hasOptional;
  2885. const auto firstUnpacked = UnpackOptional(newItems.front(), hasOptional);
  2886. bool allOptional = hasOptional;
  2887. for (size_t i = 1U; i < newItems.size(); ++i) {
  2888. bool isOptional;
  2889. const auto unpacked = UnpackOptional(newItems[i].GetStaticType(), isOptional);
  2890. MKQL_ENSURE(unpacked->IsSameType(*firstUnpacked), "Different return types in branches.");
  2891. hasOptional = hasOptional || isOptional;
  2892. allOptional = allOptional && isOptional;
  2893. }
  2894. if (hasOptional && !allOptional) {
  2895. for (auto& item : newItems) {
  2896. if (!item.GetStaticType()->IsOptional()) {
  2897. item = NewOptional(item);
  2898. }
  2899. }
  2900. }
  2901. TCallableBuilder callableBuilder(Env, __func__, newItems.front().GetStaticType());
  2902. callableBuilder.Add(variant);
  2903. for (ui32 i = 0; i < type->GetAlternativesCount(); ++i) {
  2904. callableBuilder.Add(items[i]);
  2905. callableBuilder.Add(newItems[i]);
  2906. }
  2907. return TRuntimeNode(callableBuilder.Build(), false);
  2908. }
  2909. TRuntimeNode TProgramBuilder::UnaryDataFunction(TRuntimeNode data, const std::string_view& callableName, ui32 flags) {
  2910. bool isOptional;
  2911. auto type = UnpackOptionalData(data, isOptional);
  2912. if (!(flags & TDataFunctionFlags::AllowOptionalArgs)) {
  2913. MKQL_ENSURE(!isOptional, "Optional data is not allowed");
  2914. }
  2915. auto schemeType = type->GetSchemeType();
  2916. if (flags & TDataFunctionFlags::RequiresBooleanArgs) {
  2917. MKQL_ENSURE(schemeType == NUdf::TDataType<bool>::Id, "Boolean data is required");
  2918. } else if (flags & TDataFunctionFlags::RequiresStringArgs) {
  2919. MKQL_ENSURE(schemeType == NUdf::TDataType<char*>::Id, "String data is required");
  2920. }
  2921. if (!schemeType) {
  2922. MKQL_ENSURE((flags & TDataFunctionFlags::AllowNull) != 0, "Null is not allowed");
  2923. }
  2924. TType* resultType;
  2925. if (flags & TDataFunctionFlags::HasBooleanResult) {
  2926. resultType = TDataType::Create(NUdf::TDataType<bool>::Id, Env);
  2927. } else if (flags & TDataFunctionFlags::HasUi32Result) {
  2928. resultType = TDataType::Create(NUdf::TDataType<ui32>::Id, Env);
  2929. } else if (flags & TDataFunctionFlags::HasStringResult) {
  2930. resultType = TDataType::Create(NUdf::TDataType<char*>::Id, Env);
  2931. } else if (flags & TDataFunctionFlags::HasOptionalResult) {
  2932. resultType = TOptionalType::Create(type, Env);
  2933. } else {
  2934. resultType = type;
  2935. }
  2936. if ((flags & TDataFunctionFlags::CommonOptionalResult) && isOptional) {
  2937. resultType = TOptionalType::Create(resultType, Env);
  2938. }
  2939. TCallableBuilder callableBuilder(Env, callableName, resultType);
  2940. callableBuilder.Add(data);
  2941. return TRuntimeNode(callableBuilder.Build(), false);
  2942. }
  2943. TRuntimeNode TProgramBuilder::ToDict(TRuntimeNode list, bool multi, const TUnaryLambda& keySelector,
  2944. const TUnaryLambda& payloadSelector, std::string_view callableName, bool isCompact, ui64 itemsCountHint)
  2945. {
  2946. bool isOptional;
  2947. const auto type = UnpackOptional(list, isOptional);
  2948. MKQL_ENSURE(type->IsList(), "Expected list.");
  2949. if (isOptional) {
  2950. return Map(list, [&](TRuntimeNode unpacked) { return ToDict(unpacked, multi, keySelector, payloadSelector, callableName, isCompact, itemsCountHint); } );
  2951. }
  2952. const auto itemType = AS_TYPE(TListType, type)->GetItemType();
  2953. ThrowIfListOfVoid(itemType);
  2954. const auto itemArg = Arg(itemType);
  2955. const auto key = keySelector(itemArg);
  2956. const auto keyType = key.GetStaticType();
  2957. auto payload = payloadSelector(itemArg);
  2958. auto payloadType = payload.GetStaticType();
  2959. if (multi) {
  2960. payloadType = TListType::Create(payloadType, Env);
  2961. }
  2962. auto dictType = TDictType::Create(keyType, payloadType, Env);
  2963. TCallableBuilder callableBuilder(Env, callableName, dictType);
  2964. callableBuilder.Add(list);
  2965. callableBuilder.Add(itemArg);
  2966. callableBuilder.Add(key);
  2967. callableBuilder.Add(payload);
  2968. callableBuilder.Add(NewDataLiteral(multi));
  2969. callableBuilder.Add(NewDataLiteral(isCompact));
  2970. callableBuilder.Add(NewDataLiteral(itemsCountHint));
  2971. return TRuntimeNode(callableBuilder.Build(), false);
  2972. }
  2973. TRuntimeNode TProgramBuilder::SqueezeToDict(TRuntimeNode stream, bool multi, const TUnaryLambda& keySelector,
  2974. const TUnaryLambda& payloadSelector, std::string_view callableName, bool isCompact, ui64 itemsCountHint)
  2975. {
  2976. if constexpr (RuntimeVersion < 21U) {
  2977. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  2978. }
  2979. const auto type = stream.GetStaticType();
  2980. MKQL_ENSURE(type->IsStream() || type->IsFlow(), "Expected stream or flow.");
  2981. const auto itemType = type->IsFlow() ? AS_TYPE(TFlowType, type)->GetItemType() : AS_TYPE(TStreamType, type)->GetItemType();
  2982. ThrowIfListOfVoid(itemType);
  2983. const auto itemArg = Arg(itemType);
  2984. const auto key = keySelector(itemArg);
  2985. const auto keyType = key.GetStaticType();
  2986. auto payload = payloadSelector(itemArg);
  2987. auto payloadType = payload.GetStaticType();
  2988. if (multi) {
  2989. payloadType = TListType::Create(payloadType, Env);
  2990. }
  2991. auto dictType = TDictType::Create(keyType, payloadType, Env);
  2992. auto returnType = type->IsFlow()
  2993. ? (TType*) TFlowType::Create(dictType, Env)
  2994. : (TType*) TStreamType::Create(dictType, Env);
  2995. TCallableBuilder callableBuilder(Env, callableName, returnType);
  2996. callableBuilder.Add(stream);
  2997. callableBuilder.Add(itemArg);
  2998. callableBuilder.Add(key);
  2999. callableBuilder.Add(payload);
  3000. callableBuilder.Add(NewDataLiteral(multi));
  3001. callableBuilder.Add(NewDataLiteral(isCompact));
  3002. callableBuilder.Add(NewDataLiteral(itemsCountHint));
  3003. return TRuntimeNode(callableBuilder.Build(), false);
  3004. }
  3005. TRuntimeNode TProgramBuilder::NarrowSqueezeToDict(TRuntimeNode flow, bool multi, const TNarrowLambda& keySelector,
  3006. const TNarrowLambda& payloadSelector, std::string_view callableName, bool isCompact, ui64 itemsCountHint)
  3007. {
  3008. if constexpr (RuntimeVersion < 23U) {
  3009. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3010. }
  3011. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3012. TRuntimeNode::TList itemArgs;
  3013. itemArgs.reserve(wideComponents.size());
  3014. auto i = 0U;
  3015. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3016. const auto key = keySelector(itemArgs);
  3017. const auto keyType = key.GetStaticType();
  3018. auto payload = payloadSelector(itemArgs);
  3019. auto payloadType = payload.GetStaticType();
  3020. if (multi) {
  3021. payloadType = TListType::Create(payloadType, Env);
  3022. }
  3023. const auto dictType = TDictType::Create(keyType, payloadType, Env);
  3024. const auto returnType = TFlowType::Create(dictType, Env);
  3025. TCallableBuilder callableBuilder(Env, callableName, returnType);
  3026. callableBuilder.Add(flow);
  3027. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3028. callableBuilder.Add(key);
  3029. callableBuilder.Add(payload);
  3030. callableBuilder.Add(NewDataLiteral(multi));
  3031. callableBuilder.Add(NewDataLiteral(isCompact));
  3032. callableBuilder.Add(NewDataLiteral(itemsCountHint));
  3033. return TRuntimeNode(callableBuilder.Build(), false);
  3034. }
  3035. void TProgramBuilder::ThrowIfListOfVoid(TType* type) {
  3036. MKQL_ENSURE(!VoidWithEffects || !type->IsVoid(), "List of void is forbidden for current function");
  3037. }
  3038. TRuntimeNode TProgramBuilder::BuildFlatMap(const std::string_view& callableName, TRuntimeNode list, const TUnaryLambda& handler)
  3039. {
  3040. const auto listType = list.GetStaticType();
  3041. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsOptional() || listType->IsStream(), "Expected flow, list, stream or optional");
  3042. if (listType->IsOptional()) {
  3043. const auto itemArg = Arg(AS_TYPE(TOptionalType, listType)->GetItemType());
  3044. const auto newList = handler(itemArg);
  3045. const auto type = newList.GetStaticType();
  3046. MKQL_ENSURE(type->IsList() || type->IsOptional() || type->IsStream() || type->IsFlow(), "Expected flow, list, stream or optional");
  3047. return IfPresent(list, [&](TRuntimeNode item) {
  3048. return handler(item);
  3049. }, type->IsOptional() ? NewEmptyOptional(type) : type->IsList() ? NewEmptyList(AS_TYPE(TListType, type)->GetItemType()) : EmptyIterator(type));
  3050. }
  3051. const auto itemType = listType->IsFlow() ?
  3052. AS_TYPE(TFlowType, listType)->GetItemType():
  3053. listType->IsList() ?
  3054. AS_TYPE(TListType, listType)->GetItemType():
  3055. AS_TYPE(TStreamType, listType)->GetItemType();
  3056. ThrowIfListOfVoid(itemType);
  3057. const auto itemArg = Arg(itemType);
  3058. const auto newList = handler(itemArg);
  3059. const auto type = newList.GetStaticType();
  3060. TType* retItemType = nullptr;
  3061. if (type->IsOptional()) {
  3062. retItemType = AS_TYPE(TOptionalType, type)->GetItemType();
  3063. } else if (type->IsFlow()) {
  3064. retItemType = AS_TYPE(TFlowType, type)->GetItemType();
  3065. } else if (type->IsList()) {
  3066. retItemType = AS_TYPE(TListType, type)->GetItemType();
  3067. } else if (type->IsStream()) {
  3068. retItemType = AS_TYPE(TStreamType, type)->GetItemType();
  3069. } else {
  3070. THROW yexception() << "Expected flow, list or stream.";
  3071. }
  3072. const auto resultListType = listType->IsFlow() || type->IsFlow() ?
  3073. TFlowType::Create(retItemType, Env):
  3074. listType->IsList() ?
  3075. (TType*)TListType::Create(retItemType, Env):
  3076. (TType*)TStreamType::Create(retItemType, Env);
  3077. TCallableBuilder callableBuilder(Env, callableName, resultListType);
  3078. callableBuilder.Add(list);
  3079. callableBuilder.Add(itemArg);
  3080. callableBuilder.Add(newList);
  3081. return TRuntimeNode(callableBuilder.Build(), false);
  3082. }
  3083. TRuntimeNode TProgramBuilder::MultiMap(TRuntimeNode list, const TExpandLambda& handler)
  3084. {
  3085. if constexpr (RuntimeVersion < 16U) {
  3086. const auto single = [=](TRuntimeNode item) -> TRuntimeNode {
  3087. const auto newList = handler(item);
  3088. const auto retItemType = newList.front().GetStaticType();
  3089. MKQL_ENSURE(retItemType->IsSameType(*newList.back().GetStaticType()), "Must be same type.");
  3090. return NewList(retItemType, newList);
  3091. };
  3092. return OrderedFlatMap(list, single);
  3093. }
  3094. const auto listType = list.GetStaticType();
  3095. MKQL_ENSURE(listType->IsFlow() || listType->IsList(), "Expected flow, list, stream or optional");
  3096. const auto itemType = listType->IsFlow() ? AS_TYPE(TFlowType, listType)->GetItemType() : AS_TYPE(TListType, listType)->GetItemType();
  3097. const auto itemArg = Arg(itemType);
  3098. const auto newList = handler(itemArg);
  3099. MKQL_ENSURE(newList.size() > 1U, "Expected many items.");
  3100. const auto retItemType = newList.front().GetStaticType();
  3101. MKQL_ENSURE(retItemType->IsSameType(*newList.back().GetStaticType()), "Must be same type.");
  3102. const auto resultListType = listType->IsFlow() ?
  3103. (TType*)TFlowType::Create(retItemType, Env) : (TType*)TListType::Create(retItemType, Env);
  3104. TCallableBuilder callableBuilder(Env, __func__, resultListType);
  3105. callableBuilder.Add(list);
  3106. callableBuilder.Add(itemArg);
  3107. std::for_each(newList.cbegin(), newList.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3108. return TRuntimeNode(callableBuilder.Build(), false);
  3109. }
  3110. TRuntimeNode TProgramBuilder::NarrowMultiMap(TRuntimeNode flow, const TWideLambda& handler) {
  3111. if constexpr (RuntimeVersion < 18U) {
  3112. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3113. }
  3114. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3115. TRuntimeNode::TList itemArgs;
  3116. itemArgs.reserve(wideComponents.size());
  3117. auto i = 0U;
  3118. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3119. const auto newList = handler(itemArgs);
  3120. MKQL_ENSURE(newList.size() > 1U, "Expected many items.");
  3121. const auto retItemType = newList.front().GetStaticType();
  3122. MKQL_ENSURE(retItemType->IsSameType(*newList.back().GetStaticType()), "Must be same type.");
  3123. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(newList.front().GetStaticType()));
  3124. callableBuilder.Add(flow);
  3125. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3126. std::for_each(newList.cbegin(), newList.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3127. return TRuntimeNode(callableBuilder.Build(), false);
  3128. }
  3129. TRuntimeNode TProgramBuilder::ExpandMap(TRuntimeNode flow, const TExpandLambda& handler) {
  3130. if constexpr (RuntimeVersion < 18U) {
  3131. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3132. }
  3133. const auto itemType = AS_TYPE(TFlowType, flow.GetStaticType())->GetItemType();
  3134. const auto itemArg = Arg(itemType);
  3135. const auto newItems = handler(itemArg);
  3136. std::vector<TType*> tupleItems;
  3137. tupleItems.reserve(newItems.size());
  3138. std::transform(newItems.cbegin(), newItems.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  3139. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(tupleItems)));
  3140. callableBuilder.Add(flow);
  3141. callableBuilder.Add(itemArg);
  3142. std::for_each(newItems.cbegin(), newItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3143. return TRuntimeNode(callableBuilder.Build(), false);
  3144. }
  3145. TRuntimeNode TProgramBuilder::WideMap(TRuntimeNode flow, const TWideLambda& handler) {
  3146. if constexpr (RuntimeVersion < 18U) {
  3147. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3148. }
  3149. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3150. TRuntimeNode::TList itemArgs;
  3151. itemArgs.reserve(wideComponents.size());
  3152. auto i = 0U;
  3153. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3154. const auto newItems = handler(itemArgs);
  3155. std::vector<TType*> tupleItems;
  3156. tupleItems.reserve(newItems.size());
  3157. std::transform(newItems.cbegin(), newItems.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  3158. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(tupleItems)));
  3159. callableBuilder.Add(flow);
  3160. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3161. std::for_each(newItems.cbegin(), newItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3162. return TRuntimeNode(callableBuilder.Build(), false);
  3163. }
  3164. TRuntimeNode TProgramBuilder::WideChain1Map(TRuntimeNode flow, const TWideLambda& init, const TBinaryWideLambda& update) {
  3165. if constexpr (RuntimeVersion < 23U) {
  3166. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3167. }
  3168. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3169. TRuntimeNode::TList inputArgs;
  3170. inputArgs.reserve(wideComponents.size());
  3171. auto i = 0U;
  3172. std::generate_n(std::back_inserter(inputArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3173. const auto initItems = init(inputArgs);
  3174. std::vector<TType*> tupleItems;
  3175. tupleItems.reserve(initItems.size());
  3176. std::transform(initItems.cbegin(), initItems.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  3177. TRuntimeNode::TList outputArgs;
  3178. outputArgs.reserve(tupleItems.size());
  3179. std::transform(tupleItems.cbegin(), tupleItems.cend(), std::back_inserter(outputArgs), std::bind(&TProgramBuilder::Arg, this, std::placeholders::_1));
  3180. const auto updateItems = update(inputArgs, outputArgs);
  3181. MKQL_ENSURE(initItems.size() == updateItems.size(), "Expected same width.");
  3182. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(tupleItems)));
  3183. callableBuilder.Add(flow);
  3184. std::for_each(inputArgs.cbegin(), inputArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3185. std::for_each(initItems.cbegin(), initItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3186. std::for_each(outputArgs.cbegin(), outputArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3187. std::for_each(updateItems.cbegin(), updateItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3188. return TRuntimeNode(callableBuilder.Build(), false);
  3189. }
  3190. TRuntimeNode TProgramBuilder::NarrowMap(TRuntimeNode flow, const TNarrowLambda& handler) {
  3191. if constexpr (RuntimeVersion < 18U) {
  3192. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3193. }
  3194. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3195. TRuntimeNode::TList itemArgs;
  3196. itemArgs.reserve(wideComponents.size());
  3197. auto i = 0U;
  3198. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3199. const auto newItem = handler(itemArgs);
  3200. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(newItem.GetStaticType()));
  3201. callableBuilder.Add(flow);
  3202. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3203. callableBuilder.Add(newItem);
  3204. return TRuntimeNode(callableBuilder.Build(), false);
  3205. }
  3206. TRuntimeNode TProgramBuilder::NarrowFlatMap(TRuntimeNode flow, const TNarrowLambda& handler) {
  3207. if constexpr (RuntimeVersion < 18U) {
  3208. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3209. }
  3210. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3211. TRuntimeNode::TList itemArgs;
  3212. itemArgs.reserve(wideComponents.size());
  3213. auto i = 0U;
  3214. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3215. const auto newList = handler(itemArgs);
  3216. const auto type = newList.GetStaticType();
  3217. TType* retItemType = nullptr;
  3218. if (type->IsOptional()) {
  3219. retItemType = AS_TYPE(TOptionalType, type)->GetItemType();
  3220. } else if (type->IsFlow()) {
  3221. retItemType = AS_TYPE(TFlowType, type)->GetItemType();
  3222. } else if (type->IsList()) {
  3223. retItemType = AS_TYPE(TListType, type)->GetItemType();
  3224. } else if (type->IsStream()) {
  3225. retItemType = AS_TYPE(TStreamType, type)->GetItemType();
  3226. } else {
  3227. THROW yexception() << "Expected flow, list or stream.";
  3228. }
  3229. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(retItemType));
  3230. callableBuilder.Add(flow);
  3231. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3232. callableBuilder.Add(newList);
  3233. return TRuntimeNode(callableBuilder.Build(), false);
  3234. }
  3235. TRuntimeNode TProgramBuilder::BuildWideFilter(const std::string_view& callableName, TRuntimeNode flow, const TNarrowLambda& handler) {
  3236. if constexpr (RuntimeVersion < 18U) {
  3237. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3238. }
  3239. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3240. TRuntimeNode::TList itemArgs;
  3241. itemArgs.reserve(wideComponents.size());
  3242. auto i = 0U;
  3243. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3244. const auto predicate = handler(itemArgs);
  3245. TCallableBuilder callableBuilder(Env, callableName, flow.GetStaticType());
  3246. callableBuilder.Add(flow);
  3247. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3248. callableBuilder.Add(predicate);
  3249. return TRuntimeNode(callableBuilder.Build(), false);
  3250. }
  3251. TRuntimeNode TProgramBuilder::WideFilter(TRuntimeNode flow, const TNarrowLambda& handler) {
  3252. return BuildWideFilter(__func__, flow, handler);
  3253. }
  3254. TRuntimeNode TProgramBuilder::WideTakeWhile(TRuntimeNode flow, const TNarrowLambda& handler) {
  3255. return BuildWideFilter(__func__, flow, handler);
  3256. }
  3257. TRuntimeNode TProgramBuilder::WideSkipWhile(TRuntimeNode flow, const TNarrowLambda& handler) {
  3258. return BuildWideFilter(__func__, flow, handler);
  3259. }
  3260. TRuntimeNode TProgramBuilder::WideTakeWhileInclusive(TRuntimeNode flow, const TNarrowLambda& handler) {
  3261. return BuildWideFilter(__func__, flow, handler);
  3262. }
  3263. TRuntimeNode TProgramBuilder::WideSkipWhileInclusive(TRuntimeNode flow, const TNarrowLambda& handler) {
  3264. return BuildWideFilter(__func__, flow, handler);
  3265. }
  3266. TRuntimeNode TProgramBuilder::WideFilter(TRuntimeNode flow, TRuntimeNode limit, const TNarrowLambda& handler) {
  3267. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3268. TRuntimeNode::TList itemArgs;
  3269. itemArgs.reserve(wideComponents.size());
  3270. auto i = 0U;
  3271. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3272. const auto predicate = handler(itemArgs);
  3273. TCallableBuilder callableBuilder(Env, __func__, flow.GetStaticType());
  3274. callableBuilder.Add(flow);
  3275. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3276. callableBuilder.Add(predicate);
  3277. callableBuilder.Add(limit);
  3278. return TRuntimeNode(callableBuilder.Build(), false);
  3279. }
  3280. TRuntimeNode TProgramBuilder::BuildFilter(const std::string_view& callableName, TRuntimeNode list, const TUnaryLambda& handler, TType* resultType)
  3281. {
  3282. const auto listType = list.GetStaticType();
  3283. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream.");
  3284. const auto outputType = resultType ? resultType : listType;
  3285. const auto itemType = listType->IsFlow() ?
  3286. AS_TYPE(TFlowType, listType)->GetItemType():
  3287. listType->IsList() ?
  3288. AS_TYPE(TListType, listType)->GetItemType():
  3289. AS_TYPE(TStreamType, listType)->GetItemType();
  3290. ThrowIfListOfVoid(itemType);
  3291. const auto itemArg = Arg(itemType);
  3292. const auto predicate = handler(itemArg);
  3293. MKQL_ENSURE(predicate.GetStaticType()->IsData(), "Expected boolean data");
  3294. const auto& detailedPredicateType = static_cast<const TDataType&>(*predicate.GetStaticType());
  3295. MKQL_ENSURE(detailedPredicateType.GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected boolean data");
  3296. TCallableBuilder callableBuilder(Env, callableName, outputType);
  3297. callableBuilder.Add(list);
  3298. callableBuilder.Add(itemArg);
  3299. callableBuilder.Add(predicate);
  3300. return TRuntimeNode(callableBuilder.Build(), false);
  3301. }
  3302. TRuntimeNode TProgramBuilder::BuildFilter(const std::string_view& callableName, TRuntimeNode list, TRuntimeNode limit, const TUnaryLambda& handler, TType* resultType)
  3303. {
  3304. if constexpr (RuntimeVersion < 4U) {
  3305. return Take(BuildFilter(callableName, list, handler, resultType), limit);
  3306. }
  3307. const auto listType = list.GetStaticType();
  3308. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream.");
  3309. MKQL_ENSURE(limit.GetStaticType()->IsData(), "Expected data");
  3310. const auto outputType = resultType ? resultType : listType;
  3311. const auto itemType = listType->IsFlow() ?
  3312. AS_TYPE(TFlowType, listType)->GetItemType():
  3313. listType->IsList() ?
  3314. AS_TYPE(TListType, listType)->GetItemType():
  3315. AS_TYPE(TStreamType, listType)->GetItemType();
  3316. ThrowIfListOfVoid(itemType);
  3317. const auto itemArg = Arg(itemType);
  3318. const auto predicate = handler(itemArg);
  3319. MKQL_ENSURE(predicate.GetStaticType()->IsData(), "Expected boolean data");
  3320. const auto& detailedPredicateType = static_cast<const TDataType&>(*predicate.GetStaticType());
  3321. MKQL_ENSURE(detailedPredicateType.GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected boolean data");
  3322. TCallableBuilder callableBuilder(Env, callableName, outputType);
  3323. callableBuilder.Add(list);
  3324. callableBuilder.Add(limit);
  3325. callableBuilder.Add(itemArg);
  3326. callableBuilder.Add(predicate);
  3327. return TRuntimeNode(callableBuilder.Build(), false);
  3328. }
  3329. TRuntimeNode TProgramBuilder::Filter(TRuntimeNode list, const TUnaryLambda& handler, TType* resultType)
  3330. {
  3331. const auto type = list.GetStaticType();
  3332. if (type->IsOptional()) {
  3333. return
  3334. IfPresent(list,
  3335. [&](TRuntimeNode item) {
  3336. return If(handler(item), item, NewEmptyOptional(resultType), resultType);
  3337. },
  3338. NewEmptyOptional(resultType)
  3339. );
  3340. }
  3341. return BuildFilter(__func__, list, handler, resultType);
  3342. }
  3343. TRuntimeNode TProgramBuilder::BuildHeap(const std::string_view& callableName, TRuntimeNode list, const TBinaryLambda& comparator) {
  3344. const auto listType = list.GetStaticType();
  3345. MKQL_ENSURE(listType->IsList(), "Expected list.");
  3346. const auto itemType = AS_TYPE(TListType, listType)->GetItemType();
  3347. const auto leftArg = Arg(itemType);
  3348. const auto rightArg = Arg(itemType);
  3349. const auto predicate = comparator(leftArg, rightArg);
  3350. TCallableBuilder callableBuilder(Env, callableName, listType);
  3351. callableBuilder.Add(list);
  3352. callableBuilder.Add(leftArg);
  3353. callableBuilder.Add(rightArg);
  3354. callableBuilder.Add(predicate);
  3355. return TRuntimeNode(callableBuilder.Build(), false);
  3356. }
  3357. TRuntimeNode TProgramBuilder::BuildNth(const std::string_view& callableName, TRuntimeNode list, TRuntimeNode n, const TBinaryLambda& comparator) {
  3358. const auto listType = list.GetStaticType();
  3359. MKQL_ENSURE(listType->IsList(), "Expected list.");
  3360. const auto itemType = AS_TYPE(TListType, listType)->GetItemType();
  3361. MKQL_ENSURE(n.GetStaticType()->IsData(), "Expected data");
  3362. MKQL_ENSURE(static_cast<const TDataType&>(*n.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  3363. const auto leftArg = Arg(itemType);
  3364. const auto rightArg = Arg(itemType);
  3365. const auto predicate = comparator(leftArg, rightArg);
  3366. TCallableBuilder callableBuilder(Env, callableName, listType);
  3367. callableBuilder.Add(list);
  3368. callableBuilder.Add(n);
  3369. callableBuilder.Add(leftArg);
  3370. callableBuilder.Add(rightArg);
  3371. callableBuilder.Add(predicate);
  3372. return TRuntimeNode(callableBuilder.Build(), false);
  3373. }
  3374. TRuntimeNode TProgramBuilder::MakeHeap(TRuntimeNode list, const TBinaryLambda& comparator) {
  3375. return BuildHeap(__func__, list, std::move(comparator));
  3376. }
  3377. TRuntimeNode TProgramBuilder::PushHeap(TRuntimeNode list, const TBinaryLambda& comparator) {
  3378. return BuildHeap(__func__, list, std::move(comparator));
  3379. }
  3380. TRuntimeNode TProgramBuilder::PopHeap(TRuntimeNode list, const TBinaryLambda& comparator) {
  3381. return BuildHeap(__func__, list, std::move(comparator));
  3382. }
  3383. TRuntimeNode TProgramBuilder::SortHeap(TRuntimeNode list, const TBinaryLambda& comparator) {
  3384. return BuildHeap(__func__, list, std::move(comparator));
  3385. }
  3386. TRuntimeNode TProgramBuilder::StableSort(TRuntimeNode list, const TBinaryLambda& comparator) {
  3387. return BuildHeap(__func__, list, std::move(comparator));
  3388. }
  3389. TRuntimeNode TProgramBuilder::NthElement(TRuntimeNode list, TRuntimeNode n, const TBinaryLambda& comparator) {
  3390. return BuildNth(__func__, list, n, std::move(comparator));
  3391. }
  3392. TRuntimeNode TProgramBuilder::PartialSort(TRuntimeNode list, TRuntimeNode n, const TBinaryLambda& comparator) {
  3393. return BuildNth(__func__, list, n, std::move(comparator));
  3394. }
  3395. TRuntimeNode TProgramBuilder::BuildMap(const std::string_view& callableName, TRuntimeNode list, const TUnaryLambda& handler)
  3396. {
  3397. const auto listType = list.GetStaticType();
  3398. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream() || listType->IsOptional(), "Expected flow, list, stream or optional");
  3399. if (listType->IsOptional()) {
  3400. const auto itemArg = Arg(AS_TYPE(TOptionalType, listType)->GetItemType());
  3401. const auto newItem = handler(itemArg);
  3402. return IfPresent(list,
  3403. [&](TRuntimeNode item) { return NewOptional(handler(item)); },
  3404. NewEmptyOptional(NewOptionalType(newItem.GetStaticType()))
  3405. );
  3406. }
  3407. const auto itemType = listType->IsFlow() ?
  3408. AS_TYPE(TFlowType, listType)->GetItemType():
  3409. listType->IsList() ?
  3410. AS_TYPE(TListType, listType)->GetItemType():
  3411. AS_TYPE(TStreamType, listType)->GetItemType();
  3412. ThrowIfListOfVoid(itemType);
  3413. const auto itemArg = Arg(itemType);
  3414. const auto newItem = handler(itemArg);
  3415. const auto resultListType = listType->IsFlow() ?
  3416. (TType*)TFlowType::Create(newItem.GetStaticType(), Env):
  3417. listType->IsList() ?
  3418. (TType*)TListType::Create(newItem.GetStaticType(), Env):
  3419. (TType*)TStreamType::Create(newItem.GetStaticType(), Env);
  3420. TCallableBuilder callableBuilder(Env, callableName, resultListType);
  3421. callableBuilder.Add(list);
  3422. callableBuilder.Add(itemArg);
  3423. callableBuilder.Add(newItem);
  3424. return TRuntimeNode(callableBuilder.Build(), false);
  3425. }
  3426. TRuntimeNode TProgramBuilder::Invoke(const std::string_view& funcName, TType* resultType, const TArrayRef<const TRuntimeNode>& args) {
  3427. MKQL_ENSURE(args.size() >= 1U && args.size() <= 3U, "Expected from one to three arguments.");
  3428. std::array<TArgType, 4U> argTypes;
  3429. argTypes.front().first = UnpackOptionalData(resultType, argTypes.front().second)->GetSchemeType();
  3430. auto i = 0U;
  3431. for (const auto& arg : args) {
  3432. ++i;
  3433. argTypes[i].first = UnpackOptionalData(arg, argTypes[i].second)->GetSchemeType();
  3434. }
  3435. FunctionRegistry.GetBuiltins()->GetBuiltin(funcName, argTypes.data(), 1U + args.size());
  3436. TCallableBuilder callableBuilder(Env, __func__, resultType);
  3437. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(funcName));
  3438. for (const auto& arg : args) {
  3439. callableBuilder.Add(arg);
  3440. }
  3441. return TRuntimeNode(callableBuilder.Build(), false);
  3442. }
  3443. TRuntimeNode TProgramBuilder::Udf(
  3444. const std::string_view& funcName,
  3445. TRuntimeNode runConfig,
  3446. TType* userType,
  3447. const std::string_view& typeConfig
  3448. )
  3449. {
  3450. TRuntimeNode userTypeNode = userType ? TRuntimeNode(userType, true) : TRuntimeNode(Env.GetVoidLazy()->GetType(), true);
  3451. const ui32 flags = NUdf::IUdfModule::TFlags::TypesOnly;
  3452. if (!TypeInfoHelper) {
  3453. TypeInfoHelper = new TTypeInfoHelper();
  3454. }
  3455. TFunctionTypeInfo funcInfo;
  3456. TStatus status = FunctionRegistry.FindFunctionTypeInfo(
  3457. Env, TypeInfoHelper, nullptr, funcName, userType, typeConfig, flags, {}, nullptr, &funcInfo);
  3458. MKQL_ENSURE(status.IsOk(), status.GetError());
  3459. auto runConfigType = funcInfo.RunConfigType;
  3460. if (runConfig) {
  3461. bool typesMatch = runConfigType->IsSameType(*runConfig.GetStaticType());
  3462. MKQL_ENSURE(typesMatch, "RunConfig type mismatch");
  3463. } else {
  3464. MKQL_ENSURE(runConfigType->IsVoid() || runConfigType->IsOptional(), "RunConfig must be void or optional");
  3465. if (runConfigType->IsVoid()) {
  3466. runConfig = NewVoid();
  3467. } else {
  3468. runConfig = NewEmptyOptional(const_cast<TType*>(runConfigType));
  3469. }
  3470. }
  3471. auto funNameNode = NewDataLiteral<NUdf::EDataSlot::String>(funcName);
  3472. auto typeConfigNode = NewDataLiteral<NUdf::EDataSlot::String>(typeConfig);
  3473. TCallableBuilder callableBuilder(Env, __func__, funcInfo.FunctionType);
  3474. callableBuilder.Add(funNameNode);
  3475. callableBuilder.Add(userTypeNode);
  3476. callableBuilder.Add(typeConfigNode);
  3477. callableBuilder.Add(runConfig);
  3478. return TRuntimeNode(callableBuilder.Build(), false);
  3479. }
  3480. TRuntimeNode TProgramBuilder::TypedUdf(
  3481. const std::string_view& funcName,
  3482. TType* funcType,
  3483. TRuntimeNode runConfig,
  3484. TType* userType,
  3485. const std::string_view& typeConfig,
  3486. const std::string_view& file,
  3487. ui32 row,
  3488. ui32 column)
  3489. {
  3490. auto funNameNode = NewDataLiteral<NUdf::EDataSlot::String>(funcName);
  3491. auto typeConfigNode = NewDataLiteral<NUdf::EDataSlot::String>(typeConfig);
  3492. TRuntimeNode userTypeNode = userType ? TRuntimeNode(userType, true) : TRuntimeNode(Env.GetVoidLazy(), true);
  3493. TCallableBuilder callableBuilder(Env, "Udf", funcType);
  3494. callableBuilder.Add(funNameNode);
  3495. callableBuilder.Add(userTypeNode);
  3496. callableBuilder.Add(typeConfigNode);
  3497. callableBuilder.Add(runConfig);
  3498. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  3499. callableBuilder.Add(NewDataLiteral(row));
  3500. callableBuilder.Add(NewDataLiteral(column));
  3501. return TRuntimeNode(callableBuilder.Build(), false);
  3502. }
  3503. TRuntimeNode TProgramBuilder::ScriptUdf(
  3504. const std::string_view& moduleName,
  3505. const std::string_view& funcName,
  3506. TType* funcType,
  3507. TRuntimeNode script,
  3508. const std::string_view& file,
  3509. ui32 row,
  3510. ui32 column)
  3511. {
  3512. MKQL_ENSURE(funcType, "UDF callable type must not be empty");
  3513. MKQL_ENSURE(funcType->IsCallable(), "type must be callable");
  3514. auto scriptType = NKikimr::NMiniKQL::ScriptTypeFromStr(moduleName);
  3515. MKQL_ENSURE(scriptType != EScriptType::Unknown, "unknown script type '" << moduleName << "'");
  3516. EnsureScriptSpecificTypes(scriptType, static_cast<TCallableType*>(funcType), Env);
  3517. auto scriptTypeStr = IsCustomPython(scriptType) ? moduleName : ScriptTypeAsStr(CanonizeScriptType(scriptType));
  3518. TStringBuilder name;
  3519. name.reserve(scriptTypeStr.size() + funcName.size() + 1);
  3520. name << scriptTypeStr << '.' << funcName;
  3521. auto funcNameNode = NewDataLiteral<NUdf::EDataSlot::String>(name);
  3522. TRuntimeNode userTypeNode(funcType, true);
  3523. auto typeConfigNode = NewDataLiteral<NUdf::EDataSlot::String>("");
  3524. TCallableBuilder callableBuilder(Env, __func__, funcType);
  3525. callableBuilder.Add(funcNameNode);
  3526. callableBuilder.Add(userTypeNode);
  3527. callableBuilder.Add(typeConfigNode);
  3528. callableBuilder.Add(script);
  3529. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  3530. callableBuilder.Add(NewDataLiteral(row));
  3531. callableBuilder.Add(NewDataLiteral(column));
  3532. return TRuntimeNode(callableBuilder.Build(), false);
  3533. }
  3534. TRuntimeNode TProgramBuilder::Apply(TRuntimeNode callableNode, const TArrayRef<const TRuntimeNode>& args,
  3535. const std::string_view& file, ui32 row, ui32 column, ui32 dependentCount) {
  3536. MKQL_ENSURE(dependentCount <= args.size(), "Too many dependent nodes");
  3537. ui32 usedArgs = args.size() - dependentCount;
  3538. MKQL_ENSURE(!callableNode.IsImmediate() && callableNode.GetNode()->GetType()->IsCallable(),
  3539. "Expected callable");
  3540. auto callable = static_cast<TCallable*>(callableNode.GetNode());
  3541. TType* returnType = callable->GetType()->GetReturnType();
  3542. MKQL_ENSURE(returnType->IsCallable(), "Expected callable as return type");
  3543. auto callableType = static_cast<TCallableType*>(returnType);
  3544. MKQL_ENSURE(usedArgs <= callableType->GetArgumentsCount(), "Too many arguments");
  3545. MKQL_ENSURE(usedArgs >= callableType->GetArgumentsCount() - callableType->GetOptionalArgumentsCount(), "Too few arguments");
  3546. for (ui32 i = 0; i < usedArgs; i++) {
  3547. TType* argType = callableType->GetArgumentType(i);
  3548. TRuntimeNode arg = args[i];
  3549. MKQL_ENSURE(arg.GetStaticType()->IsConvertableTo(*argType),
  3550. "Argument type mismatch for argument " << i << ": runtime " << argType->GetKindAsStr()
  3551. << " with static " << arg.GetStaticType()->GetKindAsStr());
  3552. }
  3553. TCallableBuilder callableBuilder(Env, RuntimeVersion >= 8 ? "Apply2" : "Apply", callableType->GetReturnType());
  3554. callableBuilder.Add(callableNode);
  3555. callableBuilder.Add(NewDataLiteral<ui32>(dependentCount));
  3556. if constexpr (RuntimeVersion >= 8) {
  3557. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  3558. callableBuilder.Add(NewDataLiteral(row));
  3559. callableBuilder.Add(NewDataLiteral(column));
  3560. }
  3561. for (const auto& arg: args) {
  3562. callableBuilder.Add(arg);
  3563. }
  3564. return TRuntimeNode(callableBuilder.Build(), false);
  3565. }
  3566. TRuntimeNode TProgramBuilder::Apply(
  3567. TRuntimeNode callableNode,
  3568. const TArrayRef<const TRuntimeNode>& args,
  3569. ui32 dependentCount) {
  3570. return Apply(callableNode, args, {}, 0, 0, dependentCount);
  3571. }
  3572. TRuntimeNode TProgramBuilder::Callable(TType* callableType, const TArrayLambda& handler) {
  3573. auto castedCallableType = AS_TYPE(TCallableType, callableType);
  3574. std::vector<TRuntimeNode> args;
  3575. args.reserve(castedCallableType->GetArgumentsCount());
  3576. for (ui32 i = 0; i < castedCallableType->GetArgumentsCount(); ++i) {
  3577. args.push_back(Arg(castedCallableType->GetArgumentType(i)));
  3578. }
  3579. auto res = handler(args);
  3580. TCallableBuilder callableBuilder(Env, __func__, callableType);
  3581. for (ui32 i = 0; i < castedCallableType->GetArgumentsCount(); ++i) {
  3582. callableBuilder.Add(args[i]);
  3583. }
  3584. callableBuilder.Add(res);
  3585. return TRuntimeNode(callableBuilder.Build(), false);
  3586. }
  3587. TRuntimeNode TProgramBuilder::NewNull() {
  3588. if (!UseNullType || RuntimeVersion < 11) {
  3589. TCallableBuilder callableBuilder(Env, "Null", NewOptionalType(Env.GetVoidLazy()->GetType()));
  3590. return TRuntimeNode(callableBuilder.Build(), false);
  3591. } else {
  3592. return TRuntimeNode(Env.GetNullLazy(), true);
  3593. }
  3594. }
  3595. TRuntimeNode TProgramBuilder::Concat(TRuntimeNode data1, TRuntimeNode data2) {
  3596. bool isOpt1, isOpt2;
  3597. const auto type1 = UnpackOptionalData(data1, isOpt1)->GetSchemeType();
  3598. const auto type2 = UnpackOptionalData(data2, isOpt2)->GetSchemeType();
  3599. const auto resultType = NewDataType(type1 == type2 ? type1 : NUdf::TDataType<char*>::Id);
  3600. return InvokeBinary(__func__, isOpt1 || isOpt2 ? NewOptionalType(resultType) : resultType, data1, data2);
  3601. }
  3602. TRuntimeNode TProgramBuilder::AggrConcat(TRuntimeNode data1, TRuntimeNode data2) {
  3603. MKQL_ENSURE(data1.GetStaticType()->IsSameType(*data2.GetStaticType()), "Operands type mismatch.");
  3604. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  3605. return Invoke(__func__, data1.GetStaticType(), args);
  3606. }
  3607. TRuntimeNode TProgramBuilder::Substring(TRuntimeNode data, TRuntimeNode start, TRuntimeNode count) {
  3608. const std::array<TRuntimeNode, 3U> args = {{ data, start, count }};
  3609. return Invoke(__func__, data.GetStaticType(), args);
  3610. }
  3611. TRuntimeNode TProgramBuilder::Find(TRuntimeNode haystack, TRuntimeNode needle, TRuntimeNode pos) {
  3612. const std::array<TRuntimeNode, 3U> args = {{ haystack, needle, pos }};
  3613. return Invoke(__func__, NewOptionalType(NewDataType(NUdf::TDataType<ui32>::Id)), args);
  3614. }
  3615. TRuntimeNode TProgramBuilder::RFind(TRuntimeNode haystack, TRuntimeNode needle, TRuntimeNode pos) {
  3616. const std::array<TRuntimeNode, 3U> args = {{ haystack, needle, pos }};
  3617. return Invoke(__func__, NewOptionalType(NewDataType(NUdf::TDataType<ui32>::Id)), args);
  3618. }
  3619. TRuntimeNode TProgramBuilder::StartsWith(TRuntimeNode string, TRuntimeNode prefix) {
  3620. if constexpr (RuntimeVersion < 19U) {
  3621. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3622. }
  3623. return DataCompare(__func__, string, prefix);
  3624. }
  3625. TRuntimeNode TProgramBuilder::EndsWith(TRuntimeNode string, TRuntimeNode suffix) {
  3626. if constexpr (RuntimeVersion < 19U) {
  3627. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3628. }
  3629. return DataCompare(__func__, string, suffix);
  3630. }
  3631. TRuntimeNode TProgramBuilder::StringContains(TRuntimeNode string, TRuntimeNode pattern) {
  3632. bool isOpt1, isOpt2;
  3633. TDataType* type1 = UnpackOptionalData(string, isOpt1);
  3634. TDataType* type2 = UnpackOptionalData(pattern, isOpt2);
  3635. MKQL_ENSURE(type1->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id ||
  3636. type1->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expecting string as first argument");
  3637. MKQL_ENSURE(type2->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id ||
  3638. type2->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expecting string as second argument");
  3639. if constexpr (RuntimeVersion < 32U) {
  3640. auto stringCasted = (type1->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id) ? ToString(string) : string;
  3641. auto patternCasted = (type2->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id) ? ToString(pattern) : pattern;
  3642. auto found = Exists(Find(stringCasted, patternCasted, NewDataLiteral(ui32(0))));
  3643. if (!isOpt1 && !isOpt2) {
  3644. return found;
  3645. }
  3646. TVector<TRuntimeNode> predicates;
  3647. if (isOpt1) {
  3648. predicates.push_back(Exists(string));
  3649. }
  3650. if (isOpt2) {
  3651. predicates.push_back(Exists(pattern));
  3652. }
  3653. TRuntimeNode argsNotNull = (predicates.size() == 1) ? predicates.front() : And(predicates);
  3654. return If(argsNotNull, NewOptional(found), NewEmptyOptionalDataLiteral(NUdf::TDataType<bool>::Id));
  3655. }
  3656. return DataCompare(__func__, string, pattern);
  3657. }
  3658. TRuntimeNode TProgramBuilder::ByteAt(TRuntimeNode data, TRuntimeNode index) {
  3659. const std::array<TRuntimeNode, 2U> args = {{ data, index }};
  3660. return Invoke(__func__, NewOptionalType(NewDataType(NUdf::TDataType<ui8>::Id)), args);
  3661. }
  3662. TRuntimeNode TProgramBuilder::Size(TRuntimeNode data) {
  3663. return UnaryDataFunction(data, __func__, TDataFunctionFlags::HasUi32Result | TDataFunctionFlags::AllowNull | TDataFunctionFlags::AllowOptionalArgs | TDataFunctionFlags::CommonOptionalResult);
  3664. }
  3665. template <bool Utf8>
  3666. TRuntimeNode TProgramBuilder::ToString(TRuntimeNode data) {
  3667. bool isOptional;
  3668. UnpackOptionalData(data, isOptional);
  3669. const auto resultType = NewDataType(Utf8 ? NUdf::EDataSlot::Utf8 : NUdf::EDataSlot::String, isOptional);
  3670. TCallableBuilder callableBuilder(Env, __func__, resultType);
  3671. callableBuilder.Add(data);
  3672. return TRuntimeNode(callableBuilder.Build(), false);
  3673. }
  3674. TRuntimeNode TProgramBuilder::FromString(TRuntimeNode data, TType* type) {
  3675. bool isOptional;
  3676. const auto sourceType = UnpackOptionalData(data, isOptional);
  3677. const auto targetType = UnpackOptionalData(type, isOptional);
  3678. MKQL_ENSURE(sourceType->GetSchemeType() == NUdf::TDataType<char*>::Id || sourceType->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected String");
  3679. MKQL_ENSURE(targetType->GetSchemeType() != 0, "Null is not allowed");
  3680. TCallableBuilder callableBuilder(Env, __func__, type);
  3681. callableBuilder.Add(data);
  3682. callableBuilder.Add(NewDataLiteral(static_cast<ui32>(targetType->GetSchemeType())));
  3683. if (targetType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3684. const auto& params = static_cast<const TDataDecimalType*>(targetType)->GetParams();
  3685. callableBuilder.Add(NewDataLiteral(params.first));
  3686. callableBuilder.Add(NewDataLiteral(params.second));
  3687. }
  3688. return TRuntimeNode(callableBuilder.Build(), false);
  3689. }
  3690. TRuntimeNode TProgramBuilder::StrictFromString(TRuntimeNode data, TType* type) {
  3691. bool isOptional;
  3692. const auto sourceType = UnpackOptionalData(data, isOptional);
  3693. const auto targetType = UnpackOptionalData(type, isOptional);
  3694. MKQL_ENSURE(sourceType->GetSchemeType() == NUdf::TDataType<char*>::Id || sourceType->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected String");
  3695. MKQL_ENSURE(targetType->GetSchemeType() != 0, "Null is not allowed");
  3696. TCallableBuilder callableBuilder(Env, __func__, type);
  3697. callableBuilder.Add(data);
  3698. callableBuilder.Add(NewDataLiteral(static_cast<ui32>(targetType->GetSchemeType())));
  3699. if (targetType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3700. const auto& params = static_cast<const TDataDecimalType*>(targetType)->GetParams();
  3701. callableBuilder.Add(NewDataLiteral(params.first));
  3702. callableBuilder.Add(NewDataLiteral(params.second));
  3703. }
  3704. return TRuntimeNode(callableBuilder.Build(), false);
  3705. }
  3706. TRuntimeNode TProgramBuilder::ToBytes(TRuntimeNode data) {
  3707. return UnaryDataFunction(data, __func__, TDataFunctionFlags::HasStringResult | TDataFunctionFlags::AllowOptionalArgs | TDataFunctionFlags::CommonOptionalResult);
  3708. }
  3709. TRuntimeNode TProgramBuilder::FromBytes(TRuntimeNode data, TType* targetType) {
  3710. auto type = data.GetStaticType();
  3711. bool isOptional;
  3712. auto dataType = UnpackOptionalData(type, isOptional);
  3713. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expected String");
  3714. auto resultType = NewOptionalType(targetType);
  3715. TCallableBuilder callableBuilder(Env, __func__, resultType);
  3716. callableBuilder.Add(data);
  3717. auto targetDataType = AS_TYPE(TDataType, targetType);
  3718. callableBuilder.Add(NewDataLiteral(static_cast<ui32>(targetDataType->GetSchemeType())));
  3719. if (targetDataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3720. const auto& params = static_cast<const TDataDecimalType*>(targetType)->GetParams();
  3721. callableBuilder.Add(NewDataLiteral(params.first));
  3722. callableBuilder.Add(NewDataLiteral(params.second));
  3723. }
  3724. return TRuntimeNode(callableBuilder.Build(), false);
  3725. }
  3726. TRuntimeNode TProgramBuilder::InversePresortString(TRuntimeNode data) {
  3727. const std::array<TRuntimeNode, 1U> args = {{ data }};
  3728. return Invoke(__func__, NewDataType(NUdf::TDataType<char*>::Id), args);
  3729. }
  3730. TRuntimeNode TProgramBuilder::InverseString(TRuntimeNode data) {
  3731. const std::array<TRuntimeNode, 1U> args = {{ data }};
  3732. return Invoke(__func__, NewDataType(NUdf::TDataType<char*>::Id), args);
  3733. }
  3734. TRuntimeNode TProgramBuilder::Random(const TArrayRef<const TRuntimeNode>& dependentNodes) {
  3735. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<double>::Id));
  3736. for (auto& x : dependentNodes) {
  3737. callableBuilder.Add(x);
  3738. }
  3739. return TRuntimeNode(callableBuilder.Build(), false);
  3740. }
  3741. TRuntimeNode TProgramBuilder::RandomNumber(const TArrayRef<const TRuntimeNode>& dependentNodes) {
  3742. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<ui64>::Id));
  3743. for (auto& x : dependentNodes) {
  3744. callableBuilder.Add(x);
  3745. }
  3746. return TRuntimeNode(callableBuilder.Build(), false);
  3747. }
  3748. TRuntimeNode TProgramBuilder::RandomUuid(const TArrayRef<const TRuntimeNode>& dependentNodes) {
  3749. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<NUdf::TUuid>::Id));
  3750. for (auto& x : dependentNodes) {
  3751. callableBuilder.Add(x);
  3752. }
  3753. return TRuntimeNode(callableBuilder.Build(), false);
  3754. }
  3755. TRuntimeNode TProgramBuilder::Now(const TArrayRef<const TRuntimeNode>& args) {
  3756. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<ui64>::Id));
  3757. for (const auto& x : args) {
  3758. callableBuilder.Add(x);
  3759. }
  3760. return TRuntimeNode(callableBuilder.Build(), false);
  3761. }
  3762. TRuntimeNode TProgramBuilder::CurrentUtcDate(const TArrayRef<const TRuntimeNode>& args) {
  3763. return Cast(CurrentUtcTimestamp(args), NewDataType(NUdf::TDataType<NUdf::TDate>::Id));
  3764. }
  3765. TRuntimeNode TProgramBuilder::CurrentUtcDatetime(const TArrayRef<const TRuntimeNode>& args) {
  3766. return Cast(CurrentUtcTimestamp(args), NewDataType(NUdf::TDataType<NUdf::TDatetime>::Id));
  3767. }
  3768. TRuntimeNode TProgramBuilder::CurrentUtcTimestamp(const TArrayRef<const TRuntimeNode>& args) {
  3769. return Coalesce(ToIntegral(Now(args), NewDataType(NUdf::TDataType<NUdf::TTimestamp>::Id, true)),
  3770. TRuntimeNode(BuildDataLiteral(NUdf::TUnboxedValuePod(ui64(NUdf::MAX_TIMESTAMP - 1ULL)), NUdf::TDataType<NUdf::TTimestamp>::Id, Env), true));
  3771. }
  3772. TRuntimeNode TProgramBuilder::Pickle(TRuntimeNode data) {
  3773. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::EDataSlot::String));
  3774. callableBuilder.Add(data);
  3775. return TRuntimeNode(callableBuilder.Build(), false);
  3776. }
  3777. TRuntimeNode TProgramBuilder::StablePickle(TRuntimeNode data) {
  3778. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::EDataSlot::String));
  3779. callableBuilder.Add(data);
  3780. return TRuntimeNode(callableBuilder.Build(), false);
  3781. }
  3782. TRuntimeNode TProgramBuilder::Unpickle(TType* type, TRuntimeNode serialized) {
  3783. MKQL_ENSURE(AS_TYPE(TDataType, serialized)->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expected String");
  3784. TCallableBuilder callableBuilder(Env, __func__, type);
  3785. callableBuilder.Add(TRuntimeNode(type, true));
  3786. callableBuilder.Add(serialized);
  3787. return TRuntimeNode(callableBuilder.Build(), false);
  3788. }
  3789. TRuntimeNode TProgramBuilder::Ascending(TRuntimeNode data) {
  3790. auto dataType = NewDataType(NUdf::EDataSlot::String);
  3791. TCallableBuilder callableBuilder(Env, __func__, dataType);
  3792. callableBuilder.Add(data);
  3793. return TRuntimeNode(callableBuilder.Build(), false);
  3794. }
  3795. TRuntimeNode TProgramBuilder::Descending(TRuntimeNode data) {
  3796. auto dataType = NewDataType(NUdf::EDataSlot::String);
  3797. TCallableBuilder callableBuilder(Env, __func__, dataType);
  3798. callableBuilder.Add(data);
  3799. return TRuntimeNode(callableBuilder.Build(), false);
  3800. }
  3801. TRuntimeNode TProgramBuilder::Convert(TRuntimeNode data, TType* type) {
  3802. if (data.GetStaticType()->IsSameType(*type)) {
  3803. return data;
  3804. }
  3805. bool isOptional;
  3806. const auto dataType = UnpackOptionalData(data, isOptional);
  3807. const std::array<TRuntimeNode, 1> args = {{ data }};
  3808. if (dataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3809. const auto targetSchemeType = UnpackOptionalData(type, isOptional)->GetSchemeType();
  3810. TStringStream str;
  3811. str << "To" << NUdf::GetDataTypeInfo(NUdf::GetDataSlot(targetSchemeType)).Name
  3812. << '_' << ::ToString(static_cast<const TDataDecimalType*>(dataType)->GetParams().second);
  3813. return Invoke(str.Str().c_str(), type, args);
  3814. }
  3815. return Invoke(__func__, type, args);
  3816. }
  3817. TRuntimeNode TProgramBuilder::ToDecimal(TRuntimeNode data, ui8 precision, ui8 scale) {
  3818. bool isOptional;
  3819. auto dataType = UnpackOptionalData(data, isOptional);
  3820. TType* decimal = TDataDecimalType::Create(precision, scale, Env);
  3821. if (isOptional)
  3822. decimal = TOptionalType::Create(decimal, Env);
  3823. const std::array<TRuntimeNode, 1> args = {{ data }};
  3824. if (dataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3825. const auto& params = static_cast<const TDataDecimalType*>(dataType)->GetParams();
  3826. if (precision - scale < params.first - params.second && scale != params.second) {
  3827. return ToDecimal(ToDecimal(data, precision - scale + params.second, params.second), precision, scale);
  3828. } else if (params.second < scale) {
  3829. return Invoke("ScaleUp_" + ::ToString(scale - params.second), decimal, args);
  3830. } else if (params.second > scale) {
  3831. TRuntimeNode scaled = Invoke("ScaleDown_" + ::ToString(params.second - scale), decimal, args);
  3832. return Invoke("CheckBounds_" + ::ToString(precision), decimal, {{ scaled }});
  3833. } else if (precision < params.first) {
  3834. return Invoke("CheckBounds_" + ::ToString(precision), decimal, args);
  3835. } else if (precision > params.first) {
  3836. return Invoke("Plus", decimal, args);
  3837. } else {
  3838. return data;
  3839. }
  3840. } else {
  3841. const auto digits = NUdf::GetDataTypeInfo(*dataType->GetDataSlot()).DecimalDigits;
  3842. MKQL_ENSURE(digits, "Can't cast into Decimal.");
  3843. if (digits <= precision && !scale)
  3844. return Invoke(__func__, decimal, args);
  3845. else
  3846. return ToDecimal(ToDecimal(data, digits, 0), precision, scale);
  3847. }
  3848. }
  3849. TRuntimeNode TProgramBuilder::ToIntegral(TRuntimeNode data, TType* type) {
  3850. bool isOptional;
  3851. auto dataType = UnpackOptionalData(data, isOptional);
  3852. if (dataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3853. const auto& params = static_cast<const TDataDecimalType*>(dataType)->GetParams();
  3854. if (params.second)
  3855. return ToIntegral(ToDecimal(data, params.first - params.second, 0), type);
  3856. }
  3857. const std::array<TRuntimeNode, 1> args = {{ data }};
  3858. return Invoke(__func__, type, args);
  3859. }
  3860. TRuntimeNode TProgramBuilder::ListIf(TRuntimeNode predicate, TRuntimeNode item) {
  3861. return If(predicate, NewList(item.GetStaticType(), {item}), NewEmptyList(item.GetStaticType()));
  3862. }
  3863. TRuntimeNode TProgramBuilder::AsList(TRuntimeNode item) {
  3864. TListLiteralBuilder builder(Env, item.GetStaticType());
  3865. builder.Add(item);
  3866. return TRuntimeNode(builder.Build(), true);
  3867. }
  3868. TRuntimeNode TProgramBuilder::AsList(const TArrayRef<const TRuntimeNode>& items) {
  3869. MKQL_ENSURE(!items.empty(), "required not empty list of items");
  3870. TListLiteralBuilder builder(Env, items[0].GetStaticType());
  3871. for (auto item : items) {
  3872. builder.Add(item);
  3873. }
  3874. return TRuntimeNode(builder.Build(), true);
  3875. }
  3876. TRuntimeNode TProgramBuilder::MapJoinCore(TRuntimeNode flow, TRuntimeNode dict, EJoinKind joinKind,
  3877. const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& leftRenames,
  3878. const TArrayRef<const ui32>& rightRenames, TType* returnType) {
  3879. MKQL_ENSURE(joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left || joinKind == EJoinKind::LeftSemi || joinKind == EJoinKind::LeftOnly, "Unsupported join kind");
  3880. MKQL_ENSURE(!leftKeyColumns.empty(), "At least one key column must be specified");
  3881. MKQL_ENSURE(leftRenames.size() % 2U == 0U, "Expected even count");
  3882. MKQL_ENSURE(rightRenames.size() % 2U == 0U, "Expected even count");
  3883. TRuntimeNode::TList leftKeyColumnsNodes, leftRenamesNodes, rightRenamesNodes;
  3884. leftKeyColumnsNodes.reserve(leftKeyColumns.size());
  3885. std::transform(leftKeyColumns.cbegin(), leftKeyColumns.cend(), std::back_inserter(leftKeyColumnsNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  3886. leftRenamesNodes.reserve(leftRenames.size());
  3887. std::transform(leftRenames.cbegin(), leftRenames.cend(), std::back_inserter(leftRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  3888. rightRenamesNodes.reserve(rightRenames.size());
  3889. std::transform(rightRenames.cbegin(), rightRenames.cend(), std::back_inserter(rightRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  3890. TCallableBuilder callableBuilder(Env, __func__, returnType);
  3891. callableBuilder.Add(flow);
  3892. callableBuilder.Add(dict);
  3893. callableBuilder.Add(NewDataLiteral((ui32)joinKind));
  3894. callableBuilder.Add(NewTuple(leftKeyColumnsNodes));
  3895. callableBuilder.Add(NewTuple(leftRenamesNodes));
  3896. callableBuilder.Add(NewTuple(rightRenamesNodes));
  3897. return TRuntimeNode(callableBuilder.Build(), false);
  3898. }
  3899. TRuntimeNode TProgramBuilder::CommonJoinCore(TRuntimeNode flow, EJoinKind joinKind,
  3900. const TArrayRef<const ui32>& leftColumns, const TArrayRef<const ui32>& rightColumns,
  3901. const TArrayRef<const ui32>& requiredColumns, const TArrayRef<const ui32>& keyColumns,
  3902. ui64 memLimit, std::optional<ui32> sortedTableOrder,
  3903. EAnyJoinSettings anyJoinSettings, const ui32 tableIndexField, TType* returnType) {
  3904. if constexpr (RuntimeVersion < 17U) {
  3905. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3906. }
  3907. MKQL_ENSURE(leftColumns.size() % 2U == 0U, "Expected even count");
  3908. MKQL_ENSURE(rightColumns.size() % 2U == 0U, "Expected even count");
  3909. TRuntimeNode::TList leftInputColumnsNodes, rightInputColumnsNodes, requiredColumnsNodes,
  3910. leftOutputColumnsNodes, rightOutputColumnsNodes, keyColumnsNodes;
  3911. bool s = false;
  3912. for (const auto idx : leftColumns) {
  3913. ((s = !s) ? leftInputColumnsNodes : leftOutputColumnsNodes).emplace_back(NewDataLiteral(idx));
  3914. }
  3915. for (const auto idx : rightColumns) {
  3916. ((s = !s) ? rightInputColumnsNodes : rightOutputColumnsNodes).emplace_back(NewDataLiteral(idx));
  3917. }
  3918. const std::unordered_set<ui32> requiredIndices(requiredColumns.cbegin(), requiredColumns.cend());
  3919. MKQL_ENSURE(requiredIndices.size() == requiredColumns.size(), "Duplication of requred columns.");
  3920. requiredColumnsNodes.reserve(requiredColumns.size());
  3921. std::transform(requiredColumns.cbegin(), requiredColumns.cend(), std::back_inserter(requiredColumnsNodes),
  3922. std::bind(&TProgramBuilder::NewDataLiteral<ui32>, this, std::placeholders::_1));
  3923. const std::unordered_set<ui32> keyIndices(keyColumns.cbegin(), keyColumns.cend());
  3924. MKQL_ENSURE(keyIndices.size() == keyColumns.size(), "Duplication of key columns.");
  3925. keyColumnsNodes.reserve(keyColumns.size());
  3926. std::transform(keyColumns.cbegin(), keyColumns.cend(), std::back_inserter(keyColumnsNodes),
  3927. std::bind(&TProgramBuilder::NewDataLiteral<ui32>, this, std::placeholders::_1));
  3928. TCallableBuilder callableBuilder(Env, __func__, returnType);
  3929. callableBuilder.Add(flow);
  3930. callableBuilder.Add(NewDataLiteral((ui32)joinKind));
  3931. callableBuilder.Add(NewTuple(leftInputColumnsNodes));
  3932. callableBuilder.Add(NewTuple(rightInputColumnsNodes));
  3933. callableBuilder.Add(NewTuple(requiredColumnsNodes));
  3934. callableBuilder.Add(NewTuple(leftOutputColumnsNodes));
  3935. callableBuilder.Add(NewTuple(rightOutputColumnsNodes));
  3936. callableBuilder.Add(NewTuple(keyColumnsNodes));
  3937. callableBuilder.Add(NewDataLiteral(memLimit));
  3938. callableBuilder.Add(sortedTableOrder ? NewDataLiteral(*sortedTableOrder) : NewVoid());
  3939. callableBuilder.Add(NewDataLiteral((ui32)anyJoinSettings));
  3940. callableBuilder.Add(NewDataLiteral(tableIndexField));
  3941. return TRuntimeNode(callableBuilder.Build(), false);
  3942. }
  3943. TRuntimeNode TProgramBuilder::WideCombiner(TRuntimeNode flow, i64 memLimit, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) {
  3944. if constexpr (RuntimeVersion < 18U) {
  3945. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3946. }
  3947. if (memLimit < 0) {
  3948. if constexpr (RuntimeVersion < 46U) {
  3949. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__ << " with limit " << memLimit;
  3950. }
  3951. }
  3952. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3953. TRuntimeNode::TList itemArgs;
  3954. itemArgs.reserve(wideComponents.size());
  3955. auto i = 0U;
  3956. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3957. const auto keys = extractor(itemArgs);
  3958. TRuntimeNode::TList keyArgs;
  3959. keyArgs.reserve(keys.size());
  3960. std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } );
  3961. const auto first = init(keyArgs, itemArgs);
  3962. TRuntimeNode::TList stateArgs;
  3963. stateArgs.reserve(first.size());
  3964. std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } );
  3965. const auto next = update(keyArgs, itemArgs, stateArgs);
  3966. MKQL_ENSURE(next.size() == first.size(), "Mismatch init and update state size.");
  3967. TRuntimeNode::TList finishKeyArgs;
  3968. finishKeyArgs.reserve(keys.size());
  3969. std::transform(keys.cbegin(), keys.cend(), std::back_inserter(finishKeyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } );
  3970. TRuntimeNode::TList finishStateArgs;
  3971. finishStateArgs.reserve(next.size());
  3972. std::transform(next.cbegin(), next.cend(), std::back_inserter(finishStateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } );
  3973. const auto output = finish(finishKeyArgs, finishStateArgs);
  3974. std::vector<TType*> tupleItems;
  3975. tupleItems.reserve(output.size());
  3976. std::transform(output.cbegin(), output.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  3977. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(tupleItems)));
  3978. callableBuilder.Add(flow);
  3979. if constexpr (RuntimeVersion < 46U)
  3980. callableBuilder.Add(NewDataLiteral(ui64(memLimit)));
  3981. else
  3982. callableBuilder.Add(NewDataLiteral(memLimit));
  3983. callableBuilder.Add(NewDataLiteral(ui32(keyArgs.size())));
  3984. callableBuilder.Add(NewDataLiteral(ui32(stateArgs.size())));
  3985. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3986. std::for_each(keys.cbegin(), keys.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3987. std::for_each(keyArgs.cbegin(), keyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3988. std::for_each(first.cbegin(), first.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3989. std::for_each(stateArgs.cbegin(), stateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3990. std::for_each(next.cbegin(), next.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3991. std::for_each(finishKeyArgs.cbegin(), finishKeyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3992. std::for_each(finishStateArgs.cbegin(), finishStateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3993. std::for_each(output.cbegin(), output.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3994. return TRuntimeNode(callableBuilder.Build(), false);
  3995. }
  3996. TRuntimeNode TProgramBuilder::WideLastCombinerCommon(const TStringBuf& funcName, TRuntimeNode flow, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) {
  3997. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3998. TRuntimeNode::TList itemArgs;
  3999. itemArgs.reserve(wideComponents.size());
  4000. auto i = 0U;
  4001. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  4002. const auto keys = extractor(itemArgs);
  4003. TRuntimeNode::TList keyArgs;
  4004. keyArgs.reserve(keys.size());
  4005. std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } );
  4006. const auto first = init(keyArgs, itemArgs);
  4007. TRuntimeNode::TList stateArgs;
  4008. stateArgs.reserve(first.size());
  4009. std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } );
  4010. const auto next = update(keyArgs, itemArgs, stateArgs);
  4011. MKQL_ENSURE(next.size() == first.size(), "Mismatch init and update state size.");
  4012. TRuntimeNode::TList finishKeyArgs;
  4013. finishKeyArgs.reserve(keys.size());
  4014. std::transform(keys.cbegin(), keys.cend(), std::back_inserter(finishKeyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } );
  4015. TRuntimeNode::TList finishStateArgs;
  4016. finishStateArgs.reserve(next.size());
  4017. std::transform(next.cbegin(), next.cend(), std::back_inserter(finishStateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } );
  4018. const auto output = finish(finishKeyArgs, finishStateArgs);
  4019. std::vector<TType*> tupleItems;
  4020. tupleItems.reserve(output.size());
  4021. std::transform(output.cbegin(), output.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  4022. TCallableBuilder callableBuilder(Env, funcName, NewFlowType(NewMultiType(tupleItems)));
  4023. callableBuilder.Add(flow);
  4024. callableBuilder.Add(NewDataLiteral(ui32(keyArgs.size())));
  4025. callableBuilder.Add(NewDataLiteral(ui32(stateArgs.size())));
  4026. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4027. std::for_each(keys.cbegin(), keys.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4028. std::for_each(keyArgs.cbegin(), keyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4029. std::for_each(first.cbegin(), first.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4030. std::for_each(stateArgs.cbegin(), stateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4031. std::for_each(next.cbegin(), next.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4032. std::for_each(finishKeyArgs.cbegin(), finishKeyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4033. std::for_each(finishStateArgs.cbegin(), finishStateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4034. std::for_each(output.cbegin(), output.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4035. return TRuntimeNode(callableBuilder.Build(), false);
  4036. }
  4037. TRuntimeNode TProgramBuilder::WideLastCombiner(TRuntimeNode flow, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) {
  4038. if constexpr (RuntimeVersion < 29U) {
  4039. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4040. }
  4041. return WideLastCombinerCommon(__func__, flow, extractor, init, update, finish);
  4042. }
  4043. TRuntimeNode TProgramBuilder::WideLastCombinerWithSpilling(TRuntimeNode flow, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) {
  4044. if constexpr (RuntimeVersion < 49U) {
  4045. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4046. }
  4047. return WideLastCombinerCommon(__func__, flow, extractor, init, update, finish);
  4048. }
  4049. TRuntimeNode TProgramBuilder::WideCondense1(TRuntimeNode flow, const TWideLambda& init, const TWideSwitchLambda& switcher, const TBinaryWideLambda& update, bool useCtx) {
  4050. if constexpr (RuntimeVersion < 18U) {
  4051. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4052. }
  4053. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  4054. TRuntimeNode::TList itemArgs;
  4055. itemArgs.reserve(wideComponents.size());
  4056. auto i = 0U;
  4057. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  4058. const auto first = init(itemArgs);
  4059. TRuntimeNode::TList stateArgs;
  4060. stateArgs.reserve(first.size());
  4061. std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } );
  4062. const auto chop = switcher(itemArgs, stateArgs);
  4063. const auto next = update(itemArgs, stateArgs);
  4064. MKQL_ENSURE(next.size() == first.size(), "Mismatch init and update state size.");
  4065. std::vector<TType*> tupleItems;
  4066. tupleItems.reserve(next.size());
  4067. std::transform(next.cbegin(), next.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  4068. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(tupleItems)));
  4069. callableBuilder.Add(flow);
  4070. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4071. std::for_each(first.cbegin(), first.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4072. std::for_each(stateArgs.cbegin(), stateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4073. callableBuilder.Add(chop);
  4074. std::for_each(next.cbegin(), next.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4075. if (useCtx) {
  4076. MKQL_ENSURE(RuntimeVersion >= 30U, "Too old runtime version");
  4077. callableBuilder.Add(NewDataLiteral<bool>(useCtx));
  4078. }
  4079. return TRuntimeNode(callableBuilder.Build(), false);
  4080. }
  4081. TRuntimeNode TProgramBuilder::CombineCore(TRuntimeNode stream,
  4082. const TUnaryLambda& keyExtractor,
  4083. const TBinaryLambda& init,
  4084. const TTernaryLambda& update,
  4085. const TBinaryLambda& finish,
  4086. ui64 memLimit)
  4087. {
  4088. if constexpr (RuntimeVersion < 3U) {
  4089. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4090. }
  4091. const bool isStream = stream.GetStaticType()->IsStream();
  4092. const auto itemType = isStream ? AS_TYPE(TStreamType, stream)->GetItemType() : AS_TYPE(TFlowType, stream)->GetItemType();
  4093. const auto itemArg = Arg(itemType);
  4094. const auto key = keyExtractor(itemArg);
  4095. const auto keyType = key.GetStaticType();
  4096. const auto keyArg = Arg(keyType);
  4097. const auto stateInit = init(keyArg, itemArg);
  4098. const auto stateType = stateInit.GetStaticType();
  4099. const auto stateArg = Arg(stateType);
  4100. const auto stateUpdate = update(keyArg, itemArg, stateArg);
  4101. const auto finishItem = finish(keyArg, stateArg);
  4102. const auto finishType = finishItem.GetStaticType();
  4103. MKQL_ENSURE(finishType->IsList() || finishType->IsStream() || finishType->IsOptional(), "Expected list, stream or optional");
  4104. TType* retItemType = nullptr;
  4105. if (finishType->IsOptional()) {
  4106. retItemType = AS_TYPE(TOptionalType, finishType)->GetItemType();
  4107. } else if (finishType->IsList()) {
  4108. retItemType = AS_TYPE(TListType, finishType)->GetItemType();
  4109. } else if (finishType->IsStream()) {
  4110. retItemType = AS_TYPE(TStreamType, finishType)->GetItemType();
  4111. }
  4112. const auto resultStreamType = isStream ? NewStreamType(retItemType) : NewFlowType(retItemType);
  4113. TCallableBuilder callableBuilder(Env, __func__, resultStreamType);
  4114. callableBuilder.Add(stream);
  4115. callableBuilder.Add(itemArg);
  4116. callableBuilder.Add(key);
  4117. callableBuilder.Add(keyArg);
  4118. callableBuilder.Add(stateInit);
  4119. callableBuilder.Add(stateArg);
  4120. callableBuilder.Add(stateUpdate);
  4121. callableBuilder.Add(finishItem);
  4122. callableBuilder.Add(NewDataLiteral(memLimit));
  4123. return TRuntimeNode(callableBuilder.Build(), false);
  4124. }
  4125. TRuntimeNode TProgramBuilder::GroupingCore(TRuntimeNode stream,
  4126. const TBinaryLambda& groupSwitch,
  4127. const TUnaryLambda& keyExtractor,
  4128. const TUnaryLambda& handler)
  4129. {
  4130. if (handler && RuntimeVersion < 20U) {
  4131. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__ << " with handler";
  4132. }
  4133. auto itemType = AS_TYPE(TStreamType, stream)->GetItemType();
  4134. TRuntimeNode keyExtractorItemArg = Arg(itemType);
  4135. TRuntimeNode keyExtractorResult = keyExtractor(keyExtractorItemArg);
  4136. TRuntimeNode groupSwitchKeyArg = Arg(keyExtractorResult.GetStaticType());
  4137. TRuntimeNode groupSwitchItemArg = Arg(itemType);
  4138. TRuntimeNode groupSwitchResult = groupSwitch(groupSwitchKeyArg, groupSwitchItemArg);
  4139. MKQL_ENSURE(AS_TYPE(TDataType, groupSwitchResult)->GetSchemeType() == NUdf::TDataType<bool>::Id,
  4140. "Expected bool type");
  4141. TRuntimeNode handlerItemArg;
  4142. TRuntimeNode handlerResult;
  4143. if (handler) {
  4144. handlerItemArg = Arg(itemType);
  4145. handlerResult = handler(handlerItemArg);
  4146. itemType = handlerResult.GetStaticType();
  4147. }
  4148. const std::array<TType*, 2U> tupleItems = {{ keyExtractorResult.GetStaticType(), NewStreamType(itemType) }};
  4149. const auto finishType = NewStreamType(NewTupleType(tupleItems));
  4150. TCallableBuilder callableBuilder(Env, __func__, finishType);
  4151. callableBuilder.Add(stream);
  4152. callableBuilder.Add(keyExtractorResult);
  4153. callableBuilder.Add(groupSwitchResult);
  4154. callableBuilder.Add(keyExtractorItemArg);
  4155. callableBuilder.Add(groupSwitchKeyArg);
  4156. callableBuilder.Add(groupSwitchItemArg);
  4157. if (handler) {
  4158. callableBuilder.Add(handlerResult);
  4159. callableBuilder.Add(handlerItemArg);
  4160. }
  4161. return TRuntimeNode(callableBuilder.Build(), false);
  4162. }
  4163. TRuntimeNode TProgramBuilder::Chopper(TRuntimeNode flow, const TUnaryLambda& keyExtractor, const TBinaryLambda& groupSwitch, const TBinaryLambda& groupHandler) {
  4164. const auto flowType = flow.GetStaticType();
  4165. MKQL_ENSURE(flowType->IsFlow() || flowType->IsStream(), "Expected flow or stream.");
  4166. if constexpr (RuntimeVersion < 9U) {
  4167. return FlatMap(GroupingCore(flow, groupSwitch, keyExtractor),
  4168. [&](TRuntimeNode item) -> TRuntimeNode { return groupHandler(Nth(item, 0U), Nth(item, 1U)); }
  4169. );
  4170. }
  4171. const bool isStream = flowType->IsStream();
  4172. const auto itemType = isStream ? AS_TYPE(TStreamType, flow)->GetItemType() : AS_TYPE(TFlowType, flow)->GetItemType();
  4173. const auto itemArg = Arg(itemType);
  4174. const auto keyExtractorResult = keyExtractor(itemArg);
  4175. const auto keyArg = Arg(keyExtractorResult.GetStaticType());
  4176. const auto groupSwitchResult = groupSwitch(keyArg, itemArg);
  4177. const auto input = Arg(flowType);
  4178. const auto output = groupHandler(keyArg, input);
  4179. TCallableBuilder callableBuilder(Env, __func__, output.GetStaticType());
  4180. callableBuilder.Add(flow);
  4181. callableBuilder.Add(itemArg);
  4182. callableBuilder.Add(keyExtractorResult);
  4183. callableBuilder.Add(keyArg);
  4184. callableBuilder.Add(groupSwitchResult);
  4185. callableBuilder.Add(input);
  4186. callableBuilder.Add(output);
  4187. return TRuntimeNode(callableBuilder.Build(), false);
  4188. }
  4189. TRuntimeNode TProgramBuilder::WideChopper(TRuntimeNode flow, const TWideLambda& extractor, const TWideSwitchLambda& groupSwitch,
  4190. const std::function<TRuntimeNode (TRuntimeNode::TList, TRuntimeNode)>& groupHandler
  4191. ) {
  4192. if constexpr (RuntimeVersion < 18U) {
  4193. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4194. }
  4195. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  4196. TRuntimeNode::TList itemArgs, keyArgs;
  4197. itemArgs.reserve(wideComponents.size());
  4198. auto i = 0U;
  4199. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  4200. const auto keys = extractor(itemArgs);
  4201. keyArgs.reserve(keys.size());
  4202. std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } );
  4203. const auto groupSwitchResult = groupSwitch(keyArgs, itemArgs);
  4204. const auto input = WideFlowArg(flow.GetStaticType());
  4205. const auto output = groupHandler(keyArgs, input);
  4206. TCallableBuilder callableBuilder(Env, __func__, output.GetStaticType());
  4207. callableBuilder.Add(flow);
  4208. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4209. std::for_each(keys.cbegin(), keys.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4210. std::for_each(keyArgs.cbegin(), keyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4211. callableBuilder.Add(groupSwitchResult);
  4212. callableBuilder.Add(input);
  4213. callableBuilder.Add(output);
  4214. return TRuntimeNode(callableBuilder.Build(), false);
  4215. }
  4216. TRuntimeNode TProgramBuilder::HoppingCore(TRuntimeNode list,
  4217. const TUnaryLambda& timeExtractor,
  4218. const TUnaryLambda& init,
  4219. const TBinaryLambda& update,
  4220. const TUnaryLambda& save,
  4221. const TUnaryLambda& load,
  4222. const TBinaryLambda& merge,
  4223. const TBinaryLambda& finish,
  4224. TRuntimeNode hop, TRuntimeNode interval, TRuntimeNode delay)
  4225. {
  4226. auto streamType = AS_TYPE(TStreamType, list);
  4227. auto itemType = AS_TYPE(TStructType, streamType->GetItemType());
  4228. auto timestampType = TOptionalType::Create(TDataType::Create(NUdf::TDataType<NUdf::TTimestamp>::Id, Env), Env);
  4229. TRuntimeNode itemArg = Arg(itemType);
  4230. auto outTime = timeExtractor(itemArg);
  4231. auto outStateInit = init(itemArg);
  4232. auto stateType = outStateInit.GetStaticType();
  4233. TRuntimeNode stateArg = Arg(stateType);
  4234. auto outStateUpdate = update(itemArg, stateArg);
  4235. auto hasSaveLoad = (bool)save;
  4236. TRuntimeNode saveArg, outSave, loadArg, outLoad;
  4237. if (hasSaveLoad) {
  4238. saveArg = Arg(stateType);
  4239. outSave = save(saveArg);
  4240. loadArg = Arg(outSave.GetStaticType());
  4241. outLoad = load(loadArg);
  4242. MKQL_ENSURE(outLoad.GetStaticType()->IsSameType(*stateType), "Loaded type is changed by the load handler");
  4243. } else {
  4244. saveArg = outSave = loadArg = outLoad = NewVoid();
  4245. }
  4246. TRuntimeNode state2Arg = Arg(stateType);
  4247. TRuntimeNode timeArg = Arg(timestampType);
  4248. auto outStateMerge = merge(stateArg, state2Arg);
  4249. auto outItemFinish = finish(stateArg, timeArg);
  4250. auto finishType = outItemFinish.GetStaticType();
  4251. MKQL_ENSURE(finishType->IsStruct(), "Expected struct type as finish lambda output");
  4252. auto resultType = TStreamType::Create(outItemFinish.GetStaticType(), Env);
  4253. TCallableBuilder callableBuilder(Env, __func__, resultType);
  4254. callableBuilder.Add(list);
  4255. callableBuilder.Add(itemArg);
  4256. callableBuilder.Add(stateArg);
  4257. callableBuilder.Add(state2Arg);
  4258. callableBuilder.Add(timeArg);
  4259. callableBuilder.Add(saveArg);
  4260. callableBuilder.Add(loadArg);
  4261. callableBuilder.Add(outTime);
  4262. callableBuilder.Add(outStateInit);
  4263. callableBuilder.Add(outStateUpdate);
  4264. callableBuilder.Add(outSave);
  4265. callableBuilder.Add(outLoad);
  4266. callableBuilder.Add(outStateMerge);
  4267. callableBuilder.Add(outItemFinish);
  4268. callableBuilder.Add(hop);
  4269. callableBuilder.Add(interval);
  4270. callableBuilder.Add(delay);
  4271. return TRuntimeNode(callableBuilder.Build(), false);
  4272. }
  4273. TRuntimeNode TProgramBuilder::MultiHoppingCore(TRuntimeNode list,
  4274. const TUnaryLambda& keyExtractor,
  4275. const TUnaryLambda& timeExtractor,
  4276. const TUnaryLambda& init,
  4277. const TBinaryLambda& update,
  4278. const TUnaryLambda& save,
  4279. const TUnaryLambda& load,
  4280. const TBinaryLambda& merge,
  4281. const TTernaryLambda& finish,
  4282. TRuntimeNode hop, TRuntimeNode interval, TRuntimeNode delay,
  4283. TRuntimeNode dataWatermarks, TRuntimeNode watermarksMode)
  4284. {
  4285. if constexpr (RuntimeVersion < 22U) {
  4286. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4287. }
  4288. auto streamType = AS_TYPE(TStreamType, list);
  4289. auto itemType = AS_TYPE(TStructType, streamType->GetItemType());
  4290. auto timestampType = TOptionalType::Create(TDataType::Create(NUdf::TDataType<NUdf::TTimestamp>::Id, Env), Env);
  4291. TRuntimeNode itemArg = Arg(itemType);
  4292. auto keyExtract = keyExtractor(itemArg);
  4293. auto keyType = keyExtract.GetStaticType();
  4294. TRuntimeNode keyArg = Arg(keyType);
  4295. auto outTime = timeExtractor(itemArg);
  4296. auto outStateInit = init(itemArg);
  4297. auto stateType = outStateInit.GetStaticType();
  4298. TRuntimeNode stateArg = Arg(stateType);
  4299. auto outStateUpdate = update(itemArg, stateArg);
  4300. auto hasSaveLoad = (bool)save;
  4301. TRuntimeNode saveArg, outSave, loadArg, outLoad;
  4302. if (hasSaveLoad) {
  4303. saveArg = Arg(stateType);
  4304. outSave = save(saveArg);
  4305. loadArg = Arg(outSave.GetStaticType());
  4306. outLoad = load(loadArg);
  4307. MKQL_ENSURE(outLoad.GetStaticType()->IsSameType(*stateType), "Loaded type is changed by the load handler");
  4308. } else {
  4309. saveArg = outSave = loadArg = outLoad = NewVoid();
  4310. }
  4311. TRuntimeNode state2Arg = Arg(stateType);
  4312. TRuntimeNode timeArg = Arg(timestampType);
  4313. auto outStateMerge = merge(stateArg, state2Arg);
  4314. auto outItemFinish = finish(keyArg, stateArg, timeArg);
  4315. auto finishType = outItemFinish.GetStaticType();
  4316. MKQL_ENSURE(finishType->IsStruct(), "Expected struct type as finish lambda output");
  4317. auto resultType = TStreamType::Create(outItemFinish.GetStaticType(), Env);
  4318. TCallableBuilder callableBuilder(Env, __func__, resultType);
  4319. callableBuilder.Add(list);
  4320. callableBuilder.Add(itemArg);
  4321. callableBuilder.Add(keyArg);
  4322. callableBuilder.Add(stateArg);
  4323. callableBuilder.Add(state2Arg);
  4324. callableBuilder.Add(timeArg);
  4325. callableBuilder.Add(saveArg);
  4326. callableBuilder.Add(loadArg);
  4327. callableBuilder.Add(keyExtract);
  4328. callableBuilder.Add(outTime);
  4329. callableBuilder.Add(outStateInit);
  4330. callableBuilder.Add(outStateUpdate);
  4331. callableBuilder.Add(outSave);
  4332. callableBuilder.Add(outLoad);
  4333. callableBuilder.Add(outStateMerge);
  4334. callableBuilder.Add(outItemFinish);
  4335. callableBuilder.Add(hop);
  4336. callableBuilder.Add(interval);
  4337. callableBuilder.Add(delay);
  4338. callableBuilder.Add(dataWatermarks);
  4339. callableBuilder.Add(watermarksMode);
  4340. return TRuntimeNode(callableBuilder.Build(), false);
  4341. }
  4342. TRuntimeNode TProgramBuilder::Default(TType* type) {
  4343. bool isOptional;
  4344. const auto targetType = UnpackOptionalData(type, isOptional);
  4345. if (isOptional) {
  4346. return NewOptional(Default(targetType));
  4347. }
  4348. const auto scheme = targetType->GetSchemeType();
  4349. const auto value = scheme == NUdf::TDataType<NUdf::TUuid>::Id ?
  4350. Env.NewStringValue("\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"sv) :
  4351. scheme == NUdf::TDataType<NUdf::TDyNumber>::Id ? NUdf::TUnboxedValuePod::Embedded("\1") : NUdf::TUnboxedValuePod::Zero();
  4352. return TRuntimeNode(TDataLiteral::Create(value, targetType, Env), true);
  4353. }
  4354. TRuntimeNode TProgramBuilder::Cast(TRuntimeNode arg, TType* type) {
  4355. if (arg.GetStaticType()->IsSameType(*type)) {
  4356. return arg;
  4357. }
  4358. bool isOptional;
  4359. const auto targetType = UnpackOptionalData(type, isOptional);
  4360. const auto sourceType = UnpackOptionalData(arg, isOptional);
  4361. const auto sId = sourceType->GetSchemeType();
  4362. const auto tId = targetType->GetSchemeType();
  4363. if (sId == NUdf::TDataType<char*>::Id) {
  4364. if (tId != NUdf::TDataType<char*>::Id) {
  4365. return FromString(arg, type);
  4366. } else {
  4367. return arg;
  4368. }
  4369. }
  4370. if (sId == NUdf::TDataType<NUdf::TUtf8>::Id) {
  4371. if (tId != NUdf::TDataType<char*>::Id) {
  4372. return FromString(arg, type);
  4373. } else {
  4374. return ToString(arg);
  4375. }
  4376. }
  4377. if (tId == NUdf::TDataType<char*>::Id) {
  4378. return ToString(arg);
  4379. }
  4380. if (tId == NUdf::TDataType<NUdf::TUtf8>::Id) {
  4381. return ToString<true>(arg);
  4382. }
  4383. if (tId == NUdf::TDataType<NUdf::TDecimal>::Id) {
  4384. const auto& params = static_cast<const TDataDecimalType*>(targetType)->GetParams();
  4385. return ToDecimal(arg, params.first, params.second);
  4386. }
  4387. const auto options = NKikimr::NUdf::GetCastResult(*sourceType->GetDataSlot(), *targetType->GetDataSlot());
  4388. MKQL_ENSURE((*options & NKikimr::NUdf::ECastOptions::Undefined) ||
  4389. !(*options & NKikimr::NUdf::ECastOptions::Impossible),
  4390. "Impossible to cast " << *static_cast<TType*>(sourceType) << " into " << *static_cast<TType*>(targetType));
  4391. const bool useToIntegral = (*options & NKikimr::NUdf::ECastOptions::Undefined) ||
  4392. (*options & NKikimr::NUdf::ECastOptions::MayFail);
  4393. return useToIntegral ? ToIntegral(arg, type) : Convert(arg, type);
  4394. }
  4395. TRuntimeNode TProgramBuilder::RangeCreate(TRuntimeNode list) {
  4396. MKQL_ENSURE(list.GetStaticType()->IsList(), "Expecting list");
  4397. auto itemType = static_cast<TListType*>(list.GetStaticType())->GetItemType();
  4398. MKQL_ENSURE(itemType->IsTuple(), "Expecting list of tuples");
  4399. auto tupleType = static_cast<TTupleType*>(itemType);
  4400. MKQL_ENSURE(tupleType->GetElementsCount() == 2,
  4401. "Expecting list ot 2-element tuples, got: " << tupleType->GetElementsCount() << " elements");
  4402. MKQL_ENSURE(tupleType->GetElementType(0)->IsSameType(*tupleType->GetElementType(1)),
  4403. "Expecting list ot 2-element tuples of same type");
  4404. MKQL_ENSURE(tupleType->GetElementType(0)->IsTuple(),
  4405. "Expecting range boundary to be tuple");
  4406. auto boundaryType = static_cast<TTupleType*>(tupleType->GetElementType(0));
  4407. MKQL_ENSURE(boundaryType->GetElementsCount() >= 2,
  4408. "Range boundary should have at least 2 components, got: " << boundaryType->GetElementsCount());
  4409. auto lastComp = boundaryType->GetElementType(boundaryType->GetElementsCount() - 1);
  4410. std::vector<TType*> outputComponents;
  4411. for (ui32 i = 0; i < boundaryType->GetElementsCount() - 1; ++i) {
  4412. outputComponents.push_back(lastComp);
  4413. outputComponents.push_back(boundaryType->GetElementType(i));
  4414. }
  4415. outputComponents.push_back(lastComp);
  4416. auto outputBoundary = TTupleType::Create(outputComponents.size(), &outputComponents.front(), Env);
  4417. std::vector<TType*> outputRangeComps(2, outputBoundary);
  4418. auto outputRange = TTupleType::Create(outputRangeComps.size(), &outputRangeComps.front(), Env);
  4419. TCallableBuilder callableBuilder(Env, __func__, TListType::Create(outputRange, Env));
  4420. callableBuilder.Add(list);
  4421. return TRuntimeNode(callableBuilder.Build(), false);
  4422. }
  4423. TRuntimeNode TProgramBuilder::RangeUnion(const TArrayRef<const TRuntimeNode>& lists) {
  4424. return BuildRangeLogical(__func__, lists);
  4425. }
  4426. TRuntimeNode TProgramBuilder::RangeIntersect(const TArrayRef<const TRuntimeNode>& lists) {
  4427. return BuildRangeLogical(__func__, lists);
  4428. }
  4429. TRuntimeNode TProgramBuilder::RangeMultiply(const TArrayRef<const TRuntimeNode>& args) {
  4430. MKQL_ENSURE(args.size() >= 2, "Expecting at least two arguments");
  4431. bool unlimited = false;
  4432. if (args.front().GetStaticType()->IsVoid()) {
  4433. unlimited = true;
  4434. } else {
  4435. MKQL_ENSURE(args.front().GetStaticType()->IsData() &&
  4436. static_cast<TDataType*>(args.front().GetStaticType())->GetSchemeType() == NUdf::TDataType<ui64>::Id,
  4437. "Expected ui64 as first argument");
  4438. }
  4439. std::vector<TType*> outputComponents;
  4440. for (size_t i = 1; i < args.size(); ++i) {
  4441. const auto& list = args[i];
  4442. MKQL_ENSURE(list.GetStaticType()->IsList(), "Expecting list");
  4443. auto listItemType = static_cast<TListType*>(list.GetStaticType())->GetItemType();
  4444. MKQL_ENSURE(listItemType->IsTuple(), "Expecting list of tuples");
  4445. auto rangeType = static_cast<TTupleType*>(listItemType);
  4446. MKQL_ENSURE(rangeType->GetElementsCount() == 2, "Expecting list of 2-element tuples");
  4447. MKQL_ENSURE(rangeType->GetElementType(0)->IsTuple(), "Range boundary should be tuple");
  4448. auto boundaryType = static_cast<TTupleType*>(rangeType->GetElementType(0));
  4449. ui32 elementsCount = boundaryType->GetElementsCount();
  4450. MKQL_ENSURE(elementsCount >= 3 && elementsCount % 2 == 1, "Range boundary should have odd number components (at least 3)");
  4451. for (size_t j = 0; j < elementsCount - 1; ++j) {
  4452. outputComponents.push_back(boundaryType->GetElementType(j));
  4453. }
  4454. }
  4455. outputComponents.push_back(TDataType::Create(NUdf::TDataType<i32>::Id, Env));
  4456. auto outputBoundary = TTupleType::Create(outputComponents.size(), &outputComponents.front(), Env);
  4457. std::vector<TType*> outputRangeComps(2, outputBoundary);
  4458. auto outputRange = TTupleType::Create(outputRangeComps.size(), &outputRangeComps.front(), Env);
  4459. TCallableBuilder callableBuilder(Env, __func__, TListType::Create(outputRange, Env));
  4460. if (unlimited) {
  4461. callableBuilder.Add(NewDataLiteral<ui64>(std::numeric_limits<ui64>::max()));
  4462. } else {
  4463. callableBuilder.Add(args[0]);
  4464. }
  4465. for (size_t i = 1; i < args.size(); ++i) {
  4466. callableBuilder.Add(args[i]);
  4467. }
  4468. return TRuntimeNode(callableBuilder.Build(), false);
  4469. }
  4470. TRuntimeNode TProgramBuilder::RangeFinalize(TRuntimeNode list) {
  4471. MKQL_ENSURE(list.GetStaticType()->IsList(), "Expecting list");
  4472. auto listItemType = static_cast<TListType*>(list.GetStaticType())->GetItemType();
  4473. MKQL_ENSURE(listItemType->IsTuple(), "Expecting list of tuples");
  4474. auto rangeType = static_cast<TTupleType*>(listItemType);
  4475. MKQL_ENSURE(rangeType->GetElementsCount() == 2, "Expecting list of 2-element tuples");
  4476. MKQL_ENSURE(rangeType->GetElementType(0)->IsTuple(), "Range boundary should be tuple");
  4477. auto boundaryType = static_cast<TTupleType*>(rangeType->GetElementType(0));
  4478. ui32 elementsCount = boundaryType->GetElementsCount();
  4479. MKQL_ENSURE(elementsCount >= 3 && elementsCount % 2 == 1, "Range boundary should have odd number components (at least 3)");
  4480. std::vector<TType*> outputComponents;
  4481. for (ui32 i = 0; i < elementsCount; ++i) {
  4482. if (i % 2 == 1 || i + 1 == elementsCount) {
  4483. outputComponents.push_back(boundaryType->GetElementType(i));
  4484. }
  4485. }
  4486. auto outputBoundary = TTupleType::Create(outputComponents.size(), &outputComponents.front(), Env);
  4487. std::vector<TType*> outputRangeComps(2, outputBoundary);
  4488. auto outputRange = TTupleType::Create(outputRangeComps.size(), &outputRangeComps.front(), Env);
  4489. TCallableBuilder callableBuilder(Env, __func__, TListType::Create(outputRange, Env));
  4490. callableBuilder.Add(list);
  4491. return TRuntimeNode(callableBuilder.Build(), false);
  4492. }
  4493. TRuntimeNode TProgramBuilder::Round(const std::string_view& callableName, TRuntimeNode source, TType* targetType) {
  4494. const auto sourceType = source.GetStaticType();
  4495. MKQL_ENSURE(sourceType->IsData(), "Expecting first arg to be of Data type");
  4496. MKQL_ENSURE(targetType->IsData(), "Expecting second arg to be Data type");
  4497. const auto ss = *static_cast<TDataType*>(sourceType)->GetDataSlot();
  4498. const auto ts = *static_cast<TDataType*>(targetType)->GetDataSlot();
  4499. const auto options = NKikimr::NUdf::GetCastResult(ss, ts);
  4500. MKQL_ENSURE(!(*options & NKikimr::NUdf::ECastOptions::Impossible),
  4501. "Impossible to cast " << *sourceType << " into " << *targetType);
  4502. MKQL_ENSURE(*options & (NKikimr::NUdf::ECastOptions::MayFail |
  4503. NKikimr::NUdf::ECastOptions::MayLoseData |
  4504. NKikimr::NUdf::ECastOptions::AnywayLoseData),
  4505. "Rounding from " << *sourceType << " to " << *targetType << " is trivial");
  4506. TCallableBuilder callableBuilder(Env, callableName, TOptionalType::Create(targetType, Env));
  4507. callableBuilder.Add(source);
  4508. return TRuntimeNode(callableBuilder.Build(), false);
  4509. }
  4510. TRuntimeNode TProgramBuilder::NextValue(TRuntimeNode value) {
  4511. const auto valueType = value.GetStaticType();
  4512. MKQL_ENSURE(valueType->IsData(), "Expecting argument of Data type");
  4513. const auto slot = *static_cast<TDataType*>(valueType)->GetDataSlot();
  4514. MKQL_ENSURE(slot == NUdf::EDataSlot::String || slot == NUdf::EDataSlot::Utf8,
  4515. "Unsupported type: " << *valueType);
  4516. TCallableBuilder callableBuilder(Env, __func__, TOptionalType::Create(valueType, Env));
  4517. callableBuilder.Add(value);
  4518. return TRuntimeNode(callableBuilder.Build(), false);
  4519. }
  4520. TRuntimeNode TProgramBuilder::Nop(TRuntimeNode value, TType* returnType) {
  4521. if constexpr (RuntimeVersion < 35U) {
  4522. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4523. }
  4524. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4525. callableBuilder.Add(value);
  4526. return TRuntimeNode(callableBuilder.Build(), false);
  4527. }
  4528. bool TProgramBuilder::IsNull(TRuntimeNode arg) {
  4529. return arg.GetStaticType()->IsSameType(*NewNull().GetStaticType()); // TODO ->IsNull();
  4530. }
  4531. TRuntimeNode TProgramBuilder::Replicate(TRuntimeNode item, TRuntimeNode count, const std::string_view& file, ui32 row, ui32 column) {
  4532. MKQL_ENSURE(count.GetStaticType()->IsData(), "Expected data");
  4533. MKQL_ENSURE(static_cast<const TDataType&>(*count.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  4534. const auto listType = TListType::Create(item.GetStaticType(), Env);
  4535. TCallableBuilder callableBuilder(Env, __func__, listType);
  4536. callableBuilder.Add(item);
  4537. callableBuilder.Add(count);
  4538. if constexpr (RuntimeVersion >= 2) {
  4539. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  4540. callableBuilder.Add(NewDataLiteral(row));
  4541. callableBuilder.Add(NewDataLiteral(column));
  4542. }
  4543. return TRuntimeNode(callableBuilder.Build(), false);
  4544. }
  4545. TRuntimeNode TProgramBuilder::PgConst(TPgType* pgType, const std::string_view& value, TRuntimeNode typeMod) {
  4546. if constexpr (RuntimeVersion < 30U) {
  4547. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4548. }
  4549. TCallableBuilder callableBuilder(Env, __func__, pgType);
  4550. callableBuilder.Add(NewDataLiteral(pgType->GetTypeId()));
  4551. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(value));
  4552. if (typeMod) {
  4553. callableBuilder.Add(typeMod);
  4554. }
  4555. return TRuntimeNode(callableBuilder.Build(), false);
  4556. }
  4557. TRuntimeNode TProgramBuilder::PgResolvedCall(bool useContext, const std::string_view& name,
  4558. ui32 id, const TArrayRef<const TRuntimeNode>& args,
  4559. TType* returnType, bool rangeFunction) {
  4560. if constexpr (RuntimeVersion < 45U) {
  4561. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4562. }
  4563. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4564. callableBuilder.Add(NewDataLiteral(useContext));
  4565. callableBuilder.Add(NewDataLiteral(rangeFunction));
  4566. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(name));
  4567. callableBuilder.Add(NewDataLiteral(id));
  4568. for (const auto& arg : args) {
  4569. callableBuilder.Add(arg);
  4570. }
  4571. return TRuntimeNode(callableBuilder.Build(), false);
  4572. }
  4573. TRuntimeNode TProgramBuilder::BlockPgResolvedCall(const std::string_view& name, ui32 id,
  4574. const TArrayRef<const TRuntimeNode>& args, TType* returnType) {
  4575. if constexpr (RuntimeVersion < 30U) {
  4576. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4577. }
  4578. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4579. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(name));
  4580. callableBuilder.Add(NewDataLiteral(id));
  4581. for (const auto& arg : args) {
  4582. callableBuilder.Add(arg);
  4583. }
  4584. return TRuntimeNode(callableBuilder.Build(), false);
  4585. }
  4586. TRuntimeNode TProgramBuilder::PgArray(const TArrayRef<const TRuntimeNode>& args, TType* returnType) {
  4587. if constexpr (RuntimeVersion < 30U) {
  4588. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4589. }
  4590. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4591. for (const auto& arg : args) {
  4592. callableBuilder.Add(arg);
  4593. }
  4594. return TRuntimeNode(callableBuilder.Build(), false);
  4595. }
  4596. TRuntimeNode TProgramBuilder::PgTableContent(
  4597. const std::string_view& cluster,
  4598. const std::string_view& table,
  4599. TType* returnType) {
  4600. if constexpr (RuntimeVersion < 47U) {
  4601. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4602. }
  4603. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4604. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(cluster));
  4605. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(table));
  4606. return TRuntimeNode(callableBuilder.Build(), false);
  4607. }
  4608. TRuntimeNode TProgramBuilder::PgToRecord(TRuntimeNode input, const TArrayRef<std::pair<std::string_view, std::string_view>>& members) {
  4609. if constexpr (RuntimeVersion < 48U) {
  4610. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4611. }
  4612. MKQL_ENSURE(input.GetStaticType()->IsStruct(), "Expected struct");
  4613. auto structType = AS_TYPE(TStructType, input.GetStaticType());
  4614. for (ui32 i = 0; i < structType->GetMembersCount(); ++i) {
  4615. auto itemType = structType->GetMemberType(i);
  4616. MKQL_ENSURE(itemType->IsNull() || itemType->IsPg(), "Expected null or pg");
  4617. }
  4618. auto returnType = NewPgType(NYql::NPg::LookupType("record").TypeId);
  4619. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4620. callableBuilder.Add(input);
  4621. TVector<TRuntimeNode> names;
  4622. for (const auto& x : members) {
  4623. names.push_back(NewDataLiteral<NUdf::EDataSlot::String>(x.first));
  4624. names.push_back(NewDataLiteral<NUdf::EDataSlot::String>(x.second));
  4625. }
  4626. callableBuilder.Add(NewTuple(names));
  4627. return TRuntimeNode(callableBuilder.Build(), false);
  4628. }
  4629. TRuntimeNode TProgramBuilder::PgCast(TRuntimeNode input, TType* returnType, TRuntimeNode typeMod) {
  4630. if constexpr (RuntimeVersion < 30U) {
  4631. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4632. }
  4633. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4634. callableBuilder.Add(input);
  4635. if (typeMod) {
  4636. callableBuilder.Add(typeMod);
  4637. }
  4638. return TRuntimeNode(callableBuilder.Build(), false);
  4639. }
  4640. TRuntimeNode TProgramBuilder::FromPg(TRuntimeNode input, TType* returnType) {
  4641. if constexpr (RuntimeVersion < 30U) {
  4642. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4643. }
  4644. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4645. callableBuilder.Add(input);
  4646. return TRuntimeNode(callableBuilder.Build(), false);
  4647. }
  4648. TRuntimeNode TProgramBuilder::ToPg(TRuntimeNode input, TType* returnType) {
  4649. if constexpr (RuntimeVersion < 30U) {
  4650. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4651. }
  4652. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4653. callableBuilder.Add(input);
  4654. return TRuntimeNode(callableBuilder.Build(), false);
  4655. }
  4656. TRuntimeNode TProgramBuilder::PgClone(TRuntimeNode input, const TArrayRef<const TRuntimeNode>& dependentNodes) {
  4657. if constexpr (RuntimeVersion < 38U) {
  4658. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4659. }
  4660. TCallableBuilder callableBuilder(Env, __func__, input.GetStaticType());
  4661. callableBuilder.Add(input);
  4662. for (const auto& node : dependentNodes) {
  4663. callableBuilder.Add(node);
  4664. }
  4665. return TRuntimeNode(callableBuilder.Build(), false);
  4666. }
  4667. TRuntimeNode TProgramBuilder::WithContext(TRuntimeNode input, const std::string_view& contextType) {
  4668. if constexpr (RuntimeVersion < 30U) {
  4669. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4670. }
  4671. TCallableBuilder callableBuilder(Env, __func__, input.GetStaticType());
  4672. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(contextType));
  4673. callableBuilder.Add(input);
  4674. return TRuntimeNode(callableBuilder.Build(), false);
  4675. }
  4676. TRuntimeNode TProgramBuilder::PgInternal0(TType* returnType) {
  4677. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4678. return TRuntimeNode(callableBuilder.Build(), false);
  4679. }
  4680. TRuntimeNode TProgramBuilder::BlockIf(TRuntimeNode condition, TRuntimeNode thenBranch, TRuntimeNode elseBranch) {
  4681. const auto conditionType = AS_TYPE(TBlockType, condition.GetStaticType());
  4682. MKQL_ENSURE(AS_TYPE(TDataType, conditionType->GetItemType())->GetSchemeType() == NUdf::TDataType<bool>::Id,
  4683. "Expected bool as first argument");
  4684. const auto thenType = AS_TYPE(TBlockType, thenBranch.GetStaticType());
  4685. const auto elseType = AS_TYPE(TBlockType, elseBranch.GetStaticType());
  4686. MKQL_ENSURE(thenType->GetItemType()->IsSameType(*elseType->GetItemType()), "Different return types in branches.");
  4687. auto returnType = NewBlockType(thenType->GetItemType(), GetResultShape({conditionType, thenType, elseType}));
  4688. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4689. callableBuilder.Add(condition);
  4690. callableBuilder.Add(thenBranch);
  4691. callableBuilder.Add(elseBranch);
  4692. return TRuntimeNode(callableBuilder.Build(), false);
  4693. }
  4694. TRuntimeNode TProgramBuilder::BlockJust(TRuntimeNode data) {
  4695. const auto initialType = AS_TYPE(TBlockType, data.GetStaticType());
  4696. auto returnType = NewBlockType(NewOptionalType(initialType->GetItemType()), initialType->GetShape());
  4697. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4698. callableBuilder.Add(data);
  4699. return TRuntimeNode(callableBuilder.Build(), false);
  4700. }
  4701. TRuntimeNode TProgramBuilder::BlockFunc(const std::string_view& funcName, TType* returnType, const TArrayRef<const TRuntimeNode>& args) {
  4702. for (const auto& arg : args) {
  4703. MKQL_ENSURE(arg.GetStaticType()->IsBlock(), "Expected Block type");
  4704. }
  4705. TCallableBuilder builder(Env, __func__, returnType);
  4706. builder.Add(NewDataLiteral<NUdf::EDataSlot::String>(funcName));
  4707. for (const auto& arg : args) {
  4708. builder.Add(arg);
  4709. }
  4710. return TRuntimeNode(builder.Build(), false);
  4711. }
  4712. TRuntimeNode TProgramBuilder::BuildBlockCombineAll(const std::string_view& callableName, TRuntimeNode input, std::optional<ui32> filterColumn,
  4713. const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4714. const auto inputType = input.GetStaticType();
  4715. MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type");
  4716. MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type");
  4717. TCallableBuilder builder(Env, callableName, returnType);
  4718. builder.Add(input);
  4719. if (!filterColumn) {
  4720. builder.Add(NewEmptyOptionalDataLiteral(NUdf::TDataType<ui32>::Id));
  4721. } else {
  4722. builder.Add(NewOptional(NewDataLiteral<ui32>(*filterColumn)));
  4723. }
  4724. TVector<TRuntimeNode> aggsNodes;
  4725. for (const auto& agg : aggs) {
  4726. TVector<TRuntimeNode> params;
  4727. params.push_back(NewDataLiteral<NUdf::EDataSlot::String>(agg.Name));
  4728. for (const auto& col : agg.ArgsColumns) {
  4729. params.push_back(NewDataLiteral<ui32>(col));
  4730. }
  4731. aggsNodes.push_back(NewTuple(params));
  4732. }
  4733. builder.Add(NewTuple(aggsNodes));
  4734. return TRuntimeNode(builder.Build(), false);
  4735. }
  4736. TRuntimeNode TProgramBuilder::BlockCombineAll(TRuntimeNode stream, std::optional<ui32> filterColumn,
  4737. const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4738. if constexpr (RuntimeVersion < 31U) {
  4739. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4740. }
  4741. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type");
  4742. MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type");
  4743. if constexpr (RuntimeVersion < 52U) {
  4744. const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType());
  4745. return FromFlow(BuildBlockCombineAll(__func__, ToFlow(stream), filterColumn, aggs, flowReturnType));
  4746. } else {
  4747. return BuildBlockCombineAll(__func__, stream, filterColumn, aggs, returnType);
  4748. }
  4749. }
  4750. TRuntimeNode TProgramBuilder::BuildBlockCombineHashed(const std::string_view& callableName, TRuntimeNode input, std::optional<ui32> filterColumn,
  4751. const TArrayRef<ui32>& keys, const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4752. const auto inputType = input.GetStaticType();
  4753. MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type");
  4754. MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type");
  4755. TCallableBuilder builder(Env, callableName, returnType);
  4756. builder.Add(input);
  4757. if (!filterColumn) {
  4758. builder.Add(NewEmptyOptionalDataLiteral(NUdf::TDataType<ui32>::Id));
  4759. } else {
  4760. builder.Add(NewOptional(NewDataLiteral<ui32>(*filterColumn)));
  4761. }
  4762. TVector<TRuntimeNode> keyNodes;
  4763. for (const auto& key : keys) {
  4764. keyNodes.push_back(NewDataLiteral<ui32>(key));
  4765. }
  4766. builder.Add(NewTuple(keyNodes));
  4767. TVector<TRuntimeNode> aggsNodes;
  4768. for (const auto& agg : aggs) {
  4769. TVector<TRuntimeNode> params;
  4770. params.push_back(NewDataLiteral<NUdf::EDataSlot::String>(agg.Name));
  4771. for (const auto& col : agg.ArgsColumns) {
  4772. params.push_back(NewDataLiteral<ui32>(col));
  4773. }
  4774. aggsNodes.push_back(NewTuple(params));
  4775. }
  4776. builder.Add(NewTuple(aggsNodes));
  4777. return TRuntimeNode(builder.Build(), false);
  4778. }
  4779. TRuntimeNode TProgramBuilder::BlockCombineHashed(TRuntimeNode stream, std::optional<ui32> filterColumn, const TArrayRef<ui32>& keys,
  4780. const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4781. if constexpr (RuntimeVersion < 31U) {
  4782. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4783. }
  4784. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type");
  4785. MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type");
  4786. if constexpr (RuntimeVersion < 52U) {
  4787. const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType());
  4788. return FromFlow(BuildBlockCombineHashed(__func__, ToFlow(stream), filterColumn, keys, aggs, flowReturnType));
  4789. } else {
  4790. return BuildBlockCombineHashed(__func__, stream, filterColumn, keys, aggs, returnType);
  4791. }
  4792. }
  4793. TRuntimeNode TProgramBuilder::BuildBlockMergeFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, const TArrayRef<ui32>& keys,
  4794. const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4795. const auto inputType = input.GetStaticType();
  4796. MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type");
  4797. MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type");
  4798. TCallableBuilder builder(Env, callableName, returnType);
  4799. builder.Add(input);
  4800. TVector<TRuntimeNode> keyNodes;
  4801. for (const auto& key : keys) {
  4802. keyNodes.push_back(NewDataLiteral<ui32>(key));
  4803. }
  4804. builder.Add(NewTuple(keyNodes));
  4805. TVector<TRuntimeNode> aggsNodes;
  4806. for (const auto& agg : aggs) {
  4807. TVector<TRuntimeNode> params;
  4808. params.push_back(NewDataLiteral<NUdf::EDataSlot::String>(agg.Name));
  4809. for (const auto& col : agg.ArgsColumns) {
  4810. params.push_back(NewDataLiteral<ui32>(col));
  4811. }
  4812. aggsNodes.push_back(NewTuple(params));
  4813. }
  4814. builder.Add(NewTuple(aggsNodes));
  4815. return TRuntimeNode(builder.Build(), false);
  4816. }
  4817. TRuntimeNode TProgramBuilder::BlockMergeFinalizeHashed(TRuntimeNode stream, const TArrayRef<ui32>& keys,
  4818. const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4819. if constexpr (RuntimeVersion < 31U) {
  4820. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4821. }
  4822. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type");
  4823. MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type");
  4824. if constexpr (RuntimeVersion < 52U) {
  4825. const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType());
  4826. return FromFlow(BuildBlockMergeFinalizeHashed(__func__, ToFlow(stream), keys, aggs, flowReturnType));
  4827. } else {
  4828. return BuildBlockMergeFinalizeHashed(__func__, stream, keys, aggs, returnType);
  4829. }
  4830. }
  4831. TRuntimeNode TProgramBuilder::BuildBlockMergeManyFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, const TArrayRef<ui32>& keys,
  4832. const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType) {
  4833. const auto inputType = input.GetStaticType();
  4834. MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type");
  4835. MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type");
  4836. TCallableBuilder builder(Env, callableName, returnType);
  4837. builder.Add(input);
  4838. TVector<TRuntimeNode> keyNodes;
  4839. for (const auto& key : keys) {
  4840. keyNodes.push_back(NewDataLiteral<ui32>(key));
  4841. }
  4842. builder.Add(NewTuple(keyNodes));
  4843. TVector<TRuntimeNode> aggsNodes;
  4844. for (const auto& agg : aggs) {
  4845. TVector<TRuntimeNode> params;
  4846. params.push_back(NewDataLiteral<NUdf::EDataSlot::String>(agg.Name));
  4847. for (const auto& col : agg.ArgsColumns) {
  4848. params.push_back(NewDataLiteral<ui32>(col));
  4849. }
  4850. aggsNodes.push_back(NewTuple(params));
  4851. }
  4852. builder.Add(NewTuple(aggsNodes));
  4853. builder.Add(NewDataLiteral<ui32>(streamIndex));
  4854. TVector<TRuntimeNode> streamsNodes;
  4855. for (const auto& s : streams) {
  4856. TVector<TRuntimeNode> streamNodes;
  4857. for (const auto& i : s) {
  4858. streamNodes.push_back(NewDataLiteral<ui32>(i));
  4859. }
  4860. streamsNodes.push_back(NewTuple(streamNodes));
  4861. }
  4862. builder.Add(NewTuple(streamsNodes));
  4863. return TRuntimeNode(builder.Build(), false);
  4864. }
  4865. TRuntimeNode TProgramBuilder::BlockMergeManyFinalizeHashed(TRuntimeNode stream, const TArrayRef<ui32>& keys,
  4866. const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType) {
  4867. if constexpr (RuntimeVersion < 31U) {
  4868. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4869. }
  4870. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type");
  4871. MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type");
  4872. if constexpr (RuntimeVersion < 52U) {
  4873. const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType());
  4874. return FromFlow(BuildBlockMergeManyFinalizeHashed(__func__, ToFlow(stream), keys, aggs, streamIndex, streams, flowReturnType));
  4875. } else {
  4876. return BuildBlockMergeManyFinalizeHashed(__func__, stream, keys, aggs, streamIndex, streams, returnType);
  4877. }
  4878. }
  4879. TRuntimeNode TProgramBuilder::ScalarApply(const TArrayRef<const TRuntimeNode>& args, const TArrayLambda& handler) {
  4880. if constexpr (RuntimeVersion < 39U) {
  4881. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4882. }
  4883. MKQL_ENSURE(!args.empty(), "Required at least one argument");
  4884. TVector<TRuntimeNode> lambdaArgs;
  4885. bool scalarOnly = true;
  4886. std::shared_ptr<arrow::DataType> arrowType;
  4887. for (const auto& arg : args) {
  4888. auto blockType = AS_TYPE(TBlockType, arg.GetStaticType());
  4889. scalarOnly = scalarOnly && blockType->GetShape() == TBlockType::EShape::Scalar;
  4890. MKQL_ENSURE(ConvertArrowType(blockType->GetItemType(), arrowType), "Unsupported arrow type");
  4891. lambdaArgs.emplace_back(Arg(blockType->GetItemType()));
  4892. }
  4893. auto ret = handler(lambdaArgs);
  4894. MKQL_ENSURE(ConvertArrowType(ret.GetStaticType(), arrowType), "Unsupported arrow type");
  4895. auto returnType = NewBlockType(ret.GetStaticType(), scalarOnly ? TBlockType::EShape::Scalar : TBlockType::EShape::Many);
  4896. TCallableBuilder builder(Env, __func__, returnType);
  4897. for (const auto& arg : args) {
  4898. builder.Add(arg);
  4899. }
  4900. for (const auto& arg : lambdaArgs) {
  4901. builder.Add(arg);
  4902. }
  4903. builder.Add(ret);
  4904. return TRuntimeNode(builder.Build(), false);
  4905. }
  4906. TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode leftStream, TRuntimeNode rightStream, EJoinKind joinKind,
  4907. const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& leftKeyDrops,
  4908. const TArrayRef<const ui32>& rightKeyColumns, const TArrayRef<const ui32>& rightKeyDrops, bool rightAny, TType* returnType
  4909. ) {
  4910. if constexpr (RuntimeVersion < 53U) {
  4911. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4912. }
  4913. if (RuntimeVersion < 57U && joinKind == EJoinKind::Cross) {
  4914. THROW yexception() << __func__ << " does not support cross join in runtime version (" << RuntimeVersion << ")";
  4915. }
  4916. MKQL_ENSURE(joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left ||
  4917. joinKind == EJoinKind::LeftSemi || joinKind == EJoinKind::LeftOnly || joinKind == EJoinKind::Cross,
  4918. "Unsupported join kind");
  4919. MKQL_ENSURE(leftKeyColumns.size() == rightKeyColumns.size(), "Key column count mismatch");
  4920. if (joinKind == EJoinKind::Cross) {
  4921. MKQL_ENSURE(leftKeyColumns.empty(), "Specifying key columns is not allowed for cross join");
  4922. } else {
  4923. MKQL_ENSURE(!leftKeyColumns.empty(), "At least one key column must be specified");
  4924. }
  4925. ValidateBlockStreamType(leftStream.GetStaticType());
  4926. ValidateBlockStreamType(rightStream.GetStaticType());
  4927. ValidateBlockStreamType(returnType);
  4928. TRuntimeNode::TList leftKeyColumnsNodes;
  4929. leftKeyColumnsNodes.reserve(leftKeyColumns.size());
  4930. std::transform(leftKeyColumns.cbegin(), leftKeyColumns.cend(),
  4931. std::back_inserter(leftKeyColumnsNodes), [this](const ui32 idx) {
  4932. return NewDataLiteral(idx);
  4933. });
  4934. TRuntimeNode::TList leftKeyDropsNodes;
  4935. leftKeyDropsNodes.reserve(leftKeyDrops.size());
  4936. std::transform(leftKeyDrops.cbegin(), leftKeyDrops.cend(),
  4937. std::back_inserter(leftKeyDropsNodes), [this](const ui32 idx) {
  4938. return NewDataLiteral(idx);
  4939. });
  4940. TRuntimeNode::TList rightKeyColumnsNodes;
  4941. rightKeyColumnsNodes.reserve(rightKeyColumns.size());
  4942. std::transform(rightKeyColumns.cbegin(), rightKeyColumns.cend(),
  4943. std::back_inserter(rightKeyColumnsNodes), [this](const ui32 idx) {
  4944. return NewDataLiteral(idx);
  4945. });
  4946. TRuntimeNode::TList rightKeyDropsNodes;
  4947. rightKeyDropsNodes.reserve(leftKeyDrops.size());
  4948. std::transform(rightKeyDrops.cbegin(), rightKeyDrops.cend(),
  4949. std::back_inserter(rightKeyDropsNodes), [this](const ui32 idx) {
  4950. return NewDataLiteral(idx);
  4951. });
  4952. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4953. callableBuilder.Add(leftStream);
  4954. callableBuilder.Add(rightStream);
  4955. callableBuilder.Add(NewDataLiteral((ui32)joinKind));
  4956. callableBuilder.Add(NewTuple(leftKeyColumnsNodes));
  4957. callableBuilder.Add(NewTuple(leftKeyDropsNodes));
  4958. callableBuilder.Add(NewTuple(rightKeyColumnsNodes));
  4959. callableBuilder.Add(NewTuple(rightKeyDropsNodes));
  4960. callableBuilder.Add(NewDataLiteral((bool)rightAny));
  4961. return TRuntimeNode(callableBuilder.Build(), false);
  4962. }
  4963. namespace {
  4964. using namespace NYql::NMatchRecognize;
  4965. TRuntimeNode PatternToRuntimeNode(const TRowPattern& pattern, const TProgramBuilder& programBuilder) {
  4966. const auto& env = programBuilder.GetTypeEnvironment();
  4967. TTupleLiteralBuilder patternBuilder(env);
  4968. for (const auto& term: pattern) {
  4969. TTupleLiteralBuilder termBuilder(env);
  4970. for (const auto& factor: term) {
  4971. TTupleLiteralBuilder factorBuilder(env);
  4972. factorBuilder.Add(std::visit(TOverloaded {
  4973. [&](const TString& s) {
  4974. return programBuilder.NewDataLiteral<NUdf::EDataSlot::String>(s);
  4975. },
  4976. [&](const TRowPattern& pattern) {
  4977. return PatternToRuntimeNode(pattern, programBuilder);
  4978. },
  4979. }, factor.Primary));
  4980. factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMin));
  4981. factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMax));
  4982. factorBuilder.Add(programBuilder.NewDataLiteral(factor.Greedy));
  4983. factorBuilder.Add(programBuilder.NewDataLiteral(factor.Output));
  4984. factorBuilder.Add(programBuilder.NewDataLiteral(factor.Unused));
  4985. termBuilder.Add({factorBuilder.Build(), true});
  4986. }
  4987. patternBuilder.Add({termBuilder.Build(), true});
  4988. }
  4989. return {patternBuilder.Build(), true};
  4990. };
  4991. } //namespace
  4992. TRuntimeNode TProgramBuilder::MatchRecognizeCore(
  4993. TRuntimeNode inputStream,
  4994. const TUnaryLambda& getPartitionKeySelectorNode,
  4995. const TArrayRef<TStringBuf>& partitionColumnNames,
  4996. const TVector<TStringBuf>& measureColumnNames,
  4997. const TVector<TBinaryLambda>& getMeasures,
  4998. const NYql::NMatchRecognize::TRowPattern& pattern,
  4999. const TVector<TStringBuf>& defineVarNames,
  5000. const TVector<TTernaryLambda>& getDefines,
  5001. bool streamingMode,
  5002. const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo,
  5003. NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch
  5004. ) {
  5005. MKQL_ENSURE(RuntimeVersion >= 42, "MatchRecognize is not supported in runtime version " << RuntimeVersion);
  5006. const auto inputRowType = AS_TYPE(TStructType, AS_TYPE(TFlowType, inputStream.GetStaticType())->GetItemType());
  5007. const auto inputRowArg = Arg(inputRowType);
  5008. const auto partitionKeySelectorNode = getPartitionKeySelectorNode(inputRowArg);
  5009. const auto partitionColumnTypes = AS_TYPE(TTupleType, partitionKeySelectorNode.GetStaticType())->GetElements();
  5010. const auto rangeList = NewListType(NewStructType({
  5011. {"From", NewDataType(NUdf::EDataSlot::Uint64)},
  5012. {"To", NewDataType(NUdf::EDataSlot::Uint64)}
  5013. }));
  5014. TStructTypeBuilder matchedVarsTypeBuilder(Env);
  5015. for (const auto& var: GetPatternVars(pattern)) {
  5016. matchedVarsTypeBuilder.Add(var, rangeList);
  5017. }
  5018. const auto matchedVarsType = matchedVarsTypeBuilder.Build();
  5019. TRuntimeNode matchedVarsArg = Arg(matchedVarsType);
  5020. //---These vars may be empty in case of no measures
  5021. TRuntimeNode measureInputDataArg;
  5022. std::vector<TRuntimeNode> specialColumnIndexesInMeasureInputDataRow;
  5023. TVector<TRuntimeNode> measures;
  5024. //---
  5025. if (getMeasures.empty()) {
  5026. measureInputDataArg = Arg(Env.GetTypeOfVoidLazy());
  5027. } else {
  5028. measures.reserve(getMeasures.size());
  5029. specialColumnIndexesInMeasureInputDataRow.resize(static_cast<size_t>(NYql::NMatchRecognize::EMeasureInputDataSpecialColumns::Last));
  5030. TStructTypeBuilder measureInputDataRowTypeBuilder(Env);
  5031. for (ui32 i = 0; i < inputRowType->GetMembersCount(); ++i) {
  5032. measureInputDataRowTypeBuilder.Add(inputRowType->GetMemberName(i), inputRowType->GetMemberType(i));
  5033. }
  5034. measureInputDataRowTypeBuilder.Add(
  5035. MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::Classifier),
  5036. NewDataType(NUdf::EDataSlot::Utf8)
  5037. );
  5038. measureInputDataRowTypeBuilder.Add(
  5039. MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::MatchNumber),
  5040. NewDataType(NUdf::EDataSlot::Uint64)
  5041. );
  5042. const auto measureInputDataRowType = measureInputDataRowTypeBuilder.Build();
  5043. for (ui32 i = 0; i < measureInputDataRowType->GetMembersCount(); ++i) {
  5044. //assume a few, if grows, it's better to use a lookup table here
  5045. static_assert(static_cast<size_t>(EMeasureInputDataSpecialColumns::Last) < 5);
  5046. for (size_t j = 0; j != static_cast<size_t>(EMeasureInputDataSpecialColumns::Last); ++j) {
  5047. if (measureInputDataRowType->GetMemberName(i) ==
  5048. NYql::NMatchRecognize::MeasureInputDataSpecialColumnName(static_cast<EMeasureInputDataSpecialColumns>(j)))
  5049. specialColumnIndexesInMeasureInputDataRow[j] = NewDataLiteral(i);
  5050. }
  5051. }
  5052. measureInputDataArg = Arg(NewListType(measureInputDataRowType));
  5053. for (size_t i = 0; i != getMeasures.size(); ++i) {
  5054. measures.push_back(getMeasures[i](measureInputDataArg, matchedVarsArg));
  5055. }
  5056. }
  5057. TStructTypeBuilder outputRowTypeBuilder(Env);
  5058. THashMap<TStringBuf, size_t> partitionColumnLookup;
  5059. THashMap<TStringBuf, size_t> measureColumnLookup;
  5060. THashMap<TStringBuf, size_t> otherColumnLookup;
  5061. for (size_t i = 0; i < measureColumnNames.size(); ++i) {
  5062. const auto name = measureColumnNames[i];
  5063. measureColumnLookup.emplace(name, i);
  5064. outputRowTypeBuilder.Add(name, measures[i].GetStaticType());
  5065. }
  5066. switch (rowsPerMatch) {
  5067. case NYql::NMatchRecognize::ERowsPerMatch::OneRow:
  5068. for (size_t i = 0; i < partitionColumnNames.size(); ++i) {
  5069. const auto name = partitionColumnNames[i];
  5070. partitionColumnLookup.emplace(name, i);
  5071. outputRowTypeBuilder.Add(name, partitionColumnTypes[i]);
  5072. }
  5073. break;
  5074. case NYql::NMatchRecognize::ERowsPerMatch::AllRows:
  5075. for (size_t i = 0; i < inputRowType->GetMembersCount(); ++i) {
  5076. const auto name = inputRowType->GetMemberName(i);
  5077. otherColumnLookup.emplace(name, i);
  5078. outputRowTypeBuilder.Add(name, inputRowType->GetMemberType(i));
  5079. }
  5080. break;
  5081. }
  5082. auto outputRowType = outputRowTypeBuilder.Build();
  5083. std::vector<TRuntimeNode> partitionColumnIndexes(partitionColumnLookup.size());
  5084. std::vector<TRuntimeNode> measureColumnIndexes(measureColumnLookup.size());
  5085. TVector<TRuntimeNode> outputColumnOrder(NDetail::TReserveTag{outputRowType->GetMembersCount()});
  5086. for (ui32 i = 0; i < outputRowType->GetMembersCount(); ++i) {
  5087. const auto name = outputRowType->GetMemberName(i);
  5088. if (auto iter = partitionColumnLookup.find(name);
  5089. iter != partitionColumnLookup.end()) {
  5090. partitionColumnIndexes[iter->second] = NewDataLiteral(i);
  5091. outputColumnOrder.push_back(NewStruct({
  5092. std::pair{"Index", NewDataLiteral(iter->second)},
  5093. std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::PartitionKey))},
  5094. }));
  5095. } else if (auto iter = measureColumnLookup.find(name);
  5096. iter != measureColumnLookup.end()) {
  5097. measureColumnIndexes[iter->second] = NewDataLiteral(i);
  5098. outputColumnOrder.push_back(NewStruct({
  5099. std::pair{"Index", NewDataLiteral(iter->second)},
  5100. std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Measure))},
  5101. }));
  5102. } else if (auto iter = otherColumnLookup.find(name);
  5103. iter != otherColumnLookup.end()) {
  5104. outputColumnOrder.push_back(NewStruct({
  5105. std::pair{"Index", NewDataLiteral(iter->second)},
  5106. std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Other))},
  5107. }));
  5108. }
  5109. }
  5110. const auto outputType = NewFlowType(outputRowType);
  5111. THashMap<TStringBuf, size_t> patternVarLookup;
  5112. for (ui32 i = 0; i < matchedVarsType->GetMembersCount(); ++i) {
  5113. patternVarLookup[matchedVarsType->GetMemberName(i)] = i;
  5114. }
  5115. THashMap<TStringBuf, size_t> defineLookup;
  5116. for (size_t i = 0; i < defineVarNames.size(); ++i) {
  5117. const auto name = defineVarNames[i];
  5118. defineLookup[name] = i;
  5119. }
  5120. TVector<TRuntimeNode> defineNames(patternVarLookup.size());
  5121. TVector<TRuntimeNode> defineNodes(patternVarLookup.size());
  5122. const auto inputDataArg = Arg(NewListType(inputRowType));
  5123. const auto currentRowIndexArg = Arg(NewDataType(NUdf::EDataSlot::Uint64));
  5124. for (const auto& [v, i]: patternVarLookup) {
  5125. defineNames[i] = NewDataLiteral<NUdf::EDataSlot::String>(v);
  5126. if (auto iter = defineLookup.find(v);
  5127. iter != defineLookup.end()) {
  5128. defineNodes[i] = getDefines[iter->second](inputDataArg, matchedVarsArg, currentRowIndexArg);
  5129. } else if ("$" == v || "^" == v) {
  5130. //DO nothing, //will be handled in a specific way
  5131. } else { // a var without a predicate matches any row
  5132. defineNodes[i] = NewDataLiteral(true);
  5133. }
  5134. }
  5135. TCallableBuilder callableBuilder(GetTypeEnvironment(), "MatchRecognizeCore", outputType);
  5136. const auto indexType = NewDataType(NUdf::EDataSlot::Uint32);
  5137. const auto outputColumnEntryType = NewStructType({
  5138. {"Index", NewDataType(NUdf::EDataSlot::Uint64)},
  5139. {"SourceType", NewDataType(NUdf::EDataSlot::Int32)},
  5140. });
  5141. callableBuilder.Add(inputStream);
  5142. callableBuilder.Add(inputRowArg);
  5143. callableBuilder.Add(partitionKeySelectorNode);
  5144. callableBuilder.Add(NewList(indexType, partitionColumnIndexes));
  5145. callableBuilder.Add(measureInputDataArg);
  5146. callableBuilder.Add(NewList(indexType, specialColumnIndexesInMeasureInputDataRow));
  5147. callableBuilder.Add(NewDataLiteral(inputRowType->GetMembersCount()));
  5148. callableBuilder.Add(matchedVarsArg);
  5149. callableBuilder.Add(NewList(indexType, measureColumnIndexes));
  5150. for (const auto& m: measures) {
  5151. callableBuilder.Add(m);
  5152. }
  5153. callableBuilder.Add(PatternToRuntimeNode(pattern, *this));
  5154. callableBuilder.Add(currentRowIndexArg);
  5155. callableBuilder.Add(inputDataArg);
  5156. callableBuilder.Add(NewList(NewDataType(NUdf::EDataSlot::String), defineNames));
  5157. for (const auto& d: defineNodes) {
  5158. callableBuilder.Add(d);
  5159. }
  5160. callableBuilder.Add(NewDataLiteral(streamingMode));
  5161. if constexpr (RuntimeVersion >= 52U) {
  5162. callableBuilder.Add(NewDataLiteral(static_cast<i32>(skipTo.To)));
  5163. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(skipTo.Var));
  5164. }
  5165. if constexpr (RuntimeVersion >= 54U) {
  5166. callableBuilder.Add(NewDataLiteral(static_cast<i32>(rowsPerMatch)));
  5167. callableBuilder.Add(NewList(outputColumnEntryType, outputColumnOrder));
  5168. }
  5169. return TRuntimeNode(callableBuilder.Build(), false);
  5170. }
  5171. TRuntimeNode TProgramBuilder::TimeOrderRecover(
  5172. TRuntimeNode inputStream,
  5173. const TUnaryLambda& getTimeExtractor,
  5174. TRuntimeNode delay,
  5175. TRuntimeNode ahead,
  5176. TRuntimeNode rowLimit
  5177. )
  5178. {
  5179. MKQL_ENSURE(RuntimeVersion >= 44, "TimeOrderRecover is not supported in runtime version " << RuntimeVersion);
  5180. auto& inputRowType = *static_cast<TStructType*>(AS_TYPE(TStructType, AS_TYPE(TFlowType, inputStream.GetStaticType())->GetItemType()));
  5181. const auto inputRowArg = Arg(&inputRowType);
  5182. TStructTypeBuilder outputRowTypeBuilder(Env);
  5183. outputRowTypeBuilder.Reserve(inputRowType.GetMembersCount() + 1);
  5184. const ui32 inputRowColumnCount = inputRowType.GetMembersCount();
  5185. for (ui32 i = 0; i != inputRowColumnCount; ++i) {
  5186. outputRowTypeBuilder.Add(inputRowType.GetMemberName(i), inputRowType.GetMemberType(i));
  5187. }
  5188. using NYql::NTimeOrderRecover::OUT_OF_ORDER_MARKER;
  5189. outputRowTypeBuilder.Add(OUT_OF_ORDER_MARKER, TDataType::Create(NUdf::TDataType<bool>::Id, Env));
  5190. const auto outputRowType = outputRowTypeBuilder.Build();
  5191. const auto outOfOrderColumnIndex = outputRowType->GetMemberIndex(OUT_OF_ORDER_MARKER);
  5192. TCallableBuilder callableBuilder(GetTypeEnvironment(), "TimeOrderRecover", TFlowType::Create(outputRowType, Env));
  5193. callableBuilder.Add(inputStream);
  5194. callableBuilder.Add(inputRowArg);
  5195. callableBuilder.Add(getTimeExtractor(inputRowArg));
  5196. callableBuilder.Add(NewDataLiteral(inputRowColumnCount));
  5197. callableBuilder.Add(NewDataLiteral(outOfOrderColumnIndex));
  5198. callableBuilder.Add(delay),
  5199. callableBuilder.Add(ahead),
  5200. callableBuilder.Add(rowLimit);
  5201. return TRuntimeNode(callableBuilder.Build(), false);
  5202. }
  5203. bool CanExportType(TType* type, const TTypeEnvironment& env) {
  5204. if (type->GetKind() == TType::EKind::Type) {
  5205. return false; // Type of Type
  5206. }
  5207. TExploringNodeVisitor explorer;
  5208. explorer.Walk(type, env);
  5209. bool canExport = true;
  5210. for (auto& node : explorer.GetNodes()) {
  5211. switch (static_cast<TType*>(node)->GetKind()) {
  5212. case TType::EKind::Void:
  5213. node->SetCookie(1);
  5214. break;
  5215. case TType::EKind::Data:
  5216. node->SetCookie(1);
  5217. break;
  5218. case TType::EKind::Pg:
  5219. node->SetCookie(1);
  5220. break;
  5221. case TType::EKind::Optional: {
  5222. auto optionalType = static_cast<TOptionalType*>(node);
  5223. if (!optionalType->GetItemType()->GetCookie()) {
  5224. canExport = false;
  5225. } else {
  5226. node->SetCookie(1);
  5227. }
  5228. break;
  5229. }
  5230. case TType::EKind::List: {
  5231. auto listType = static_cast<TListType*>(node);
  5232. if (!listType->GetItemType()->GetCookie()) {
  5233. canExport = false;
  5234. } else {
  5235. node->SetCookie(1);
  5236. }
  5237. break;
  5238. }
  5239. case TType::EKind::Struct: {
  5240. auto structType = static_cast<TStructType*>(node);
  5241. for (ui32 index = 0; index < structType->GetMembersCount(); ++index) {
  5242. if (!structType->GetMemberType(index)->GetCookie()) {
  5243. canExport = false;
  5244. break;
  5245. }
  5246. }
  5247. if (canExport) {
  5248. node->SetCookie(1);
  5249. }
  5250. break;
  5251. }
  5252. case TType::EKind::Tuple: {
  5253. auto tupleType = static_cast<TTupleType*>(node);
  5254. for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) {
  5255. if (!tupleType->GetElementType(index)->GetCookie()) {
  5256. canExport = false;
  5257. break;
  5258. }
  5259. }
  5260. if (canExport) {
  5261. node->SetCookie(1);
  5262. }
  5263. break;
  5264. }
  5265. case TType::EKind::Dict: {
  5266. auto dictType = static_cast<TDictType*>(node);
  5267. if (!dictType->GetKeyType()->GetCookie() || !dictType->GetPayloadType()->GetCookie()) {
  5268. canExport = false;
  5269. } else {
  5270. node->SetCookie(1);
  5271. }
  5272. break;
  5273. }
  5274. case TType::EKind::Variant: {
  5275. auto variantType = static_cast<TVariantType*>(node);
  5276. TType* innerType = variantType->GetUnderlyingType();
  5277. if (innerType->IsStruct()) {
  5278. auto structType = static_cast<TStructType*>(innerType);
  5279. for (ui32 index = 0; index < structType->GetMembersCount(); ++index) {
  5280. if (!structType->GetMemberType(index)->GetCookie()) {
  5281. canExport = false;
  5282. break;
  5283. }
  5284. }
  5285. }
  5286. if (innerType->IsTuple()) {
  5287. auto tupleType = static_cast<TTupleType*>(innerType);
  5288. for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) {
  5289. if (!tupleType->GetElementType(index)->GetCookie()) {
  5290. canExport = false;
  5291. break;
  5292. }
  5293. }
  5294. }
  5295. if (canExport) {
  5296. node->SetCookie(1);
  5297. }
  5298. break;
  5299. }
  5300. case TType::EKind::Type:
  5301. break;
  5302. default:
  5303. canExport = false;
  5304. }
  5305. if (!canExport) {
  5306. break;
  5307. }
  5308. }
  5309. for (auto& node : explorer.GetNodes()) {
  5310. node->SetCookie(0);
  5311. }
  5312. return canExport;
  5313. }
  5314. }
  5315. }