mkql_program_builder.cpp 279 KB


  1. #include "mkql_program_builder.h"
  2. #include "mkql_node_visitor.h"
  3. #include "mkql_node_cast.h"
  4. #include "mkql_runtime_version.h"
  5. #include "yql/essentials/minikql/mkql_node_printer.h"
  6. #include "yql/essentials/minikql/mkql_function_registry.h"
  7. #include "yql/essentials/minikql/mkql_utils.h"
  8. #include "yql/essentials/minikql/mkql_type_builder.h"
  9. #include "yql/essentials/core/sql_types/match_recognize.h"
  10. #include "yql/essentials/core/sql_types/time_order_recover.h"
  11. #include <yql/essentials/parser/pg_catalog/catalog.h>
  12. #include <util/generic/overloaded.h>
  13. #include <util/string/cast.h>
  14. #include <util/string/printf.h>
  15. #include <array>
  16. using namespace std::string_view_literals;
  17. namespace NKikimr {
  18. namespace NMiniKQL {
  19. namespace {
  20. struct TDataFunctionFlags {
  21. enum {
  22. HasBooleanResult = 0x01,
  23. RequiresBooleanArgs = 0x02,
  24. HasOptionalResult = 0x04,
  25. AllowOptionalArgs = 0x08,
  26. HasUi32Result = 0x10,
  27. RequiresCompare = 0x20,
  28. HasStringResult = 0x40,
  29. RequiresStringArgs = 0x80,
  30. RequiresHash = 0x100,
  31. RequiresEquals = 0x200,
  32. AllowNull = 0x400,
  33. CommonOptionalResult = 0x800,
  34. SupportsTuple = 0x1000,
  35. SameOptionalArgs = 0x2000,
  36. Default = 0x00
  37. };
  38. };
  39. #define MKQL_BAD_TYPE_VISIT(NodeType, ScriptName) \
  40. void Visit(NodeType& node) override { \
  41. Y_UNUSED(node); \
  42. MKQL_ENSURE(false, "Can't convert " #NodeType " to " ScriptName " object"); \
  43. }
  44. class TPythonTypeChecker : public TExploringNodeVisitor {
  45. using TExploringNodeVisitor::Visit;
  46. MKQL_BAD_TYPE_VISIT(TAnyType, "Python");
  47. };
  48. class TLuaTypeChecker : public TExploringNodeVisitor {
  49. using TExploringNodeVisitor::Visit;
  50. MKQL_BAD_TYPE_VISIT(TVoidType, "Lua");
  51. MKQL_BAD_TYPE_VISIT(TAnyType, "Lua");
  52. MKQL_BAD_TYPE_VISIT(TVariantType, "Lua");
  53. };
  54. class TJavascriptTypeChecker : public TExploringNodeVisitor {
  55. using TExploringNodeVisitor::Visit;
  56. MKQL_BAD_TYPE_VISIT(TAnyType, "Javascript");
  57. };
  58. #undef MKQL_BAD_TYPE_VISIT
  59. void EnsureScriptSpecificTypes(
  60. EScriptType scriptType,
  61. TCallableType* funcType,
  62. const TTypeEnvironment& env)
  63. {
  64. switch (scriptType) {
  65. case EScriptType::Lua:
  66. return TLuaTypeChecker().Walk(funcType, env);
  67. case EScriptType::Python:
  68. case EScriptType::Python2:
  69. case EScriptType::Python3:
  70. case EScriptType::ArcPython:
  71. case EScriptType::ArcPython2:
  72. case EScriptType::ArcPython3:
  73. case EScriptType::CustomPython:
  74. case EScriptType::CustomPython2:
  75. case EScriptType::CustomPython3:
  76. case EScriptType::SystemPython2:
  77. case EScriptType::SystemPython3:
  78. case EScriptType::SystemPython3_8:
  79. case EScriptType::SystemPython3_9:
  80. case EScriptType::SystemPython3_10:
  81. case EScriptType::SystemPython3_11:
  82. case EScriptType::SystemPython3_12:
  83. case EScriptType::SystemPython3_13:
  84. return TPythonTypeChecker().Walk(funcType, env);
  85. case EScriptType::Javascript:
  86. return TJavascriptTypeChecker().Walk(funcType, env);
  87. default:
  88. MKQL_ENSURE(false, "Unknown script type " << static_cast<ui32>(scriptType));
  89. }
  90. }
  91. ui32 GetNumericSchemeTypeLevel(NUdf::TDataTypeId typeId) {
  92. switch (typeId) {
  93. case NUdf::TDataType<ui8>::Id:
  94. return 0;
  95. case NUdf::TDataType<i8>::Id:
  96. return 1;
  97. case NUdf::TDataType<ui16>::Id:
  98. return 2;
  99. case NUdf::TDataType<i16>::Id:
  100. return 3;
  101. case NUdf::TDataType<ui32>::Id:
  102. return 4;
  103. case NUdf::TDataType<i32>::Id:
  104. return 5;
  105. case NUdf::TDataType<ui64>::Id:
  106. return 6;
  107. case NUdf::TDataType<i64>::Id:
  108. return 7;
  109. case NUdf::TDataType<float>::Id:
  110. return 8;
  111. case NUdf::TDataType<double>::Id:
  112. return 9;
  113. default:
  114. ythrow yexception() << "Unknown numeric type: " << typeId;
  115. }
  116. }
  117. NUdf::TDataTypeId GetNumericSchemeTypeByLevel(ui32 level) {
  118. switch (level) {
  119. case 0:
  120. return NUdf::TDataType<ui8>::Id;
  121. case 1:
  122. return NUdf::TDataType<i8>::Id;
  123. case 2:
  124. return NUdf::TDataType<ui16>::Id;
  125. case 3:
  126. return NUdf::TDataType<i16>::Id;
  127. case 4:
  128. return NUdf::TDataType<ui32>::Id;
  129. case 5:
  130. return NUdf::TDataType<i32>::Id;
  131. case 6:
  132. return NUdf::TDataType<ui64>::Id;
  133. case 7:
  134. return NUdf::TDataType<i64>::Id;
  135. case 8:
  136. return NUdf::TDataType<float>::Id;
  137. case 9:
  138. return NUdf::TDataType<double>::Id;
  139. default:
  140. ythrow yexception() << "Unknown numeric level: " << level;
  141. }
  142. }
  143. NUdf::TDataTypeId MakeNumericDataSuperType(NUdf::TDataTypeId typeId1, NUdf::TDataTypeId typeId2) {
  144. return typeId1 == typeId2 ? typeId1 :
  145. GetNumericSchemeTypeByLevel(std::max(GetNumericSchemeTypeLevel(typeId1), GetNumericSchemeTypeLevel(typeId2)));
  146. }
  147. template<bool IsFilter>
  148. bool CollectOptionalElements(const TType* type, std::vector<std::string_view>& test, std::vector<std::pair<std::string_view, TType*>>& output) {
  149. const auto structType = AS_TYPE(TStructType, type);
  150. test.reserve(structType->GetMembersCount());
  151. output.reserve(structType->GetMembersCount());
  152. bool multiOptional = false;
  153. for (ui32 i = 0; i < structType->GetMembersCount(); ++i) {
  154. output.emplace_back(structType->GetMemberName(i), structType->GetMemberType(i));
  155. auto& memberType = output.back().second;
  156. if (memberType->IsOptional()) {
  157. test.emplace_back(output.back().first);
  158. if constexpr (IsFilter) {
  159. memberType = AS_TYPE(TOptionalType, memberType)->GetItemType();
  160. multiOptional = multiOptional || memberType->IsOptional();
  161. }
  162. }
  163. }
  164. return multiOptional;
  165. }
  166. template<bool IsFilter>
  167. bool CollectOptionalElements(const TType* type, std::vector<ui32>& test, std::vector<TType*>& output) {
  168. const auto typleType = AS_TYPE(TTupleType, type);
  169. test.reserve(typleType->GetElementsCount());
  170. output.reserve(typleType->GetElementsCount());
  171. bool multiOptional = false;
  172. for (ui32 i = 0; i < typleType->GetElementsCount(); ++i) {
  173. output.emplace_back(typleType->GetElementType(i));
  174. auto& elementType = output.back();
  175. if (elementType->IsOptional()) {
  176. test.emplace_back(i);
  177. if constexpr (IsFilter) {
  178. elementType = AS_TYPE(TOptionalType, elementType)->GetItemType();
  179. multiOptional = multiOptional || elementType->IsOptional();
  180. }
  181. }
  182. }
  183. return multiOptional;
  184. }
  185. bool ReduceOptionalElements(const TType* type, const TArrayRef<const std::string_view>& test, std::vector<std::pair<std::string_view, TType*>>& output) {
  186. const auto structType = AS_TYPE(TStructType, type);
  187. output.reserve(structType->GetMembersCount());
  188. for (ui32 i = 0U; i < structType->GetMembersCount(); ++i) {
  189. output.emplace_back(structType->GetMemberName(i), structType->GetMemberType(i));
  190. }
  191. bool multiOptional = false;
  192. for (const auto& member : test) {
  193. auto& memberType = output[structType->GetMemberIndex(member)].second;
  194. MKQL_ENSURE(memberType->IsOptional(), "Required optional column type");
  195. memberType = AS_TYPE(TOptionalType, memberType)->GetItemType();
  196. multiOptional = multiOptional || memberType->IsOptional();
  197. }
  198. return multiOptional;
  199. }
  200. bool ReduceOptionalElements(const TType* type, const TArrayRef<const ui32>& test, std::vector<TType*>& output) {
  201. const auto typleType = AS_TYPE(TTupleType, type);
  202. output.reserve(typleType->GetElementsCount());
  203. for (ui32 i = 0U; i < typleType->GetElementsCount(); ++i) {
  204. output.emplace_back(typleType->GetElementType(i));
  205. }
  206. bool multiOptional = false;
  207. for (const auto& member : test) {
  208. auto& memberType = output[member];
  209. MKQL_ENSURE(memberType->IsOptional(), "Required optional column type");
  210. memberType = AS_TYPE(TOptionalType, memberType)->GetItemType();
  211. multiOptional = multiOptional || memberType->IsOptional();
  212. }
  213. return multiOptional;
  214. }
  215. static std::vector<TType*> ValidateBlockItems(const TArrayRef<TType* const>& wideComponents, bool unwrap) {
  216. MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column");
  217. std::vector<TType*> items;
  218. items.reserve(wideComponents.size());
  219. // XXX: Declare these variables outside the loop body to use for the last
  220. // item (i.e. block length column) in the assertions below.
  221. bool isScalar;
  222. TType* itemType;
  223. for (const auto& wideComponent : wideComponents) {
  224. auto blockType = AS_TYPE(TBlockType, wideComponent);
  225. isScalar = blockType->GetShape() == TBlockType::EShape::Scalar;
  226. itemType = blockType->GetItemType();
  227. items.push_back(unwrap ? itemType : blockType);
  228. }
  229. MKQL_ENSURE(isScalar, "Last column should be scalar");
  230. MKQL_ENSURE(AS_TYPE(TDataType, itemType)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64");
  231. return items;
  232. }
  233. } // namespace
  234. std::string_view ScriptTypeAsStr(EScriptType type) {
  235. switch (type) {
  236. #define MKQL_SCRIPT_TYPE_CASE(name, value, ...) \
  237. case EScriptType::name: return std::string_view(#name);
  238. MKQL_SCRIPT_TYPES(MKQL_SCRIPT_TYPE_CASE)
  239. #undef MKQL_SCRIPT_TYPE_CASE
  240. } // switch
  241. return std::string_view("Unknown");
  242. }
  243. EScriptType ScriptTypeFromStr(std::string_view str) {
  244. TString lowerStr = TString(str);
  245. lowerStr.to_lower();
  246. #define MKQL_SCRIPT_TYPE_FROM_STR(name, value, lowerName, allowSuffix) \
  247. if ((allowSuffix && lowerStr.StartsWith(#lowerName)) || lowerStr == #lowerName) return EScriptType::name;
  248. MKQL_SCRIPT_TYPES(MKQL_SCRIPT_TYPE_FROM_STR)
  249. #undef MKQL_SCRIPT_TYPE_FROM_STR
  250. return EScriptType::Unknown;
  251. }
  252. bool IsCustomPython(EScriptType type) {
  253. return type == EScriptType::CustomPython ||
  254. type == EScriptType::CustomPython2 ||
  255. type == EScriptType::CustomPython3;
  256. }
  257. bool IsSystemPython(EScriptType type) {
  258. return type == EScriptType::SystemPython2
  259. || type == EScriptType::SystemPython3
  260. || type == EScriptType::SystemPython3_8
  261. || type == EScriptType::SystemPython3_9
  262. || type == EScriptType::SystemPython3_10
  263. || type == EScriptType::SystemPython3_11
  264. || type == EScriptType::SystemPython3_12
  265. || type == EScriptType::SystemPython3_13
  266. || type == EScriptType::Python
  267. || type == EScriptType::Python2;
  268. }
  269. EScriptType CanonizeScriptType(EScriptType type) {
  270. if (type == EScriptType::Python) {
  271. return EScriptType::Python2;
  272. }
  273. if (type == EScriptType::ArcPython) {
  274. return EScriptType::ArcPython2;
  275. }
  276. return type;
  277. }
  278. void EnsureDataOrOptionalOfData(TRuntimeNode node) {
  279. MKQL_ENSURE(node.GetStaticType()->IsData() ||
  280. node.GetStaticType()->IsOptional() && AS_TYPE(TOptionalType, node.GetStaticType())
  281. ->GetItemType()->IsData(), "Expected data or optional of data");
  282. }
  283. std::vector<TType*> ValidateBlockType(const TType* type, bool unwrap) {
  284. const auto wideComponents = AS_TYPE(TMultiType, type)->GetElements();
  285. return ValidateBlockItems(wideComponents, unwrap);
  286. }
  287. std::vector<TType*> ValidateBlockStreamType(const TType* streamType, bool unwrap) {
  288. const auto wideComponents = GetWideComponents(AS_TYPE(TStreamType, streamType));
  289. return ValidateBlockItems(wideComponents, unwrap);
  290. }
  291. std::vector<TType*> ValidateBlockFlowType(const TType* flowType, bool unwrap) {
  292. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flowType));
  293. return ValidateBlockItems(wideComponents, unwrap);
  294. }
  295. TProgramBuilder::TProgramBuilder(const TTypeEnvironment& env, const IFunctionRegistry& functionRegistry, bool voidWithEffects)
  296. : TTypeBuilder(env)
  297. , FunctionRegistry(functionRegistry)
  298. , VoidWithEffects(voidWithEffects)
  299. {}
  300. const TTypeEnvironment& TProgramBuilder::GetTypeEnvironment() const {
  301. return Env;
  302. }
  303. const IFunctionRegistry& TProgramBuilder::GetFunctionRegistry() const {
  304. return FunctionRegistry;
  305. }
  306. TType* TProgramBuilder::ChooseCommonType(TType* type1, TType* type2) {
  307. bool isOptional1, isOptional2;
  308. const auto data1 = UnpackOptionalData(type1, isOptional1);
  309. const auto data2 = UnpackOptionalData(type2, isOptional2);
  310. if (data1->IsSameType(*data2)) {
  311. return isOptional1 ? type1 : type2;
  312. }
  313. MKQL_ENSURE(!
  314. ((NUdf::GetDataTypeInfo(*data1->GetDataSlot()).Features | NUdf::GetDataTypeInfo(*data2->GetDataSlot()).Features) & (NUdf::EDataTypeFeatures::DateType | NUdf::EDataTypeFeatures::TzDateType)),
  315. "Not same date types: " << *type1 << " and " << *type2
  316. );
  317. const auto data = NewDataType(MakeNumericDataSuperType(data1->GetSchemeType(), data2->GetSchemeType()));
  318. return isOptional1 || isOptional2 ? NewOptionalType(data) : data;
  319. }
  320. TType* TProgramBuilder::BuildArithmeticCommonType(TType* type1, TType* type2) {
  321. bool isOptional1, isOptional2;
  322. const auto data1 = UnpackOptionalData(type1, isOptional1);
  323. const auto data2 = UnpackOptionalData(type2, isOptional2);
  324. const auto features1 = NUdf::GetDataTypeInfo(*data1->GetDataSlot()).Features;
  325. const auto features2 = NUdf::GetDataTypeInfo(*data2->GetDataSlot()).Features;
  326. const bool isOptional = isOptional1 || isOptional2;
  327. if (features1 & features2 & NUdf::EDataTypeFeatures::TimeIntervalType) {
  328. return NewOptionalType(features1 & NUdf::EDataTypeFeatures::BigDateType ? data1 : data2);
  329. } else if (features1 & NUdf::EDataTypeFeatures::TimeIntervalType) {
  330. return NewOptionalType(features2 & NUdf::EDataTypeFeatures::IntegralType ? data1 : data2);
  331. } else if (features2 & NUdf::EDataTypeFeatures::TimeIntervalType) {
  332. return NewOptionalType(features1 & NUdf::EDataTypeFeatures::IntegralType ? data2 : data1);
  333. } else if (
  334. features1 & (NUdf::EDataTypeFeatures::DateType | NUdf::EDataTypeFeatures::TzDateType) &&
  335. features2 & (NUdf::EDataTypeFeatures::DateType | NUdf::EDataTypeFeatures::TzDateType)
  336. ) {
  337. const auto used = ((features1 | features2) & NUdf::EDataTypeFeatures::BigDateType)
  338. ? NewDataType(NUdf::EDataSlot::Interval64)
  339. : NewDataType(NUdf::EDataSlot::Interval);
  340. return isOptional ? NewOptionalType(used) : used;
  341. } else if (data1->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  342. MKQL_ENSURE(data1->IsSameType(*data2), "Must be same type.");
  343. return isOptional ? NewOptionalType(data1) : data2;
  344. }
  345. const auto data = NewDataType(MakeNumericDataSuperType(data1->GetSchemeType(), data2->GetSchemeType()));
  346. return isOptional ? NewOptionalType(data) : data;
  347. }
  348. TRuntimeNode TProgramBuilder::Arg(TType* type) const {
  349. TCallableBuilder builder(Env, __func__, type, true);
  350. return TRuntimeNode(builder.Build(), false);
  351. }
  352. TRuntimeNode TProgramBuilder::WideFlowArg(TType* type) const {
  353. if constexpr (RuntimeVersion < 18U) {
  354. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  355. }
  356. TCallableBuilder builder(Env, __func__, type, true);
  357. return TRuntimeNode(builder.Build(), false);
  358. }
  359. TRuntimeNode TProgramBuilder::Member(TRuntimeNode structObj, const std::string_view& memberName) {
  360. bool isOptional;
  361. const auto type = AS_TYPE(TStructType, UnpackOptional(structObj.GetStaticType(), isOptional));
  362. const auto memberIndex = type->GetMemberIndex(memberName);
  363. auto memberType = type->GetMemberType(memberIndex);
  364. if (isOptional && !memberType->IsOptional() && !memberType->IsNull() && !memberType->IsPg()) {
  365. memberType = NewOptionalType(memberType);
  366. }
  367. TCallableBuilder callableBuilder(Env, __func__, memberType);
  368. callableBuilder.Add(structObj);
  369. callableBuilder.Add(NewDataLiteral<ui32>(memberIndex));
  370. return TRuntimeNode(callableBuilder.Build(), false);
  371. }
  372. TRuntimeNode TProgramBuilder::Element(TRuntimeNode structObj, const std::string_view& memberName) {
  373. return Member(structObj, memberName);
  374. }
  375. TRuntimeNode TProgramBuilder::AddMember(TRuntimeNode structObj, const std::string_view& memberName, TRuntimeNode memberValue) {
  376. auto oldType = structObj.GetStaticType();
  377. MKQL_ENSURE(oldType->IsStruct(), "Expected struct");
  378. const auto& oldTypeDetailed = static_cast<const TStructType&>(*oldType);
  379. TStructTypeBuilder newTypeBuilder(Env);
  380. newTypeBuilder.Reserve(oldTypeDetailed.GetMembersCount() + 1);
  381. for (ui32 i = 0, e = oldTypeDetailed.GetMembersCount(); i < e; ++i) {
  382. newTypeBuilder.Add(oldTypeDetailed.GetMemberName(i), oldTypeDetailed.GetMemberType(i));
  383. }
  384. newTypeBuilder.Add(memberName, memberValue.GetStaticType());
  385. auto newType = newTypeBuilder.Build();
  386. for (ui32 i = 0, e = newType->GetMembersCount(); i < e; ++i) {
  387. if (newType->GetMemberName(i) == memberName) {
  388. // insert at position i in the struct
  389. TCallableBuilder callableBuilder(Env, __func__, newType);
  390. callableBuilder.Add(structObj);
  391. callableBuilder.Add(memberValue);
  392. callableBuilder.Add(NewDataLiteral<ui32>(i));
  393. return TRuntimeNode(callableBuilder.Build(), false);
  394. }
  395. }
  396. Y_ABORT();
  397. }
  398. TRuntimeNode TProgramBuilder::RemoveMember(TRuntimeNode structObj, const std::string_view& memberName, bool forced) {
  399. auto oldType = structObj.GetStaticType();
  400. MKQL_ENSURE(oldType->IsStruct(), "Expected struct");
  401. const auto& oldTypeDetailed = static_cast<const TStructType&>(*oldType);
  402. MKQL_ENSURE(oldTypeDetailed.GetMembersCount() > 0, "Expected non-empty struct");
  403. TStructTypeBuilder newTypeBuilder(Env);
  404. newTypeBuilder.Reserve(oldTypeDetailed.GetMembersCount() - 1);
  405. std::optional<ui32> memberIndex;
  406. for (ui32 i = 0, e = oldTypeDetailed.GetMembersCount(); i < e; ++i) {
  407. if (oldTypeDetailed.GetMemberName(i) != memberName) {
  408. newTypeBuilder.Add(oldTypeDetailed.GetMemberName(i), oldTypeDetailed.GetMemberType(i));
  409. }
  410. else {
  411. memberIndex = i;
  412. }
  413. }
  414. if (!memberIndex && forced) {
  415. return structObj;
  416. }
  417. MKQL_ENSURE(memberIndex, "Unknown member name: " << memberName);
  418. // remove at position i in the struct
  419. auto newType = newTypeBuilder.Build();
  420. TCallableBuilder callableBuilder(Env, __func__, newType);
  421. callableBuilder.Add(structObj);
  422. callableBuilder.Add(NewDataLiteral<ui32>(*memberIndex));
  423. return TRuntimeNode(callableBuilder.Build(), false);
  424. }
  425. TRuntimeNode TProgramBuilder::Zip(const TArrayRef<const TRuntimeNode>& lists) {
  426. if (lists.empty()) {
  427. return NewEmptyList(Env.GetEmptyTupleLazy()->GetGenericType());
  428. }
  429. std::vector<TType*> tupleTypes;
  430. tupleTypes.reserve(lists.size());
  431. for (auto& list : lists) {
  432. if (list.GetStaticType()->IsEmptyList()) {
  433. tupleTypes.push_back(Env.GetTypeOfVoidLazy());
  434. continue;
  435. }
  436. AS_TYPE(TListType, list.GetStaticType());
  437. auto itemType = static_cast<const TListType&>(*list.GetStaticType()).GetItemType();
  438. tupleTypes.push_back(itemType);
  439. }
  440. auto returnType = TListType::Create(TTupleType::Create(tupleTypes.size(), tupleTypes.data(), Env), Env);
  441. TCallableBuilder callableBuilder(Env, __func__, returnType);
  442. for (auto& list : lists) {
  443. callableBuilder.Add(list);
  444. }
  445. return TRuntimeNode(callableBuilder.Build(), false);
  446. }
  447. TRuntimeNode TProgramBuilder::ZipAll(const TArrayRef<const TRuntimeNode>& lists) {
  448. if (lists.empty()) {
  449. return NewEmptyList(Env.GetEmptyTupleLazy()->GetGenericType());
  450. }
  451. std::vector<TType*> tupleTypes;
  452. tupleTypes.reserve(lists.size());
  453. for (auto& list : lists) {
  454. if (list.GetStaticType()->IsEmptyList()) {
  455. tupleTypes.push_back(TOptionalType::Create(Env.GetTypeOfVoidLazy(), Env));
  456. continue;
  457. }
  458. AS_TYPE(TListType, list.GetStaticType());
  459. auto itemType = static_cast<const TListType&>(*list.GetStaticType()).GetItemType();
  460. tupleTypes.push_back(TOptionalType::Create(itemType, Env));
  461. }
  462. auto returnType = TListType::Create(TTupleType::Create(tupleTypes.size(), tupleTypes.data(), Env), Env);
  463. TCallableBuilder callableBuilder(Env, __func__, returnType);
  464. for (auto& list : lists) {
  465. callableBuilder.Add(list);
  466. }
  467. return TRuntimeNode(callableBuilder.Build(), false);
  468. }
  469. TRuntimeNode TProgramBuilder::Enumerate(TRuntimeNode list, TRuntimeNode start, TRuntimeNode step) {
  470. const auto itemType = AS_TYPE(TListType, list.GetStaticType())->GetItemType();
  471. ThrowIfListOfVoid(itemType);
  472. MKQL_ENSURE(AS_TYPE(TDataType, start)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64 as start");
  473. MKQL_ENSURE(AS_TYPE(TDataType, step)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64 as step");
  474. const std::array<TType*, 2U> tupleTypes = {{ NewDataType(NUdf::EDataSlot::Uint64), itemType }};
  475. const auto returnType = NewListType(NewTupleType(tupleTypes));
  476. TCallableBuilder callableBuilder(Env, __func__, returnType);
  477. callableBuilder.Add(list);
  478. callableBuilder.Add(start);
  479. callableBuilder.Add(step);
  480. return TRuntimeNode(callableBuilder.Build(), false);
  481. }
  482. TRuntimeNode TProgramBuilder::Enumerate(TRuntimeNode list) {
  483. return TProgramBuilder::Enumerate(list, NewDataLiteral<ui64>(0), NewDataLiteral<ui64>(1));
  484. }
  485. TRuntimeNode TProgramBuilder::Fold(TRuntimeNode list, TRuntimeNode state, const TBinaryLambda& handler) {
  486. const auto itemType = AS_TYPE(TListType, list.GetStaticType())->GetItemType();
  487. ThrowIfListOfVoid(itemType);
  488. const auto stateNodeArg = Arg(state.GetStaticType());
  489. const auto itemArg = Arg(itemType);
  490. const auto newState = handler(itemArg, stateNodeArg);
  491. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*state.GetStaticType()), "State type is changed by the handler");
  492. TCallableBuilder callableBuilder(Env, __func__, state.GetStaticType());
  493. callableBuilder.Add(list);
  494. callableBuilder.Add(state);
  495. callableBuilder.Add(itemArg);
  496. callableBuilder.Add(stateNodeArg);
  497. callableBuilder.Add(newState);
  498. return TRuntimeNode(callableBuilder.Build(), false);
  499. }
  500. TRuntimeNode TProgramBuilder::Fold1(TRuntimeNode list, const TUnaryLambda& init, const TBinaryLambda& handler) {
  501. const auto itemType = AS_TYPE(TListType, list.GetStaticType())->GetItemType();
  502. ThrowIfListOfVoid(itemType);
  503. const auto itemArg = Arg(itemType);
  504. const auto initState = init(itemArg);
  505. const auto stateNodeArg = Arg(initState.GetStaticType());
  506. const auto newState = handler(itemArg, stateNodeArg);
  507. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*initState.GetStaticType()), "State type is changed by the handler");
  508. TCallableBuilder callableBuilder(Env, __func__, NewOptionalType(newState.GetStaticType()));
  509. callableBuilder.Add(list);
  510. callableBuilder.Add(itemArg);
  511. callableBuilder.Add(initState);
  512. callableBuilder.Add(stateNodeArg);
  513. callableBuilder.Add(newState);
  514. return TRuntimeNode(callableBuilder.Build(), false);
  515. }
  516. TRuntimeNode TProgramBuilder::Reduce(TRuntimeNode list, TRuntimeNode state1,
  517. const TBinaryLambda& handler1,
  518. const TUnaryLambda& handler2,
  519. TRuntimeNode state3,
  520. const TBinaryLambda& handler3) {
  521. const auto listType = list.GetStaticType();
  522. MKQL_ENSURE(listType->IsList() || listType->IsStream(), "Expected list or stream");
  523. const auto itemType = listType->IsList()?
  524. static_cast<const TListType&>(*listType).GetItemType():
  525. static_cast<const TStreamType&>(*listType).GetItemType();
  526. ThrowIfListOfVoid(itemType);
  527. const auto state1NodeArg = Arg(state1.GetStaticType());
  528. const auto state3NodeArg = Arg(state3.GetStaticType());
  529. const auto itemArg = Arg(itemType);
  530. const auto newState1 = handler1(itemArg, state1NodeArg);
  531. MKQL_ENSURE(newState1.GetStaticType()->IsSameType(*state1.GetStaticType()), "State 1 type is changed by the handler");
  532. const auto newState2 = handler2(state1NodeArg);
  533. TRuntimeNode itemState2Arg = Arg(newState2.GetStaticType());
  534. const auto newState3 = handler3(itemState2Arg, state3NodeArg);
  535. MKQL_ENSURE(newState3.GetStaticType()->IsSameType(*state3.GetStaticType()), "State 3 type is changed by the handler");
  536. TCallableBuilder callableBuilder(Env, __func__, newState3.GetStaticType());
  537. callableBuilder.Add(list);
  538. callableBuilder.Add(state1);
  539. callableBuilder.Add(state3);
  540. callableBuilder.Add(itemArg);
  541. callableBuilder.Add(state1NodeArg);
  542. callableBuilder.Add(newState1);
  543. callableBuilder.Add(newState2);
  544. callableBuilder.Add(itemState2Arg);
  545. callableBuilder.Add(state3NodeArg);
  546. callableBuilder.Add(newState3);
  547. return TRuntimeNode(callableBuilder.Build(), false);
  548. }
  549. TRuntimeNode TProgramBuilder::Condense(TRuntimeNode flow, TRuntimeNode state,
  550. const TBinaryLambda& switcher,
  551. const TBinaryLambda& handler, bool useCtx) {
  552. const auto flowType = flow.GetStaticType();
  553. if (flowType->IsList()) {
  554. // TODO: Native implementation for list.
  555. return Collect(Condense(ToFlow(flow), state, switcher, handler));
  556. }
  557. MKQL_ENSURE(flowType->IsFlow() || flowType->IsStream(), "Expected flow or stream.");
  558. const auto itemType = flowType->IsFlow() ?
  559. static_cast<const TFlowType&>(*flowType).GetItemType():
  560. static_cast<const TStreamType&>(*flowType).GetItemType();
  561. const auto itemArg = Arg(itemType);
  562. const auto stateArg = Arg(state.GetStaticType());
  563. const auto outSwitch = switcher(itemArg, stateArg);
  564. const auto newState = handler(itemArg, stateArg);
  565. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*state.GetStaticType()), "State type is changed by the handler");
  566. TCallableBuilder callableBuilder(Env, __func__, flowType->IsFlow() ? NewFlowType(state.GetStaticType()) : NewStreamType(state.GetStaticType()));
  567. callableBuilder.Add(flow);
  568. callableBuilder.Add(state);
  569. callableBuilder.Add(itemArg);
  570. callableBuilder.Add(stateArg);
  571. callableBuilder.Add(outSwitch);
  572. callableBuilder.Add(newState);
  573. if (useCtx) {
  574. MKQL_ENSURE(RuntimeVersion >= 30U, "Too old runtime version");
  575. callableBuilder.Add(NewDataLiteral<bool>(useCtx));
  576. }
  577. return TRuntimeNode(callableBuilder.Build(), false);
  578. }
  579. TRuntimeNode TProgramBuilder::Condense1(TRuntimeNode flow, const TUnaryLambda& init,
  580. const TBinaryLambda& switcher,
  581. const TBinaryLambda& handler, bool useCtx) {
  582. const auto flowType = flow.GetStaticType();
  583. if (flowType->IsList()) {
  584. // TODO: Native implementation for list.
  585. return Collect(Condense1(ToFlow(flow), init, switcher, handler));
  586. }
  587. MKQL_ENSURE(flowType->IsFlow() || flowType->IsStream(), "Expected flow or stream.");
  588. const auto itemType = flowType->IsFlow() ?
  589. static_cast<const TFlowType&>(*flowType).GetItemType():
  590. static_cast<const TStreamType&>(*flowType).GetItemType();
  591. const auto itemArg = Arg(itemType);
  592. const auto initState = init(itemArg);
  593. const auto stateArg = Arg(initState.GetStaticType());
  594. const auto outSwitch = switcher(itemArg, stateArg);
  595. const auto newState = handler(itemArg, stateArg);
  596. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*initState.GetStaticType()), "State type is changed by the handler");
  597. TCallableBuilder callableBuilder(Env, __func__, flowType->IsFlow() ? NewFlowType(newState.GetStaticType()) : NewStreamType(newState.GetStaticType()));
  598. callableBuilder.Add(flow);
  599. callableBuilder.Add(itemArg);
  600. callableBuilder.Add(initState);
  601. callableBuilder.Add(stateArg);
  602. callableBuilder.Add(outSwitch);
  603. callableBuilder.Add(newState);
  604. if (useCtx) {
  605. MKQL_ENSURE(RuntimeVersion >= 30U, "Too old runtime version");
  606. callableBuilder.Add(NewDataLiteral<bool>(useCtx));
  607. }
  608. return TRuntimeNode(callableBuilder.Build(), false);
  609. }
  610. TRuntimeNode TProgramBuilder::Squeeze(TRuntimeNode stream, TRuntimeNode state,
  611. const TBinaryLambda& handler,
  612. const TUnaryLambda& save,
  613. const TUnaryLambda& load) {
  614. const auto streamType = stream.GetStaticType();
  615. MKQL_ENSURE(streamType->IsStream(), "Expected stream");
  616. const auto& streamDetailedType = static_cast<const TStreamType&>(*streamType);
  617. const auto itemType = streamDetailedType.GetItemType();
  618. ThrowIfListOfVoid(itemType);
  619. const auto stateNodeArg = Arg(state.GetStaticType());
  620. const auto itemArg = Arg(itemType);
  621. const auto newState = handler(itemArg, stateNodeArg);
  622. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*state.GetStaticType()), "State type is changed by the handler");
  623. TRuntimeNode saveArg, outSave, loadArg, outLoad;
  624. if (save && load) {
  625. outSave = save(saveArg = Arg(state.GetStaticType()));
  626. outLoad = load(loadArg = Arg(outSave.GetStaticType()));
  627. MKQL_ENSURE(outLoad.GetStaticType()->IsSameType(*state.GetStaticType()), "Loaded type is changed by the load handler");
  628. } else {
  629. saveArg = outSave = loadArg = outLoad = NewVoid();
  630. }
  631. TCallableBuilder callableBuilder(Env, __func__, TStreamType::Create(state.GetStaticType(), Env));
  632. callableBuilder.Add(stream);
  633. callableBuilder.Add(state);
  634. callableBuilder.Add(itemArg);
  635. callableBuilder.Add(stateNodeArg);
  636. callableBuilder.Add(newState);
  637. callableBuilder.Add(saveArg);
  638. callableBuilder.Add(outSave);
  639. callableBuilder.Add(loadArg);
  640. callableBuilder.Add(outLoad);
  641. return TRuntimeNode(callableBuilder.Build(), false);
  642. }
  643. TRuntimeNode TProgramBuilder::Squeeze1(TRuntimeNode stream, const TUnaryLambda& init,
  644. const TBinaryLambda& handler,
  645. const TUnaryLambda& save,
  646. const TUnaryLambda& load) {
  647. const auto streamType = stream.GetStaticType();
  648. MKQL_ENSURE(streamType->IsStream(), "Expected stream");
  649. const auto& streamDetailedType = static_cast<const TStreamType&>(*streamType);
  650. const auto itemType = streamDetailedType.GetItemType();
  651. ThrowIfListOfVoid(itemType);
  652. const auto itemArg = Arg(itemType);
  653. const auto initState = init(itemArg);
  654. const auto stateNodeArg = Arg(initState.GetStaticType());
  655. const auto newState = handler(itemArg, stateNodeArg);
  656. MKQL_ENSURE(newState.GetStaticType()->IsSameType(*initState.GetStaticType()), "State type is changed by the handler");
  657. TRuntimeNode saveArg, outSave, loadArg, outLoad;
  658. if (save && load) {
  659. outSave = save(saveArg = Arg(initState.GetStaticType()));
  660. outLoad = load(loadArg = Arg(outSave.GetStaticType()));
  661. MKQL_ENSURE(outLoad.GetStaticType()->IsSameType(*initState.GetStaticType()), "Loaded type is changed by the load handler");
  662. } else {
  663. saveArg = outSave = loadArg = outLoad = NewVoid();
  664. }
  665. TCallableBuilder callableBuilder(Env, __func__, NewStreamType(newState.GetStaticType()));
  666. callableBuilder.Add(stream);
  667. callableBuilder.Add(itemArg);
  668. callableBuilder.Add(initState);
  669. callableBuilder.Add(stateNodeArg);
  670. callableBuilder.Add(newState);
  671. callableBuilder.Add(saveArg);
  672. callableBuilder.Add(outSave);
  673. callableBuilder.Add(loadArg);
  674. callableBuilder.Add(outLoad);
  675. return TRuntimeNode(callableBuilder.Build(), false);
  676. }
  677. TRuntimeNode TProgramBuilder::Discard(TRuntimeNode stream) {
  678. const auto streamType = stream.GetStaticType();
  679. MKQL_ENSURE(streamType->IsStream() || streamType->IsFlow(), "Expected stream or flow.");
  680. TCallableBuilder callableBuilder(Env, __func__, streamType);
  681. callableBuilder.Add(stream);
  682. return TRuntimeNode(callableBuilder.Build(), false);
  683. }
  684. TRuntimeNode TProgramBuilder::Map(TRuntimeNode list, const TUnaryLambda& handler) {
  685. return BuildMap(__func__, list, handler);
  686. }
  687. TRuntimeNode TProgramBuilder::OrderedMap(TRuntimeNode list, const TUnaryLambda& handler) {
  688. return BuildMap(__func__, list, handler);
  689. }
  690. TRuntimeNode TProgramBuilder::MapNext(TRuntimeNode list, const TBinaryLambda& handler) {
  691. const auto listType = list.GetStaticType();
  692. MKQL_ENSURE(listType->IsStream() || listType->IsFlow(), "Expected stream or flow");
  693. const auto itemType = listType->IsFlow() ?
  694. AS_TYPE(TFlowType, listType)->GetItemType():
  695. AS_TYPE(TStreamType, listType)->GetItemType();
  696. ThrowIfListOfVoid(itemType);
  697. TType* nextItemType = TOptionalType::Create(itemType, Env);
  698. const auto itemArg = Arg(itemType);
  699. const auto nextItemArg = Arg(nextItemType);
  700. const auto newItem = handler(itemArg, nextItemArg);
  701. const auto resultListType = listType->IsFlow() ?
  702. (TType*)TFlowType::Create(newItem.GetStaticType(), Env):
  703. (TType*)TStreamType::Create(newItem.GetStaticType(), Env);
  704. TCallableBuilder callableBuilder(Env, __func__, resultListType);
  705. callableBuilder.Add(list);
  706. callableBuilder.Add(itemArg);
  707. callableBuilder.Add(nextItemArg);
  708. callableBuilder.Add(newItem);
  709. return TRuntimeNode(callableBuilder.Build(), false);
  710. }
  711. template <bool Ordered>
  712. TRuntimeNode TProgramBuilder::BuildExtract(TRuntimeNode list, const std::string_view& name) {
  713. const auto listType = list.GetStaticType();
  714. MKQL_ENSURE(listType->IsList() || listType->IsOptional(), "Expected list or optional.");
  715. const auto itemType = listType->IsList() ?
  716. AS_TYPE(TListType, listType)->GetItemType():
  717. AS_TYPE(TOptionalType, listType)->GetItemType();
  718. const auto lambda = [&](TRuntimeNode item) {
  719. return itemType->IsStruct() ? Member(item, name) : Nth(item, ::FromString<ui32>(name));
  720. };
  721. return Ordered ? OrderedMap(list, lambda) : Map(list, lambda);
  722. }
  723. TRuntimeNode TProgramBuilder::Extract(TRuntimeNode list, const std::string_view& name) {
  724. return BuildExtract<false>(list, name);
  725. }
  726. TRuntimeNode TProgramBuilder::OrderedExtract(TRuntimeNode list, const std::string_view& name) {
  727. return BuildExtract<true>(list, name);
  728. }
  729. TRuntimeNode TProgramBuilder::ChainMap(TRuntimeNode list, TRuntimeNode state, const TBinaryLambda& handler) {
  730. return ChainMap(list, state, [&](TRuntimeNode item, TRuntimeNode state) -> TRuntimeNodePair {
  731. const auto result = handler(item, state);
  732. return {result, result};
  733. });
  734. }
  735. TRuntimeNode TProgramBuilder::ChainMap(TRuntimeNode list, TRuntimeNode state, const TBinarySplitLambda& handler) {
  736. const auto listType = list.GetStaticType();
  737. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream");
  738. const auto itemType = listType->IsFlow() ?
  739. AS_TYPE(TFlowType, listType)->GetItemType():
  740. listType->IsList() ?
  741. AS_TYPE(TListType, listType)->GetItemType():
  742. AS_TYPE(TStreamType, listType)->GetItemType();
  743. ThrowIfListOfVoid(itemType);
  744. const auto stateNodeArg = Arg(state.GetStaticType());
  745. const auto itemArg = Arg(itemType);
  746. const auto newItemAndState = handler(itemArg, stateNodeArg);
  747. MKQL_ENSURE(std::get<1U>(newItemAndState).GetStaticType()->IsSameType(*state.GetStaticType()), "State type is changed by the handler");
  748. const auto resultItemType = std::get<0U>(newItemAndState).GetStaticType();
  749. TType* resultListType = nullptr;
  750. if (listType->IsFlow()) {
  751. resultListType = TFlowType::Create(resultItemType, Env);
  752. } else if (listType->IsList()) {
  753. resultListType = TListType::Create(resultItemType, Env);
  754. } else if (listType->IsStream()) {
  755. resultListType = TStreamType::Create(resultItemType, Env);
  756. }
  757. TCallableBuilder callableBuilder(Env, __func__, resultListType);
  758. callableBuilder.Add(list);
  759. callableBuilder.Add(state);
  760. callableBuilder.Add(itemArg);
  761. callableBuilder.Add(stateNodeArg);
  762. callableBuilder.Add(std::get<0U>(newItemAndState));
  763. callableBuilder.Add(std::get<1U>(newItemAndState));
  764. return TRuntimeNode(callableBuilder.Build(), false);
  765. }
  766. TRuntimeNode TProgramBuilder::Chain1Map(TRuntimeNode list, const TUnaryLambda& init, const TBinaryLambda& handler) {
  767. return Chain1Map(list,
  768. [&](TRuntimeNode item) -> TRuntimeNodePair {
  769. const auto result = init(item);
  770. return {result, result};
  771. },
  772. [&](TRuntimeNode item, TRuntimeNode state) -> TRuntimeNodePair {
  773. const auto result = handler(item, state);
  774. return {result, result};
  775. }
  776. );
  777. }
  778. TRuntimeNode TProgramBuilder::Chain1Map(TRuntimeNode list, const TUnarySplitLambda& init, const TBinarySplitLambda& handler) {
  779. const auto listType = list.GetStaticType();
  780. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream");
  781. const auto itemType = listType->IsFlow() ?
  782. AS_TYPE(TFlowType, listType)->GetItemType():
  783. listType->IsList() ?
  784. AS_TYPE(TListType, listType)->GetItemType():
  785. AS_TYPE(TStreamType, listType)->GetItemType();
  786. ThrowIfListOfVoid(itemType);
  787. const auto itemArg = Arg(itemType);
  788. const auto initItemAndState = init(itemArg);
  789. const auto resultItemType = std::get<0U>(initItemAndState).GetStaticType();
  790. const auto stateType = std::get<1U>(initItemAndState).GetStaticType();;
  791. TType* resultListType = nullptr;
  792. if (listType->IsFlow()) {
  793. resultListType = TFlowType::Create(resultItemType, Env);
  794. } else if (listType->IsList()) {
  795. resultListType = TListType::Create(resultItemType, Env);
  796. } else if (listType->IsStream()) {
  797. resultListType = TStreamType::Create(resultItemType, Env);
  798. }
  799. const auto stateArg = Arg(stateType);
  800. const auto updateItemAndState = handler(itemArg, stateArg);
  801. MKQL_ENSURE(std::get<0U>(updateItemAndState).GetStaticType()->IsSameType(*resultItemType), "Item type is changed by the handler");
  802. MKQL_ENSURE(std::get<1U>(updateItemAndState).GetStaticType()->IsSameType(*stateType), "State type is changed by the handler");
  803. TCallableBuilder callableBuilder(Env, __func__, resultListType);
  804. callableBuilder.Add(list);
  805. callableBuilder.Add(itemArg);
  806. callableBuilder.Add(std::get<0U>(initItemAndState));
  807. callableBuilder.Add(std::get<1U>(initItemAndState));
  808. callableBuilder.Add(stateArg);
  809. callableBuilder.Add(std::get<0U>(updateItemAndState));
  810. callableBuilder.Add(std::get<1U>(updateItemAndState));
  811. return TRuntimeNode(callableBuilder.Build(), false);
  812. }
  813. TRuntimeNode TProgramBuilder::ToList(TRuntimeNode optional) {
  814. const auto optionalType = optional.GetStaticType();
  815. MKQL_ENSURE(optionalType->IsOptional(), "Expected optional");
  816. const auto& optionalDetailedType = static_cast<const TOptionalType&>(*optionalType);
  817. const auto itemType = optionalDetailedType.GetItemType();
  818. return IfPresent(optional, [&](TRuntimeNode item) { return AsList(item); }, NewEmptyList(itemType));
  819. }
  820. TRuntimeNode TProgramBuilder::Iterable(TZeroLambda lambda) {
  821. if constexpr (RuntimeVersion < 19U) {
  822. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  823. }
  824. const auto itemArg = Arg(NewNull().GetStaticType());
  825. auto lambdaRes = lambda();
  826. const auto resultType = NewListType(AS_TYPE(TStreamType, lambdaRes.GetStaticType())->GetItemType());
  827. TCallableBuilder callableBuilder(Env, __func__, resultType);
  828. callableBuilder.Add(lambdaRes);
  829. callableBuilder.Add(itemArg);
  830. return TRuntimeNode(callableBuilder.Build(), false);
  831. }
  832. TRuntimeNode TProgramBuilder::ToOptional(TRuntimeNode list) {
  833. return Head(list);
  834. }
  835. TRuntimeNode TProgramBuilder::Head(TRuntimeNode list) {
  836. const auto resultType = NewOptionalType(AS_TYPE(TListType, list.GetStaticType())->GetItemType());
  837. TCallableBuilder callableBuilder(Env, __func__, resultType);
  838. callableBuilder.Add(list);
  839. return TRuntimeNode(callableBuilder.Build(), false);
  840. }
  841. TRuntimeNode TProgramBuilder::Last(TRuntimeNode list) {
  842. const auto resultType = NewOptionalType(AS_TYPE(TListType, list.GetStaticType())->GetItemType());
  843. TCallableBuilder callableBuilder(Env, __func__, resultType);
  844. callableBuilder.Add(list);
  845. return TRuntimeNode(callableBuilder.Build(), false);
  846. }
  847. TRuntimeNode TProgramBuilder::Nanvl(TRuntimeNode data, TRuntimeNode dataIfNaN) {
  848. const std::array<TRuntimeNode, 2> args = {{ data, dataIfNaN }};
  849. return Invoke(__func__, BuildArithmeticCommonType(data.GetStaticType(), dataIfNaN.GetStaticType()), args);
  850. }
  851. TRuntimeNode TProgramBuilder::FlatMap(TRuntimeNode list, const TUnaryLambda& handler)
  852. {
  853. return BuildFlatMap(__func__, list, handler);
  854. }
  855. TRuntimeNode TProgramBuilder::OrderedFlatMap(TRuntimeNode list, const TUnaryLambda& handler)
  856. {
  857. return BuildFlatMap(__func__, list, handler);
  858. }
  859. TRuntimeNode TProgramBuilder::Filter(TRuntimeNode list, const TUnaryLambda& handler)
  860. {
  861. return BuildFilter(__func__, list, handler);
  862. }
  863. TRuntimeNode TProgramBuilder::Filter(TRuntimeNode list, TRuntimeNode limit, const TUnaryLambda& handler)
  864. {
  865. return BuildFilter(__func__, list, limit, handler);
  866. }
  867. TRuntimeNode TProgramBuilder::OrderedFilter(TRuntimeNode list, const TUnaryLambda& handler)
  868. {
  869. return BuildFilter(__func__, list, handler);
  870. }
  871. TRuntimeNode TProgramBuilder::OrderedFilter(TRuntimeNode list, TRuntimeNode limit, const TUnaryLambda& handler)
  872. {
  873. return BuildFilter(__func__, list, limit, handler);
  874. }
  875. TRuntimeNode TProgramBuilder::TakeWhile(TRuntimeNode list, const TUnaryLambda& handler)
  876. {
  877. return BuildFilter(__func__, list, handler);
  878. }
  879. TRuntimeNode TProgramBuilder::SkipWhile(TRuntimeNode list, const TUnaryLambda& handler)
  880. {
  881. return BuildFilter(__func__, list, handler);
  882. }
  883. TRuntimeNode TProgramBuilder::TakeWhileInclusive(TRuntimeNode list, const TUnaryLambda& handler)
  884. {
  885. return BuildFilter(__func__, list, handler);
  886. }
  887. TRuntimeNode TProgramBuilder::SkipWhileInclusive(TRuntimeNode list, const TUnaryLambda& handler)
  888. {
  889. return BuildFilter(__func__, list, handler);
  890. }
  891. TRuntimeNode TProgramBuilder::BuildListSort(const std::string_view& callableName, TRuntimeNode list, TRuntimeNode ascending,
  892. const TUnaryLambda& keyExtractor)
  893. {
  894. const auto listType = list.GetStaticType();
  895. MKQL_ENSURE(listType->IsList(), "Expected list.");
  896. const auto itemType = static_cast<const TListType&>(*listType).GetItemType();
  897. ThrowIfListOfVoid(itemType);
  898. const auto ascendingType = ascending.GetStaticType();
  899. const auto itemArg = Arg(itemType);
  900. auto key = keyExtractor(itemArg);
  901. if (ascendingType->IsTuple()) {
  902. const auto ascendingTuple = AS_TYPE(TTupleType, ascendingType);
  903. if (ascendingTuple->GetElementsCount() == 0) {
  904. return list;
  905. }
  906. if (ascendingTuple->GetElementsCount() == 1) {
  907. ascending = Nth(ascending, 0);
  908. key = Nth(key, 0);
  909. }
  910. }
  911. TCallableBuilder callableBuilder(Env, callableName, listType);
  912. callableBuilder.Add(list);
  913. callableBuilder.Add(itemArg);
  914. callableBuilder.Add(key);
  915. callableBuilder.Add(ascending);
  916. return TRuntimeNode(callableBuilder.Build(), false);
  917. }
  918. TRuntimeNode TProgramBuilder::BuildListNth(const std::string_view& callableName, TRuntimeNode list, TRuntimeNode n, TRuntimeNode ascending,
  919. const TUnaryLambda& keyExtractor)
  920. {
  921. const auto listType = list.GetStaticType();
  922. MKQL_ENSURE(listType->IsList(), "Expected list.");
  923. const auto itemType = static_cast<const TListType&>(*listType).GetItemType();
  924. ThrowIfListOfVoid(itemType);
  925. MKQL_ENSURE(n.GetStaticType()->IsData(), "Expected data");
  926. MKQL_ENSURE(static_cast<const TDataType&>(*n.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  927. const auto ascendingType = ascending.GetStaticType();
  928. const auto itemArg = Arg(itemType);
  929. auto key = keyExtractor(itemArg);
  930. if (ascendingType->IsTuple()) {
  931. const auto ascendingTuple = AS_TYPE(TTupleType, ascendingType);
  932. if (ascendingTuple->GetElementsCount() == 0) {
  933. return Take(list, n);
  934. }
  935. if (ascendingTuple->GetElementsCount() == 1) {
  936. ascending = Nth(ascending, 0);
  937. key = Nth(key, 0);
  938. }
  939. }
  940. TCallableBuilder callableBuilder(Env, callableName, listType);
  941. callableBuilder.Add(list);
  942. callableBuilder.Add(n);
  943. callableBuilder.Add(itemArg);
  944. callableBuilder.Add(key);
  945. callableBuilder.Add(ascending);
  946. return TRuntimeNode(callableBuilder.Build(), false);
  947. }
  948. TRuntimeNode TProgramBuilder::BuildSort(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode ascending,
  949. const TUnaryLambda& keyExtractor)
  950. {
  951. if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
  952. const bool newVersion = RuntimeVersion >= 25U && flowType->IsFlow();
  953. const auto condense = newVersion ?
  954. SqueezeToList(Map(flow, [&](TRuntimeNode item) { return Pickle(item); }), NewEmptyOptionalDataLiteral(NUdf::TDataType<ui64>::Id)) :
  955. Condense1(flow,
  956. [this](TRuntimeNode item) { return AsList(item); },
  957. [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
  958. [this](TRuntimeNode item, TRuntimeNode state) { return Append(state, item); }
  959. );
  960. const auto finalKeyExtractor = newVersion ? [&](TRuntimeNode item) {
  961. auto itemType = AS_TYPE(TFlowType, flowType)->GetItemType();
  962. return keyExtractor(Unpickle(itemType, item));
  963. } : keyExtractor;
  964. return FlatMap(condense, [&](TRuntimeNode list) {
  965. auto stealed = RuntimeVersion >= 27U ? Steal(list) : list;
  966. auto sorted = BuildSort(RuntimeVersion >= 26U ? "UnstableSort" : callableName, stealed, ascending, finalKeyExtractor);
  967. return newVersion ? Map(LazyList(sorted), [&](TRuntimeNode item) {
  968. auto itemType = AS_TYPE(TFlowType, flowType)->GetItemType();
  969. return Unpickle(itemType, item);
  970. }) : sorted;
  971. });
  972. }
  973. return BuildListSort(callableName, flow, ascending, keyExtractor);
  974. }
  975. TRuntimeNode TProgramBuilder::BuildNth(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode n, TRuntimeNode ascending,
  976. const TUnaryLambda& keyExtractor)
  977. {
  978. if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
  979. return FlatMap(Condense1(flow,
  980. [this](TRuntimeNode item) { return AsList(item); },
  981. [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
  982. [this](TRuntimeNode item, TRuntimeNode state) { return Append(state, item); }
  983. ),
  984. [&](TRuntimeNode list) { return BuildNth(callableName, list, n, ascending, keyExtractor); }
  985. );
  986. }
  987. return BuildListNth(callableName, flow, n, ascending, keyExtractor);
  988. }
  989. TRuntimeNode TProgramBuilder::BuildTake(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count) {
  990. const auto listType = flow.GetStaticType();
  991. TType* itemType = nullptr;
  992. if (listType->IsFlow()) {
  993. itemType = AS_TYPE(TFlowType, listType)->GetItemType();
  994. } else if (listType->IsList()) {
  995. itemType = AS_TYPE(TListType, listType)->GetItemType();
  996. } else if (listType->IsStream()) {
  997. itemType = AS_TYPE(TStreamType, listType)->GetItemType();
  998. }
  999. MKQL_ENSURE(itemType, "Expected flow, list or stream.");
  1000. ThrowIfListOfVoid(itemType);
  1001. MKQL_ENSURE(count.GetStaticType()->IsData(), "Expected data");
  1002. MKQL_ENSURE(static_cast<const TDataType&>(*count.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  1003. TCallableBuilder callableBuilder(Env, callableName, listType);
  1004. callableBuilder.Add(flow);
  1005. callableBuilder.Add(count);
  1006. return TRuntimeNode(callableBuilder.Build(), false);
  1007. }
  1008. template<bool IsFilter, bool OnStruct>
  1009. TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list) {
  1010. const auto listType = list.GetStaticType();
  1011. TType* itemType;
  1012. if (listType->IsFlow()) {
  1013. itemType = AS_TYPE(TFlowType, listType)->GetItemType();
  1014. } else if (listType->IsList()) {
  1015. itemType = AS_TYPE(TListType, listType)->GetItemType();
  1016. } else if (listType->IsStream()) {
  1017. itemType = AS_TYPE(TStreamType, listType)->GetItemType();
  1018. } else if (listType->IsOptional()) {
  1019. itemType = AS_TYPE(TOptionalType, listType)->GetItemType();
  1020. } else {
  1021. THROW yexception() << "Expected flow or list or stream or optional of " << (OnStruct ? "struct." : "tuple.");
  1022. }
  1023. std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TType*>>, std::vector<TType*>> filteredItems;
  1024. std::vector<std::conditional_t<OnStruct, std::string_view, ui32>> members;
  1025. const bool multiOptional = CollectOptionalElements<IsFilter>(itemType, members, filteredItems);
  1026. const auto predicate = [=](TRuntimeNode item) {
  1027. std::vector<TRuntimeNode> checkMembers;
  1028. checkMembers.reserve(members.size());
  1029. std::transform(members.cbegin(), members.cend(), std::back_inserter(checkMembers),
  1030. [=](const auto& i){ return Exists(Element(item, i)); });
  1031. return And(checkMembers);
  1032. };
  1033. auto resultType = listType;
  1034. if constexpr (IsFilter) {
  1035. if (const auto filteredItemType = NewArrayType(filteredItems); multiOptional) {
  1036. return BuildFilterNulls<OnStruct>(list, members, filteredItems);
  1037. } else {
  1038. resultType = listType->IsFlow() ?
  1039. NewFlowType(filteredItemType):
  1040. listType->IsList() ?
  1041. NewListType(filteredItemType):
  1042. listType->IsStream() ? NewStreamType(filteredItemType) : NewOptionalType(filteredItemType);
  1043. }
  1044. }
  1045. return Filter(list, predicate, resultType);
  1046. }
  1047. template<bool IsFilter, bool OnStruct>
  1048. TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list, const TArrayRef<std::conditional_t<OnStruct, const std::string_view, const ui32>>& members) {
  1049. if (members.empty()) {
  1050. return list;
  1051. }
  1052. const auto listType = list.GetStaticType();
  1053. TType* itemType;
  1054. if (listType->IsFlow()) {
  1055. itemType = AS_TYPE(TFlowType, listType)->GetItemType();
  1056. } else if (listType->IsList()) {
  1057. itemType = AS_TYPE(TListType, listType)->GetItemType();
  1058. } else if (listType->IsStream()) {
  1059. itemType = AS_TYPE(TStreamType, listType)->GetItemType();
  1060. } else if (listType->IsOptional()) {
  1061. itemType = AS_TYPE(TOptionalType, listType)->GetItemType();
  1062. } else {
  1063. THROW yexception() << "Expected flow or list or stream or optional of struct.";
  1064. }
  1065. const auto predicate = [=](TRuntimeNode item) {
  1066. TRuntimeNode::TList checkMembers;
  1067. checkMembers.reserve(members.size());
  1068. std::transform(members.cbegin(), members.cend(), std::back_inserter(checkMembers),
  1069. [=](const auto& i){ return Exists(Element(item, i)); });
  1070. return And(checkMembers);
  1071. };
  1072. auto resultType = listType;
  1073. if constexpr (IsFilter) {
  1074. if (std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TType*>>, std::vector<TType*>> filteredItems;
  1075. ReduceOptionalElements(itemType, members, filteredItems)) {
  1076. return BuildFilterNulls<OnStruct>(list, members, filteredItems);
  1077. } else {
  1078. const auto filteredItemType = NewArrayType(filteredItems);
  1079. resultType = listType->IsFlow() ?
  1080. NewFlowType(filteredItemType):
  1081. listType->IsList() ?
  1082. NewListType(filteredItemType):
  1083. listType->IsStream() ? NewStreamType(filteredItemType) : NewOptionalType(filteredItemType);
  1084. }
  1085. }
  1086. return Filter(list, predicate, resultType);
  1087. }
  1088. template<bool OnStruct>
  1089. TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list, const TArrayRef<std::conditional_t<OnStruct, const std::string_view, const ui32>>& members,
  1090. const std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TType*>>, std::vector<TType*>>& filteredItems) {
  1091. return FlatMap(list, [&](TRuntimeNode item) {
  1092. TRuntimeNode::TList checkMembers;
  1093. checkMembers.reserve(members.size());
  1094. std::transform(members.cbegin(), members.cend(), std::back_inserter(checkMembers),
  1095. [=](const auto& i){ return Element(item, i); });
  1096. return IfPresent(checkMembers, [&](TRuntimeNode::TList items) {
  1097. std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TRuntimeNode>>, TRuntimeNode::TList> row;
  1098. row.reserve(filteredItems.size());
  1099. auto j = 0U;
  1100. if constexpr (OnStruct) {
  1101. std::transform(filteredItems.cbegin(), filteredItems.cend(), std::back_inserter(row),
  1102. [&](const std::pair<std::string_view, TType*>& i) {
  1103. const auto& member = i.first;
  1104. const bool passtrought = members.cend() == std::find(members.cbegin(), members.cend(), member);
  1105. return std::make_pair(member, passtrought ? Element(item, member) : items[j++]);
  1106. }
  1107. );
  1108. return NewOptional(NewStruct(row));
  1109. } else {
  1110. auto i = 0U;
  1111. std::generate_n(std::back_inserter(row), filteredItems.size(),
  1112. [&]() {
  1113. const auto index = i++;
  1114. const bool passtrought = members.cend() == std::find(members.cbegin(), members.cend(), index);
  1115. return passtrought ? Element(item, index) : items[j++];
  1116. }
  1117. );
  1118. return NewOptional(NewTuple(row));
  1119. }
  1120. }, NewEmptyOptional(NewOptionalType(NewArrayType(filteredItems))));
  1121. });
  1122. }
  1123. TRuntimeNode TProgramBuilder::SkipNullMembers(TRuntimeNode list) {
  1124. return BuildFilterNulls<false, true>(list);
  1125. }
  1126. TRuntimeNode TProgramBuilder::FilterNullMembers(TRuntimeNode list) {
  1127. return BuildFilterNulls<true, true>(list);
  1128. }
  1129. TRuntimeNode TProgramBuilder::SkipNullMembers(TRuntimeNode list, const TArrayRef<const std::string_view>& members) {
  1130. return BuildFilterNulls<false, true>(list, members);
  1131. }
  1132. TRuntimeNode TProgramBuilder::FilterNullMembers(TRuntimeNode list, const TArrayRef<const std::string_view>& members) {
  1133. return BuildFilterNulls<true, true>(list, members);
  1134. }
  1135. TRuntimeNode TProgramBuilder::FilterNullElements(TRuntimeNode list) {
  1136. return BuildFilterNulls<true, false>(list);
  1137. }
  1138. TRuntimeNode TProgramBuilder::SkipNullElements(TRuntimeNode list) {
  1139. return BuildFilterNulls<false, false>(list);
  1140. }
  1141. TRuntimeNode TProgramBuilder::FilterNullElements(TRuntimeNode list, const TArrayRef<const ui32>& elements) {
  1142. return BuildFilterNulls<true, false>(list, elements);
  1143. }
  1144. TRuntimeNode TProgramBuilder::SkipNullElements(TRuntimeNode list, const TArrayRef<const ui32>& elements) {
  1145. return BuildFilterNulls<false, false>(list, elements);
  1146. }
  1147. template <typename ResultType>
  1148. TRuntimeNode TProgramBuilder::BuildContainerProperty(const std::string_view& callableName, TRuntimeNode listOrDict) {
  1149. const auto type = listOrDict.GetStaticType();
  1150. MKQL_ENSURE(type->IsList() || type->IsDict() || type->IsEmptyList() || type->IsEmptyDict(), "Expected list or dict.");
  1151. if (type->IsList()) {
  1152. const auto itemType = AS_TYPE(TListType, type)->GetItemType();
  1153. ThrowIfListOfVoid(itemType);
  1154. }
  1155. TCallableBuilder callableBuilder(Env, callableName, NewDataType(NUdf::TDataType<ResultType>::Id));
  1156. callableBuilder.Add(listOrDict);
  1157. return TRuntimeNode(callableBuilder.Build(), false);
  1158. }
  1159. TRuntimeNode TProgramBuilder::Length(TRuntimeNode listOrDict) {
  1160. return BuildContainerProperty<ui64>(__func__, listOrDict);
  1161. }
  1162. TRuntimeNode TProgramBuilder::Iterator(TRuntimeNode list, const TArrayRef<const TRuntimeNode>& dependentNodes) {
  1163. const auto streamType = NewStreamType(AS_TYPE(TListType, list.GetStaticType())->GetItemType());
  1164. TCallableBuilder callableBuilder(Env, __func__, streamType);
  1165. callableBuilder.Add(list);
  1166. for (auto node : dependentNodes) {
  1167. callableBuilder.Add(node);
  1168. }
  1169. return TRuntimeNode(callableBuilder.Build(), false);
  1170. }
  1171. TRuntimeNode TProgramBuilder::EmptyIterator(TType* streamType) {
  1172. MKQL_ENSURE(streamType->IsStream() || streamType->IsFlow(), "Expected stream or flow.");
  1173. if (RuntimeVersion < 7U && streamType->IsFlow()) {
  1174. return ToFlow(EmptyIterator(NewStreamType(AS_TYPE(TFlowType, streamType)->GetItemType())));
  1175. }
  1176. TCallableBuilder callableBuilder(Env, __func__, streamType);
  1177. return TRuntimeNode(callableBuilder.Build(), false);
  1178. }
  1179. TRuntimeNode TProgramBuilder::Collect(TRuntimeNode flow) {
  1180. const auto seqType = flow.GetStaticType();
  1181. TType* itemType = nullptr;
  1182. if (seqType->IsFlow()) {
  1183. itemType = AS_TYPE(TFlowType, seqType)->GetItemType();
  1184. } else if (seqType->IsList()) {
  1185. itemType = AS_TYPE(TListType, seqType)->GetItemType();
  1186. } else if (seqType->IsStream()) {
  1187. itemType = AS_TYPE(TStreamType, seqType)->GetItemType();
  1188. } else {
  1189. THROW yexception() << "Expected flow, list or stream.";
  1190. }
  1191. TCallableBuilder callableBuilder(Env, __func__, NewListType(itemType));
  1192. callableBuilder.Add(flow);
  1193. return TRuntimeNode(callableBuilder.Build(), false);
  1194. }
  1195. TRuntimeNode TProgramBuilder::LazyList(TRuntimeNode list) {
  1196. const auto type = list.GetStaticType();
  1197. bool isOptional;
  1198. const auto listType = UnpackOptional(type, isOptional);
  1199. MKQL_ENSURE(listType->IsList(), "Expected list");
  1200. TCallableBuilder callableBuilder(Env, __func__, type);
  1201. callableBuilder.Add(list);
  1202. return TRuntimeNode(callableBuilder.Build(), false);
  1203. }
  1204. TRuntimeNode TProgramBuilder::ForwardList(TRuntimeNode stream) {
  1205. const auto type = stream.GetStaticType();
  1206. MKQL_ENSURE(type->IsStream() || type->IsFlow(), "Expected flow or stream.");
  1207. if constexpr (RuntimeVersion < 10U) {
  1208. if (type->IsFlow()) {
  1209. return ForwardList(FromFlow(stream));
  1210. }
  1211. }
  1212. TCallableBuilder callableBuilder(Env, __func__, NewListType(type->IsFlow() ? AS_TYPE(TFlowType, stream)->GetItemType() : AS_TYPE(TStreamType, stream)->GetItemType()));
  1213. callableBuilder.Add(stream);
  1214. return TRuntimeNode(callableBuilder.Build(), false);
  1215. }
  1216. TRuntimeNode TProgramBuilder::ToFlow(TRuntimeNode stream) {
  1217. const auto type = stream.GetStaticType();
  1218. MKQL_ENSURE(type->IsStream() || type->IsList() || type->IsOptional(), "Expected stream, list or optional.");
  1219. const auto itemType = type->IsStream() ? AS_TYPE(TStreamType, stream)->GetItemType() :
  1220. type->IsList() ? AS_TYPE(TListType, stream)->GetItemType() : AS_TYPE(TOptionalType, stream)->GetItemType();
  1221. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(itemType));
  1222. callableBuilder.Add(stream);
  1223. return TRuntimeNode(callableBuilder.Build(), false);
  1224. }
  1225. TRuntimeNode TProgramBuilder::FromFlow(TRuntimeNode flow) {
  1226. MKQL_ENSURE(flow.GetStaticType()->IsFlow(), "Expected flow.");
  1227. TCallableBuilder callableBuilder(Env, __func__, NewStreamType(AS_TYPE(TFlowType, flow)->GetItemType()));
  1228. callableBuilder.Add(flow);
  1229. return TRuntimeNode(callableBuilder.Build(), false);
  1230. }
  1231. TRuntimeNode TProgramBuilder::Steal(TRuntimeNode input) {
  1232. if constexpr (RuntimeVersion < 27U) {
  1233. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1234. }
  1235. TCallableBuilder callableBuilder(Env, __func__, input.GetStaticType(), true);
  1236. callableBuilder.Add(input);
  1237. return TRuntimeNode(callableBuilder.Build(), false);
  1238. }
  1239. TRuntimeNode TProgramBuilder::ToBlocks(TRuntimeNode flow) {
  1240. auto* flowType = AS_TYPE(TFlowType, flow.GetStaticType());
  1241. auto* blockType = NewBlockType(flowType->GetItemType(), TBlockType::EShape::Many);
  1242. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(blockType));
  1243. callableBuilder.Add(flow);
  1244. return TRuntimeNode(callableBuilder.Build(), false);
  1245. }
  1246. TType* TProgramBuilder::BuildWideBlockType(const TArrayRef<TType* const>& wideComponents) {
  1247. std::vector<TType*> blockItems;
  1248. blockItems.reserve(wideComponents.size());
  1249. for (size_t i = 0; i < wideComponents.size(); i++) {
  1250. blockItems.push_back(NewBlockType(wideComponents[i], TBlockType::EShape::Many));
  1251. }
  1252. blockItems.push_back(NewBlockType(NewDataType(NUdf::TDataType<ui64>::Id), TBlockType::EShape::Scalar));
  1253. return NewMultiType(blockItems);
  1254. }
  1255. TRuntimeNode TProgramBuilder::WideToBlocks(TRuntimeNode stream) {
  1256. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected WideStream as input type");
  1257. if constexpr (RuntimeVersion < 58U) {
  1258. // Preserve the old behaviour for ABI compatibility.
  1259. // Emit (FromFlow (WideToBlocks (ToFlow (<stream>)))) to
  1260. // process the flow in favor to the given stream following
  1261. // the older MKQL ABI.
  1262. // FIXME: Drop the branch below, when the time comes.
  1263. const auto inputFlow = ToFlow(stream);
  1264. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, inputFlow.GetStaticType()));
  1265. TType* outputMultiType = BuildWideBlockType(wideComponents);
  1266. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(outputMultiType));
  1267. callableBuilder.Add(inputFlow);
  1268. const auto outputFlow = TRuntimeNode(callableBuilder.Build(), false);
  1269. return FromFlow(outputFlow);
  1270. }
  1271. const auto wideComponents = GetWideComponents(AS_TYPE(TStreamType, stream.GetStaticType()));
  1272. TType* outputMultiType = BuildWideBlockType(wideComponents);
  1273. TCallableBuilder callableBuilder(Env, __func__, NewStreamType(outputMultiType));
  1274. callableBuilder.Add(stream);
  1275. return TRuntimeNode(callableBuilder.Build(), false);
  1276. }
  1277. TRuntimeNode TProgramBuilder::FromBlocks(TRuntimeNode flow) {
  1278. auto* flowType = AS_TYPE(TFlowType, flow.GetStaticType());
  1279. auto* blockType = AS_TYPE(TBlockType, flowType->GetItemType());
  1280. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(blockType->GetItemType()));
  1281. callableBuilder.Add(flow);
  1282. return TRuntimeNode(callableBuilder.Build(), false);
  1283. }
  1284. TRuntimeNode TProgramBuilder::WideFromBlocks(TRuntimeNode stream) {
  1285. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected WideStream as input type");
  1286. if constexpr (RuntimeVersion < 55U) {
  1287. // Preserve the old behaviour for ABI compatibility.
  1288. // Emit (FromFlow (WideFromBlocks (ToFlow (<stream>)))) to
  1289. // process the flow in favor to the given stream following
  1290. // the older MKQL ABI.
  1291. // FIXME: Drop the branch below, when the time comes.
  1292. const auto inputFlow = ToFlow(stream);
  1293. auto outputItems = ValidateBlockFlowType(inputFlow.GetStaticType());
  1294. outputItems.pop_back();
  1295. TType* outputMultiType = NewMultiType(outputItems);
  1296. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(outputMultiType));
  1297. callableBuilder.Add(inputFlow);
  1298. const auto outputFlow = TRuntimeNode(callableBuilder.Build(), false);
  1299. return FromFlow(outputFlow);
  1300. }
  1301. auto outputItems = ValidateBlockStreamType(stream.GetStaticType());
  1302. outputItems.pop_back();
  1303. TType* outputMultiType = NewMultiType(outputItems);
  1304. TCallableBuilder callableBuilder(Env, __func__, NewStreamType(outputMultiType));
  1305. callableBuilder.Add(stream);
  1306. return TRuntimeNode(callableBuilder.Build(), false);
  1307. }
  1308. TRuntimeNode TProgramBuilder::WideSkipBlocks(TRuntimeNode flow, TRuntimeNode count) {
  1309. return BuildWideSkipTakeBlocks(__func__, flow, count);
  1310. }
  1311. TRuntimeNode TProgramBuilder::WideTakeBlocks(TRuntimeNode flow, TRuntimeNode count) {
  1312. return BuildWideSkipTakeBlocks(__func__, flow, count);
  1313. }
  1314. TRuntimeNode TProgramBuilder::WideTopBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
  1315. return BuildWideTopOrSort(__func__, flow, count, keys);
  1316. }
  1317. TRuntimeNode TProgramBuilder::WideTopSortBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
  1318. return BuildWideTopOrSort(__func__, flow, count, keys);
  1319. }
  1320. TRuntimeNode TProgramBuilder::WideSortBlocks(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
  1321. return BuildWideTopOrSort(__func__, flow, Nothing(), keys);
  1322. }
  1323. TRuntimeNode TProgramBuilder::AsScalar(TRuntimeNode value) {
  1324. TCallableBuilder callableBuilder(Env, __func__, NewBlockType(value.GetStaticType(), TBlockType::EShape::Scalar));
  1325. callableBuilder.Add(value);
  1326. return TRuntimeNode(callableBuilder.Build(), false);
  1327. }
  1328. TRuntimeNode TProgramBuilder::ReplicateScalar(TRuntimeNode value, TRuntimeNode count) {
  1329. if constexpr (RuntimeVersion < 43U) {
  1330. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1331. }
  1332. auto valueType = AS_TYPE(TBlockType, value.GetStaticType());
  1333. auto countType = AS_TYPE(TBlockType, count.GetStaticType());
  1334. MKQL_ENSURE(valueType->GetShape() == TBlockType::EShape::Scalar, "Expecting scalar as first arguemnt");
  1335. MKQL_ENSURE(countType->GetShape() == TBlockType::EShape::Scalar, "Expecting scalar as second arguemnt");
  1336. MKQL_ENSURE(countType->GetItemType()->IsData(), "Expected scalar data as second argument");
  1337. MKQL_ENSURE(AS_TYPE(TDataType, countType->GetItemType())->GetSchemeType() ==
  1338. NUdf::TDataType<ui64>::Id, "Expected scalar ui64 as second argument");
  1339. auto outputType = NewBlockType(valueType->GetItemType(), TBlockType::EShape::Many);
  1340. TCallableBuilder callableBuilder(Env, __func__, outputType);
  1341. callableBuilder.Add(value);
  1342. callableBuilder.Add(count);
  1343. return TRuntimeNode(callableBuilder.Build(), false);
  1344. }
  1345. TRuntimeNode TProgramBuilder::BlockCompress(TRuntimeNode flow, ui32 bitmapIndex) {
  1346. auto blockItemTypes = ValidateBlockFlowType(flow.GetStaticType());
  1347. MKQL_ENSURE(blockItemTypes.size() >= 2, "Expected at least two input columns");
  1348. MKQL_ENSURE(bitmapIndex < blockItemTypes.size() - 1, "Invalid bitmap index");
  1349. MKQL_ENSURE(AS_TYPE(TDataType, blockItemTypes[bitmapIndex])->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected Bool as bitmap column type");
  1350. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  1351. MKQL_ENSURE(wideComponents.size() == blockItemTypes.size(), "Unexpected tuple size");
  1352. std::vector<TType*> flowItems;
  1353. for (size_t i = 0; i < wideComponents.size(); ++i) {
  1354. if (i == bitmapIndex) {
  1355. continue;
  1356. }
  1357. flowItems.push_back(wideComponents[i]);
  1358. }
  1359. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(flowItems)));
  1360. callableBuilder.Add(flow);
  1361. callableBuilder.Add(NewDataLiteral<ui32>(bitmapIndex));
  1362. return TRuntimeNode(callableBuilder.Build(), false);
  1363. }
  1364. TRuntimeNode TProgramBuilder::BlockExpandChunked(TRuntimeNode comp) {
  1365. if (comp.GetStaticType()->IsStream()) {
  1366. ValidateBlockStreamType(comp.GetStaticType());
  1367. } else {
  1368. ValidateBlockFlowType(comp.GetStaticType());
  1369. }
  1370. TCallableBuilder callableBuilder(Env, __func__, comp.GetStaticType());
  1371. callableBuilder.Add(comp);
  1372. return TRuntimeNode(callableBuilder.Build(), false);
  1373. }
  1374. TRuntimeNode TProgramBuilder::BlockCoalesce(TRuntimeNode first, TRuntimeNode second) {
  1375. auto firstType = AS_TYPE(TBlockType, first.GetStaticType());
  1376. auto secondType = AS_TYPE(TBlockType, second.GetStaticType());
  1377. auto firstItemType = firstType->GetItemType();
  1378. auto secondItemType = secondType->GetItemType();
  1379. MKQL_ENSURE(firstItemType->IsOptional() || firstItemType->IsPg(), "Expecting Optional or Pg type as first argument");
  1380. if (!firstItemType->IsSameType(*secondItemType)) {
  1381. bool firstOptional;
  1382. firstItemType = UnpackOptional(firstItemType, firstOptional);
  1383. MKQL_ENSURE(firstItemType->IsSameType(*secondItemType), "Uncompatible arguemnt types");
  1384. }
  1385. auto outputType = NewBlockType(secondType->GetItemType(), GetResultShape({firstType, secondType}));
  1386. TCallableBuilder callableBuilder(Env, __func__, outputType);
  1387. callableBuilder.Add(first);
  1388. callableBuilder.Add(second);
  1389. return TRuntimeNode(callableBuilder.Build(), false);
  1390. }
  1391. TRuntimeNode TProgramBuilder::BlockExists(TRuntimeNode data) {
  1392. auto dataType = AS_TYPE(TBlockType, data.GetStaticType());
  1393. auto outputType = NewBlockType(NewDataType(NUdf::TDataType<bool>::Id), dataType->GetShape());
  1394. TCallableBuilder callableBuilder(Env, __func__, outputType);
  1395. callableBuilder.Add(data);
  1396. return TRuntimeNode(callableBuilder.Build(), false);
  1397. }
  1398. TRuntimeNode TProgramBuilder::BlockMember(TRuntimeNode structObj, const std::string_view& memberName) {
  1399. auto blockType = AS_TYPE(TBlockType, structObj.GetStaticType());
  1400. bool isOptional;
  1401. const auto type = AS_TYPE(TStructType, UnpackOptional(blockType->GetItemType(), isOptional));
  1402. const auto memberIndex = type->GetMemberIndex(memberName);
  1403. auto memberType = type->GetMemberType(memberIndex);
  1404. if (isOptional && !memberType->IsOptional() && !memberType->IsNull() && !memberType->IsPg()) {
  1405. memberType = NewOptionalType(memberType);
  1406. }
  1407. auto returnType = NewBlockType(memberType, blockType->GetShape());
  1408. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1409. callableBuilder.Add(structObj);
  1410. callableBuilder.Add(NewDataLiteral<ui32>(memberIndex));
  1411. return TRuntimeNode(callableBuilder.Build(), false);
  1412. }
  1413. TRuntimeNode TProgramBuilder::BlockNth(TRuntimeNode tuple, ui32 index) {
  1414. auto blockType = AS_TYPE(TBlockType, tuple.GetStaticType());
  1415. bool isOptional;
  1416. const auto type = AS_TYPE(TTupleType, UnpackOptional(blockType->GetItemType(), isOptional));
  1417. MKQL_ENSURE(index < type->GetElementsCount(), "Index out of range: " << index <<
  1418. " is not less than " << type->GetElementsCount());
  1419. auto itemType = type->GetElementType(index);
  1420. if (isOptional && !itemType->IsOptional() && !itemType->IsNull() && !itemType->IsPg()) {
  1421. itemType = TOptionalType::Create(itemType, Env);
  1422. }
  1423. auto returnType = NewBlockType(itemType, blockType->GetShape());
  1424. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1425. callableBuilder.Add(tuple);
  1426. callableBuilder.Add(NewDataLiteral<ui32>(index));
  1427. return TRuntimeNode(callableBuilder.Build(), false);
  1428. }
  1429. TRuntimeNode TProgramBuilder::BlockAsStruct(const TArrayRef<std::pair<std::string_view, TRuntimeNode>>& args) {
  1430. MKQL_ENSURE(!args.empty(), "Expected at least one argument");
  1431. TBlockType::EShape resultShape = TBlockType::EShape::Scalar;
  1432. TVector<std::pair<std::string_view, TType*>> members;
  1433. for (const auto& x : args) {
  1434. auto blockType = AS_TYPE(TBlockType, x.second.GetStaticType());
  1435. members.emplace_back(x.first, blockType->GetItemType());
  1436. if (blockType->GetShape() == TBlockType::EShape::Many) {
  1437. resultShape = TBlockType::EShape::Many;
  1438. }
  1439. }
  1440. auto returnType = NewBlockType(NewStructType(members), resultShape);
  1441. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1442. for (const auto& x : args) {
  1443. callableBuilder.Add(x.second);
  1444. }
  1445. return TRuntimeNode(callableBuilder.Build(), false);
  1446. }
  1447. TRuntimeNode TProgramBuilder::BlockAsTuple(const TArrayRef<const TRuntimeNode>& args) {
  1448. MKQL_ENSURE(!args.empty(), "Expected at least one argument");
  1449. TBlockType::EShape resultShape = TBlockType::EShape::Scalar;
  1450. TVector<TType*> types;
  1451. for (const auto& x : args) {
  1452. auto blockType = AS_TYPE(TBlockType, x.GetStaticType());
  1453. types.push_back(blockType->GetItemType());
  1454. if (blockType->GetShape() == TBlockType::EShape::Many) {
  1455. resultShape = TBlockType::EShape::Many;
  1456. }
  1457. }
  1458. auto tupleType = NewTupleType(types);
  1459. auto returnType = NewBlockType(tupleType, resultShape);
  1460. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1461. for (const auto& x : args) {
  1462. callableBuilder.Add(x);
  1463. }
  1464. return TRuntimeNode(callableBuilder.Build(), false);
  1465. }
  1466. TRuntimeNode TProgramBuilder::BlockToPg(TRuntimeNode input, TType* returnType) {
  1467. if constexpr (RuntimeVersion < 37U) {
  1468. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1469. }
  1470. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1471. callableBuilder.Add(input);
  1472. return TRuntimeNode(callableBuilder.Build(), false);
  1473. }
  1474. TRuntimeNode TProgramBuilder::BlockFromPg(TRuntimeNode input, TType* returnType) {
  1475. if constexpr (RuntimeVersion < 37U) {
  1476. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1477. }
  1478. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1479. callableBuilder.Add(input);
  1480. return TRuntimeNode(callableBuilder.Build(), false);
  1481. }
  1482. TRuntimeNode TProgramBuilder::BlockNot(TRuntimeNode data) {
  1483. auto dataType = AS_TYPE(TBlockType, data.GetStaticType());
  1484. bool isOpt;
  1485. MKQL_ENSURE(UnpackOptionalData(dataType->GetItemType(), isOpt)->GetSchemeType() == NUdf::TDataType<bool>::Id, "Requires boolean args.");
  1486. TCallableBuilder callableBuilder(Env, __func__, data.GetStaticType());
  1487. callableBuilder.Add(data);
  1488. return TRuntimeNode(callableBuilder.Build(), false);
  1489. }
  1490. TRuntimeNode TProgramBuilder::BlockAnd(TRuntimeNode first, TRuntimeNode second) {
  1491. return BuildBlockLogical(__func__, first, second);
  1492. }
  1493. TRuntimeNode TProgramBuilder::BlockOr(TRuntimeNode first, TRuntimeNode second) {
  1494. return BuildBlockLogical(__func__, first, second);
  1495. }
  1496. TRuntimeNode TProgramBuilder::BlockXor(TRuntimeNode first, TRuntimeNode second) {
  1497. return BuildBlockLogical(__func__, first, second);
  1498. }
  1499. TRuntimeNode TProgramBuilder::BlockDecimalDiv(TRuntimeNode first, TRuntimeNode second) {
  1500. return BuildBlockDecimalBinary(__func__, first, second);
  1501. }
  1502. TRuntimeNode TProgramBuilder::BlockDecimalMod(TRuntimeNode first, TRuntimeNode second) {
  1503. return BuildBlockDecimalBinary(__func__, first, second);
  1504. }
  1505. TRuntimeNode TProgramBuilder::BlockDecimalMul(TRuntimeNode first, TRuntimeNode second) {
  1506. return BuildBlockDecimalBinary(__func__, first, second);
  1507. }
  1508. TRuntimeNode TProgramBuilder::ListFromRange(TRuntimeNode start, TRuntimeNode end, TRuntimeNode step) {
  1509. MKQL_ENSURE(start.GetStaticType()->IsData(), "Expected data");
  1510. MKQL_ENSURE(end.GetStaticType()->IsSameType(*start.GetStaticType()), "Mismatch type");
  1511. if constexpr (RuntimeVersion < 24U) {
  1512. MKQL_ENSURE(IsNumericType(AS_TYPE(TDataType, start)->GetSchemeType()), "Expected numeric");
  1513. } else {
  1514. MKQL_ENSURE(IsNumericType(AS_TYPE(TDataType, start)->GetSchemeType()) ||
  1515. IsDateType(AS_TYPE(TDataType, start)->GetSchemeType()) ||
  1516. IsTzDateType(AS_TYPE(TDataType, start)->GetSchemeType()) ||
  1517. IsIntervalType(AS_TYPE(TDataType, start)->GetSchemeType()),
  1518. "Expected numeric, date or tzdate");
  1519. if (IsNumericType(AS_TYPE(TDataType, start)->GetSchemeType())) {
  1520. MKQL_ENSURE(IsNumericType(AS_TYPE(TDataType, step)->GetSchemeType()), "Expected numeric");
  1521. } else {
  1522. MKQL_ENSURE(IsIntervalType(AS_TYPE(TDataType, step)->GetSchemeType()), "Expected interval");
  1523. }
  1524. }
  1525. TCallableBuilder callableBuilder(Env, __func__, TListType::Create(start.GetStaticType(), Env));
  1526. callableBuilder.Add(start);
  1527. callableBuilder.Add(end);
  1528. callableBuilder.Add(step);
  1529. return TRuntimeNode(callableBuilder.Build(), false);
  1530. }
  1531. TRuntimeNode TProgramBuilder::Switch(TRuntimeNode stream,
  1532. const TArrayRef<const TSwitchInput>& handlerInputs,
  1533. std::function<TRuntimeNode(ui32 index, TRuntimeNode item)> handler,
  1534. ui64 memoryLimitBytes, TType* returnType) {
  1535. MKQL_ENSURE(stream.GetStaticType()->IsStream() || stream.GetStaticType()->IsFlow(), "Expected stream or flow.");
  1536. std::vector<TRuntimeNode> argNodes(handlerInputs.size());
  1537. std::vector<TRuntimeNode> outputNodes(handlerInputs.size());
  1538. for (ui32 i = 0; i < handlerInputs.size(); ++i) {
  1539. TRuntimeNode arg = Arg(handlerInputs[i].InputType);
  1540. argNodes[i] = arg;
  1541. outputNodes[i] = handler(i, arg);
  1542. }
  1543. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1544. callableBuilder.Add(stream);
  1545. callableBuilder.Add(NewDataLiteral<ui64>(memoryLimitBytes));
  1546. for (ui32 i = 0; i < handlerInputs.size(); ++i) {
  1547. std::vector<TRuntimeNode> tupleElems;
  1548. for (auto index : handlerInputs[i].Indicies) {
  1549. tupleElems.push_back(NewDataLiteral<ui32>(index));
  1550. }
  1551. auto indiciesTuple = NewTuple(tupleElems);
  1552. callableBuilder.Add(indiciesTuple);
  1553. callableBuilder.Add(argNodes[i]);
  1554. callableBuilder.Add(outputNodes[i]);
  1555. if (!handlerInputs[i].ResultVariantOffset) {
  1556. callableBuilder.Add(NewVoid());
  1557. } else {
  1558. callableBuilder.Add(NewDataLiteral<ui32>(*handlerInputs[i].ResultVariantOffset));
  1559. }
  1560. }
  1561. return TRuntimeNode(callableBuilder.Build(), false);
  1562. }
  1563. TRuntimeNode TProgramBuilder::HasItems(TRuntimeNode listOrDict) {
  1564. return BuildContainerProperty<bool>(__func__, listOrDict);
  1565. }
  1566. TRuntimeNode TProgramBuilder::Reverse(TRuntimeNode list) {
  1567. bool isOptional = false;
  1568. const auto listType = UnpackOptional(list, isOptional);
  1569. if (isOptional) {
  1570. return Map(list, [&](TRuntimeNode unpacked) { return Reverse(unpacked); } );
  1571. }
  1572. const auto listDetailedType = AS_TYPE(TListType, listType);
  1573. const auto itemType = listDetailedType->GetItemType();
  1574. ThrowIfListOfVoid(itemType);
  1575. TCallableBuilder callableBuilder(Env, __func__, listType);
  1576. callableBuilder.Add(list);
  1577. return TRuntimeNode(callableBuilder.Build(), false);
  1578. }
  1579. TRuntimeNode TProgramBuilder::Skip(TRuntimeNode list, TRuntimeNode count) {
  1580. return BuildTake(__func__, list, count);
  1581. }
  1582. TRuntimeNode TProgramBuilder::Take(TRuntimeNode list, TRuntimeNode count) {
  1583. return BuildTake(__func__, list, count);
  1584. }
  1585. TRuntimeNode TProgramBuilder::Sort(TRuntimeNode list, TRuntimeNode ascending, const TUnaryLambda& keyExtractor)
  1586. {
  1587. return BuildSort(__func__, list, ascending, keyExtractor);
  1588. }
  1589. TRuntimeNode TProgramBuilder::WideTop(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys)
  1590. {
  1591. return BuildWideTopOrSort(__func__, flow, count, keys);
  1592. }
  1593. TRuntimeNode TProgramBuilder::WideTopSort(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys)
  1594. {
  1595. return BuildWideTopOrSort(__func__, flow, count, keys);
  1596. }
  1597. TRuntimeNode TProgramBuilder::WideSort(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys)
  1598. {
  1599. return BuildWideTopOrSort(__func__, flow, Nothing(), keys);
  1600. }
  1601. TRuntimeNode TProgramBuilder::BuildWideTopOrSort(const std::string_view& callableName, TRuntimeNode flow, TMaybe<TRuntimeNode> count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
  1602. if (count) {
  1603. if constexpr (RuntimeVersion < 33U) {
  1604. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << callableName;
  1605. }
  1606. } else {
  1607. if constexpr (RuntimeVersion < 34U) {
  1608. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << callableName;
  1609. }
  1610. }
  1611. const auto width = GetWideComponentsCount(AS_TYPE(TFlowType, flow.GetStaticType()));
  1612. MKQL_ENSURE(!keys.empty() && keys.size() <= width, "Unexpected keys count: " << keys.size());
  1613. TCallableBuilder callableBuilder(Env, callableName, flow.GetStaticType());
  1614. callableBuilder.Add(flow);
  1615. if (count) {
  1616. callableBuilder.Add(*count);
  1617. }
  1618. std::for_each(keys.cbegin(), keys.cend(), [&](const std::pair<ui32, TRuntimeNode>& key) {
  1619. MKQL_ENSURE(key.first < width, "Key index too large: " << key.first);
  1620. callableBuilder.Add(NewDataLiteral(key.first));
  1621. callableBuilder.Add(key.second);
  1622. });
  1623. return TRuntimeNode(callableBuilder.Build(), false);
  1624. }
  1625. TRuntimeNode TProgramBuilder::Top(TRuntimeNode flow, TRuntimeNode count, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) {
  1626. if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
  1627. const TUnaryLambda getKey = [&](TRuntimeNode item) { return Nth(item, 0U); };
  1628. const TUnaryLambda getItem = [&](TRuntimeNode item) { return Nth(item, 1U); };
  1629. const TUnaryLambda cacheKeyExtractor = [&](TRuntimeNode item) {
  1630. return NewTuple({keyExtractor(item), item});
  1631. };
  1632. return FlatMap(Condense1(Map(flow, cacheKeyExtractor),
  1633. [&](TRuntimeNode item) { return AsList(item); },
  1634. [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
  1635. [&](TRuntimeNode item, TRuntimeNode state) {
  1636. return KeepTop(count, state, item, ascending, getKey);
  1637. }
  1638. ),
  1639. [&](TRuntimeNode list) { return Map(Top(list, count, ascending, getKey), getItem); }
  1640. );
  1641. }
  1642. return BuildListNth(__func__, flow, count, ascending, keyExtractor);
  1643. }
  1644. TRuntimeNode TProgramBuilder::TopSort(TRuntimeNode flow, TRuntimeNode count, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) {
  1645. if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
  1646. const TUnaryLambda getKey = [&](TRuntimeNode item) { return Nth(item, 0U); };
  1647. const TUnaryLambda getItem = [&](TRuntimeNode item) { return Nth(item, 1U); };
  1648. const TUnaryLambda cacheKeyExtractor = [&](TRuntimeNode item) {
  1649. return NewTuple({keyExtractor(item), item});
  1650. };
  1651. return FlatMap(Condense1(Map(flow, cacheKeyExtractor),
  1652. [&](TRuntimeNode item) { return AsList(item); },
  1653. [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
  1654. [&](TRuntimeNode item, TRuntimeNode state) {
  1655. return KeepTop(count, state, item, ascending, getKey);
  1656. }
  1657. ),
  1658. [&](TRuntimeNode list) { return Map(TopSort(list, count, ascending, getKey), getItem); }
  1659. );
  1660. }
  1661. if constexpr (RuntimeVersion >= 25U)
  1662. return BuildListNth(__func__, flow, count, ascending, keyExtractor);
  1663. else
  1664. return BuildListSort("Sort", BuildListNth("Top", flow, count, ascending, keyExtractor), ascending, keyExtractor);
  1665. }
  1666. TRuntimeNode TProgramBuilder::KeepTop(TRuntimeNode count, TRuntimeNode list, TRuntimeNode item, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) {
  1667. const auto listType = list.GetStaticType();
  1668. MKQL_ENSURE(listType->IsList(), "Expected list.");
  1669. const auto itemType = static_cast<const TListType&>(*listType).GetItemType();
  1670. ThrowIfListOfVoid(itemType);
  1671. MKQL_ENSURE(count.GetStaticType()->IsData(), "Expected data");
  1672. MKQL_ENSURE(static_cast<const TDataType&>(*count.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  1673. MKQL_ENSURE(itemType->IsSameType(*item.GetStaticType()), "Types of list and item are different.");
  1674. const auto ascendingType = ascending.GetStaticType();
  1675. const auto itemArg = Arg(itemType);
  1676. auto key = keyExtractor(itemArg);
  1677. const auto hotkey = Arg(key.GetStaticType());
  1678. if (ascendingType->IsTuple()) {
  1679. const auto ascendingTuple = AS_TYPE(TTupleType, ascendingType);
  1680. if (ascendingTuple->GetElementsCount() == 0) {
  1681. return If(AggrLess(Length(list), count), Append(list, item), list);
  1682. }
  1683. if (ascendingTuple->GetElementsCount() == 1) {
  1684. ascending = Nth(ascending, 0);
  1685. key = Nth(key, 0);
  1686. }
  1687. }
  1688. TCallableBuilder callableBuilder(Env, __func__, listType);
  1689. callableBuilder.Add(count);
  1690. callableBuilder.Add(list);
  1691. callableBuilder.Add(item);
  1692. callableBuilder.Add(itemArg);
  1693. callableBuilder.Add(key);
  1694. callableBuilder.Add(ascending);
  1695. callableBuilder.Add(hotkey);
  1696. return TRuntimeNode(callableBuilder.Build(), false);
  1697. }
  1698. TRuntimeNode TProgramBuilder::Contains(TRuntimeNode dict, TRuntimeNode key) {
  1699. if constexpr (RuntimeVersion >= 25U)
  1700. if (!dict.GetStaticType()->IsDict())
  1701. return DataCompare(__func__, dict, key);
  1702. const auto keyType = AS_TYPE(TDictType, dict.GetStaticType())->GetKeyType();
  1703. MKQL_ENSURE(keyType->IsSameType(*key.GetStaticType()), "Key type mismatch. Requred: " << *keyType << ", but got: " << *key.GetStaticType());
  1704. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<bool>::Id));
  1705. callableBuilder.Add(dict);
  1706. callableBuilder.Add(key);
  1707. return TRuntimeNode(callableBuilder.Build(), false);
  1708. }
  1709. TRuntimeNode TProgramBuilder::Lookup(TRuntimeNode dict, TRuntimeNode key) {
  1710. const auto dictType = AS_TYPE(TDictType, dict.GetStaticType());
  1711. const auto keyType = dictType->GetKeyType();
  1712. MKQL_ENSURE(keyType->IsSameType(*key.GetStaticType()), "Key type mismatch. Requred: " << *keyType << ", but got: " << *key.GetStaticType());
  1713. TCallableBuilder callableBuilder(Env, __func__, NewOptionalType(dictType->GetPayloadType()));
  1714. callableBuilder.Add(dict);
  1715. callableBuilder.Add(key);
  1716. return TRuntimeNode(callableBuilder.Build(), false);
  1717. }
  1718. TRuntimeNode TProgramBuilder::DictItems(TRuntimeNode dict, EDictItems mode) {
  1719. const auto dictTypeChecked = AS_TYPE(TDictType, dict.GetStaticType());
  1720. TType* itemType;
  1721. switch (mode) {
  1722. case EDictItems::Both: {
  1723. const std::array<TType*, 2U> tupleTypes = {{ dictTypeChecked->GetKeyType(), dictTypeChecked->GetPayloadType() }};
  1724. itemType = NewTupleType(tupleTypes);
  1725. break;
  1726. }
  1727. case EDictItems::Keys: itemType = dictTypeChecked->GetKeyType(); break;
  1728. case EDictItems::Payloads: itemType = dictTypeChecked->GetPayloadType(); break;
  1729. }
  1730. TCallableBuilder callableBuilder(Env, __func__, NewListType(itemType));
  1731. callableBuilder.Add(dict);
  1732. callableBuilder.Add(NewDataLiteral((ui32)mode));
  1733. return TRuntimeNode(callableBuilder.Build(), false);
  1734. }
  1735. TRuntimeNode TProgramBuilder::DictItems(TRuntimeNode dict) {
  1736. if constexpr (RuntimeVersion < 6U) {
  1737. return DictItems(dict, EDictItems::Both);
  1738. }
  1739. const auto dictTypeChecked = AS_TYPE(TDictType, dict.GetStaticType());
  1740. const auto itemType = NewTupleType({ dictTypeChecked->GetKeyType(), dictTypeChecked->GetPayloadType() });
  1741. TCallableBuilder callableBuilder(Env, __func__, NewListType(itemType));
  1742. callableBuilder.Add(dict);
  1743. return TRuntimeNode(callableBuilder.Build(), false);
  1744. }
  1745. TRuntimeNode TProgramBuilder::DictKeys(TRuntimeNode dict) {
  1746. if constexpr (RuntimeVersion < 6U) {
  1747. return DictItems(dict, EDictItems::Keys);
  1748. }
  1749. const auto dictTypeChecked = AS_TYPE(TDictType, dict.GetStaticType());
  1750. TCallableBuilder callableBuilder(Env, __func__, NewListType(dictTypeChecked->GetKeyType()));
  1751. callableBuilder.Add(dict);
  1752. return TRuntimeNode(callableBuilder.Build(), false);
  1753. }
  1754. TRuntimeNode TProgramBuilder::DictPayloads(TRuntimeNode dict) {
  1755. if constexpr (RuntimeVersion < 6U) {
  1756. return DictItems(dict, EDictItems::Payloads);
  1757. }
  1758. const auto dictTypeChecked = AS_TYPE(TDictType, dict.GetStaticType());
  1759. TCallableBuilder callableBuilder(Env, __func__, NewListType(dictTypeChecked->GetPayloadType()));
  1760. callableBuilder.Add(dict);
  1761. return TRuntimeNode(callableBuilder.Build(), false);
  1762. }
  1763. TRuntimeNode TProgramBuilder::ToIndexDict(TRuntimeNode list) {
  1764. const auto itemType = AS_TYPE(TListType, list.GetStaticType())->GetItemType();
  1765. ThrowIfListOfVoid(itemType);
  1766. const auto keyType = NewDataType(NUdf::TDataType<ui64>::Id);
  1767. const auto dictType = NewDictType(keyType, itemType, false);
  1768. TCallableBuilder callableBuilder(Env, __func__, dictType);
  1769. callableBuilder.Add(list);
  1770. return TRuntimeNode(callableBuilder.Build(), false);
  1771. }
  1772. TRuntimeNode TProgramBuilder::JoinDict(TRuntimeNode dict1, bool isMulti1, TRuntimeNode dict2, bool isMulti2, EJoinKind joinKind) {
  1773. const auto dict1type = AS_TYPE(TDictType, dict1);
  1774. const auto dict2type = AS_TYPE(TDictType, dict2);
  1775. MKQL_ENSURE(dict1type->GetKeyType()->IsSameType(*dict2type->GetKeyType()), "Dict key types must be the same");
  1776. if (joinKind == EJoinKind::RightOnly || joinKind == EJoinKind::RightSemi)
  1777. MKQL_ENSURE(dict1type->GetPayloadType()->IsVoid(), "Void required for first dict payload.");
  1778. else if (isMulti1)
  1779. MKQL_ENSURE(dict1type->GetPayloadType()->IsList(), "List required for first dict payload.");
  1780. if (joinKind == EJoinKind::LeftOnly || joinKind == EJoinKind::LeftSemi)
  1781. MKQL_ENSURE(dict2type->GetPayloadType()->IsVoid(), "Void required for second dict payload.");
  1782. else if (isMulti2)
  1783. MKQL_ENSURE(dict2type->GetPayloadType()->IsList(), "List required for second dict payload.");
  1784. std::array<TType*, 2> tupleItems = {{ dict1type->GetPayloadType(), dict2type->GetPayloadType() }};
  1785. if (isMulti1 && tupleItems.front()->IsList())
  1786. tupleItems.front() = AS_TYPE(TListType, tupleItems.front())->GetItemType();
  1787. if (isMulti2 && tupleItems.back()->IsList())
  1788. tupleItems.back() = AS_TYPE(TListType, tupleItems.back())->GetItemType();
  1789. if (IsLeftOptional(joinKind))
  1790. tupleItems.front() = NewOptionalType(tupleItems.front());
  1791. if (IsRightOptional(joinKind))
  1792. tupleItems.back() = NewOptionalType(tupleItems.back());
  1793. TType* itemType;
  1794. if (joinKind == EJoinKind::LeftOnly || joinKind == EJoinKind::LeftSemi)
  1795. itemType = tupleItems.front();
  1796. else if (joinKind == EJoinKind::RightOnly || joinKind == EJoinKind::RightSemi)
  1797. itemType = tupleItems.back();
  1798. else
  1799. itemType = NewTupleType(tupleItems);
  1800. const auto returnType = NewListType(itemType);
  1801. TCallableBuilder callableBuilder(Env, __func__, returnType);
  1802. callableBuilder.Add(dict1);
  1803. callableBuilder.Add(dict2);
  1804. callableBuilder.Add(NewDataLiteral(isMulti1));
  1805. callableBuilder.Add(NewDataLiteral(isMulti2));
  1806. callableBuilder.Add(NewDataLiteral(ui32(joinKind)));
  1807. return TRuntimeNode(callableBuilder.Build(), false);
  1808. }
  1809. TRuntimeNode TProgramBuilder::GraceJoinCommon(const TStringBuf& funcName, TRuntimeNode flowLeft, TRuntimeNode flowRight, EJoinKind joinKind,
  1810. const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& rightKeyColumns,
  1811. const TArrayRef<const ui32>& leftRenames, const TArrayRef<const ui32>& rightRenames, TType* returnType, EAnyJoinSettings anyJoinSettings ) {
  1812. MKQL_ENSURE(!leftKeyColumns.empty(), "At least one key column must be specified");
  1813. if (flowRight) {
  1814. MKQL_ENSURE(!rightKeyColumns.empty(), "At least one key column must be specified");
  1815. }
  1816. TRuntimeNode::TList leftKeyColumnsNodes, rightKeyColumnsNodes, leftRenamesNodes, rightRenamesNodes;
  1817. leftKeyColumnsNodes.reserve(leftKeyColumns.size());
  1818. std::transform(leftKeyColumns.cbegin(), leftKeyColumns.cend(), std::back_inserter(leftKeyColumnsNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  1819. rightKeyColumnsNodes.reserve(rightKeyColumns.size());
  1820. std::transform(rightKeyColumns.cbegin(), rightKeyColumns.cend(), std::back_inserter(rightKeyColumnsNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  1821. leftRenamesNodes.reserve(leftRenames.size());
  1822. std::transform(leftRenames.cbegin(), leftRenames.cend(), std::back_inserter(leftRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  1823. rightRenamesNodes.reserve(rightRenames.size());
  1824. std::transform(rightRenames.cbegin(), rightRenames.cend(), std::back_inserter(rightRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  1825. TCallableBuilder callableBuilder(Env, funcName, returnType);
  1826. callableBuilder.Add(flowLeft);
  1827. if (flowRight) {
  1828. callableBuilder.Add(flowRight);
  1829. }
  1830. callableBuilder.Add(NewDataLiteral((ui32)joinKind));
  1831. callableBuilder.Add(NewTuple(leftKeyColumnsNodes));
  1832. callableBuilder.Add(NewTuple(rightKeyColumnsNodes));
  1833. callableBuilder.Add(NewTuple(leftRenamesNodes));
  1834. callableBuilder.Add(NewTuple(rightRenamesNodes));
  1835. callableBuilder.Add(NewDataLiteral((ui32)anyJoinSettings));
  1836. return TRuntimeNode(callableBuilder.Build(), false);
  1837. }
  1838. TRuntimeNode TProgramBuilder::GraceJoin(TRuntimeNode flowLeft, TRuntimeNode flowRight, EJoinKind joinKind,
  1839. const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& rightKeyColumns,
  1840. const TArrayRef<const ui32>& leftRenames, const TArrayRef<const ui32>& rightRenames, TType* returnType, EAnyJoinSettings anyJoinSettings ) {
  1841. return GraceJoinCommon(__func__, flowLeft, flowRight, joinKind, leftKeyColumns, rightKeyColumns, leftRenames, rightRenames, returnType, anyJoinSettings);
  1842. }
  1843. TRuntimeNode TProgramBuilder::GraceSelfJoin(TRuntimeNode flowLeft, EJoinKind joinKind, const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& rightKeyColumns,
  1844. const TArrayRef<const ui32>& leftRenames, const TArrayRef<const ui32>& rightRenames, TType* returnType, EAnyJoinSettings anyJoinSettings ) {
  1845. if constexpr (RuntimeVersion < 40U) {
  1846. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1847. }
  1848. return GraceJoinCommon(__func__, flowLeft, {}, joinKind, leftKeyColumns, rightKeyColumns, leftRenames, rightRenames, returnType, anyJoinSettings);
  1849. }
  1850. TRuntimeNode TProgramBuilder::ToSortedDict(TRuntimeNode list, bool all, const TUnaryLambda& keySelector,
  1851. const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1852. return ToDict(list, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1853. }
  1854. TRuntimeNode TProgramBuilder::ToHashedDict(TRuntimeNode list, bool all, const TUnaryLambda& keySelector,
  1855. const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1856. return ToDict(list, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1857. }
  1858. TRuntimeNode TProgramBuilder::SqueezeToSortedDict(TRuntimeNode stream, bool all, const TUnaryLambda& keySelector,
  1859. const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1860. return SqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1861. }
  1862. TRuntimeNode TProgramBuilder::SqueezeToHashedDict(TRuntimeNode stream, bool all, const TUnaryLambda& keySelector,
  1863. const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1864. return SqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1865. }
  1866. TRuntimeNode TProgramBuilder::NarrowSqueezeToSortedDict(TRuntimeNode stream, bool all, const TNarrowLambda& keySelector,
  1867. const TNarrowLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1868. return NarrowSqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1869. }
  1870. TRuntimeNode TProgramBuilder::NarrowSqueezeToHashedDict(TRuntimeNode stream, bool all, const TNarrowLambda& keySelector,
  1871. const TNarrowLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) {
  1872. return NarrowSqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint);
  1873. }
  1874. TRuntimeNode TProgramBuilder::SqueezeToList(TRuntimeNode flow, TRuntimeNode limit) {
  1875. if constexpr (RuntimeVersion < 25U) {
  1876. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  1877. }
  1878. const auto itemType = AS_TYPE(TFlowType, flow.GetStaticType())->GetItemType();
  1879. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewListType(itemType)));
  1880. callableBuilder.Add(flow);
  1881. callableBuilder.Add(limit);
  1882. return TRuntimeNode(callableBuilder.Build(), false);
  1883. }
  1884. TRuntimeNode TProgramBuilder::Append(TRuntimeNode list, TRuntimeNode item) {
  1885. auto listType = list.GetStaticType();
  1886. AS_TYPE(TListType, listType);
  1887. const auto& listDetailedType = static_cast<const TListType&>(*listType);
  1888. auto itemType = item.GetStaticType();
  1889. MKQL_ENSURE(itemType->IsSameType(*listDetailedType.GetItemType()), "Types of list and item are different");
  1890. TCallableBuilder callableBuilder(Env, __func__, listType);
  1891. callableBuilder.Add(list);
  1892. callableBuilder.Add(item);
  1893. return TRuntimeNode(callableBuilder.Build(), false);
  1894. }
  1895. TRuntimeNode TProgramBuilder::Prepend(TRuntimeNode item, TRuntimeNode list) {
  1896. auto listType = list.GetStaticType();
  1897. AS_TYPE(TListType, listType);
  1898. const auto& listDetailedType = static_cast<const TListType&>(*listType);
  1899. auto itemType = item.GetStaticType();
  1900. MKQL_ENSURE(itemType->IsSameType(*listDetailedType.GetItemType()), "Types of list and item are different");
  1901. TCallableBuilder callableBuilder(Env, __func__, listType);
  1902. callableBuilder.Add(item);
  1903. callableBuilder.Add(list);
  1904. return TRuntimeNode(callableBuilder.Build(), false);
  1905. }
  1906. TRuntimeNode TProgramBuilder::BuildExtend(const std::string_view& callableName, const TArrayRef<const TRuntimeNode>& lists) {
  1907. MKQL_ENSURE(lists.size() > 0, "Expected at least 1 list or flow");
  1908. if (lists.size() == 1) {
  1909. return lists.front();
  1910. }
  1911. auto listType = lists.front().GetStaticType();
  1912. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected either flow, list or stream");
  1913. for (ui32 i = 1; i < lists.size(); ++i) {
  1914. auto listType2 = lists[i].GetStaticType();
  1915. MKQL_ENSURE(listType->IsSameType(*listType2), "Types of flows are different, left: " <<
  1916. PrintNode(listType, true) << ", right: " <<
  1917. PrintNode(listType2, true));
  1918. }
  1919. TCallableBuilder callableBuilder(Env, callableName, listType);
  1920. for (auto list : lists) {
  1921. callableBuilder.Add(list);
  1922. }
  1923. return TRuntimeNode(callableBuilder.Build(), false);
  1924. }
  1925. TRuntimeNode TProgramBuilder::Extend(const TArrayRef<const TRuntimeNode>& lists) {
  1926. return BuildExtend(__func__, lists);
  1927. }
  1928. TRuntimeNode TProgramBuilder::OrderedExtend(const TArrayRef<const TRuntimeNode>& lists) {
  1929. return BuildExtend(__func__, lists);
  1930. }
  1931. template<>
  1932. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::String>(const NUdf::TStringRef& data) const {
  1933. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<const char*>::Id, Env), true);
  1934. }
  1935. template<>
  1936. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Utf8>(const NUdf::TStringRef& data) const {
  1937. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TUtf8>::Id, Env), true);
  1938. }
  1939. template<>
  1940. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Yson>(const NUdf::TStringRef& data) const {
  1941. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TYson>::Id, Env), true);
  1942. }
  1943. template<>
  1944. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Json>(const NUdf::TStringRef& data) const {
  1945. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TJson>::Id, Env), true);
  1946. }
  1947. template<>
  1948. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::JsonDocument>(const NUdf::TStringRef& data) const {
  1949. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TJsonDocument>::Id, Env), true);
  1950. }
  1951. template<>
  1952. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Uuid>(const NUdf::TStringRef& data) const {
  1953. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TUuid>::Id, Env), true);
  1954. }
  1955. template<>
  1956. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Date>(const NUdf::TStringRef& data) const {
  1957. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDate>::Id, Env), true);
  1958. }
  1959. template<>
  1960. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Datetime>(const NUdf::TStringRef& data) const {
  1961. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDatetime>::Id, Env), true);
  1962. }
  1963. template<>
  1964. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Timestamp>(const NUdf::TStringRef& data) const {
  1965. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TTimestamp>::Id, Env), true);
  1966. }
  1967. template<>
  1968. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Interval>(const NUdf::TStringRef& data) const {
  1969. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TInterval>::Id, Env), true);
  1970. }
  1971. template<>
  1972. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::DyNumber>(const NUdf::TStringRef& data) const {
  1973. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDyNumber>::Id, Env), true);
  1974. }
  1975. TRuntimeNode TProgramBuilder::NewDecimalLiteral(NYql::NDecimal::TInt128 data, ui8 precision, ui8 scale) const {
  1976. return TRuntimeNode(TDataLiteral::Create(NUdf::TUnboxedValuePod(data), TDataDecimalType::Create(precision, scale, Env), Env), true);
  1977. }
  1978. template<>
  1979. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Date32>(const NUdf::TStringRef& data) const {
  1980. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDate32>::Id, Env), true);
  1981. }
  1982. template<>
  1983. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Datetime64>(const NUdf::TStringRef& data) const {
  1984. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDatetime64>::Id, Env), true);
  1985. }
  1986. template<>
  1987. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Timestamp64>(const NUdf::TStringRef& data) const {
  1988. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TTimestamp64>::Id, Env), true);
  1989. }
  1990. template<>
  1991. TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Interval64>(const NUdf::TStringRef& data) const {
  1992. return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TInterval64>::Id, Env), true);
  1993. }
  1994. TRuntimeNode TProgramBuilder::NewOptional(TRuntimeNode data) {
  1995. auto type = TOptionalType::Create(data.GetStaticType(), Env);
  1996. return TRuntimeNode(TOptionalLiteral::Create(data, type, Env), true);
  1997. }
  1998. TRuntimeNode TProgramBuilder::NewOptional(TType* optionalType, TRuntimeNode data) {
  1999. auto type = AS_TYPE(TOptionalType, optionalType);
  2000. return TRuntimeNode(TOptionalLiteral::Create(data, type, Env), true);
  2001. }
  2002. TRuntimeNode TProgramBuilder::NewVoid() {
  2003. return TRuntimeNode(Env.GetVoidLazy(), true);
  2004. }
  2005. TRuntimeNode TProgramBuilder::NewEmptyListOfVoid() {
  2006. return TRuntimeNode(Env.GetListOfVoidLazy(), true);
  2007. }
  2008. TRuntimeNode TProgramBuilder::NewEmptyOptional(TType* optionalOrPgType) {
  2009. MKQL_ENSURE(optionalOrPgType->IsOptional() || optionalOrPgType->IsPg(), "Expected optional or pg type");
  2010. if (optionalOrPgType->IsOptional()) {
  2011. return TRuntimeNode(TOptionalLiteral::Create(static_cast<TOptionalType*>(optionalOrPgType), Env), true);
  2012. }
  2013. return PgCast(NewNull(), optionalOrPgType);
  2014. }
  2015. TRuntimeNode TProgramBuilder::NewEmptyOptionalDataLiteral(NUdf::TDataTypeId schemeType) {
  2016. return TRuntimeNode(BuildEmptyOptionalDataLiteral(schemeType, Env), true);
  2017. }
  2018. TRuntimeNode TProgramBuilder::NewEmptyStruct() {
  2019. return TRuntimeNode(Env.GetEmptyStructLazy(), true);
  2020. }
  2021. TRuntimeNode TProgramBuilder::NewStruct(const TArrayRef<const std::pair<std::string_view, TRuntimeNode>>& members) {
  2022. if (members.empty()) {
  2023. return NewEmptyStruct();
  2024. }
  2025. TStructLiteralBuilder builder(Env);
  2026. for (auto x : members) {
  2027. builder.Add(x.first, x.second);
  2028. }
  2029. return TRuntimeNode(builder.Build(), true);
  2030. }
  2031. TRuntimeNode TProgramBuilder::NewStruct(TType* structType, const TArrayRef<const std::pair<std::string_view, TRuntimeNode>>& members) {
  2032. const auto detailedStructType = AS_TYPE(TStructType, structType);
  2033. MKQL_ENSURE(members.size() == detailedStructType->GetMembersCount(), "Mismatch count of members");
  2034. if (members.empty()) {
  2035. return NewEmptyStruct();
  2036. }
  2037. std::vector<TRuntimeNode> values(detailedStructType->GetMembersCount());
  2038. for (ui32 i = 0; i < detailedStructType->GetMembersCount(); ++i) {
  2039. const auto& name = members[i].first;
  2040. ui32 index = detailedStructType->GetMemberIndex(name);
  2041. MKQL_ENSURE(!values[index], "Duplicate of member: " << name);
  2042. values[index] = members[i].second;
  2043. }
  2044. return TRuntimeNode(TStructLiteral::Create(values.size(), values.data(), detailedStructType, Env), true);
  2045. }
  2046. TRuntimeNode TProgramBuilder::NewEmptyList() {
  2047. return TRuntimeNode(Env.GetEmptyListLazy(), true);
  2048. }
  2049. TRuntimeNode TProgramBuilder::NewEmptyList(TType* itemType) {
  2050. TListLiteralBuilder builder(Env, itemType);
  2051. return TRuntimeNode(builder.Build(), true);
  2052. }
  2053. TRuntimeNode TProgramBuilder::NewList(TType* itemType, const TArrayRef<const TRuntimeNode>& items) {
  2054. TListLiteralBuilder builder(Env, itemType);
  2055. for (auto item : items) {
  2056. builder.Add(item);
  2057. }
  2058. return TRuntimeNode(builder.Build(), true);
  2059. }
  2060. TRuntimeNode TProgramBuilder::NewEmptyDict() {
  2061. return TRuntimeNode(Env.GetEmptyDictLazy(), true);
  2062. }
  2063. TRuntimeNode TProgramBuilder::NewDict(TType* dictType, const TArrayRef<const std::pair<TRuntimeNode, TRuntimeNode>>& items) {
  2064. MKQL_ENSURE(dictType->IsDict(), "Expected dict type");
  2065. return TRuntimeNode(TDictLiteral::Create(items.size(), items.data(), static_cast<TDictType*>(dictType), Env), true);
  2066. }
  2067. TRuntimeNode TProgramBuilder::NewEmptyTuple() {
  2068. return TRuntimeNode(Env.GetEmptyTupleLazy(), true);
  2069. }
  2070. TRuntimeNode TProgramBuilder::NewTuple(TType* tupleType, const TArrayRef<const TRuntimeNode>& elements) {
  2071. MKQL_ENSURE(tupleType->IsTuple(), "Expected tuple type");
  2072. return TRuntimeNode(TTupleLiteral::Create(elements.size(), elements.data(), static_cast<TTupleType*>(tupleType), Env), true);
  2073. }
  2074. TRuntimeNode TProgramBuilder::NewTuple(const TArrayRef<const TRuntimeNode>& elements) {
  2075. std::vector<TType*> types;
  2076. types.reserve(elements.size());
  2077. for (auto elem : elements) {
  2078. types.push_back(elem.GetStaticType());
  2079. }
  2080. return NewTuple(NewTupleType(types), elements);
  2081. }
  2082. TRuntimeNode TProgramBuilder::NewVariant(TRuntimeNode item, ui32 index, TType* variantType) {
  2083. const auto type = AS_TYPE(TVariantType, variantType);
  2084. MKQL_ENSURE(type->GetUnderlyingType()->IsTuple(), "Expected tuple as underlying type");
  2085. return TRuntimeNode(TVariantLiteral::Create(item, index, type, Env), true);
  2086. }
  2087. TRuntimeNode TProgramBuilder::NewVariant(TRuntimeNode item, const std::string_view& member, TType* variantType) {
  2088. const auto type = AS_TYPE(TVariantType, variantType);
  2089. MKQL_ENSURE(type->GetUnderlyingType()->IsStruct(), "Expected struct as underlying type");
  2090. ui32 index = AS_TYPE(TStructType, type->GetUnderlyingType())->GetMemberIndex(member);
  2091. return TRuntimeNode(TVariantLiteral::Create(item, index, type, Env), true);
  2092. }
  2093. TRuntimeNode TProgramBuilder::Coalesce(TRuntimeNode data, TRuntimeNode defaultData) {
  2094. bool isOptional = false;
  2095. const auto dataType = UnpackOptional(data, isOptional);
  2096. if (!isOptional && !data.GetStaticType()->IsPg()) {
  2097. MKQL_ENSURE(data.GetStaticType()->IsSameType(*defaultData.GetStaticType()), "Mismatch operand types");
  2098. return data;
  2099. }
  2100. if (!dataType->IsSameType(*defaultData.GetStaticType())) {
  2101. bool isOptionalDefault;
  2102. const auto defaultDataType = UnpackOptional(defaultData, isOptionalDefault);
  2103. MKQL_ENSURE(dataType->IsSameType(*defaultDataType), "Mismatch operand types");
  2104. }
  2105. TCallableBuilder callableBuilder(Env, __func__, defaultData.GetStaticType());
  2106. callableBuilder.Add(data);
  2107. callableBuilder.Add(defaultData);
  2108. return TRuntimeNode(callableBuilder.Build(), false);
  2109. }
  2110. TRuntimeNode TProgramBuilder::Unwrap(TRuntimeNode optional, TRuntimeNode message, const std::string_view& file, ui32 row, ui32 column) {
  2111. bool isOptional;
  2112. auto underlyingType = UnpackOptional(optional, isOptional);
  2113. MKQL_ENSURE(isOptional, "Expected optional");
  2114. const auto& messageType = message.GetStaticType();
  2115. MKQL_ENSURE(messageType->IsData(), "Expected data");
  2116. const auto& messageTypeData = static_cast<const TDataType&>(*messageType);
  2117. MKQL_ENSURE(messageTypeData.GetSchemeType() == NUdf::TDataType<char*>::Id || messageTypeData.GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected string or utf8.");
  2118. TCallableBuilder callableBuilder(Env, __func__, underlyingType);
  2119. callableBuilder.Add(optional);
  2120. callableBuilder.Add(message);
  2121. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  2122. callableBuilder.Add(NewDataLiteral(row));
  2123. callableBuilder.Add(NewDataLiteral(column));
  2124. return TRuntimeNode(callableBuilder.Build(), false);
  2125. }
  2126. TRuntimeNode TProgramBuilder::Increment(TRuntimeNode data) {
  2127. const std::array<TRuntimeNode, 1> args = {{ data }};
  2128. bool isOptional;
  2129. const auto type = UnpackOptionalData(data, isOptional);
  2130. if (type->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id)
  2131. return Invoke(__func__, data.GetStaticType(), args);
  2132. return Invoke(TString("Inc_") += ::ToString(static_cast<TDataDecimalType*>(type)->GetParams().first), data.GetStaticType(), args);
  2133. }
  2134. TRuntimeNode TProgramBuilder::Decrement(TRuntimeNode data) {
  2135. const std::array<TRuntimeNode, 1> args = {{ data }};
  2136. bool isOptional;
  2137. const auto type = UnpackOptionalData(data, isOptional);
  2138. if (type->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id)
  2139. return Invoke(__func__, data.GetStaticType(), args);
  2140. return Invoke(TString("Dec_") += ::ToString(static_cast<TDataDecimalType*>(type)->GetParams().first), data.GetStaticType(), args);
  2141. }
  2142. TRuntimeNode TProgramBuilder::Abs(TRuntimeNode data) {
  2143. const std::array<TRuntimeNode, 1> args = {{ data }};
  2144. return Invoke(__func__, data.GetStaticType(), args);
  2145. }
  2146. TRuntimeNode TProgramBuilder::Plus(TRuntimeNode data) {
  2147. const std::array<TRuntimeNode, 1> args = {{ data }};
  2148. return Invoke(__func__, data.GetStaticType(), args);
  2149. }
  2150. TRuntimeNode TProgramBuilder::Minus(TRuntimeNode data) {
  2151. const std::array<TRuntimeNode, 1> args = {{ data }};
  2152. return Invoke(__func__, data.GetStaticType(), args);
  2153. }
  2154. TRuntimeNode TProgramBuilder::Add(TRuntimeNode data1, TRuntimeNode data2) {
  2155. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2156. bool isOptionalLeft;
  2157. const auto leftType = UnpackOptionalData(data1, isOptionalLeft);
  2158. if (leftType->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id)
  2159. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2160. const auto decimalType = static_cast<TDataDecimalType*>(leftType);
  2161. bool isOptionalRight;
  2162. const auto rightType = static_cast<TDataDecimalType*>(UnpackOptionalData(data2, isOptionalRight));
  2163. MKQL_ENSURE(rightType->IsSameType(*decimalType), "Operands type mismatch");
  2164. const auto resultType = isOptionalLeft || isOptionalRight ? NewOptionalType(decimalType) : decimalType;
  2165. return Invoke(TString("Add_") += ::ToString(decimalType->GetParams().first), resultType, args);
  2166. }
  2167. TRuntimeNode TProgramBuilder::Sub(TRuntimeNode data1, TRuntimeNode data2) {
  2168. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2169. bool isOptionalLeft;
  2170. const auto leftType = UnpackOptionalData(data1, isOptionalLeft);
  2171. if (leftType->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id)
  2172. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2173. const auto decimalType = static_cast<TDataDecimalType*>(leftType);
  2174. bool isOptionalRight;
  2175. const auto rightType = static_cast<TDataDecimalType*>(UnpackOptionalData(data2, isOptionalRight));
  2176. MKQL_ENSURE(rightType->IsSameType(*decimalType), "Operands type mismatch");
  2177. const auto resultType = isOptionalLeft || isOptionalRight ? NewOptionalType(decimalType) : decimalType;
  2178. return Invoke(TString("Sub_") += ::ToString(decimalType->GetParams().first), resultType, args);
  2179. }
  2180. TRuntimeNode TProgramBuilder::Mul(TRuntimeNode data1, TRuntimeNode data2) {
  2181. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2182. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2183. }
  2184. TRuntimeNode TProgramBuilder::Div(TRuntimeNode data1, TRuntimeNode data2) {
  2185. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2186. auto resultType = BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType());
  2187. if (resultType->IsData() && !(NUdf::GetDataTypeInfo(*static_cast<TDataType*>(resultType)->GetDataSlot()).Features & (NUdf::EDataTypeFeatures::FloatType | NUdf::EDataTypeFeatures::DecimalType))) {
  2188. resultType = NewOptionalType(resultType);
  2189. }
  2190. return Invoke(__func__, resultType, args);
  2191. }
  2192. TRuntimeNode TProgramBuilder::DecimalDiv(TRuntimeNode data1, TRuntimeNode data2) {
  2193. bool isOptionalLeft, isOptionalRight;
  2194. const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(data1, isOptionalLeft));
  2195. const auto rightType = UnpackOptionalData(data2, isOptionalRight);
  2196. if (rightType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id)
  2197. MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch");
  2198. else
  2199. MKQL_ENSURE(NUdf::GetDataTypeInfo(*rightType->GetDataSlot()).Features & NUdf::IntegralType, "Operands type mismatch");
  2200. const auto returnType = isOptionalLeft || isOptionalRight ? NewOptionalType(leftType) : leftType;
  2201. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2202. callableBuilder.Add(data1);
  2203. callableBuilder.Add(data2);
  2204. return TRuntimeNode(callableBuilder.Build(), false);
  2205. }
  2206. TRuntimeNode TProgramBuilder::DecimalMod(TRuntimeNode data1, TRuntimeNode data2) {
  2207. bool isOptionalLeft, isOptionalRight;
  2208. const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(data1, isOptionalLeft));
  2209. const auto rightType = UnpackOptionalData(data2, isOptionalRight);
  2210. if (rightType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id)
  2211. MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch");
  2212. else
  2213. MKQL_ENSURE(NUdf::GetDataTypeInfo(*rightType->GetDataSlot()).Features & NUdf::IntegralType, "Operands type mismatch");
  2214. const auto returnType = isOptionalLeft || isOptionalRight ? NewOptionalType(leftType) : leftType;
  2215. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2216. callableBuilder.Add(data1);
  2217. callableBuilder.Add(data2);
  2218. return TRuntimeNode(callableBuilder.Build(), false);
  2219. }
  2220. TRuntimeNode TProgramBuilder::DecimalMul(TRuntimeNode data1, TRuntimeNode data2) {
  2221. bool isOptionalLeft, isOptionalRight;
  2222. const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(data1, isOptionalLeft));
  2223. const auto rightType = UnpackOptionalData(data2, isOptionalRight);
  2224. if (rightType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id)
  2225. MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch");
  2226. else
  2227. MKQL_ENSURE(NUdf::GetDataTypeInfo(*rightType->GetDataSlot()).Features & NUdf::IntegralType, "Operands type mismatch");
  2228. const auto returnType = isOptionalLeft || isOptionalRight ? NewOptionalType(leftType) : leftType;
  2229. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2230. callableBuilder.Add(data1);
  2231. callableBuilder.Add(data2);
  2232. return TRuntimeNode(callableBuilder.Build(), false);
  2233. }
  2234. TRuntimeNode TProgramBuilder::AllOf(TRuntimeNode list, const TUnaryLambda& predicate) {
  2235. return Not(NotAllOf(list, predicate));
  2236. }
  2237. TRuntimeNode TProgramBuilder::NotAllOf(TRuntimeNode list, const TUnaryLambda& predicate) {
  2238. return Exists(ToOptional(SkipWhile(list, predicate)));
  2239. }
  2240. TRuntimeNode TProgramBuilder::BitNot(TRuntimeNode data) {
  2241. const std::array<TRuntimeNode, 1> args = {{ data }};
  2242. return Invoke(__func__, data.GetStaticType(), args);
  2243. }
  2244. TRuntimeNode TProgramBuilder::CountBits(TRuntimeNode data) {
  2245. const std::array<TRuntimeNode, 1> args = {{ data }};
  2246. return Invoke(__func__, data.GetStaticType(), args);
  2247. }
  2248. TRuntimeNode TProgramBuilder::BitAnd(TRuntimeNode data1, TRuntimeNode data2) {
  2249. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2250. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2251. }
  2252. TRuntimeNode TProgramBuilder::BitOr(TRuntimeNode data1, TRuntimeNode data2) {
  2253. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2254. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2255. }
  2256. TRuntimeNode TProgramBuilder::BitXor(TRuntimeNode data1, TRuntimeNode data2) {
  2257. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2258. return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args);
  2259. }
  2260. TRuntimeNode TProgramBuilder::ShiftLeft(TRuntimeNode arg, TRuntimeNode bits) {
  2261. const std::array<TRuntimeNode, 2> args = {{ arg, bits }};
  2262. return Invoke(__func__, arg.GetStaticType(), args);
  2263. }
  2264. TRuntimeNode TProgramBuilder::RotLeft(TRuntimeNode arg, TRuntimeNode bits) {
  2265. const std::array<TRuntimeNode, 2> args = {{ arg, bits }};
  2266. return Invoke(__func__, arg.GetStaticType(), args);
  2267. }
  2268. TRuntimeNode TProgramBuilder::ShiftRight(TRuntimeNode arg, TRuntimeNode bits) {
  2269. const std::array<TRuntimeNode, 2> args = {{ arg, bits }};
  2270. return Invoke(__func__, arg.GetStaticType(), args);
  2271. }
  2272. TRuntimeNode TProgramBuilder::RotRight(TRuntimeNode arg, TRuntimeNode bits) {
  2273. const std::array<TRuntimeNode, 2> args = {{ arg, bits }};
  2274. return Invoke(__func__, arg.GetStaticType(), args);
  2275. }
  2276. TRuntimeNode TProgramBuilder::Mod(TRuntimeNode data1, TRuntimeNode data2) {
  2277. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2278. auto resultType = BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType());
  2279. if (resultType->IsData() && !(NUdf::GetDataTypeInfo(*static_cast<TDataType*>(resultType)->GetDataSlot()).Features & (NUdf::EDataTypeFeatures::FloatType | NUdf::EDataTypeFeatures::DecimalType))) {
  2280. resultType = NewOptionalType(resultType);
  2281. }
  2282. return Invoke(__func__, resultType, args);
  2283. }
  2284. TRuntimeNode TProgramBuilder::BuildMinMax(const std::string_view& callableName, const TRuntimeNode* data, size_t size) {
  2285. switch (size) {
  2286. case 0U: return NewNull();
  2287. case 1U: return *data;
  2288. case 2U: return InvokeBinary(callableName, ChooseCommonType(data[0U].GetStaticType(), data[1U].GetStaticType()), data[0U], data[1U]);
  2289. default: break;
  2290. }
  2291. const auto half = size >> 1U;
  2292. const std::array<TRuntimeNode, 2U> args = {{ BuildMinMax(callableName, data, half), BuildMinMax(callableName, data + half, size - half) }};
  2293. return BuildMinMax(callableName, args.data(), args.size());
  2294. }
  2295. TRuntimeNode TProgramBuilder::BuildWideSkipTakeBlocks(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count) {
  2296. ValidateBlockFlowType(flow.GetStaticType());
  2297. MKQL_ENSURE(count.GetStaticType()->IsData(), "Expected data");
  2298. MKQL_ENSURE(static_cast<const TDataType&>(*count.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  2299. TCallableBuilder callableBuilder(Env, callableName, flow.GetStaticType());
  2300. callableBuilder.Add(flow);
  2301. callableBuilder.Add(count);
  2302. return TRuntimeNode(callableBuilder.Build(), false);
  2303. }
  2304. TRuntimeNode TProgramBuilder::BuildBlockLogical(const std::string_view& callableName, TRuntimeNode first, TRuntimeNode second) {
  2305. auto firstType = AS_TYPE(TBlockType, first.GetStaticType());
  2306. auto secondType = AS_TYPE(TBlockType, second.GetStaticType());
  2307. bool isOpt1, isOpt2;
  2308. MKQL_ENSURE(UnpackOptionalData(firstType->GetItemType(), isOpt1)->GetSchemeType() == NUdf::TDataType<bool>::Id, "Requires boolean args.");
  2309. MKQL_ENSURE(UnpackOptionalData(secondType->GetItemType(), isOpt2)->GetSchemeType() == NUdf::TDataType<bool>::Id, "Requires boolean args.");
  2310. const auto itemType = NewDataType(NUdf::TDataType<bool>::Id, isOpt1 || isOpt2);
  2311. auto outputType = NewBlockType(itemType, GetResultShape({firstType, secondType}));
  2312. TCallableBuilder callableBuilder(Env, callableName, outputType);
  2313. callableBuilder.Add(first);
  2314. callableBuilder.Add(second);
  2315. return TRuntimeNode(callableBuilder.Build(), false);
  2316. }
  2317. TRuntimeNode TProgramBuilder::BuildBlockDecimalBinary(const std::string_view& callableName, TRuntimeNode first, TRuntimeNode second) {
  2318. auto firstType = AS_TYPE(TBlockType, first.GetStaticType());
  2319. auto secondType = AS_TYPE(TBlockType, second.GetStaticType());
  2320. bool isOpt1, isOpt2;
  2321. auto* leftDataType = UnpackOptionalData(firstType->GetItemType(), isOpt1);
  2322. UnpackOptionalData(secondType->GetItemType(), isOpt2);
  2323. MKQL_ENSURE(leftDataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id, "Requires decimal args.");
  2324. const auto& lParams = static_cast<TDataDecimalType*>(leftDataType)->GetParams();
  2325. auto [precision, scale] = lParams;
  2326. TType* outputType = TDataDecimalType::Create(precision, scale, Env);
  2327. if (isOpt1 || isOpt2) {
  2328. outputType = TOptionalType::Create(outputType, Env);
  2329. }
  2330. outputType = NewBlockType(outputType, TBlockType::EShape::Many);
  2331. TCallableBuilder callableBuilder(Env, callableName, outputType);
  2332. callableBuilder.Add(first);
  2333. callableBuilder.Add(second);
  2334. return TRuntimeNode(callableBuilder.Build(), false);
  2335. }
  2336. TRuntimeNode TProgramBuilder::Min(const TArrayRef<const TRuntimeNode>& args) {
  2337. return BuildMinMax(__func__, args.data(), args.size());
  2338. }
  2339. TRuntimeNode TProgramBuilder::Max(const TArrayRef<const TRuntimeNode>& args) {
  2340. return BuildMinMax(__func__, args.data(), args.size());
  2341. }
  2342. TRuntimeNode TProgramBuilder::Min(TRuntimeNode data1, TRuntimeNode data2) {
  2343. const std::array<TRuntimeNode, 2U> args = {{ data1, data2 }};
  2344. return Min(args);
  2345. }
  2346. TRuntimeNode TProgramBuilder::Max(TRuntimeNode data1, TRuntimeNode data2) {
  2347. const std::array<TRuntimeNode, 2U> args = {{ data1, data2 }};
  2348. return Max(args);
  2349. }
  2350. TRuntimeNode TProgramBuilder::Equals(TRuntimeNode data1, TRuntimeNode data2) {
  2351. return DataCompare(__func__, data1, data2);
  2352. }
  2353. TRuntimeNode TProgramBuilder::NotEquals(TRuntimeNode data1, TRuntimeNode data2) {
  2354. return DataCompare(__func__, data1, data2);
  2355. }
  2356. TRuntimeNode TProgramBuilder::Less(TRuntimeNode data1, TRuntimeNode data2) {
  2357. return DataCompare(__func__, data1, data2);
  2358. }
  2359. TRuntimeNode TProgramBuilder::LessOrEqual(TRuntimeNode data1, TRuntimeNode data2) {
  2360. return DataCompare(__func__, data1, data2);
  2361. }
  2362. TRuntimeNode TProgramBuilder::Greater(TRuntimeNode data1, TRuntimeNode data2) {
  2363. return DataCompare(__func__, data1, data2);
  2364. }
  2365. TRuntimeNode TProgramBuilder::GreaterOrEqual(TRuntimeNode data1, TRuntimeNode data2) {
  2366. return DataCompare(__func__, data1, data2);
  2367. }
  2368. TRuntimeNode TProgramBuilder::InvokeBinary(const std::string_view& callableName, TType* type, TRuntimeNode data1, TRuntimeNode data2) {
  2369. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2370. return Invoke(callableName, type, args);
  2371. }
  2372. TRuntimeNode TProgramBuilder::AggrCompare(const std::string_view& callableName, TRuntimeNode data1, TRuntimeNode data2) {
  2373. return InvokeBinary(callableName, NewDataType(NUdf::TDataType<bool>::Id), data1, data2);
  2374. }
  2375. TRuntimeNode TProgramBuilder::DataCompare(const std::string_view& callableName, TRuntimeNode left, TRuntimeNode right) {
  2376. bool isOptionalLeft, isOptionalRight;
  2377. const auto leftType = UnpackOptionalData(left, isOptionalLeft);
  2378. const auto rightType = UnpackOptionalData(right, isOptionalRight);
  2379. const auto lId = leftType->GetSchemeType();
  2380. const auto rId = rightType->GetSchemeType();
  2381. if (lId == NUdf::TDataType<NUdf::TDecimal>::Id && rId == NUdf::TDataType<NUdf::TDecimal>::Id) {
  2382. const auto& lDec = static_cast<TDataDecimalType*>(leftType)->GetParams();
  2383. const auto& rDec = static_cast<TDataDecimalType*>(rightType)->GetParams();
  2384. if (lDec.second < rDec.second) {
  2385. left = ToDecimal(left, std::min<ui8>(lDec.first + rDec.second - lDec.second, NYql::NDecimal::MaxPrecision), rDec.second);
  2386. } else if (lDec.second > rDec.second) {
  2387. right = ToDecimal(right, std::min<ui8>(rDec.first + lDec.second - rDec.second, NYql::NDecimal::MaxPrecision), lDec.second);
  2388. }
  2389. } else if (lId == NUdf::TDataType<NUdf::TDecimal>::Id && NUdf::GetDataTypeInfo(NUdf::GetDataSlot(rId)).Features & NUdf::EDataTypeFeatures::IntegralType) {
  2390. const auto scale = static_cast<TDataDecimalType*>(leftType)->GetParams().second;
  2391. right = ToDecimal(right, std::min<ui8>(NYql::NDecimal::MaxPrecision, NUdf::GetDataTypeInfo(NUdf::GetDataSlot(rId)).DecimalDigits + scale), scale);
  2392. } else if (rId == NUdf::TDataType<NUdf::TDecimal>::Id && NUdf::GetDataTypeInfo(NUdf::GetDataSlot(lId)).Features & NUdf::EDataTypeFeatures::IntegralType) {
  2393. const auto scale = static_cast<TDataDecimalType*>(rightType)->GetParams().second;
  2394. left = ToDecimal(left, std::min<ui8>(NYql::NDecimal::MaxPrecision, NUdf::GetDataTypeInfo(NUdf::GetDataSlot(lId)).DecimalDigits + scale), scale);
  2395. }
  2396. const std::array<TRuntimeNode, 2> args = {{ left, right }};
  2397. const auto resultType = isOptionalLeft || isOptionalRight ? NewOptionalType(NewDataType(NUdf::TDataType<bool>::Id)) : NewDataType(NUdf::TDataType<bool>::Id);
  2398. return Invoke(callableName, resultType, args);
  2399. }
  2400. TRuntimeNode TProgramBuilder::BuildRangeLogical(const std::string_view& callableName, const TArrayRef<const TRuntimeNode>& lists) {
  2401. MKQL_ENSURE(!lists.empty(), "Expecting at least one argument");
  2402. for (auto& list : lists) {
  2403. MKQL_ENSURE(list.GetStaticType()->IsList(), "Expecting lists");
  2404. MKQL_ENSURE(list.GetStaticType()->IsSameType(*lists.front().GetStaticType()), "Expecting arguments of same type");
  2405. }
  2406. TCallableBuilder callableBuilder(Env, callableName, lists.front().GetStaticType());
  2407. for (auto& list : lists) {
  2408. callableBuilder.Add(list);
  2409. }
  2410. return TRuntimeNode(callableBuilder.Build(), false);
  2411. }
  2412. TRuntimeNode TProgramBuilder::AggrEquals(TRuntimeNode data1, TRuntimeNode data2) {
  2413. return AggrCompare(__func__, data1, data2);
  2414. }
  2415. TRuntimeNode TProgramBuilder::AggrNotEquals(TRuntimeNode data1, TRuntimeNode data2) {
  2416. return AggrCompare(__func__, data1, data2);
  2417. }
  2418. TRuntimeNode TProgramBuilder::AggrLess(TRuntimeNode data1, TRuntimeNode data2) {
  2419. return AggrCompare(__func__, data1, data2);
  2420. }
  2421. TRuntimeNode TProgramBuilder::AggrLessOrEqual(TRuntimeNode data1, TRuntimeNode data2) {
  2422. return AggrCompare(__func__, data1, data2);
  2423. }
  2424. TRuntimeNode TProgramBuilder::AggrGreater(TRuntimeNode data1, TRuntimeNode data2) {
  2425. return AggrCompare(__func__, data1, data2);
  2426. }
  2427. TRuntimeNode TProgramBuilder::AggrGreaterOrEqual(TRuntimeNode data1, TRuntimeNode data2) {
  2428. return AggrCompare(__func__, data1, data2);
  2429. }
  2430. TRuntimeNode TProgramBuilder::If(TRuntimeNode condition, TRuntimeNode thenBranch, TRuntimeNode elseBranch) {
  2431. bool condOpt, thenOpt, elseOpt;
  2432. const auto conditionType = UnpackOptionalData(condition, condOpt);
  2433. MKQL_ENSURE(conditionType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool");
  2434. const auto thenUnpacked = UnpackOptional(thenBranch, thenOpt);
  2435. const auto elseUnpacked = UnpackOptional(elseBranch, elseOpt);
  2436. MKQL_ENSURE(thenUnpacked->IsSameType(*elseUnpacked), "Different return types in branches.");
  2437. const bool isOptional = condOpt || thenOpt || elseOpt;
  2438. TCallableBuilder callableBuilder(Env, __func__, isOptional ? NewOptionalType(thenUnpacked) : thenUnpacked);
  2439. callableBuilder.Add(condition);
  2440. callableBuilder.Add(thenBranch);
  2441. callableBuilder.Add(elseBranch);
  2442. return TRuntimeNode(callableBuilder.Build(), false);
  2443. }
  2444. TRuntimeNode TProgramBuilder::If(const TArrayRef<const TRuntimeNode>& args) {
  2445. MKQL_ENSURE(args.size() % 2U, "Expected odd arguments.");
  2446. MKQL_ENSURE(args.size() >= 3U, "Expected at least three arguments.");
  2447. return If(args.front(), args[1U], 3U == args.size() ? args.back() : If(args.last(args.size() - 2U)));
  2448. }
  2449. TRuntimeNode TProgramBuilder::If(TRuntimeNode condition, TRuntimeNode thenBranch, TRuntimeNode elseBranch, TType* resultType) {
  2450. bool condOpt;
  2451. const auto conditionType = UnpackOptionalData(condition, condOpt);
  2452. MKQL_ENSURE(conditionType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool");
  2453. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2454. callableBuilder.Add(condition);
  2455. callableBuilder.Add(thenBranch);
  2456. callableBuilder.Add(elseBranch);
  2457. return TRuntimeNode(callableBuilder.Build(), false);
  2458. }
  2459. TRuntimeNode TProgramBuilder::Ensure(TRuntimeNode value, TRuntimeNode predicate, TRuntimeNode message, const std::string_view& file, ui32 row, ui32 column) {
  2460. bool isOptional;
  2461. const auto unpackedType = UnpackOptionalData(predicate, isOptional);
  2462. MKQL_ENSURE(unpackedType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool");
  2463. const auto& messageType = message.GetStaticType();
  2464. MKQL_ENSURE(messageType->IsData(), "Expected data");
  2465. const auto& messageTypeData = static_cast<const TDataType&>(*messageType);
  2466. MKQL_ENSURE(messageTypeData.GetSchemeType() == NUdf::TDataType<char*>::Id || messageTypeData.GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected string or utf8.");
  2467. TCallableBuilder callableBuilder(Env, __func__, value.GetStaticType());
  2468. callableBuilder.Add(value);
  2469. callableBuilder.Add(predicate);
  2470. callableBuilder.Add(message);
  2471. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  2472. callableBuilder.Add(NewDataLiteral(row));
  2473. callableBuilder.Add(NewDataLiteral(column));
  2474. return TRuntimeNode(callableBuilder.Build(), false);
  2475. }
  2476. TRuntimeNode TProgramBuilder::SourceOf(TType* returnType) {
  2477. MKQL_ENSURE(returnType->IsFlow() || returnType->IsStream(), "Expected flow or stream.");
  2478. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2479. return TRuntimeNode(callableBuilder.Build(), false);
  2480. }
  2481. TRuntimeNode TProgramBuilder::Source() {
  2482. if constexpr (RuntimeVersion < 18U) {
  2483. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  2484. }
  2485. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType({})));
  2486. return TRuntimeNode(callableBuilder.Build(), false);
  2487. }
  2488. TRuntimeNode TProgramBuilder::IfPresent(TRuntimeNode optional, const TUnaryLambda& thenBranch, TRuntimeNode elseBranch) {
  2489. bool isOptional;
  2490. const auto unpackedType = UnpackOptional(optional, isOptional);
  2491. if (!isOptional) {
  2492. return thenBranch(optional);
  2493. }
  2494. const auto itemArg = Arg(unpackedType);
  2495. const auto then = thenBranch(itemArg);
  2496. bool thenOpt, elseOpt;
  2497. const auto thenUnpacked = UnpackOptional(then, thenOpt);
  2498. const auto elseUnpacked = UnpackOptional(elseBranch, elseOpt);
  2499. MKQL_ENSURE(thenUnpacked->IsSameType(*elseUnpacked), "Different return types in branches.");
  2500. TCallableBuilder callableBuilder(Env, __func__, (thenOpt || elseOpt) ? NewOptionalType(thenUnpacked) : thenUnpacked);
  2501. callableBuilder.Add(optional);
  2502. callableBuilder.Add(itemArg);
  2503. callableBuilder.Add(then);
  2504. callableBuilder.Add(elseBranch);
  2505. return TRuntimeNode(callableBuilder.Build(), false);
  2506. }
  2507. TRuntimeNode TProgramBuilder::IfPresent(TRuntimeNode::TList optionals, const TNarrowLambda& thenBranch, TRuntimeNode elseBranch) {
  2508. switch (optionals.size()) {
  2509. case 0U:
  2510. return thenBranch({});
  2511. case 1U:
  2512. return IfPresent(optionals.front(), [&](TRuntimeNode unwrap){ return thenBranch({unwrap}); }, elseBranch);
  2513. default:
  2514. break;
  2515. }
  2516. const auto first = optionals.front();
  2517. optionals.erase(optionals.cbegin());
  2518. return IfPresent(first,
  2519. [&](TRuntimeNode head) {
  2520. return IfPresent(optionals,
  2521. [&](TRuntimeNode::TList tail) {
  2522. tail.insert(tail.cbegin(), head);
  2523. return thenBranch(tail);
  2524. },
  2525. elseBranch
  2526. );
  2527. },
  2528. elseBranch
  2529. );
  2530. }
  2531. TRuntimeNode TProgramBuilder::Not(TRuntimeNode data) {
  2532. return UnaryDataFunction(data, __func__, TDataFunctionFlags::CommonOptionalResult | TDataFunctionFlags::RequiresBooleanArgs | TDataFunctionFlags::AllowOptionalArgs);
  2533. }
  2534. TRuntimeNode TProgramBuilder::BuildBinaryLogical(const std::string_view& callableName, TRuntimeNode data1, TRuntimeNode data2) {
  2535. bool isOpt1, isOpt2;
  2536. MKQL_ENSURE(UnpackOptionalData(data1, isOpt1)->GetSchemeType() == NUdf::TDataType<bool>::Id, "Requires boolean args.");
  2537. MKQL_ENSURE(UnpackOptionalData(data2, isOpt2)->GetSchemeType() == NUdf::TDataType<bool>::Id, "Requires boolean args.");
  2538. const auto resultType = NewDataType(NUdf::TDataType<bool>::Id, isOpt1 || isOpt2);
  2539. TCallableBuilder callableBuilder(Env, callableName, resultType);
  2540. callableBuilder.Add(data1);
  2541. callableBuilder.Add(data2);
  2542. return TRuntimeNode(callableBuilder.Build(), false);
  2543. }
  2544. TRuntimeNode TProgramBuilder::BuildLogical(const std::string_view& callableName, const TArrayRef<const TRuntimeNode>& args) {
  2545. MKQL_ENSURE(!args.empty(), "Empty logical args.");
  2546. switch (args.size()) {
  2547. case 1U: return args.front();
  2548. case 2U: return BuildBinaryLogical(callableName, args.front(), args.back());
  2549. }
  2550. const auto half = (args.size() + 1U) >> 1U;
  2551. const TArrayRef<const TRuntimeNode> one(args.data(), half), two(args.data() + half, args.size() - half);
  2552. return BuildBinaryLogical(callableName, BuildLogical(callableName, one), BuildLogical(callableName, two));
  2553. }
  2554. TRuntimeNode TProgramBuilder::And(const TArrayRef<const TRuntimeNode>& args) {
  2555. return BuildLogical(__func__, args);
  2556. }
  2557. TRuntimeNode TProgramBuilder::Or(const TArrayRef<const TRuntimeNode>& args) {
  2558. return BuildLogical(__func__, args);
  2559. }
  2560. TRuntimeNode TProgramBuilder::Xor(const TArrayRef<const TRuntimeNode>& args) {
  2561. return BuildLogical(__func__, args);
  2562. }
  2563. TRuntimeNode TProgramBuilder::Exists(TRuntimeNode data) {
  2564. const auto& nodeType = data.GetStaticType();
  2565. if (nodeType->IsVoid()) {
  2566. return NewDataLiteral(false);
  2567. }
  2568. if (!nodeType->IsOptional() && !nodeType->IsPg()) {
  2569. return NewDataLiteral(true);
  2570. }
  2571. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<bool>::Id));
  2572. callableBuilder.Add(data);
  2573. return TRuntimeNode(callableBuilder.Build(), false);
  2574. }
  2575. TRuntimeNode TProgramBuilder::NewMTRand(TRuntimeNode seed) {
  2576. auto seedData = AS_TYPE(TDataType, seed);
  2577. MKQL_ENSURE(seedData->GetSchemeType() == NUdf::TDataType<ui64>::Id, "seed must be ui64");
  2578. TCallableBuilder callableBuilder(Env, __func__, NewResourceType(RandomMTResource), true);
  2579. callableBuilder.Add(seed);
  2580. return TRuntimeNode(callableBuilder.Build(), false);
  2581. }
  2582. TRuntimeNode TProgramBuilder::NextMTRand(TRuntimeNode rand) {
  2583. auto resType = AS_TYPE(TResourceType, rand);
  2584. MKQL_ENSURE(resType->GetTag() == RandomMTResource, "Expected MTRand resource");
  2585. const std::array<TType*, 2U> tupleTypes = {{ NewDataType(NUdf::TDataType<ui64>::Id), rand.GetStaticType() }};
  2586. auto returnType = NewTupleType(tupleTypes);
  2587. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2588. callableBuilder.Add(rand);
  2589. return TRuntimeNode(callableBuilder.Build(), false);
  2590. }
  2591. TRuntimeNode TProgramBuilder::AggrCountInit(TRuntimeNode value) {
  2592. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<ui64>::Id));
  2593. callableBuilder.Add(value);
  2594. return TRuntimeNode(callableBuilder.Build(), false);
  2595. }
  2596. TRuntimeNode TProgramBuilder::AggrCountUpdate(TRuntimeNode value, TRuntimeNode state) {
  2597. MKQL_ENSURE(AS_TYPE(TDataType, state)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64 type");
  2598. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<ui64>::Id));
  2599. callableBuilder.Add(value);
  2600. callableBuilder.Add(state);
  2601. return TRuntimeNode(callableBuilder.Build(), false);
  2602. }
  2603. TRuntimeNode TProgramBuilder::AggrMin(TRuntimeNode data1, TRuntimeNode data2) {
  2604. const auto type = data1.GetStaticType();
  2605. MKQL_ENSURE(type->IsSameType(*data2.GetStaticType()), "Must be same type.");
  2606. return InvokeBinary(__func__, type, data1, data2);
  2607. }
  2608. TRuntimeNode TProgramBuilder::AggrMax(TRuntimeNode data1, TRuntimeNode data2) {
  2609. const auto type = data1.GetStaticType();
  2610. MKQL_ENSURE(type->IsSameType(*data2.GetStaticType()), "Must be same type.");
  2611. return InvokeBinary(__func__, type, data1, data2);
  2612. }
  2613. TRuntimeNode TProgramBuilder::AggrAdd(TRuntimeNode data1, TRuntimeNode data2) {
  2614. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  2615. bool isOptionalLeft;
  2616. const auto leftType = UnpackOptionalData(data1, isOptionalLeft);
  2617. if (leftType->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id)
  2618. return Invoke(__func__, data1.GetStaticType(), args);
  2619. const auto decimalType = static_cast<TDataDecimalType*>(leftType);
  2620. bool isOptionalRight;
  2621. const auto rightType = static_cast<TDataDecimalType*>(UnpackOptionalData(data2, isOptionalRight));
  2622. MKQL_ENSURE(rightType->IsSameType(*decimalType), "Operands type mismatch");
  2623. return Invoke(TString("AggrAdd_") += ::ToString(decimalType->GetParams().first), data1.GetStaticType(), args);
  2624. }
  2625. TRuntimeNode TProgramBuilder::QueueCreate(TRuntimeNode initCapacity, TRuntimeNode initSize, const TArrayRef<const TRuntimeNode>& dependentNodes, TType* returnType) {
  2626. auto resType = AS_TYPE(TResourceType, returnType);
  2627. const auto tag = resType->GetTag();
  2628. if (initCapacity.GetStaticType()->IsVoid()) {
  2629. MKQL_ENSURE(RuntimeVersion >= 13, "Unbounded queue is not supported in runtime version " << RuntimeVersion);
  2630. } else {
  2631. auto initCapacityType = AS_TYPE(TDataType, initCapacity);
  2632. MKQL_ENSURE(initCapacityType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "init capcity must be ui64");
  2633. }
  2634. auto initSizeType = AS_TYPE(TDataType, initSize);
  2635. MKQL_ENSURE(initSizeType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "init size must be ui64");
  2636. TCallableBuilder callableBuilder(Env, __func__, returnType, true);
  2637. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(tag));
  2638. callableBuilder.Add(initCapacity);
  2639. callableBuilder.Add(initSize);
  2640. for (auto node : dependentNodes) {
  2641. callableBuilder.Add(node);
  2642. }
  2643. return TRuntimeNode(callableBuilder.Build(), false);
  2644. }
  2645. TRuntimeNode TProgramBuilder::QueuePush(TRuntimeNode resource, TRuntimeNode value) {
  2646. auto resType = AS_TYPE(TResourceType, resource);
  2647. const auto tag = resType->GetTag();
  2648. MKQL_ENSURE(tag.StartsWith(ResourceQueuePrefix), "Expected Queue resource");
  2649. TCallableBuilder callableBuilder(Env, __func__, resource.GetStaticType());
  2650. callableBuilder.Add(resource);
  2651. callableBuilder.Add(value);
  2652. return TRuntimeNode(callableBuilder.Build(), false);
  2653. }
  2654. TRuntimeNode TProgramBuilder::QueuePop(TRuntimeNode resource) {
  2655. auto resType = AS_TYPE(TResourceType, resource);
  2656. const auto tag = resType->GetTag();
  2657. MKQL_ENSURE(tag.StartsWith(ResourceQueuePrefix), "Expected Queue resource");
  2658. TCallableBuilder callableBuilder(Env, __func__, resource.GetStaticType());
  2659. callableBuilder.Add(resource);
  2660. return TRuntimeNode(callableBuilder.Build(), false);
  2661. }
  2662. TRuntimeNode TProgramBuilder::QueuePeek(TRuntimeNode resource, TRuntimeNode index, const TArrayRef<const TRuntimeNode>& dependentNodes, TType* returnType) {
  2663. MKQL_ENSURE(returnType->IsOptional(), "Expected optional type as result of QueuePeek");
  2664. auto resType = AS_TYPE(TResourceType, resource);
  2665. auto indexType = AS_TYPE(TDataType, index);
  2666. MKQL_ENSURE(indexType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "index size must be ui64");
  2667. const auto tag = resType->GetTag();
  2668. MKQL_ENSURE(tag.StartsWith(ResourceQueuePrefix), "Expected Queue resource");
  2669. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2670. callableBuilder.Add(resource);
  2671. callableBuilder.Add(index);
  2672. for (auto node : dependentNodes) {
  2673. callableBuilder.Add(node);
  2674. }
  2675. return TRuntimeNode(callableBuilder.Build(), false);
  2676. }
  2677. TRuntimeNode TProgramBuilder::QueueRange(TRuntimeNode resource, TRuntimeNode begin, TRuntimeNode end, const TArrayRef<const TRuntimeNode>& dependentNodes, TType* returnType) {
  2678. MKQL_ENSURE(RuntimeVersion >= 14, "QueueRange is not supported in runtime version " << RuntimeVersion);
  2679. MKQL_ENSURE(returnType->IsList(), "Expected list type as result of QueueRange");
  2680. auto resType = AS_TYPE(TResourceType, resource);
  2681. auto beginType = AS_TYPE(TDataType, begin);
  2682. MKQL_ENSURE(beginType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "begin index must be ui64");
  2683. auto endType = AS_TYPE(TDataType, end);
  2684. MKQL_ENSURE(endType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "end index must be ui64");
  2685. const auto tag = resType->GetTag();
  2686. MKQL_ENSURE(tag.StartsWith(ResourceQueuePrefix), "Expected Queue resource");
  2687. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2688. callableBuilder.Add(resource);
  2689. callableBuilder.Add(begin);
  2690. callableBuilder.Add(end);
  2691. for (auto node : dependentNodes) {
  2692. callableBuilder.Add(node);
  2693. }
  2694. return TRuntimeNode(callableBuilder.Build(), false);
  2695. }
  2696. TRuntimeNode TProgramBuilder::PreserveStream(TRuntimeNode stream, TRuntimeNode queue, TRuntimeNode outpace) {
  2697. auto streamType = AS_TYPE(TStreamType, stream);
  2698. auto resType = AS_TYPE(TResourceType, queue);
  2699. auto outpaceType = AS_TYPE(TDataType, outpace);
  2700. MKQL_ENSURE(outpaceType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "PreserveStream: outpace size must be ui64");
  2701. const auto tag = resType->GetTag();
  2702. MKQL_ENSURE(tag.StartsWith(ResourceQueuePrefix), "PreserveStream: Expected Queue resource");
  2703. TCallableBuilder callableBuilder(Env, __func__, streamType);
  2704. callableBuilder.Add(stream);
  2705. callableBuilder.Add(queue);
  2706. callableBuilder.Add(outpace);
  2707. return TRuntimeNode(callableBuilder.Build(), false);
  2708. }
  2709. TRuntimeNode TProgramBuilder::Seq(const TArrayRef<const TRuntimeNode>& args, TType* returnType) {
  2710. MKQL_ENSURE(RuntimeVersion >= 15, "Seq is not supported in runtime version " << RuntimeVersion);
  2711. TCallableBuilder callableBuilder(Env, __func__, returnType);
  2712. for (auto node : args) {
  2713. callableBuilder.Add(node);
  2714. }
  2715. return TRuntimeNode(callableBuilder.Build(), false);
  2716. }
  2717. TRuntimeNode TProgramBuilder::FromYsonSimpleType(TRuntimeNode input, NUdf::TDataTypeId schemeType) {
  2718. auto type = input.GetStaticType();
  2719. if (type->IsOptional()) {
  2720. type = static_cast<const TOptionalType&>(*type).GetItemType();
  2721. }
  2722. MKQL_ENSURE(type->IsData(), "Expected data type");
  2723. auto resDataType = NewDataType(schemeType);
  2724. auto resultType = NewOptionalType(resDataType);
  2725. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2726. callableBuilder.Add(input);
  2727. callableBuilder.Add(NewDataLiteral(static_cast<ui32>(schemeType)));
  2728. return TRuntimeNode(callableBuilder.Build(), false);
  2729. }
  2730. TRuntimeNode TProgramBuilder::TryWeakMemberFromDict(TRuntimeNode other, TRuntimeNode rest, NUdf::TDataTypeId schemeType, const std::string_view& memberName) {
  2731. auto resDataType = NewDataType(schemeType);
  2732. auto resultType = NewOptionalType(resDataType);
  2733. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2734. callableBuilder.Add(other);
  2735. callableBuilder.Add(rest);
  2736. callableBuilder.Add(NewDataLiteral(static_cast<ui32>(schemeType)));
  2737. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(memberName));
  2738. return TRuntimeNode(callableBuilder.Build(), false);
  2739. }
  2740. TRuntimeNode TProgramBuilder::TimezoneId(TRuntimeNode name) {
  2741. bool isOptional;
  2742. auto dataType = UnpackOptionalData(name, isOptional);
  2743. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expected string");
  2744. auto resultType = NewOptionalType(NewDataType(NUdf::EDataSlot::Uint16));
  2745. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2746. callableBuilder.Add(name);
  2747. return TRuntimeNode(callableBuilder.Build(), false);
  2748. }
  2749. TRuntimeNode TProgramBuilder::TimezoneName(TRuntimeNode id) {
  2750. bool isOptional;
  2751. auto dataType = UnpackOptionalData(id, isOptional);
  2752. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<ui16>::Id, "Expected ui32");
  2753. auto resultType = NewOptionalType(NewDataType(NUdf::EDataSlot::String));
  2754. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2755. callableBuilder.Add(id);
  2756. return TRuntimeNode(callableBuilder.Build(), false);
  2757. }
  2758. TRuntimeNode TProgramBuilder::AddTimezone(TRuntimeNode utc, TRuntimeNode id) {
  2759. bool isOptional1;
  2760. auto dataType1 = UnpackOptionalData(utc, isOptional1);
  2761. MKQL_ENSURE(NUdf::GetDataTypeInfo(*dataType1->GetDataSlot()).Features & NUdf::DateType, "Expected date type");
  2762. bool isOptional2;
  2763. auto dataType2 = UnpackOptionalData(id, isOptional2);
  2764. MKQL_ENSURE(dataType2->GetSchemeType() == NUdf::TDataType<ui16>::Id, "Expected ui16");
  2765. NUdf::EDataSlot tzType;
  2766. switch (*dataType1->GetDataSlot()) {
  2767. case NUdf::EDataSlot::Date: tzType = NUdf::EDataSlot::TzDate; break;
  2768. case NUdf::EDataSlot::Datetime: tzType = NUdf::EDataSlot::TzDatetime; break;
  2769. case NUdf::EDataSlot::Timestamp: tzType = NUdf::EDataSlot::TzTimestamp; break;
  2770. case NUdf::EDataSlot::Date32: tzType = NUdf::EDataSlot::TzDate32; break;
  2771. case NUdf::EDataSlot::Datetime64: tzType = NUdf::EDataSlot::TzDatetime64; break;
  2772. case NUdf::EDataSlot::Timestamp64: tzType = NUdf::EDataSlot::TzTimestamp64; break;
  2773. default:
  2774. ythrow yexception() << "Unknown date type: " << *dataType1->GetDataSlot();
  2775. }
  2776. auto resultType = NewOptionalType(NewDataType(tzType));
  2777. TCallableBuilder callableBuilder(Env, __func__, resultType);
  2778. callableBuilder.Add(utc);
  2779. callableBuilder.Add(id);
  2780. return TRuntimeNode(callableBuilder.Build(), false);
  2781. }
  2782. TRuntimeNode TProgramBuilder::RemoveTimezone(TRuntimeNode local) {
  2783. bool isOptional1;
  2784. const auto dataType1 = UnpackOptionalData(local, isOptional1);
  2785. MKQL_ENSURE((NUdf::GetDataTypeInfo(*dataType1->GetDataSlot()).Features & NUdf::TzDateType), "Expected date with timezone type");
  2786. NUdf::EDataSlot type;
  2787. switch (*dataType1->GetDataSlot()) {
  2788. case NUdf::EDataSlot::TzDate: type = NUdf::EDataSlot::Date; break;
  2789. case NUdf::EDataSlot::TzDatetime: type = NUdf::EDataSlot::Datetime; break;
  2790. case NUdf::EDataSlot::TzTimestamp: type = NUdf::EDataSlot::Timestamp; break;
  2791. case NUdf::EDataSlot::TzDate32: type = NUdf::EDataSlot::Date32; break;
  2792. case NUdf::EDataSlot::TzDatetime64: type = NUdf::EDataSlot::Datetime64; break;
  2793. case NUdf::EDataSlot::TzTimestamp64: type = NUdf::EDataSlot::Timestamp64; break;
  2794. default:
  2795. ythrow yexception() << "Unknown date with timezone type: " << *dataType1->GetDataSlot();
  2796. }
  2797. return Convert(local, NewDataType(type, isOptional1));
  2798. }
  2799. TRuntimeNode TProgramBuilder::Nth(TRuntimeNode tuple, ui32 index) {
  2800. bool isOptional;
  2801. const auto type = AS_TYPE(TTupleType, UnpackOptional(tuple.GetStaticType(), isOptional));
  2802. MKQL_ENSURE(index < type->GetElementsCount(), "Index out of range: " << index <<
  2803. " is not less than " << type->GetElementsCount());
  2804. auto itemType = type->GetElementType(index);
  2805. if (isOptional && !itemType->IsOptional() && !itemType->IsNull() && !itemType->IsPg()) {
  2806. itemType = TOptionalType::Create(itemType, Env);
  2807. }
  2808. TCallableBuilder callableBuilder(Env, __func__, itemType);
  2809. callableBuilder.Add(tuple);
  2810. callableBuilder.Add(NewDataLiteral<ui32>(index));
  2811. return TRuntimeNode(callableBuilder.Build(), false);
  2812. }
  2813. TRuntimeNode TProgramBuilder::Element(TRuntimeNode tuple, ui32 index) {
  2814. return Nth(tuple, index);
  2815. }
  2816. TRuntimeNode TProgramBuilder::Guess(TRuntimeNode variant, ui32 tupleIndex) {
  2817. bool isOptional;
  2818. auto unpacked = UnpackOptional(variant, isOptional);
  2819. auto type = AS_TYPE(TVariantType, unpacked);
  2820. auto underlyingType = AS_TYPE(TTupleType, type->GetUnderlyingType());
  2821. MKQL_ENSURE(tupleIndex < underlyingType->GetElementsCount(), "Wrong tuple index");
  2822. auto resType = TOptionalType::Create(underlyingType->GetElementType(tupleIndex), Env);
  2823. TCallableBuilder callableBuilder(Env, __func__, resType);
  2824. callableBuilder.Add(variant);
  2825. callableBuilder.Add(NewDataLiteral<ui32>(tupleIndex));
  2826. return TRuntimeNode(callableBuilder.Build(), false);
  2827. }
  2828. TRuntimeNode TProgramBuilder::Guess(TRuntimeNode variant, const std::string_view& memberName) {
  2829. bool isOptional;
  2830. auto unpacked = UnpackOptional(variant, isOptional);
  2831. auto type = AS_TYPE(TVariantType, unpacked);
  2832. auto underlyingType = AS_TYPE(TStructType, type->GetUnderlyingType());
  2833. auto structIndex = underlyingType->GetMemberIndex(memberName);
  2834. auto resType = TOptionalType::Create(underlyingType->GetMemberType(structIndex), Env);
  2835. TCallableBuilder callableBuilder(Env, __func__, resType);
  2836. callableBuilder.Add(variant);
  2837. callableBuilder.Add(NewDataLiteral<ui32>(structIndex));
  2838. return TRuntimeNode(callableBuilder.Build(), false);
  2839. }
  2840. TRuntimeNode TProgramBuilder::Way(TRuntimeNode variant) {
  2841. bool isOptional;
  2842. auto unpacked = UnpackOptional(variant, isOptional);
  2843. auto type = AS_TYPE(TVariantType, unpacked);
  2844. auto underlyingType = type->GetUnderlyingType();
  2845. auto dataType = NewDataType(underlyingType->IsTuple() ? NUdf::EDataSlot::Uint32 : NUdf::EDataSlot::Utf8);
  2846. auto resType = isOptional ? TOptionalType::Create(dataType, Env) : dataType;
  2847. TCallableBuilder callableBuilder(Env, __func__, resType);
  2848. callableBuilder.Add(variant);
  2849. return TRuntimeNode(callableBuilder.Build(), false);
  2850. }
  2851. TRuntimeNode TProgramBuilder::VariantItem(TRuntimeNode variant) {
  2852. bool isOptional;
  2853. auto unpacked = UnpackOptional(variant, isOptional);
  2854. auto type = AS_TYPE(TVariantType, unpacked);
  2855. auto underlyingType = type->GetAlternativeType(0);
  2856. auto resType = isOptional ? TOptionalType::Create(underlyingType, Env) : underlyingType;
  2857. TCallableBuilder callableBuilder(Env, __func__, resType);
  2858. callableBuilder.Add(variant);
  2859. return TRuntimeNode(callableBuilder.Build(), false);
  2860. }
  2861. TRuntimeNode TProgramBuilder::DynamicVariant(TRuntimeNode item, TRuntimeNode index, TType* variantType) {
  2862. if constexpr (RuntimeVersion < 56U) {
  2863. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  2864. }
  2865. auto type = AS_TYPE(TVariantType, variantType);
  2866. auto expectedIndexSlot = type->GetUnderlyingType()->IsTuple() ? NUdf::EDataSlot::Uint32 : NUdf::EDataSlot::Utf8;
  2867. bool isOptional;
  2868. auto indexType = UnpackOptionalData(index.GetStaticType(), isOptional);
  2869. MKQL_ENSURE(indexType->GetDataSlot() == expectedIndexSlot, "Mismatch type of index");
  2870. auto resType = TOptionalType::Create(type, Env);
  2871. TCallableBuilder callableBuilder(Env, __func__, resType);
  2872. callableBuilder.Add(item);
  2873. callableBuilder.Add(index);
  2874. callableBuilder.Add(TRuntimeNode(variantType, true));
  2875. return TRuntimeNode(callableBuilder.Build(), false);
  2876. }
  2877. TRuntimeNode TProgramBuilder::VisitAll(TRuntimeNode variant, std::function<TRuntimeNode(ui32, TRuntimeNode)> handler) {
  2878. const auto type = AS_TYPE(TVariantType, variant);
  2879. std::vector<TRuntimeNode> items;
  2880. std::vector<TRuntimeNode> newItems;
  2881. for (ui32 i = 0; i < type->GetAlternativesCount(); ++i) {
  2882. const auto itemType = type->GetAlternativeType(i);
  2883. const auto itemArg = Arg(itemType);
  2884. const auto res = handler(i, itemArg);
  2885. items.emplace_back(itemArg);
  2886. newItems.emplace_back(res);
  2887. }
  2888. bool hasOptional;
  2889. const auto firstUnpacked = UnpackOptional(newItems.front(), hasOptional);
  2890. bool allOptional = hasOptional;
  2891. for (size_t i = 1U; i < newItems.size(); ++i) {
  2892. bool isOptional;
  2893. const auto unpacked = UnpackOptional(newItems[i].GetStaticType(), isOptional);
  2894. MKQL_ENSURE(unpacked->IsSameType(*firstUnpacked), "Different return types in branches.");
  2895. hasOptional = hasOptional || isOptional;
  2896. allOptional = allOptional && isOptional;
  2897. }
  2898. if (hasOptional && !allOptional) {
  2899. for (auto& item : newItems) {
  2900. if (!item.GetStaticType()->IsOptional()) {
  2901. item = NewOptional(item);
  2902. }
  2903. }
  2904. }
  2905. TCallableBuilder callableBuilder(Env, __func__, newItems.front().GetStaticType());
  2906. callableBuilder.Add(variant);
  2907. for (ui32 i = 0; i < type->GetAlternativesCount(); ++i) {
  2908. callableBuilder.Add(items[i]);
  2909. callableBuilder.Add(newItems[i]);
  2910. }
  2911. return TRuntimeNode(callableBuilder.Build(), false);
  2912. }
  2913. TRuntimeNode TProgramBuilder::UnaryDataFunction(TRuntimeNode data, const std::string_view& callableName, ui32 flags) {
  2914. bool isOptional;
  2915. auto type = UnpackOptionalData(data, isOptional);
  2916. if (!(flags & TDataFunctionFlags::AllowOptionalArgs)) {
  2917. MKQL_ENSURE(!isOptional, "Optional data is not allowed");
  2918. }
  2919. auto schemeType = type->GetSchemeType();
  2920. if (flags & TDataFunctionFlags::RequiresBooleanArgs) {
  2921. MKQL_ENSURE(schemeType == NUdf::TDataType<bool>::Id, "Boolean data is required");
  2922. } else if (flags & TDataFunctionFlags::RequiresStringArgs) {
  2923. MKQL_ENSURE(schemeType == NUdf::TDataType<char*>::Id, "String data is required");
  2924. }
  2925. if (!schemeType) {
  2926. MKQL_ENSURE((flags & TDataFunctionFlags::AllowNull) != 0, "Null is not allowed");
  2927. }
  2928. TType* resultType;
  2929. if (flags & TDataFunctionFlags::HasBooleanResult) {
  2930. resultType = TDataType::Create(NUdf::TDataType<bool>::Id, Env);
  2931. } else if (flags & TDataFunctionFlags::HasUi32Result) {
  2932. resultType = TDataType::Create(NUdf::TDataType<ui32>::Id, Env);
  2933. } else if (flags & TDataFunctionFlags::HasStringResult) {
  2934. resultType = TDataType::Create(NUdf::TDataType<char*>::Id, Env);
  2935. } else if (flags & TDataFunctionFlags::HasOptionalResult) {
  2936. resultType = TOptionalType::Create(type, Env);
  2937. } else {
  2938. resultType = type;
  2939. }
  2940. if ((flags & TDataFunctionFlags::CommonOptionalResult) && isOptional) {
  2941. resultType = TOptionalType::Create(resultType, Env);
  2942. }
  2943. TCallableBuilder callableBuilder(Env, callableName, resultType);
  2944. callableBuilder.Add(data);
  2945. return TRuntimeNode(callableBuilder.Build(), false);
  2946. }
  2947. TRuntimeNode TProgramBuilder::ToDict(TRuntimeNode list, bool multi, const TUnaryLambda& keySelector,
  2948. const TUnaryLambda& payloadSelector, std::string_view callableName, bool isCompact, ui64 itemsCountHint)
  2949. {
  2950. bool isOptional;
  2951. const auto type = UnpackOptional(list, isOptional);
  2952. MKQL_ENSURE(type->IsList(), "Expected list.");
  2953. if (isOptional) {
  2954. return Map(list, [&](TRuntimeNode unpacked) { return ToDict(unpacked, multi, keySelector, payloadSelector, callableName, isCompact, itemsCountHint); } );
  2955. }
  2956. const auto itemType = AS_TYPE(TListType, type)->GetItemType();
  2957. ThrowIfListOfVoid(itemType);
  2958. const auto itemArg = Arg(itemType);
  2959. const auto key = keySelector(itemArg);
  2960. const auto keyType = key.GetStaticType();
  2961. auto payload = payloadSelector(itemArg);
  2962. auto payloadType = payload.GetStaticType();
  2963. if (multi) {
  2964. payloadType = TListType::Create(payloadType, Env);
  2965. }
  2966. auto dictType = TDictType::Create(keyType, payloadType, Env);
  2967. TCallableBuilder callableBuilder(Env, callableName, dictType);
  2968. callableBuilder.Add(list);
  2969. callableBuilder.Add(itemArg);
  2970. callableBuilder.Add(key);
  2971. callableBuilder.Add(payload);
  2972. callableBuilder.Add(NewDataLiteral(multi));
  2973. callableBuilder.Add(NewDataLiteral(isCompact));
  2974. callableBuilder.Add(NewDataLiteral(itemsCountHint));
  2975. return TRuntimeNode(callableBuilder.Build(), false);
  2976. }
  2977. TRuntimeNode TProgramBuilder::SqueezeToDict(TRuntimeNode stream, bool multi, const TUnaryLambda& keySelector,
  2978. const TUnaryLambda& payloadSelector, std::string_view callableName, bool isCompact, ui64 itemsCountHint)
  2979. {
  2980. if constexpr (RuntimeVersion < 21U) {
  2981. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  2982. }
  2983. const auto type = stream.GetStaticType();
  2984. MKQL_ENSURE(type->IsStream() || type->IsFlow(), "Expected stream or flow.");
  2985. const auto itemType = type->IsFlow() ? AS_TYPE(TFlowType, type)->GetItemType() : AS_TYPE(TStreamType, type)->GetItemType();
  2986. ThrowIfListOfVoid(itemType);
  2987. const auto itemArg = Arg(itemType);
  2988. const auto key = keySelector(itemArg);
  2989. const auto keyType = key.GetStaticType();
  2990. auto payload = payloadSelector(itemArg);
  2991. auto payloadType = payload.GetStaticType();
  2992. if (multi) {
  2993. payloadType = TListType::Create(payloadType, Env);
  2994. }
  2995. auto dictType = TDictType::Create(keyType, payloadType, Env);
  2996. auto returnType = type->IsFlow()
  2997. ? (TType*) TFlowType::Create(dictType, Env)
  2998. : (TType*) TStreamType::Create(dictType, Env);
  2999. TCallableBuilder callableBuilder(Env, callableName, returnType);
  3000. callableBuilder.Add(stream);
  3001. callableBuilder.Add(itemArg);
  3002. callableBuilder.Add(key);
  3003. callableBuilder.Add(payload);
  3004. callableBuilder.Add(NewDataLiteral(multi));
  3005. callableBuilder.Add(NewDataLiteral(isCompact));
  3006. callableBuilder.Add(NewDataLiteral(itemsCountHint));
  3007. return TRuntimeNode(callableBuilder.Build(), false);
  3008. }
  3009. TRuntimeNode TProgramBuilder::NarrowSqueezeToDict(TRuntimeNode flow, bool multi, const TNarrowLambda& keySelector,
  3010. const TNarrowLambda& payloadSelector, std::string_view callableName, bool isCompact, ui64 itemsCountHint)
  3011. {
  3012. if constexpr (RuntimeVersion < 23U) {
  3013. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3014. }
  3015. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3016. TRuntimeNode::TList itemArgs;
  3017. itemArgs.reserve(wideComponents.size());
  3018. auto i = 0U;
  3019. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3020. const auto key = keySelector(itemArgs);
  3021. const auto keyType = key.GetStaticType();
  3022. auto payload = payloadSelector(itemArgs);
  3023. auto payloadType = payload.GetStaticType();
  3024. if (multi) {
  3025. payloadType = TListType::Create(payloadType, Env);
  3026. }
  3027. const auto dictType = TDictType::Create(keyType, payloadType, Env);
  3028. const auto returnType = TFlowType::Create(dictType, Env);
  3029. TCallableBuilder callableBuilder(Env, callableName, returnType);
  3030. callableBuilder.Add(flow);
  3031. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3032. callableBuilder.Add(key);
  3033. callableBuilder.Add(payload);
  3034. callableBuilder.Add(NewDataLiteral(multi));
  3035. callableBuilder.Add(NewDataLiteral(isCompact));
  3036. callableBuilder.Add(NewDataLiteral(itemsCountHint));
  3037. return TRuntimeNode(callableBuilder.Build(), false);
  3038. }
  3039. void TProgramBuilder::ThrowIfListOfVoid(TType* type) {
  3040. MKQL_ENSURE(!VoidWithEffects || !type->IsVoid(), "List of void is forbidden for current function");
  3041. }
  3042. TRuntimeNode TProgramBuilder::BuildFlatMap(const std::string_view& callableName, TRuntimeNode list, const TUnaryLambda& handler)
  3043. {
  3044. const auto listType = list.GetStaticType();
  3045. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsOptional() || listType->IsStream(), "Expected flow, list, stream or optional");
  3046. if (listType->IsOptional()) {
  3047. const auto itemArg = Arg(AS_TYPE(TOptionalType, listType)->GetItemType());
  3048. const auto newList = handler(itemArg);
  3049. const auto type = newList.GetStaticType();
  3050. MKQL_ENSURE(type->IsList() || type->IsOptional() || type->IsStream() || type->IsFlow(), "Expected flow, list, stream or optional");
  3051. return IfPresent(list, [&](TRuntimeNode item) {
  3052. return handler(item);
  3053. }, type->IsOptional() ? NewEmptyOptional(type) : type->IsList() ? NewEmptyList(AS_TYPE(TListType, type)->GetItemType()) : EmptyIterator(type));
  3054. }
  3055. const auto itemType = listType->IsFlow() ?
  3056. AS_TYPE(TFlowType, listType)->GetItemType():
  3057. listType->IsList() ?
  3058. AS_TYPE(TListType, listType)->GetItemType():
  3059. AS_TYPE(TStreamType, listType)->GetItemType();
  3060. ThrowIfListOfVoid(itemType);
  3061. const auto itemArg = Arg(itemType);
  3062. const auto newList = handler(itemArg);
  3063. const auto type = newList.GetStaticType();
  3064. TType* retItemType = nullptr;
  3065. if (type->IsOptional()) {
  3066. retItemType = AS_TYPE(TOptionalType, type)->GetItemType();
  3067. } else if (type->IsFlow()) {
  3068. retItemType = AS_TYPE(TFlowType, type)->GetItemType();
  3069. } else if (type->IsList()) {
  3070. retItemType = AS_TYPE(TListType, type)->GetItemType();
  3071. } else if (type->IsStream()) {
  3072. retItemType = AS_TYPE(TStreamType, type)->GetItemType();
  3073. } else {
  3074. THROW yexception() << "Expected flow, list or stream.";
  3075. }
  3076. const auto resultListType = listType->IsFlow() || type->IsFlow() ?
  3077. TFlowType::Create(retItemType, Env):
  3078. listType->IsList() ?
  3079. (TType*)TListType::Create(retItemType, Env):
  3080. (TType*)TStreamType::Create(retItemType, Env);
  3081. TCallableBuilder callableBuilder(Env, callableName, resultListType);
  3082. callableBuilder.Add(list);
  3083. callableBuilder.Add(itemArg);
  3084. callableBuilder.Add(newList);
  3085. return TRuntimeNode(callableBuilder.Build(), false);
  3086. }
  3087. TRuntimeNode TProgramBuilder::MultiMap(TRuntimeNode list, const TExpandLambda& handler)
  3088. {
  3089. if constexpr (RuntimeVersion < 16U) {
  3090. const auto single = [=](TRuntimeNode item) -> TRuntimeNode {
  3091. const auto newList = handler(item);
  3092. const auto retItemType = newList.front().GetStaticType();
  3093. MKQL_ENSURE(retItemType->IsSameType(*newList.back().GetStaticType()), "Must be same type.");
  3094. return NewList(retItemType, newList);
  3095. };
  3096. return OrderedFlatMap(list, single);
  3097. }
  3098. const auto listType = list.GetStaticType();
  3099. MKQL_ENSURE(listType->IsFlow() || listType->IsList(), "Expected flow, list, stream or optional");
  3100. const auto itemType = listType->IsFlow() ? AS_TYPE(TFlowType, listType)->GetItemType() : AS_TYPE(TListType, listType)->GetItemType();
  3101. const auto itemArg = Arg(itemType);
  3102. const auto newList = handler(itemArg);
  3103. MKQL_ENSURE(newList.size() > 1U, "Expected many items.");
  3104. const auto retItemType = newList.front().GetStaticType();
  3105. MKQL_ENSURE(retItemType->IsSameType(*newList.back().GetStaticType()), "Must be same type.");
  3106. const auto resultListType = listType->IsFlow() ?
  3107. (TType*)TFlowType::Create(retItemType, Env) : (TType*)TListType::Create(retItemType, Env);
  3108. TCallableBuilder callableBuilder(Env, __func__, resultListType);
  3109. callableBuilder.Add(list);
  3110. callableBuilder.Add(itemArg);
  3111. std::for_each(newList.cbegin(), newList.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3112. return TRuntimeNode(callableBuilder.Build(), false);
  3113. }
  3114. TRuntimeNode TProgramBuilder::NarrowMultiMap(TRuntimeNode flow, const TWideLambda& handler) {
  3115. if constexpr (RuntimeVersion < 18U) {
  3116. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3117. }
  3118. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3119. TRuntimeNode::TList itemArgs;
  3120. itemArgs.reserve(wideComponents.size());
  3121. auto i = 0U;
  3122. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3123. const auto newList = handler(itemArgs);
  3124. MKQL_ENSURE(newList.size() > 1U, "Expected many items.");
  3125. const auto retItemType = newList.front().GetStaticType();
  3126. MKQL_ENSURE(retItemType->IsSameType(*newList.back().GetStaticType()), "Must be same type.");
  3127. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(newList.front().GetStaticType()));
  3128. callableBuilder.Add(flow);
  3129. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3130. std::for_each(newList.cbegin(), newList.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3131. return TRuntimeNode(callableBuilder.Build(), false);
  3132. }
  3133. TRuntimeNode TProgramBuilder::ExpandMap(TRuntimeNode flow, const TExpandLambda& handler) {
  3134. if constexpr (RuntimeVersion < 18U) {
  3135. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3136. }
  3137. const auto itemType = AS_TYPE(TFlowType, flow.GetStaticType())->GetItemType();
  3138. const auto itemArg = Arg(itemType);
  3139. const auto newItems = handler(itemArg);
  3140. std::vector<TType*> tupleItems;
  3141. tupleItems.reserve(newItems.size());
  3142. std::transform(newItems.cbegin(), newItems.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  3143. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(tupleItems)));
  3144. callableBuilder.Add(flow);
  3145. callableBuilder.Add(itemArg);
  3146. std::for_each(newItems.cbegin(), newItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3147. return TRuntimeNode(callableBuilder.Build(), false);
  3148. }
  3149. TRuntimeNode TProgramBuilder::WideMap(TRuntimeNode flow, const TWideLambda& handler) {
  3150. if constexpr (RuntimeVersion < 18U) {
  3151. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3152. }
  3153. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3154. TRuntimeNode::TList itemArgs;
  3155. itemArgs.reserve(wideComponents.size());
  3156. auto i = 0U;
  3157. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3158. const auto newItems = handler(itemArgs);
  3159. std::vector<TType*> tupleItems;
  3160. tupleItems.reserve(newItems.size());
  3161. std::transform(newItems.cbegin(), newItems.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  3162. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(tupleItems)));
  3163. callableBuilder.Add(flow);
  3164. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3165. std::for_each(newItems.cbegin(), newItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3166. return TRuntimeNode(callableBuilder.Build(), false);
  3167. }
  3168. TRuntimeNode TProgramBuilder::WideChain1Map(TRuntimeNode flow, const TWideLambda& init, const TBinaryWideLambda& update) {
  3169. if constexpr (RuntimeVersion < 23U) {
  3170. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3171. }
  3172. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3173. TRuntimeNode::TList inputArgs;
  3174. inputArgs.reserve(wideComponents.size());
  3175. auto i = 0U;
  3176. std::generate_n(std::back_inserter(inputArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3177. const auto initItems = init(inputArgs);
  3178. std::vector<TType*> tupleItems;
  3179. tupleItems.reserve(initItems.size());
  3180. std::transform(initItems.cbegin(), initItems.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  3181. TRuntimeNode::TList outputArgs;
  3182. outputArgs.reserve(tupleItems.size());
  3183. std::transform(tupleItems.cbegin(), tupleItems.cend(), std::back_inserter(outputArgs), std::bind(&TProgramBuilder::Arg, this, std::placeholders::_1));
  3184. const auto updateItems = update(inputArgs, outputArgs);
  3185. MKQL_ENSURE(initItems.size() == updateItems.size(), "Expected same width.");
  3186. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(tupleItems)));
  3187. callableBuilder.Add(flow);
  3188. std::for_each(inputArgs.cbegin(), inputArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3189. std::for_each(initItems.cbegin(), initItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3190. std::for_each(outputArgs.cbegin(), outputArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3191. std::for_each(updateItems.cbegin(), updateItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3192. return TRuntimeNode(callableBuilder.Build(), false);
  3193. }
  3194. TRuntimeNode TProgramBuilder::NarrowMap(TRuntimeNode flow, const TNarrowLambda& handler) {
  3195. if constexpr (RuntimeVersion < 18U) {
  3196. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3197. }
  3198. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3199. TRuntimeNode::TList itemArgs;
  3200. itemArgs.reserve(wideComponents.size());
  3201. auto i = 0U;
  3202. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3203. const auto newItem = handler(itemArgs);
  3204. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(newItem.GetStaticType()));
  3205. callableBuilder.Add(flow);
  3206. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3207. callableBuilder.Add(newItem);
  3208. return TRuntimeNode(callableBuilder.Build(), false);
  3209. }
  3210. TRuntimeNode TProgramBuilder::NarrowFlatMap(TRuntimeNode flow, const TNarrowLambda& handler) {
  3211. if constexpr (RuntimeVersion < 18U) {
  3212. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3213. }
  3214. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3215. TRuntimeNode::TList itemArgs;
  3216. itemArgs.reserve(wideComponents.size());
  3217. auto i = 0U;
  3218. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3219. const auto newList = handler(itemArgs);
  3220. const auto type = newList.GetStaticType();
  3221. TType* retItemType = nullptr;
  3222. if (type->IsOptional()) {
  3223. retItemType = AS_TYPE(TOptionalType, type)->GetItemType();
  3224. } else if (type->IsFlow()) {
  3225. retItemType = AS_TYPE(TFlowType, type)->GetItemType();
  3226. } else if (type->IsList()) {
  3227. retItemType = AS_TYPE(TListType, type)->GetItemType();
  3228. } else if (type->IsStream()) {
  3229. retItemType = AS_TYPE(TStreamType, type)->GetItemType();
  3230. } else {
  3231. THROW yexception() << "Expected flow, list or stream.";
  3232. }
  3233. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(retItemType));
  3234. callableBuilder.Add(flow);
  3235. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3236. callableBuilder.Add(newList);
  3237. return TRuntimeNode(callableBuilder.Build(), false);
  3238. }
  3239. TRuntimeNode TProgramBuilder::BuildWideFilter(const std::string_view& callableName, TRuntimeNode flow, const TNarrowLambda& handler) {
  3240. if constexpr (RuntimeVersion < 18U) {
  3241. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3242. }
  3243. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3244. TRuntimeNode::TList itemArgs;
  3245. itemArgs.reserve(wideComponents.size());
  3246. auto i = 0U;
  3247. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3248. const auto predicate = handler(itemArgs);
  3249. TCallableBuilder callableBuilder(Env, callableName, flow.GetStaticType());
  3250. callableBuilder.Add(flow);
  3251. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3252. callableBuilder.Add(predicate);
  3253. return TRuntimeNode(callableBuilder.Build(), false);
  3254. }
  3255. TRuntimeNode TProgramBuilder::WideFilter(TRuntimeNode flow, const TNarrowLambda& handler) {
  3256. return BuildWideFilter(__func__, flow, handler);
  3257. }
  3258. TRuntimeNode TProgramBuilder::WideTakeWhile(TRuntimeNode flow, const TNarrowLambda& handler) {
  3259. return BuildWideFilter(__func__, flow, handler);
  3260. }
  3261. TRuntimeNode TProgramBuilder::WideSkipWhile(TRuntimeNode flow, const TNarrowLambda& handler) {
  3262. return BuildWideFilter(__func__, flow, handler);
  3263. }
  3264. TRuntimeNode TProgramBuilder::WideTakeWhileInclusive(TRuntimeNode flow, const TNarrowLambda& handler) {
  3265. return BuildWideFilter(__func__, flow, handler);
  3266. }
  3267. TRuntimeNode TProgramBuilder::WideSkipWhileInclusive(TRuntimeNode flow, const TNarrowLambda& handler) {
  3268. return BuildWideFilter(__func__, flow, handler);
  3269. }
  3270. TRuntimeNode TProgramBuilder::WideFilter(TRuntimeNode flow, TRuntimeNode limit, const TNarrowLambda& handler) {
  3271. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3272. TRuntimeNode::TList itemArgs;
  3273. itemArgs.reserve(wideComponents.size());
  3274. auto i = 0U;
  3275. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3276. const auto predicate = handler(itemArgs);
  3277. TCallableBuilder callableBuilder(Env, __func__, flow.GetStaticType());
  3278. callableBuilder.Add(flow);
  3279. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3280. callableBuilder.Add(predicate);
  3281. callableBuilder.Add(limit);
  3282. return TRuntimeNode(callableBuilder.Build(), false);
  3283. }
  3284. TRuntimeNode TProgramBuilder::BuildFilter(const std::string_view& callableName, TRuntimeNode list, const TUnaryLambda& handler, TType* resultType)
  3285. {
  3286. const auto listType = list.GetStaticType();
  3287. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream.");
  3288. const auto outputType = resultType ? resultType : listType;
  3289. const auto itemType = listType->IsFlow() ?
  3290. AS_TYPE(TFlowType, listType)->GetItemType():
  3291. listType->IsList() ?
  3292. AS_TYPE(TListType, listType)->GetItemType():
  3293. AS_TYPE(TStreamType, listType)->GetItemType();
  3294. ThrowIfListOfVoid(itemType);
  3295. const auto itemArg = Arg(itemType);
  3296. const auto predicate = handler(itemArg);
  3297. MKQL_ENSURE(predicate.GetStaticType()->IsData(), "Expected boolean data");
  3298. const auto& detailedPredicateType = static_cast<const TDataType&>(*predicate.GetStaticType());
  3299. MKQL_ENSURE(detailedPredicateType.GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected boolean data");
  3300. TCallableBuilder callableBuilder(Env, callableName, outputType);
  3301. callableBuilder.Add(list);
  3302. callableBuilder.Add(itemArg);
  3303. callableBuilder.Add(predicate);
  3304. return TRuntimeNode(callableBuilder.Build(), false);
  3305. }
  3306. TRuntimeNode TProgramBuilder::BuildFilter(const std::string_view& callableName, TRuntimeNode list, TRuntimeNode limit, const TUnaryLambda& handler, TType* resultType)
  3307. {
  3308. if constexpr (RuntimeVersion < 4U) {
  3309. return Take(BuildFilter(callableName, list, handler, resultType), limit);
  3310. }
  3311. const auto listType = list.GetStaticType();
  3312. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream.");
  3313. MKQL_ENSURE(limit.GetStaticType()->IsData(), "Expected data");
  3314. const auto outputType = resultType ? resultType : listType;
  3315. const auto itemType = listType->IsFlow() ?
  3316. AS_TYPE(TFlowType, listType)->GetItemType():
  3317. listType->IsList() ?
  3318. AS_TYPE(TListType, listType)->GetItemType():
  3319. AS_TYPE(TStreamType, listType)->GetItemType();
  3320. ThrowIfListOfVoid(itemType);
  3321. const auto itemArg = Arg(itemType);
  3322. const auto predicate = handler(itemArg);
  3323. MKQL_ENSURE(predicate.GetStaticType()->IsData(), "Expected boolean data");
  3324. const auto& detailedPredicateType = static_cast<const TDataType&>(*predicate.GetStaticType());
  3325. MKQL_ENSURE(detailedPredicateType.GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected boolean data");
  3326. TCallableBuilder callableBuilder(Env, callableName, outputType);
  3327. callableBuilder.Add(list);
  3328. callableBuilder.Add(limit);
  3329. callableBuilder.Add(itemArg);
  3330. callableBuilder.Add(predicate);
  3331. return TRuntimeNode(callableBuilder.Build(), false);
  3332. }
  3333. TRuntimeNode TProgramBuilder::Filter(TRuntimeNode list, const TUnaryLambda& handler, TType* resultType)
  3334. {
  3335. const auto type = list.GetStaticType();
  3336. if (type->IsOptional()) {
  3337. return
  3338. IfPresent(list,
  3339. [&](TRuntimeNode item) {
  3340. return If(handler(item), item, NewEmptyOptional(resultType), resultType);
  3341. },
  3342. NewEmptyOptional(resultType)
  3343. );
  3344. }
  3345. return BuildFilter(__func__, list, handler, resultType);
  3346. }
  3347. TRuntimeNode TProgramBuilder::BuildHeap(const std::string_view& callableName, TRuntimeNode list, const TBinaryLambda& comparator) {
  3348. const auto listType = list.GetStaticType();
  3349. MKQL_ENSURE(listType->IsList(), "Expected list.");
  3350. const auto itemType = AS_TYPE(TListType, listType)->GetItemType();
  3351. const auto leftArg = Arg(itemType);
  3352. const auto rightArg = Arg(itemType);
  3353. const auto predicate = comparator(leftArg, rightArg);
  3354. TCallableBuilder callableBuilder(Env, callableName, listType);
  3355. callableBuilder.Add(list);
  3356. callableBuilder.Add(leftArg);
  3357. callableBuilder.Add(rightArg);
  3358. callableBuilder.Add(predicate);
  3359. return TRuntimeNode(callableBuilder.Build(), false);
  3360. }
  3361. TRuntimeNode TProgramBuilder::BuildNth(const std::string_view& callableName, TRuntimeNode list, TRuntimeNode n, const TBinaryLambda& comparator) {
  3362. const auto listType = list.GetStaticType();
  3363. MKQL_ENSURE(listType->IsList(), "Expected list.");
  3364. const auto itemType = AS_TYPE(TListType, listType)->GetItemType();
  3365. MKQL_ENSURE(n.GetStaticType()->IsData(), "Expected data");
  3366. MKQL_ENSURE(static_cast<const TDataType&>(*n.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  3367. const auto leftArg = Arg(itemType);
  3368. const auto rightArg = Arg(itemType);
  3369. const auto predicate = comparator(leftArg, rightArg);
  3370. TCallableBuilder callableBuilder(Env, callableName, listType);
  3371. callableBuilder.Add(list);
  3372. callableBuilder.Add(n);
  3373. callableBuilder.Add(leftArg);
  3374. callableBuilder.Add(rightArg);
  3375. callableBuilder.Add(predicate);
  3376. return TRuntimeNode(callableBuilder.Build(), false);
  3377. }
  3378. TRuntimeNode TProgramBuilder::MakeHeap(TRuntimeNode list, const TBinaryLambda& comparator) {
  3379. return BuildHeap(__func__, list, std::move(comparator));
  3380. }
  3381. TRuntimeNode TProgramBuilder::PushHeap(TRuntimeNode list, const TBinaryLambda& comparator) {
  3382. return BuildHeap(__func__, list, std::move(comparator));
  3383. }
  3384. TRuntimeNode TProgramBuilder::PopHeap(TRuntimeNode list, const TBinaryLambda& comparator) {
  3385. return BuildHeap(__func__, list, std::move(comparator));
  3386. }
  3387. TRuntimeNode TProgramBuilder::SortHeap(TRuntimeNode list, const TBinaryLambda& comparator) {
  3388. return BuildHeap(__func__, list, std::move(comparator));
  3389. }
  3390. TRuntimeNode TProgramBuilder::StableSort(TRuntimeNode list, const TBinaryLambda& comparator) {
  3391. return BuildHeap(__func__, list, std::move(comparator));
  3392. }
  3393. TRuntimeNode TProgramBuilder::NthElement(TRuntimeNode list, TRuntimeNode n, const TBinaryLambda& comparator) {
  3394. return BuildNth(__func__, list, n, std::move(comparator));
  3395. }
  3396. TRuntimeNode TProgramBuilder::PartialSort(TRuntimeNode list, TRuntimeNode n, const TBinaryLambda& comparator) {
  3397. return BuildNth(__func__, list, n, std::move(comparator));
  3398. }
  3399. TRuntimeNode TProgramBuilder::BuildMap(const std::string_view& callableName, TRuntimeNode list, const TUnaryLambda& handler)
  3400. {
  3401. const auto listType = list.GetStaticType();
  3402. MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream() || listType->IsOptional(), "Expected flow, list, stream or optional");
  3403. if (listType->IsOptional()) {
  3404. const auto itemArg = Arg(AS_TYPE(TOptionalType, listType)->GetItemType());
  3405. const auto newItem = handler(itemArg);
  3406. return IfPresent(list,
  3407. [&](TRuntimeNode item) { return NewOptional(handler(item)); },
  3408. NewEmptyOptional(NewOptionalType(newItem.GetStaticType()))
  3409. );
  3410. }
  3411. const auto itemType = listType->IsFlow() ?
  3412. AS_TYPE(TFlowType, listType)->GetItemType():
  3413. listType->IsList() ?
  3414. AS_TYPE(TListType, listType)->GetItemType():
  3415. AS_TYPE(TStreamType, listType)->GetItemType();
  3416. ThrowIfListOfVoid(itemType);
  3417. const auto itemArg = Arg(itemType);
  3418. const auto newItem = handler(itemArg);
  3419. const auto resultListType = listType->IsFlow() ?
  3420. (TType*)TFlowType::Create(newItem.GetStaticType(), Env):
  3421. listType->IsList() ?
  3422. (TType*)TListType::Create(newItem.GetStaticType(), Env):
  3423. (TType*)TStreamType::Create(newItem.GetStaticType(), Env);
  3424. TCallableBuilder callableBuilder(Env, callableName, resultListType);
  3425. callableBuilder.Add(list);
  3426. callableBuilder.Add(itemArg);
  3427. callableBuilder.Add(newItem);
  3428. return TRuntimeNode(callableBuilder.Build(), false);
  3429. }
  3430. TRuntimeNode TProgramBuilder::Invoke(const std::string_view& funcName, TType* resultType, const TArrayRef<const TRuntimeNode>& args) {
  3431. MKQL_ENSURE(args.size() >= 1U && args.size() <= 3U, "Expected from one to three arguments.");
  3432. std::array<TArgType, 4U> argTypes;
  3433. argTypes.front().first = UnpackOptionalData(resultType, argTypes.front().second)->GetSchemeType();
  3434. auto i = 0U;
  3435. for (const auto& arg : args) {
  3436. ++i;
  3437. argTypes[i].first = UnpackOptionalData(arg, argTypes[i].second)->GetSchemeType();
  3438. }
  3439. FunctionRegistry.GetBuiltins()->GetBuiltin(funcName, argTypes.data(), 1U + args.size());
  3440. TCallableBuilder callableBuilder(Env, __func__, resultType);
  3441. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(funcName));
  3442. for (const auto& arg : args) {
  3443. callableBuilder.Add(arg);
  3444. }
  3445. return TRuntimeNode(callableBuilder.Build(), false);
  3446. }
  3447. TRuntimeNode TProgramBuilder::Udf(
  3448. const std::string_view& funcName,
  3449. TRuntimeNode runConfig,
  3450. TType* userType,
  3451. const std::string_view& typeConfig
  3452. )
  3453. {
  3454. TRuntimeNode userTypeNode = userType ? TRuntimeNode(userType, true) : TRuntimeNode(Env.GetVoidLazy()->GetType(), true);
  3455. const ui32 flags = NUdf::IUdfModule::TFlags::TypesOnly;
  3456. if (!TypeInfoHelper) {
  3457. TypeInfoHelper = new TTypeInfoHelper();
  3458. }
  3459. TFunctionTypeInfo funcInfo;
  3460. TStatus status = FunctionRegistry.FindFunctionTypeInfo(
  3461. Env, TypeInfoHelper, nullptr, funcName, userType, typeConfig, flags, {}, nullptr, &funcInfo);
  3462. MKQL_ENSURE(status.IsOk(), status.GetError());
  3463. auto runConfigType = funcInfo.RunConfigType;
  3464. if (runConfig) {
  3465. bool typesMatch = runConfigType->IsSameType(*runConfig.GetStaticType());
  3466. MKQL_ENSURE(typesMatch, "RunConfig type mismatch");
  3467. } else {
  3468. MKQL_ENSURE(runConfigType->IsVoid() || runConfigType->IsOptional(), "RunConfig must be void or optional");
  3469. if (runConfigType->IsVoid()) {
  3470. runConfig = NewVoid();
  3471. } else {
  3472. runConfig = NewEmptyOptional(const_cast<TType*>(runConfigType));
  3473. }
  3474. }
  3475. auto funNameNode = NewDataLiteral<NUdf::EDataSlot::String>(funcName);
  3476. auto typeConfigNode = NewDataLiteral<NUdf::EDataSlot::String>(typeConfig);
  3477. TCallableBuilder callableBuilder(Env, __func__, funcInfo.FunctionType);
  3478. callableBuilder.Add(funNameNode);
  3479. callableBuilder.Add(userTypeNode);
  3480. callableBuilder.Add(typeConfigNode);
  3481. callableBuilder.Add(runConfig);
  3482. return TRuntimeNode(callableBuilder.Build(), false);
  3483. }
  3484. TRuntimeNode TProgramBuilder::TypedUdf(
  3485. const std::string_view& funcName,
  3486. TType* funcType,
  3487. TRuntimeNode runConfig,
  3488. TType* userType,
  3489. const std::string_view& typeConfig,
  3490. const std::string_view& file,
  3491. ui32 row,
  3492. ui32 column)
  3493. {
  3494. auto funNameNode = NewDataLiteral<NUdf::EDataSlot::String>(funcName);
  3495. auto typeConfigNode = NewDataLiteral<NUdf::EDataSlot::String>(typeConfig);
  3496. TRuntimeNode userTypeNode = userType ? TRuntimeNode(userType, true) : TRuntimeNode(Env.GetVoidLazy(), true);
  3497. TCallableBuilder callableBuilder(Env, "Udf", funcType);
  3498. callableBuilder.Add(funNameNode);
  3499. callableBuilder.Add(userTypeNode);
  3500. callableBuilder.Add(typeConfigNode);
  3501. callableBuilder.Add(runConfig);
  3502. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  3503. callableBuilder.Add(NewDataLiteral(row));
  3504. callableBuilder.Add(NewDataLiteral(column));
  3505. return TRuntimeNode(callableBuilder.Build(), false);
  3506. }
  3507. TRuntimeNode TProgramBuilder::ScriptUdf(
  3508. const std::string_view& moduleName,
  3509. const std::string_view& funcName,
  3510. TType* funcType,
  3511. TRuntimeNode script,
  3512. const std::string_view& file,
  3513. ui32 row,
  3514. ui32 column)
  3515. {
  3516. MKQL_ENSURE(funcType, "UDF callable type must not be empty");
  3517. MKQL_ENSURE(funcType->IsCallable(), "type must be callable");
  3518. auto scriptType = NKikimr::NMiniKQL::ScriptTypeFromStr(moduleName);
  3519. MKQL_ENSURE(scriptType != EScriptType::Unknown, "unknown script type '" << moduleName << "'");
  3520. EnsureScriptSpecificTypes(scriptType, static_cast<TCallableType*>(funcType), Env);
  3521. auto scriptTypeStr = IsCustomPython(scriptType) ? moduleName : ScriptTypeAsStr(CanonizeScriptType(scriptType));
  3522. TStringBuilder name;
  3523. name.reserve(scriptTypeStr.size() + funcName.size() + 1);
  3524. name << scriptTypeStr << '.' << funcName;
  3525. auto funcNameNode = NewDataLiteral<NUdf::EDataSlot::String>(name);
  3526. TRuntimeNode userTypeNode(funcType, true);
  3527. auto typeConfigNode = NewDataLiteral<NUdf::EDataSlot::String>("");
  3528. TCallableBuilder callableBuilder(Env, __func__, funcType);
  3529. callableBuilder.Add(funcNameNode);
  3530. callableBuilder.Add(userTypeNode);
  3531. callableBuilder.Add(typeConfigNode);
  3532. callableBuilder.Add(script);
  3533. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  3534. callableBuilder.Add(NewDataLiteral(row));
  3535. callableBuilder.Add(NewDataLiteral(column));
  3536. return TRuntimeNode(callableBuilder.Build(), false);
  3537. }
  3538. TRuntimeNode TProgramBuilder::Apply(TRuntimeNode callableNode, const TArrayRef<const TRuntimeNode>& args,
  3539. const std::string_view& file, ui32 row, ui32 column, ui32 dependentCount) {
  3540. MKQL_ENSURE(dependentCount <= args.size(), "Too many dependent nodes");
  3541. ui32 usedArgs = args.size() - dependentCount;
  3542. MKQL_ENSURE(!callableNode.IsImmediate() && callableNode.GetNode()->GetType()->IsCallable(),
  3543. "Expected callable");
  3544. auto callable = static_cast<TCallable*>(callableNode.GetNode());
  3545. TType* returnType = callable->GetType()->GetReturnType();
  3546. MKQL_ENSURE(returnType->IsCallable(), "Expected callable as return type");
  3547. auto callableType = static_cast<TCallableType*>(returnType);
  3548. MKQL_ENSURE(usedArgs <= callableType->GetArgumentsCount(), "Too many arguments");
  3549. MKQL_ENSURE(usedArgs >= callableType->GetArgumentsCount() - callableType->GetOptionalArgumentsCount(), "Too few arguments");
  3550. for (ui32 i = 0; i < usedArgs; i++) {
  3551. TType* argType = callableType->GetArgumentType(i);
  3552. TRuntimeNode arg = args[i];
  3553. MKQL_ENSURE(arg.GetStaticType()->IsConvertableTo(*argType),
  3554. "Argument type mismatch for argument " << i << ": runtime " << argType->GetKindAsStr()
  3555. << " with static " << arg.GetStaticType()->GetKindAsStr());
  3556. }
  3557. TCallableBuilder callableBuilder(Env, RuntimeVersion >= 8 ? "Apply2" : "Apply", callableType->GetReturnType());
  3558. callableBuilder.Add(callableNode);
  3559. callableBuilder.Add(NewDataLiteral<ui32>(dependentCount));
  3560. if constexpr (RuntimeVersion >= 8) {
  3561. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  3562. callableBuilder.Add(NewDataLiteral(row));
  3563. callableBuilder.Add(NewDataLiteral(column));
  3564. }
  3565. for (const auto& arg: args) {
  3566. callableBuilder.Add(arg);
  3567. }
  3568. return TRuntimeNode(callableBuilder.Build(), false);
  3569. }
  3570. TRuntimeNode TProgramBuilder::Apply(
  3571. TRuntimeNode callableNode,
  3572. const TArrayRef<const TRuntimeNode>& args,
  3573. ui32 dependentCount) {
  3574. return Apply(callableNode, args, {}, 0, 0, dependentCount);
  3575. }
  3576. TRuntimeNode TProgramBuilder::Callable(TType* callableType, const TArrayLambda& handler) {
  3577. auto castedCallableType = AS_TYPE(TCallableType, callableType);
  3578. std::vector<TRuntimeNode> args;
  3579. args.reserve(castedCallableType->GetArgumentsCount());
  3580. for (ui32 i = 0; i < castedCallableType->GetArgumentsCount(); ++i) {
  3581. args.push_back(Arg(castedCallableType->GetArgumentType(i)));
  3582. }
  3583. auto res = handler(args);
  3584. TCallableBuilder callableBuilder(Env, __func__, callableType);
  3585. for (ui32 i = 0; i < castedCallableType->GetArgumentsCount(); ++i) {
  3586. callableBuilder.Add(args[i]);
  3587. }
  3588. callableBuilder.Add(res);
  3589. return TRuntimeNode(callableBuilder.Build(), false);
  3590. }
  3591. TRuntimeNode TProgramBuilder::NewNull() {
  3592. if (!UseNullType || RuntimeVersion < 11) {
  3593. TCallableBuilder callableBuilder(Env, "Null", NewOptionalType(Env.GetVoidLazy()->GetType()));
  3594. return TRuntimeNode(callableBuilder.Build(), false);
  3595. } else {
  3596. return TRuntimeNode(Env.GetNullLazy(), true);
  3597. }
  3598. }
  3599. TRuntimeNode TProgramBuilder::Concat(TRuntimeNode data1, TRuntimeNode data2) {
  3600. bool isOpt1, isOpt2;
  3601. const auto type1 = UnpackOptionalData(data1, isOpt1)->GetSchemeType();
  3602. const auto type2 = UnpackOptionalData(data2, isOpt2)->GetSchemeType();
  3603. const auto resultType = NewDataType(type1 == type2 ? type1 : NUdf::TDataType<char*>::Id);
  3604. return InvokeBinary(__func__, isOpt1 || isOpt2 ? NewOptionalType(resultType) : resultType, data1, data2);
  3605. }
  3606. TRuntimeNode TProgramBuilder::AggrConcat(TRuntimeNode data1, TRuntimeNode data2) {
  3607. MKQL_ENSURE(data1.GetStaticType()->IsSameType(*data2.GetStaticType()), "Operands type mismatch.");
  3608. const std::array<TRuntimeNode, 2> args = {{ data1, data2 }};
  3609. return Invoke(__func__, data1.GetStaticType(), args);
  3610. }
  3611. TRuntimeNode TProgramBuilder::Substring(TRuntimeNode data, TRuntimeNode start, TRuntimeNode count) {
  3612. const std::array<TRuntimeNode, 3U> args = {{ data, start, count }};
  3613. return Invoke(__func__, data.GetStaticType(), args);
  3614. }
  3615. TRuntimeNode TProgramBuilder::Find(TRuntimeNode haystack, TRuntimeNode needle, TRuntimeNode pos) {
  3616. const std::array<TRuntimeNode, 3U> args = {{ haystack, needle, pos }};
  3617. return Invoke(__func__, NewOptionalType(NewDataType(NUdf::TDataType<ui32>::Id)), args);
  3618. }
  3619. TRuntimeNode TProgramBuilder::RFind(TRuntimeNode haystack, TRuntimeNode needle, TRuntimeNode pos) {
  3620. const std::array<TRuntimeNode, 3U> args = {{ haystack, needle, pos }};
  3621. return Invoke(__func__, NewOptionalType(NewDataType(NUdf::TDataType<ui32>::Id)), args);
  3622. }
  3623. TRuntimeNode TProgramBuilder::StartsWith(TRuntimeNode string, TRuntimeNode prefix) {
  3624. if constexpr (RuntimeVersion < 19U) {
  3625. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3626. }
  3627. return DataCompare(__func__, string, prefix);
  3628. }
  3629. TRuntimeNode TProgramBuilder::EndsWith(TRuntimeNode string, TRuntimeNode suffix) {
  3630. if constexpr (RuntimeVersion < 19U) {
  3631. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3632. }
  3633. return DataCompare(__func__, string, suffix);
  3634. }
  3635. TRuntimeNode TProgramBuilder::StringContains(TRuntimeNode string, TRuntimeNode pattern) {
  3636. bool isOpt1, isOpt2;
  3637. TDataType* type1 = UnpackOptionalData(string, isOpt1);
  3638. TDataType* type2 = UnpackOptionalData(pattern, isOpt2);
  3639. MKQL_ENSURE(type1->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id ||
  3640. type1->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expecting string as first argument");
  3641. MKQL_ENSURE(type2->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id ||
  3642. type2->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expecting string as second argument");
  3643. if constexpr (RuntimeVersion < 32U) {
  3644. auto stringCasted = (type1->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id) ? ToString(string) : string;
  3645. auto patternCasted = (type2->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id) ? ToString(pattern) : pattern;
  3646. auto found = Exists(Find(stringCasted, patternCasted, NewDataLiteral(ui32(0))));
  3647. if (!isOpt1 && !isOpt2) {
  3648. return found;
  3649. }
  3650. TVector<TRuntimeNode> predicates;
  3651. if (isOpt1) {
  3652. predicates.push_back(Exists(string));
  3653. }
  3654. if (isOpt2) {
  3655. predicates.push_back(Exists(pattern));
  3656. }
  3657. TRuntimeNode argsNotNull = (predicates.size() == 1) ? predicates.front() : And(predicates);
  3658. return If(argsNotNull, NewOptional(found), NewEmptyOptionalDataLiteral(NUdf::TDataType<bool>::Id));
  3659. }
  3660. return DataCompare(__func__, string, pattern);
  3661. }
  3662. TRuntimeNode TProgramBuilder::ByteAt(TRuntimeNode data, TRuntimeNode index) {
  3663. const std::array<TRuntimeNode, 2U> args = {{ data, index }};
  3664. return Invoke(__func__, NewOptionalType(NewDataType(NUdf::TDataType<ui8>::Id)), args);
  3665. }
  3666. TRuntimeNode TProgramBuilder::Size(TRuntimeNode data) {
  3667. return UnaryDataFunction(data, __func__, TDataFunctionFlags::HasUi32Result | TDataFunctionFlags::AllowNull | TDataFunctionFlags::AllowOptionalArgs | TDataFunctionFlags::CommonOptionalResult);
  3668. }
  3669. template <bool Utf8>
  3670. TRuntimeNode TProgramBuilder::ToString(TRuntimeNode data) {
  3671. bool isOptional;
  3672. UnpackOptionalData(data, isOptional);
  3673. const auto resultType = NewDataType(Utf8 ? NUdf::EDataSlot::Utf8 : NUdf::EDataSlot::String, isOptional);
  3674. TCallableBuilder callableBuilder(Env, __func__, resultType);
  3675. callableBuilder.Add(data);
  3676. return TRuntimeNode(callableBuilder.Build(), false);
  3677. }
  3678. TRuntimeNode TProgramBuilder::FromString(TRuntimeNode data, TType* type) {
  3679. bool isOptional;
  3680. const auto sourceType = UnpackOptionalData(data, isOptional);
  3681. const auto targetType = UnpackOptionalData(type, isOptional);
  3682. MKQL_ENSURE(sourceType->GetSchemeType() == NUdf::TDataType<char*>::Id || sourceType->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected String");
  3683. MKQL_ENSURE(targetType->GetSchemeType() != 0, "Null is not allowed");
  3684. TCallableBuilder callableBuilder(Env, __func__, type);
  3685. callableBuilder.Add(data);
  3686. callableBuilder.Add(NewDataLiteral(static_cast<ui32>(targetType->GetSchemeType())));
  3687. if (targetType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3688. const auto& params = static_cast<const TDataDecimalType*>(targetType)->GetParams();
  3689. callableBuilder.Add(NewDataLiteral(params.first));
  3690. callableBuilder.Add(NewDataLiteral(params.second));
  3691. }
  3692. return TRuntimeNode(callableBuilder.Build(), false);
  3693. }
  3694. TRuntimeNode TProgramBuilder::StrictFromString(TRuntimeNode data, TType* type) {
  3695. bool isOptional;
  3696. const auto sourceType = UnpackOptionalData(data, isOptional);
  3697. const auto targetType = UnpackOptionalData(type, isOptional);
  3698. MKQL_ENSURE(sourceType->GetSchemeType() == NUdf::TDataType<char*>::Id || sourceType->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected String");
  3699. MKQL_ENSURE(targetType->GetSchemeType() != 0, "Null is not allowed");
  3700. TCallableBuilder callableBuilder(Env, __func__, type);
  3701. callableBuilder.Add(data);
  3702. callableBuilder.Add(NewDataLiteral(static_cast<ui32>(targetType->GetSchemeType())));
  3703. if (targetType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3704. const auto& params = static_cast<const TDataDecimalType*>(targetType)->GetParams();
  3705. callableBuilder.Add(NewDataLiteral(params.first));
  3706. callableBuilder.Add(NewDataLiteral(params.second));
  3707. }
  3708. return TRuntimeNode(callableBuilder.Build(), false);
  3709. }
  3710. TRuntimeNode TProgramBuilder::ToBytes(TRuntimeNode data) {
  3711. return UnaryDataFunction(data, __func__, TDataFunctionFlags::HasStringResult | TDataFunctionFlags::AllowOptionalArgs | TDataFunctionFlags::CommonOptionalResult);
  3712. }
  3713. TRuntimeNode TProgramBuilder::FromBytes(TRuntimeNode data, TType* targetType) {
  3714. auto type = data.GetStaticType();
  3715. bool isOptional;
  3716. auto dataType = UnpackOptionalData(type, isOptional);
  3717. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expected String");
  3718. auto resultType = NewOptionalType(targetType);
  3719. TCallableBuilder callableBuilder(Env, __func__, resultType);
  3720. callableBuilder.Add(data);
  3721. auto targetDataType = AS_TYPE(TDataType, targetType);
  3722. callableBuilder.Add(NewDataLiteral(static_cast<ui32>(targetDataType->GetSchemeType())));
  3723. if (targetDataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3724. const auto& params = static_cast<const TDataDecimalType*>(targetType)->GetParams();
  3725. callableBuilder.Add(NewDataLiteral(params.first));
  3726. callableBuilder.Add(NewDataLiteral(params.second));
  3727. }
  3728. return TRuntimeNode(callableBuilder.Build(), false);
  3729. }
  3730. TRuntimeNode TProgramBuilder::InversePresortString(TRuntimeNode data) {
  3731. const std::array<TRuntimeNode, 1U> args = {{ data }};
  3732. return Invoke(__func__, NewDataType(NUdf::TDataType<char*>::Id), args);
  3733. }
  3734. TRuntimeNode TProgramBuilder::InverseString(TRuntimeNode data) {
  3735. const std::array<TRuntimeNode, 1U> args = {{ data }};
  3736. return Invoke(__func__, NewDataType(NUdf::TDataType<char*>::Id), args);
  3737. }
  3738. TRuntimeNode TProgramBuilder::Random(const TArrayRef<const TRuntimeNode>& dependentNodes) {
  3739. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<double>::Id));
  3740. for (auto& x : dependentNodes) {
  3741. callableBuilder.Add(x);
  3742. }
  3743. return TRuntimeNode(callableBuilder.Build(), false);
  3744. }
  3745. TRuntimeNode TProgramBuilder::RandomNumber(const TArrayRef<const TRuntimeNode>& dependentNodes) {
  3746. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<ui64>::Id));
  3747. for (auto& x : dependentNodes) {
  3748. callableBuilder.Add(x);
  3749. }
  3750. return TRuntimeNode(callableBuilder.Build(), false);
  3751. }
  3752. TRuntimeNode TProgramBuilder::RandomUuid(const TArrayRef<const TRuntimeNode>& dependentNodes) {
  3753. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<NUdf::TUuid>::Id));
  3754. for (auto& x : dependentNodes) {
  3755. callableBuilder.Add(x);
  3756. }
  3757. return TRuntimeNode(callableBuilder.Build(), false);
  3758. }
  3759. TRuntimeNode TProgramBuilder::Now(const TArrayRef<const TRuntimeNode>& args) {
  3760. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::TDataType<ui64>::Id));
  3761. for (const auto& x : args) {
  3762. callableBuilder.Add(x);
  3763. }
  3764. return TRuntimeNode(callableBuilder.Build(), false);
  3765. }
  3766. TRuntimeNode TProgramBuilder::CurrentUtcDate(const TArrayRef<const TRuntimeNode>& args) {
  3767. return Cast(CurrentUtcTimestamp(args), NewDataType(NUdf::TDataType<NUdf::TDate>::Id));
  3768. }
  3769. TRuntimeNode TProgramBuilder::CurrentUtcDatetime(const TArrayRef<const TRuntimeNode>& args) {
  3770. return Cast(CurrentUtcTimestamp(args), NewDataType(NUdf::TDataType<NUdf::TDatetime>::Id));
  3771. }
  3772. TRuntimeNode TProgramBuilder::CurrentUtcTimestamp(const TArrayRef<const TRuntimeNode>& args) {
  3773. return Coalesce(ToIntegral(Now(args), NewDataType(NUdf::TDataType<NUdf::TTimestamp>::Id, true)),
  3774. TRuntimeNode(BuildDataLiteral(NUdf::TUnboxedValuePod(ui64(NUdf::MAX_TIMESTAMP - 1ULL)), NUdf::TDataType<NUdf::TTimestamp>::Id, Env), true));
  3775. }
  3776. TRuntimeNode TProgramBuilder::Pickle(TRuntimeNode data) {
  3777. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::EDataSlot::String));
  3778. callableBuilder.Add(data);
  3779. return TRuntimeNode(callableBuilder.Build(), false);
  3780. }
  3781. TRuntimeNode TProgramBuilder::StablePickle(TRuntimeNode data) {
  3782. TCallableBuilder callableBuilder(Env, __func__, NewDataType(NUdf::EDataSlot::String));
  3783. callableBuilder.Add(data);
  3784. return TRuntimeNode(callableBuilder.Build(), false);
  3785. }
  3786. TRuntimeNode TProgramBuilder::Unpickle(TType* type, TRuntimeNode serialized) {
  3787. MKQL_ENSURE(AS_TYPE(TDataType, serialized)->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expected String");
  3788. TCallableBuilder callableBuilder(Env, __func__, type);
  3789. callableBuilder.Add(TRuntimeNode(type, true));
  3790. callableBuilder.Add(serialized);
  3791. return TRuntimeNode(callableBuilder.Build(), false);
  3792. }
  3793. TRuntimeNode TProgramBuilder::Ascending(TRuntimeNode data) {
  3794. auto dataType = NewDataType(NUdf::EDataSlot::String);
  3795. TCallableBuilder callableBuilder(Env, __func__, dataType);
  3796. callableBuilder.Add(data);
  3797. return TRuntimeNode(callableBuilder.Build(), false);
  3798. }
  3799. TRuntimeNode TProgramBuilder::Descending(TRuntimeNode data) {
  3800. auto dataType = NewDataType(NUdf::EDataSlot::String);
  3801. TCallableBuilder callableBuilder(Env, __func__, dataType);
  3802. callableBuilder.Add(data);
  3803. return TRuntimeNode(callableBuilder.Build(), false);
  3804. }
  3805. TRuntimeNode TProgramBuilder::Convert(TRuntimeNode data, TType* type) {
  3806. if (data.GetStaticType()->IsSameType(*type)) {
  3807. return data;
  3808. }
  3809. bool isOptional;
  3810. const auto dataType = UnpackOptionalData(data, isOptional);
  3811. const std::array<TRuntimeNode, 1> args = {{ data }};
  3812. if (dataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3813. const auto targetSchemeType = UnpackOptionalData(type, isOptional)->GetSchemeType();
  3814. TStringStream str;
  3815. str << "To" << NUdf::GetDataTypeInfo(NUdf::GetDataSlot(targetSchemeType)).Name
  3816. << '_' << ::ToString(static_cast<const TDataDecimalType*>(dataType)->GetParams().second);
  3817. return Invoke(str.Str().c_str(), type, args);
  3818. }
  3819. return Invoke(__func__, type, args);
  3820. }
  3821. TRuntimeNode TProgramBuilder::ToDecimal(TRuntimeNode data, ui8 precision, ui8 scale) {
  3822. bool isOptional;
  3823. auto dataType = UnpackOptionalData(data, isOptional);
  3824. TType* decimal = TDataDecimalType::Create(precision, scale, Env);
  3825. if (isOptional)
  3826. decimal = TOptionalType::Create(decimal, Env);
  3827. const std::array<TRuntimeNode, 1> args = {{ data }};
  3828. if (dataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3829. const auto& params = static_cast<const TDataDecimalType*>(dataType)->GetParams();
  3830. if (precision - scale < params.first - params.second && scale != params.second) {
  3831. return ToDecimal(ToDecimal(data, precision - scale + params.second, params.second), precision, scale);
  3832. } else if (params.second < scale) {
  3833. return Invoke("ScaleUp_" + ::ToString(scale - params.second), decimal, args);
  3834. } else if (params.second > scale) {
  3835. TRuntimeNode scaled = Invoke("ScaleDown_" + ::ToString(params.second - scale), decimal, args);
  3836. return Invoke("CheckBounds_" + ::ToString(precision), decimal, {{ scaled }});
  3837. } else if (precision < params.first) {
  3838. return Invoke("CheckBounds_" + ::ToString(precision), decimal, args);
  3839. } else if (precision > params.first) {
  3840. return Invoke("Plus", decimal, args);
  3841. } else {
  3842. return data;
  3843. }
  3844. } else {
  3845. const auto digits = NUdf::GetDataTypeInfo(*dataType->GetDataSlot()).DecimalDigits;
  3846. MKQL_ENSURE(digits, "Can't cast into Decimal.");
  3847. if (digits <= precision && !scale)
  3848. return Invoke(__func__, decimal, args);
  3849. else
  3850. return ToDecimal(ToDecimal(data, digits, 0), precision, scale);
  3851. }
  3852. }
  3853. TRuntimeNode TProgramBuilder::ToIntegral(TRuntimeNode data, TType* type) {
  3854. bool isOptional;
  3855. auto dataType = UnpackOptionalData(data, isOptional);
  3856. if (dataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) {
  3857. const auto& params = static_cast<const TDataDecimalType*>(dataType)->GetParams();
  3858. if (params.second)
  3859. return ToIntegral(ToDecimal(data, params.first - params.second, 0), type);
  3860. }
  3861. const std::array<TRuntimeNode, 1> args = {{ data }};
  3862. return Invoke(__func__, type, args);
  3863. }
  3864. TRuntimeNode TProgramBuilder::ListIf(TRuntimeNode predicate, TRuntimeNode item) {
  3865. return If(predicate, NewList(item.GetStaticType(), {item}), NewEmptyList(item.GetStaticType()));
  3866. }
  3867. TRuntimeNode TProgramBuilder::AsList(TRuntimeNode item) {
  3868. TListLiteralBuilder builder(Env, item.GetStaticType());
  3869. builder.Add(item);
  3870. return TRuntimeNode(builder.Build(), true);
  3871. }
  3872. TRuntimeNode TProgramBuilder::AsList(const TArrayRef<const TRuntimeNode>& items) {
  3873. MKQL_ENSURE(!items.empty(), "required not empty list of items");
  3874. TListLiteralBuilder builder(Env, items[0].GetStaticType());
  3875. for (auto item : items) {
  3876. builder.Add(item);
  3877. }
  3878. return TRuntimeNode(builder.Build(), true);
  3879. }
  3880. TRuntimeNode TProgramBuilder::MapJoinCore(TRuntimeNode flow, TRuntimeNode dict, EJoinKind joinKind,
  3881. const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& leftRenames,
  3882. const TArrayRef<const ui32>& rightRenames, TType* returnType) {
  3883. MKQL_ENSURE(joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left || joinKind == EJoinKind::LeftSemi || joinKind == EJoinKind::LeftOnly, "Unsupported join kind");
  3884. MKQL_ENSURE(!leftKeyColumns.empty(), "At least one key column must be specified");
  3885. MKQL_ENSURE(leftRenames.size() % 2U == 0U, "Expected even count");
  3886. MKQL_ENSURE(rightRenames.size() % 2U == 0U, "Expected even count");
  3887. TRuntimeNode::TList leftKeyColumnsNodes, leftRenamesNodes, rightRenamesNodes;
  3888. leftKeyColumnsNodes.reserve(leftKeyColumns.size());
  3889. std::transform(leftKeyColumns.cbegin(), leftKeyColumns.cend(), std::back_inserter(leftKeyColumnsNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  3890. leftRenamesNodes.reserve(leftRenames.size());
  3891. std::transform(leftRenames.cbegin(), leftRenames.cend(), std::back_inserter(leftRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  3892. rightRenamesNodes.reserve(rightRenames.size());
  3893. std::transform(rightRenames.cbegin(), rightRenames.cend(), std::back_inserter(rightRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); });
  3894. TCallableBuilder callableBuilder(Env, __func__, returnType);
  3895. callableBuilder.Add(flow);
  3896. callableBuilder.Add(dict);
  3897. callableBuilder.Add(NewDataLiteral((ui32)joinKind));
  3898. callableBuilder.Add(NewTuple(leftKeyColumnsNodes));
  3899. callableBuilder.Add(NewTuple(leftRenamesNodes));
  3900. callableBuilder.Add(NewTuple(rightRenamesNodes));
  3901. return TRuntimeNode(callableBuilder.Build(), false);
  3902. }
  3903. TRuntimeNode TProgramBuilder::CommonJoinCore(TRuntimeNode flow, EJoinKind joinKind,
  3904. const TArrayRef<const ui32>& leftColumns, const TArrayRef<const ui32>& rightColumns,
  3905. const TArrayRef<const ui32>& requiredColumns, const TArrayRef<const ui32>& keyColumns,
  3906. ui64 memLimit, std::optional<ui32> sortedTableOrder,
  3907. EAnyJoinSettings anyJoinSettings, const ui32 tableIndexField, TType* returnType) {
  3908. if constexpr (RuntimeVersion < 17U) {
  3909. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3910. }
  3911. MKQL_ENSURE(leftColumns.size() % 2U == 0U, "Expected even count");
  3912. MKQL_ENSURE(rightColumns.size() % 2U == 0U, "Expected even count");
  3913. TRuntimeNode::TList leftInputColumnsNodes, rightInputColumnsNodes, requiredColumnsNodes,
  3914. leftOutputColumnsNodes, rightOutputColumnsNodes, keyColumnsNodes;
  3915. bool s = false;
  3916. for (const auto idx : leftColumns) {
  3917. ((s = !s) ? leftInputColumnsNodes : leftOutputColumnsNodes).emplace_back(NewDataLiteral(idx));
  3918. }
  3919. for (const auto idx : rightColumns) {
  3920. ((s = !s) ? rightInputColumnsNodes : rightOutputColumnsNodes).emplace_back(NewDataLiteral(idx));
  3921. }
  3922. const std::unordered_set<ui32> requiredIndices(requiredColumns.cbegin(), requiredColumns.cend());
  3923. MKQL_ENSURE(requiredIndices.size() == requiredColumns.size(), "Duplication of requred columns.");
  3924. requiredColumnsNodes.reserve(requiredColumns.size());
  3925. std::transform(requiredColumns.cbegin(), requiredColumns.cend(), std::back_inserter(requiredColumnsNodes),
  3926. std::bind(&TProgramBuilder::NewDataLiteral<ui32>, this, std::placeholders::_1));
  3927. const std::unordered_set<ui32> keyIndices(keyColumns.cbegin(), keyColumns.cend());
  3928. MKQL_ENSURE(keyIndices.size() == keyColumns.size(), "Duplication of key columns.");
  3929. keyColumnsNodes.reserve(keyColumns.size());
  3930. std::transform(keyColumns.cbegin(), keyColumns.cend(), std::back_inserter(keyColumnsNodes),
  3931. std::bind(&TProgramBuilder::NewDataLiteral<ui32>, this, std::placeholders::_1));
  3932. TCallableBuilder callableBuilder(Env, __func__, returnType);
  3933. callableBuilder.Add(flow);
  3934. callableBuilder.Add(NewDataLiteral((ui32)joinKind));
  3935. callableBuilder.Add(NewTuple(leftInputColumnsNodes));
  3936. callableBuilder.Add(NewTuple(rightInputColumnsNodes));
  3937. callableBuilder.Add(NewTuple(requiredColumnsNodes));
  3938. callableBuilder.Add(NewTuple(leftOutputColumnsNodes));
  3939. callableBuilder.Add(NewTuple(rightOutputColumnsNodes));
  3940. callableBuilder.Add(NewTuple(keyColumnsNodes));
  3941. callableBuilder.Add(NewDataLiteral(memLimit));
  3942. callableBuilder.Add(sortedTableOrder ? NewDataLiteral(*sortedTableOrder) : NewVoid());
  3943. callableBuilder.Add(NewDataLiteral((ui32)anyJoinSettings));
  3944. callableBuilder.Add(NewDataLiteral(tableIndexField));
  3945. return TRuntimeNode(callableBuilder.Build(), false);
  3946. }
  3947. TRuntimeNode TProgramBuilder::WideCombiner(TRuntimeNode flow, i64 memLimit, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) {
  3948. if constexpr (RuntimeVersion < 18U) {
  3949. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  3950. }
  3951. if (memLimit < 0) {
  3952. if constexpr (RuntimeVersion < 46U) {
  3953. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__ << " with limit " << memLimit;
  3954. }
  3955. }
  3956. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  3957. TRuntimeNode::TList itemArgs;
  3958. itemArgs.reserve(wideComponents.size());
  3959. auto i = 0U;
  3960. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  3961. const auto keys = extractor(itemArgs);
  3962. TRuntimeNode::TList keyArgs;
  3963. keyArgs.reserve(keys.size());
  3964. std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } );
  3965. const auto first = init(keyArgs, itemArgs);
  3966. TRuntimeNode::TList stateArgs;
  3967. stateArgs.reserve(first.size());
  3968. std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } );
  3969. const auto next = update(keyArgs, itemArgs, stateArgs);
  3970. MKQL_ENSURE(next.size() == first.size(), "Mismatch init and update state size.");
  3971. TRuntimeNode::TList finishKeyArgs;
  3972. finishKeyArgs.reserve(keys.size());
  3973. std::transform(keys.cbegin(), keys.cend(), std::back_inserter(finishKeyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } );
  3974. TRuntimeNode::TList finishStateArgs;
  3975. finishStateArgs.reserve(next.size());
  3976. std::transform(next.cbegin(), next.cend(), std::back_inserter(finishStateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } );
  3977. const auto output = finish(finishKeyArgs, finishStateArgs);
  3978. std::vector<TType*> tupleItems;
  3979. tupleItems.reserve(output.size());
  3980. std::transform(output.cbegin(), output.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  3981. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(tupleItems)));
  3982. callableBuilder.Add(flow);
  3983. if constexpr (RuntimeVersion < 46U)
  3984. callableBuilder.Add(NewDataLiteral(ui64(memLimit)));
  3985. else
  3986. callableBuilder.Add(NewDataLiteral(memLimit));
  3987. callableBuilder.Add(NewDataLiteral(ui32(keyArgs.size())));
  3988. callableBuilder.Add(NewDataLiteral(ui32(stateArgs.size())));
  3989. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3990. std::for_each(keys.cbegin(), keys.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3991. std::for_each(keyArgs.cbegin(), keyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3992. std::for_each(first.cbegin(), first.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3993. std::for_each(stateArgs.cbegin(), stateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3994. std::for_each(next.cbegin(), next.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3995. std::for_each(finishKeyArgs.cbegin(), finishKeyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3996. std::for_each(finishStateArgs.cbegin(), finishStateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3997. std::for_each(output.cbegin(), output.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  3998. return TRuntimeNode(callableBuilder.Build(), false);
  3999. }
  4000. TRuntimeNode TProgramBuilder::WideLastCombinerCommon(const TStringBuf& funcName, TRuntimeNode flow, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) {
  4001. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  4002. TRuntimeNode::TList itemArgs;
  4003. itemArgs.reserve(wideComponents.size());
  4004. auto i = 0U;
  4005. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  4006. const auto keys = extractor(itemArgs);
  4007. TRuntimeNode::TList keyArgs;
  4008. keyArgs.reserve(keys.size());
  4009. std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } );
  4010. const auto first = init(keyArgs, itemArgs);
  4011. TRuntimeNode::TList stateArgs;
  4012. stateArgs.reserve(first.size());
  4013. std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } );
  4014. const auto next = update(keyArgs, itemArgs, stateArgs);
  4015. MKQL_ENSURE(next.size() == first.size(), "Mismatch init and update state size.");
  4016. TRuntimeNode::TList finishKeyArgs;
  4017. finishKeyArgs.reserve(keys.size());
  4018. std::transform(keys.cbegin(), keys.cend(), std::back_inserter(finishKeyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } );
  4019. TRuntimeNode::TList finishStateArgs;
  4020. finishStateArgs.reserve(next.size());
  4021. std::transform(next.cbegin(), next.cend(), std::back_inserter(finishStateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } );
  4022. const auto output = finish(finishKeyArgs, finishStateArgs);
  4023. std::vector<TType*> tupleItems;
  4024. tupleItems.reserve(output.size());
  4025. std::transform(output.cbegin(), output.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  4026. TCallableBuilder callableBuilder(Env, funcName, NewFlowType(NewMultiType(tupleItems)));
  4027. callableBuilder.Add(flow);
  4028. callableBuilder.Add(NewDataLiteral(ui32(keyArgs.size())));
  4029. callableBuilder.Add(NewDataLiteral(ui32(stateArgs.size())));
  4030. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4031. std::for_each(keys.cbegin(), keys.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4032. std::for_each(keyArgs.cbegin(), keyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4033. std::for_each(first.cbegin(), first.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4034. std::for_each(stateArgs.cbegin(), stateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4035. std::for_each(next.cbegin(), next.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4036. std::for_each(finishKeyArgs.cbegin(), finishKeyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4037. std::for_each(finishStateArgs.cbegin(), finishStateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4038. std::for_each(output.cbegin(), output.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4039. return TRuntimeNode(callableBuilder.Build(), false);
  4040. }
  4041. TRuntimeNode TProgramBuilder::WideLastCombiner(TRuntimeNode flow, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) {
  4042. if constexpr (RuntimeVersion < 29U) {
  4043. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4044. }
  4045. return WideLastCombinerCommon(__func__, flow, extractor, init, update, finish);
  4046. }
  4047. TRuntimeNode TProgramBuilder::WideLastCombinerWithSpilling(TRuntimeNode flow, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) {
  4048. if constexpr (RuntimeVersion < 49U) {
  4049. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4050. }
  4051. return WideLastCombinerCommon(__func__, flow, extractor, init, update, finish);
  4052. }
  4053. TRuntimeNode TProgramBuilder::WideCondense1(TRuntimeNode flow, const TWideLambda& init, const TWideSwitchLambda& switcher, const TBinaryWideLambda& update, bool useCtx) {
  4054. if constexpr (RuntimeVersion < 18U) {
  4055. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4056. }
  4057. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  4058. TRuntimeNode::TList itemArgs;
  4059. itemArgs.reserve(wideComponents.size());
  4060. auto i = 0U;
  4061. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  4062. const auto first = init(itemArgs);
  4063. TRuntimeNode::TList stateArgs;
  4064. stateArgs.reserve(first.size());
  4065. std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } );
  4066. const auto chop = switcher(itemArgs, stateArgs);
  4067. const auto next = update(itemArgs, stateArgs);
  4068. MKQL_ENSURE(next.size() == first.size(), "Mismatch init and update state size.");
  4069. std::vector<TType*> tupleItems;
  4070. tupleItems.reserve(next.size());
  4071. std::transform(next.cbegin(), next.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1));
  4072. TCallableBuilder callableBuilder(Env, __func__, NewFlowType(NewMultiType(tupleItems)));
  4073. callableBuilder.Add(flow);
  4074. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4075. std::for_each(first.cbegin(), first.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4076. std::for_each(stateArgs.cbegin(), stateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4077. callableBuilder.Add(chop);
  4078. std::for_each(next.cbegin(), next.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4079. if (useCtx) {
  4080. MKQL_ENSURE(RuntimeVersion >= 30U, "Too old runtime version");
  4081. callableBuilder.Add(NewDataLiteral<bool>(useCtx));
  4082. }
  4083. return TRuntimeNode(callableBuilder.Build(), false);
  4084. }
  4085. TRuntimeNode TProgramBuilder::CombineCore(TRuntimeNode stream,
  4086. const TUnaryLambda& keyExtractor,
  4087. const TBinaryLambda& init,
  4088. const TTernaryLambda& update,
  4089. const TBinaryLambda& finish,
  4090. ui64 memLimit)
  4091. {
  4092. if constexpr (RuntimeVersion < 3U) {
  4093. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4094. }
  4095. const bool isStream = stream.GetStaticType()->IsStream();
  4096. const auto itemType = isStream ? AS_TYPE(TStreamType, stream)->GetItemType() : AS_TYPE(TFlowType, stream)->GetItemType();
  4097. const auto itemArg = Arg(itemType);
  4098. const auto key = keyExtractor(itemArg);
  4099. const auto keyType = key.GetStaticType();
  4100. const auto keyArg = Arg(keyType);
  4101. const auto stateInit = init(keyArg, itemArg);
  4102. const auto stateType = stateInit.GetStaticType();
  4103. const auto stateArg = Arg(stateType);
  4104. const auto stateUpdate = update(keyArg, itemArg, stateArg);
  4105. const auto finishItem = finish(keyArg, stateArg);
  4106. const auto finishType = finishItem.GetStaticType();
  4107. MKQL_ENSURE(finishType->IsList() || finishType->IsStream() || finishType->IsOptional(), "Expected list, stream or optional");
  4108. TType* retItemType = nullptr;
  4109. if (finishType->IsOptional()) {
  4110. retItemType = AS_TYPE(TOptionalType, finishType)->GetItemType();
  4111. } else if (finishType->IsList()) {
  4112. retItemType = AS_TYPE(TListType, finishType)->GetItemType();
  4113. } else if (finishType->IsStream()) {
  4114. retItemType = AS_TYPE(TStreamType, finishType)->GetItemType();
  4115. }
  4116. const auto resultStreamType = isStream ? NewStreamType(retItemType) : NewFlowType(retItemType);
  4117. TCallableBuilder callableBuilder(Env, __func__, resultStreamType);
  4118. callableBuilder.Add(stream);
  4119. callableBuilder.Add(itemArg);
  4120. callableBuilder.Add(key);
  4121. callableBuilder.Add(keyArg);
  4122. callableBuilder.Add(stateInit);
  4123. callableBuilder.Add(stateArg);
  4124. callableBuilder.Add(stateUpdate);
  4125. callableBuilder.Add(finishItem);
  4126. callableBuilder.Add(NewDataLiteral(memLimit));
  4127. return TRuntimeNode(callableBuilder.Build(), false);
  4128. }
  4129. TRuntimeNode TProgramBuilder::GroupingCore(TRuntimeNode stream,
  4130. const TBinaryLambda& groupSwitch,
  4131. const TUnaryLambda& keyExtractor,
  4132. const TUnaryLambda& handler)
  4133. {
  4134. if (handler && RuntimeVersion < 20U) {
  4135. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__ << " with handler";
  4136. }
  4137. auto itemType = AS_TYPE(TStreamType, stream)->GetItemType();
  4138. TRuntimeNode keyExtractorItemArg = Arg(itemType);
  4139. TRuntimeNode keyExtractorResult = keyExtractor(keyExtractorItemArg);
  4140. TRuntimeNode groupSwitchKeyArg = Arg(keyExtractorResult.GetStaticType());
  4141. TRuntimeNode groupSwitchItemArg = Arg(itemType);
  4142. TRuntimeNode groupSwitchResult = groupSwitch(groupSwitchKeyArg, groupSwitchItemArg);
  4143. MKQL_ENSURE(AS_TYPE(TDataType, groupSwitchResult)->GetSchemeType() == NUdf::TDataType<bool>::Id,
  4144. "Expected bool type");
  4145. TRuntimeNode handlerItemArg;
  4146. TRuntimeNode handlerResult;
  4147. if (handler) {
  4148. handlerItemArg = Arg(itemType);
  4149. handlerResult = handler(handlerItemArg);
  4150. itemType = handlerResult.GetStaticType();
  4151. }
  4152. const std::array<TType*, 2U> tupleItems = {{ keyExtractorResult.GetStaticType(), NewStreamType(itemType) }};
  4153. const auto finishType = NewStreamType(NewTupleType(tupleItems));
  4154. TCallableBuilder callableBuilder(Env, __func__, finishType);
  4155. callableBuilder.Add(stream);
  4156. callableBuilder.Add(keyExtractorResult);
  4157. callableBuilder.Add(groupSwitchResult);
  4158. callableBuilder.Add(keyExtractorItemArg);
  4159. callableBuilder.Add(groupSwitchKeyArg);
  4160. callableBuilder.Add(groupSwitchItemArg);
  4161. if (handler) {
  4162. callableBuilder.Add(handlerResult);
  4163. callableBuilder.Add(handlerItemArg);
  4164. }
  4165. return TRuntimeNode(callableBuilder.Build(), false);
  4166. }
  4167. TRuntimeNode TProgramBuilder::Chopper(TRuntimeNode flow, const TUnaryLambda& keyExtractor, const TBinaryLambda& groupSwitch, const TBinaryLambda& groupHandler) {
  4168. const auto flowType = flow.GetStaticType();
  4169. MKQL_ENSURE(flowType->IsFlow() || flowType->IsStream(), "Expected flow or stream.");
  4170. if constexpr (RuntimeVersion < 9U) {
  4171. return FlatMap(GroupingCore(flow, groupSwitch, keyExtractor),
  4172. [&](TRuntimeNode item) -> TRuntimeNode { return groupHandler(Nth(item, 0U), Nth(item, 1U)); }
  4173. );
  4174. }
  4175. const bool isStream = flowType->IsStream();
  4176. const auto itemType = isStream ? AS_TYPE(TStreamType, flow)->GetItemType() : AS_TYPE(TFlowType, flow)->GetItemType();
  4177. const auto itemArg = Arg(itemType);
  4178. const auto keyExtractorResult = keyExtractor(itemArg);
  4179. const auto keyArg = Arg(keyExtractorResult.GetStaticType());
  4180. const auto groupSwitchResult = groupSwitch(keyArg, itemArg);
  4181. const auto input = Arg(flowType);
  4182. const auto output = groupHandler(keyArg, input);
  4183. TCallableBuilder callableBuilder(Env, __func__, output.GetStaticType());
  4184. callableBuilder.Add(flow);
  4185. callableBuilder.Add(itemArg);
  4186. callableBuilder.Add(keyExtractorResult);
  4187. callableBuilder.Add(keyArg);
  4188. callableBuilder.Add(groupSwitchResult);
  4189. callableBuilder.Add(input);
  4190. callableBuilder.Add(output);
  4191. return TRuntimeNode(callableBuilder.Build(), false);
  4192. }
  4193. TRuntimeNode TProgramBuilder::WideChopper(TRuntimeNode flow, const TWideLambda& extractor, const TWideSwitchLambda& groupSwitch,
  4194. const std::function<TRuntimeNode (TRuntimeNode::TList, TRuntimeNode)>& groupHandler
  4195. ) {
  4196. if constexpr (RuntimeVersion < 18U) {
  4197. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4198. }
  4199. const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType()));
  4200. TRuntimeNode::TList itemArgs, keyArgs;
  4201. itemArgs.reserve(wideComponents.size());
  4202. auto i = 0U;
  4203. std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); });
  4204. const auto keys = extractor(itemArgs);
  4205. keyArgs.reserve(keys.size());
  4206. std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } );
  4207. const auto groupSwitchResult = groupSwitch(keyArgs, itemArgs);
  4208. const auto input = WideFlowArg(flow.GetStaticType());
  4209. const auto output = groupHandler(keyArgs, input);
  4210. TCallableBuilder callableBuilder(Env, __func__, output.GetStaticType());
  4211. callableBuilder.Add(flow);
  4212. std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4213. std::for_each(keys.cbegin(), keys.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4214. std::for_each(keyArgs.cbegin(), keyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1));
  4215. callableBuilder.Add(groupSwitchResult);
  4216. callableBuilder.Add(input);
  4217. callableBuilder.Add(output);
  4218. return TRuntimeNode(callableBuilder.Build(), false);
  4219. }
  4220. TRuntimeNode TProgramBuilder::HoppingCore(TRuntimeNode list,
  4221. const TUnaryLambda& timeExtractor,
  4222. const TUnaryLambda& init,
  4223. const TBinaryLambda& update,
  4224. const TUnaryLambda& save,
  4225. const TUnaryLambda& load,
  4226. const TBinaryLambda& merge,
  4227. const TBinaryLambda& finish,
  4228. TRuntimeNode hop, TRuntimeNode interval, TRuntimeNode delay)
  4229. {
  4230. auto streamType = AS_TYPE(TStreamType, list);
  4231. auto itemType = AS_TYPE(TStructType, streamType->GetItemType());
  4232. auto timestampType = TOptionalType::Create(TDataType::Create(NUdf::TDataType<NUdf::TTimestamp>::Id, Env), Env);
  4233. TRuntimeNode itemArg = Arg(itemType);
  4234. auto outTime = timeExtractor(itemArg);
  4235. auto outStateInit = init(itemArg);
  4236. auto stateType = outStateInit.GetStaticType();
  4237. TRuntimeNode stateArg = Arg(stateType);
  4238. auto outStateUpdate = update(itemArg, stateArg);
  4239. auto hasSaveLoad = (bool)save;
  4240. TRuntimeNode saveArg, outSave, loadArg, outLoad;
  4241. if (hasSaveLoad) {
  4242. saveArg = Arg(stateType);
  4243. outSave = save(saveArg);
  4244. loadArg = Arg(outSave.GetStaticType());
  4245. outLoad = load(loadArg);
  4246. MKQL_ENSURE(outLoad.GetStaticType()->IsSameType(*stateType), "Loaded type is changed by the load handler");
  4247. } else {
  4248. saveArg = outSave = loadArg = outLoad = NewVoid();
  4249. }
  4250. TRuntimeNode state2Arg = Arg(stateType);
  4251. TRuntimeNode timeArg = Arg(timestampType);
  4252. auto outStateMerge = merge(stateArg, state2Arg);
  4253. auto outItemFinish = finish(stateArg, timeArg);
  4254. auto finishType = outItemFinish.GetStaticType();
  4255. MKQL_ENSURE(finishType->IsStruct(), "Expected struct type as finish lambda output");
  4256. auto resultType = TStreamType::Create(outItemFinish.GetStaticType(), Env);
  4257. TCallableBuilder callableBuilder(Env, __func__, resultType);
  4258. callableBuilder.Add(list);
  4259. callableBuilder.Add(itemArg);
  4260. callableBuilder.Add(stateArg);
  4261. callableBuilder.Add(state2Arg);
  4262. callableBuilder.Add(timeArg);
  4263. callableBuilder.Add(saveArg);
  4264. callableBuilder.Add(loadArg);
  4265. callableBuilder.Add(outTime);
  4266. callableBuilder.Add(outStateInit);
  4267. callableBuilder.Add(outStateUpdate);
  4268. callableBuilder.Add(outSave);
  4269. callableBuilder.Add(outLoad);
  4270. callableBuilder.Add(outStateMerge);
  4271. callableBuilder.Add(outItemFinish);
  4272. callableBuilder.Add(hop);
  4273. callableBuilder.Add(interval);
  4274. callableBuilder.Add(delay);
  4275. return TRuntimeNode(callableBuilder.Build(), false);
  4276. }
  4277. TRuntimeNode TProgramBuilder::MultiHoppingCore(TRuntimeNode list,
  4278. const TUnaryLambda& keyExtractor,
  4279. const TUnaryLambda& timeExtractor,
  4280. const TUnaryLambda& init,
  4281. const TBinaryLambda& update,
  4282. const TUnaryLambda& save,
  4283. const TUnaryLambda& load,
  4284. const TBinaryLambda& merge,
  4285. const TTernaryLambda& finish,
  4286. TRuntimeNode hop, TRuntimeNode interval, TRuntimeNode delay,
  4287. TRuntimeNode dataWatermarks, TRuntimeNode watermarksMode)
  4288. {
  4289. if constexpr (RuntimeVersion < 22U) {
  4290. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4291. }
  4292. auto streamType = AS_TYPE(TStreamType, list);
  4293. auto itemType = AS_TYPE(TStructType, streamType->GetItemType());
  4294. auto timestampType = TOptionalType::Create(TDataType::Create(NUdf::TDataType<NUdf::TTimestamp>::Id, Env), Env);
  4295. TRuntimeNode itemArg = Arg(itemType);
  4296. auto keyExtract = keyExtractor(itemArg);
  4297. auto keyType = keyExtract.GetStaticType();
  4298. TRuntimeNode keyArg = Arg(keyType);
  4299. auto outTime = timeExtractor(itemArg);
  4300. auto outStateInit = init(itemArg);
  4301. auto stateType = outStateInit.GetStaticType();
  4302. TRuntimeNode stateArg = Arg(stateType);
  4303. auto outStateUpdate = update(itemArg, stateArg);
  4304. auto hasSaveLoad = (bool)save;
  4305. TRuntimeNode saveArg, outSave, loadArg, outLoad;
  4306. if (hasSaveLoad) {
  4307. saveArg = Arg(stateType);
  4308. outSave = save(saveArg);
  4309. loadArg = Arg(outSave.GetStaticType());
  4310. outLoad = load(loadArg);
  4311. MKQL_ENSURE(outLoad.GetStaticType()->IsSameType(*stateType), "Loaded type is changed by the load handler");
  4312. } else {
  4313. saveArg = outSave = loadArg = outLoad = NewVoid();
  4314. }
  4315. TRuntimeNode state2Arg = Arg(stateType);
  4316. TRuntimeNode timeArg = Arg(timestampType);
  4317. auto outStateMerge = merge(stateArg, state2Arg);
  4318. auto outItemFinish = finish(keyArg, stateArg, timeArg);
  4319. auto finishType = outItemFinish.GetStaticType();
  4320. MKQL_ENSURE(finishType->IsStruct(), "Expected struct type as finish lambda output");
  4321. auto resultType = TStreamType::Create(outItemFinish.GetStaticType(), Env);
  4322. TCallableBuilder callableBuilder(Env, __func__, resultType);
  4323. callableBuilder.Add(list);
  4324. callableBuilder.Add(itemArg);
  4325. callableBuilder.Add(keyArg);
  4326. callableBuilder.Add(stateArg);
  4327. callableBuilder.Add(state2Arg);
  4328. callableBuilder.Add(timeArg);
  4329. callableBuilder.Add(saveArg);
  4330. callableBuilder.Add(loadArg);
  4331. callableBuilder.Add(keyExtract);
  4332. callableBuilder.Add(outTime);
  4333. callableBuilder.Add(outStateInit);
  4334. callableBuilder.Add(outStateUpdate);
  4335. callableBuilder.Add(outSave);
  4336. callableBuilder.Add(outLoad);
  4337. callableBuilder.Add(outStateMerge);
  4338. callableBuilder.Add(outItemFinish);
  4339. callableBuilder.Add(hop);
  4340. callableBuilder.Add(interval);
  4341. callableBuilder.Add(delay);
  4342. callableBuilder.Add(dataWatermarks);
  4343. callableBuilder.Add(watermarksMode);
  4344. return TRuntimeNode(callableBuilder.Build(), false);
  4345. }
  4346. TRuntimeNode TProgramBuilder::Default(TType* type) {
  4347. bool isOptional;
  4348. const auto targetType = UnpackOptionalData(type, isOptional);
  4349. if (isOptional) {
  4350. return NewOptional(Default(targetType));
  4351. }
  4352. const auto scheme = targetType->GetSchemeType();
  4353. const auto value = scheme == NUdf::TDataType<NUdf::TUuid>::Id ?
  4354. Env.NewStringValue("\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"sv) :
  4355. scheme == NUdf::TDataType<NUdf::TDyNumber>::Id ? NUdf::TUnboxedValuePod::Embedded("\1") : NUdf::TUnboxedValuePod::Zero();
  4356. return TRuntimeNode(TDataLiteral::Create(value, targetType, Env), true);
  4357. }
  4358. TRuntimeNode TProgramBuilder::Cast(TRuntimeNode arg, TType* type) {
  4359. if (arg.GetStaticType()->IsSameType(*type)) {
  4360. return arg;
  4361. }
  4362. bool isOptional;
  4363. const auto targetType = UnpackOptionalData(type, isOptional);
  4364. const auto sourceType = UnpackOptionalData(arg, isOptional);
  4365. const auto sId = sourceType->GetSchemeType();
  4366. const auto tId = targetType->GetSchemeType();
  4367. if (sId == NUdf::TDataType<char*>::Id) {
  4368. if (tId != NUdf::TDataType<char*>::Id) {
  4369. return FromString(arg, type);
  4370. } else {
  4371. return arg;
  4372. }
  4373. }
  4374. if (sId == NUdf::TDataType<NUdf::TUtf8>::Id) {
  4375. if (tId != NUdf::TDataType<char*>::Id) {
  4376. return FromString(arg, type);
  4377. } else {
  4378. return ToString(arg);
  4379. }
  4380. }
  4381. if (tId == NUdf::TDataType<char*>::Id) {
  4382. return ToString(arg);
  4383. }
  4384. if (tId == NUdf::TDataType<NUdf::TUtf8>::Id) {
  4385. return ToString<true>(arg);
  4386. }
  4387. if (tId == NUdf::TDataType<NUdf::TDecimal>::Id) {
  4388. const auto& params = static_cast<const TDataDecimalType*>(targetType)->GetParams();
  4389. return ToDecimal(arg, params.first, params.second);
  4390. }
  4391. const auto options = NKikimr::NUdf::GetCastResult(*sourceType->GetDataSlot(), *targetType->GetDataSlot());
  4392. MKQL_ENSURE((*options & NKikimr::NUdf::ECastOptions::Undefined) ||
  4393. !(*options & NKikimr::NUdf::ECastOptions::Impossible),
  4394. "Impossible to cast " << *static_cast<TType*>(sourceType) << " into " << *static_cast<TType*>(targetType));
  4395. const bool useToIntegral = (*options & NKikimr::NUdf::ECastOptions::Undefined) ||
  4396. (*options & NKikimr::NUdf::ECastOptions::MayFail);
  4397. return useToIntegral ? ToIntegral(arg, type) : Convert(arg, type);
  4398. }
  4399. TRuntimeNode TProgramBuilder::RangeCreate(TRuntimeNode list) {
  4400. MKQL_ENSURE(list.GetStaticType()->IsList(), "Expecting list");
  4401. auto itemType = static_cast<TListType*>(list.GetStaticType())->GetItemType();
  4402. MKQL_ENSURE(itemType->IsTuple(), "Expecting list of tuples");
  4403. auto tupleType = static_cast<TTupleType*>(itemType);
  4404. MKQL_ENSURE(tupleType->GetElementsCount() == 2,
  4405. "Expecting list ot 2-element tuples, got: " << tupleType->GetElementsCount() << " elements");
  4406. MKQL_ENSURE(tupleType->GetElementType(0)->IsSameType(*tupleType->GetElementType(1)),
  4407. "Expecting list ot 2-element tuples of same type");
  4408. MKQL_ENSURE(tupleType->GetElementType(0)->IsTuple(),
  4409. "Expecting range boundary to be tuple");
  4410. auto boundaryType = static_cast<TTupleType*>(tupleType->GetElementType(0));
  4411. MKQL_ENSURE(boundaryType->GetElementsCount() >= 2,
  4412. "Range boundary should have at least 2 components, got: " << boundaryType->GetElementsCount());
  4413. auto lastComp = boundaryType->GetElementType(boundaryType->GetElementsCount() - 1);
  4414. std::vector<TType*> outputComponents;
  4415. for (ui32 i = 0; i < boundaryType->GetElementsCount() - 1; ++i) {
  4416. outputComponents.push_back(lastComp);
  4417. outputComponents.push_back(boundaryType->GetElementType(i));
  4418. }
  4419. outputComponents.push_back(lastComp);
  4420. auto outputBoundary = TTupleType::Create(outputComponents.size(), &outputComponents.front(), Env);
  4421. std::vector<TType*> outputRangeComps(2, outputBoundary);
  4422. auto outputRange = TTupleType::Create(outputRangeComps.size(), &outputRangeComps.front(), Env);
  4423. TCallableBuilder callableBuilder(Env, __func__, TListType::Create(outputRange, Env));
  4424. callableBuilder.Add(list);
  4425. return TRuntimeNode(callableBuilder.Build(), false);
  4426. }
  4427. TRuntimeNode TProgramBuilder::RangeUnion(const TArrayRef<const TRuntimeNode>& lists) {
  4428. return BuildRangeLogical(__func__, lists);
  4429. }
  4430. TRuntimeNode TProgramBuilder::RangeIntersect(const TArrayRef<const TRuntimeNode>& lists) {
  4431. return BuildRangeLogical(__func__, lists);
  4432. }
  4433. TRuntimeNode TProgramBuilder::RangeMultiply(const TArrayRef<const TRuntimeNode>& args) {
  4434. MKQL_ENSURE(args.size() >= 2, "Expecting at least two arguments");
  4435. bool unlimited = false;
  4436. if (args.front().GetStaticType()->IsVoid()) {
  4437. unlimited = true;
  4438. } else {
  4439. MKQL_ENSURE(args.front().GetStaticType()->IsData() &&
  4440. static_cast<TDataType*>(args.front().GetStaticType())->GetSchemeType() == NUdf::TDataType<ui64>::Id,
  4441. "Expected ui64 as first argument");
  4442. }
  4443. std::vector<TType*> outputComponents;
  4444. for (size_t i = 1; i < args.size(); ++i) {
  4445. const auto& list = args[i];
  4446. MKQL_ENSURE(list.GetStaticType()->IsList(), "Expecting list");
  4447. auto listItemType = static_cast<TListType*>(list.GetStaticType())->GetItemType();
  4448. MKQL_ENSURE(listItemType->IsTuple(), "Expecting list of tuples");
  4449. auto rangeType = static_cast<TTupleType*>(listItemType);
  4450. MKQL_ENSURE(rangeType->GetElementsCount() == 2, "Expecting list of 2-element tuples");
  4451. MKQL_ENSURE(rangeType->GetElementType(0)->IsTuple(), "Range boundary should be tuple");
  4452. auto boundaryType = static_cast<TTupleType*>(rangeType->GetElementType(0));
  4453. ui32 elementsCount = boundaryType->GetElementsCount();
  4454. MKQL_ENSURE(elementsCount >= 3 && elementsCount % 2 == 1, "Range boundary should have odd number components (at least 3)");
  4455. for (size_t j = 0; j < elementsCount - 1; ++j) {
  4456. outputComponents.push_back(boundaryType->GetElementType(j));
  4457. }
  4458. }
  4459. outputComponents.push_back(TDataType::Create(NUdf::TDataType<i32>::Id, Env));
  4460. auto outputBoundary = TTupleType::Create(outputComponents.size(), &outputComponents.front(), Env);
  4461. std::vector<TType*> outputRangeComps(2, outputBoundary);
  4462. auto outputRange = TTupleType::Create(outputRangeComps.size(), &outputRangeComps.front(), Env);
  4463. TCallableBuilder callableBuilder(Env, __func__, TListType::Create(outputRange, Env));
  4464. if (unlimited) {
  4465. callableBuilder.Add(NewDataLiteral<ui64>(std::numeric_limits<ui64>::max()));
  4466. } else {
  4467. callableBuilder.Add(args[0]);
  4468. }
  4469. for (size_t i = 1; i < args.size(); ++i) {
  4470. callableBuilder.Add(args[i]);
  4471. }
  4472. return TRuntimeNode(callableBuilder.Build(), false);
  4473. }
  4474. TRuntimeNode TProgramBuilder::RangeFinalize(TRuntimeNode list) {
  4475. MKQL_ENSURE(list.GetStaticType()->IsList(), "Expecting list");
  4476. auto listItemType = static_cast<TListType*>(list.GetStaticType())->GetItemType();
  4477. MKQL_ENSURE(listItemType->IsTuple(), "Expecting list of tuples");
  4478. auto rangeType = static_cast<TTupleType*>(listItemType);
  4479. MKQL_ENSURE(rangeType->GetElementsCount() == 2, "Expecting list of 2-element tuples");
  4480. MKQL_ENSURE(rangeType->GetElementType(0)->IsTuple(), "Range boundary should be tuple");
  4481. auto boundaryType = static_cast<TTupleType*>(rangeType->GetElementType(0));
  4482. ui32 elementsCount = boundaryType->GetElementsCount();
  4483. MKQL_ENSURE(elementsCount >= 3 && elementsCount % 2 == 1, "Range boundary should have odd number components (at least 3)");
  4484. std::vector<TType*> outputComponents;
  4485. for (ui32 i = 0; i < elementsCount; ++i) {
  4486. if (i % 2 == 1 || i + 1 == elementsCount) {
  4487. outputComponents.push_back(boundaryType->GetElementType(i));
  4488. }
  4489. }
  4490. auto outputBoundary = TTupleType::Create(outputComponents.size(), &outputComponents.front(), Env);
  4491. std::vector<TType*> outputRangeComps(2, outputBoundary);
  4492. auto outputRange = TTupleType::Create(outputRangeComps.size(), &outputRangeComps.front(), Env);
  4493. TCallableBuilder callableBuilder(Env, __func__, TListType::Create(outputRange, Env));
  4494. callableBuilder.Add(list);
  4495. return TRuntimeNode(callableBuilder.Build(), false);
  4496. }
  4497. TRuntimeNode TProgramBuilder::Round(const std::string_view& callableName, TRuntimeNode source, TType* targetType) {
  4498. const auto sourceType = source.GetStaticType();
  4499. MKQL_ENSURE(sourceType->IsData(), "Expecting first arg to be of Data type");
  4500. MKQL_ENSURE(targetType->IsData(), "Expecting second arg to be Data type");
  4501. const auto ss = *static_cast<TDataType*>(sourceType)->GetDataSlot();
  4502. const auto ts = *static_cast<TDataType*>(targetType)->GetDataSlot();
  4503. const auto options = NKikimr::NUdf::GetCastResult(ss, ts);
  4504. MKQL_ENSURE(!(*options & NKikimr::NUdf::ECastOptions::Impossible),
  4505. "Impossible to cast " << *sourceType << " into " << *targetType);
  4506. MKQL_ENSURE(*options & (NKikimr::NUdf::ECastOptions::MayFail |
  4507. NKikimr::NUdf::ECastOptions::MayLoseData |
  4508. NKikimr::NUdf::ECastOptions::AnywayLoseData),
  4509. "Rounding from " << *sourceType << " to " << *targetType << " is trivial");
  4510. TCallableBuilder callableBuilder(Env, callableName, TOptionalType::Create(targetType, Env));
  4511. callableBuilder.Add(source);
  4512. return TRuntimeNode(callableBuilder.Build(), false);
  4513. }
  4514. TRuntimeNode TProgramBuilder::NextValue(TRuntimeNode value) {
  4515. const auto valueType = value.GetStaticType();
  4516. MKQL_ENSURE(valueType->IsData(), "Expecting argument of Data type");
  4517. const auto slot = *static_cast<TDataType*>(valueType)->GetDataSlot();
  4518. MKQL_ENSURE(slot == NUdf::EDataSlot::String || slot == NUdf::EDataSlot::Utf8,
  4519. "Unsupported type: " << *valueType);
  4520. TCallableBuilder callableBuilder(Env, __func__, TOptionalType::Create(valueType, Env));
  4521. callableBuilder.Add(value);
  4522. return TRuntimeNode(callableBuilder.Build(), false);
  4523. }
  4524. TRuntimeNode TProgramBuilder::Nop(TRuntimeNode value, TType* returnType) {
  4525. if constexpr (RuntimeVersion < 35U) {
  4526. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4527. }
  4528. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4529. callableBuilder.Add(value);
  4530. return TRuntimeNode(callableBuilder.Build(), false);
  4531. }
  4532. bool TProgramBuilder::IsNull(TRuntimeNode arg) {
  4533. return arg.GetStaticType()->IsSameType(*NewNull().GetStaticType()); // TODO ->IsNull();
  4534. }
  4535. TRuntimeNode TProgramBuilder::Replicate(TRuntimeNode item, TRuntimeNode count, const std::string_view& file, ui32 row, ui32 column) {
  4536. MKQL_ENSURE(count.GetStaticType()->IsData(), "Expected data");
  4537. MKQL_ENSURE(static_cast<const TDataType&>(*count.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  4538. const auto listType = TListType::Create(item.GetStaticType(), Env);
  4539. TCallableBuilder callableBuilder(Env, __func__, listType);
  4540. callableBuilder.Add(item);
  4541. callableBuilder.Add(count);
  4542. if constexpr (RuntimeVersion >= 2) {
  4543. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(file));
  4544. callableBuilder.Add(NewDataLiteral(row));
  4545. callableBuilder.Add(NewDataLiteral(column));
  4546. }
  4547. return TRuntimeNode(callableBuilder.Build(), false);
  4548. }
  4549. TRuntimeNode TProgramBuilder::PgConst(TPgType* pgType, const std::string_view& value, TRuntimeNode typeMod) {
  4550. if constexpr (RuntimeVersion < 30U) {
  4551. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4552. }
  4553. TCallableBuilder callableBuilder(Env, __func__, pgType);
  4554. callableBuilder.Add(NewDataLiteral(pgType->GetTypeId()));
  4555. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(value));
  4556. if (typeMod) {
  4557. callableBuilder.Add(typeMod);
  4558. }
  4559. return TRuntimeNode(callableBuilder.Build(), false);
  4560. }
  4561. TRuntimeNode TProgramBuilder::PgResolvedCall(bool useContext, const std::string_view& name,
  4562. ui32 id, const TArrayRef<const TRuntimeNode>& args,
  4563. TType* returnType, bool rangeFunction) {
  4564. if constexpr (RuntimeVersion < 45U) {
  4565. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4566. }
  4567. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4568. callableBuilder.Add(NewDataLiteral(useContext));
  4569. callableBuilder.Add(NewDataLiteral(rangeFunction));
  4570. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(name));
  4571. callableBuilder.Add(NewDataLiteral(id));
  4572. for (const auto& arg : args) {
  4573. callableBuilder.Add(arg);
  4574. }
  4575. return TRuntimeNode(callableBuilder.Build(), false);
  4576. }
  4577. TRuntimeNode TProgramBuilder::BlockPgResolvedCall(const std::string_view& name, ui32 id,
  4578. const TArrayRef<const TRuntimeNode>& args, TType* returnType) {
  4579. if constexpr (RuntimeVersion < 30U) {
  4580. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4581. }
  4582. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4583. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(name));
  4584. callableBuilder.Add(NewDataLiteral(id));
  4585. for (const auto& arg : args) {
  4586. callableBuilder.Add(arg);
  4587. }
  4588. return TRuntimeNode(callableBuilder.Build(), false);
  4589. }
  4590. TRuntimeNode TProgramBuilder::PgArray(const TArrayRef<const TRuntimeNode>& args, TType* returnType) {
  4591. if constexpr (RuntimeVersion < 30U) {
  4592. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4593. }
  4594. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4595. for (const auto& arg : args) {
  4596. callableBuilder.Add(arg);
  4597. }
  4598. return TRuntimeNode(callableBuilder.Build(), false);
  4599. }
  4600. TRuntimeNode TProgramBuilder::PgTableContent(
  4601. const std::string_view& cluster,
  4602. const std::string_view& table,
  4603. TType* returnType) {
  4604. if constexpr (RuntimeVersion < 47U) {
  4605. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4606. }
  4607. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4608. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(cluster));
  4609. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(table));
  4610. return TRuntimeNode(callableBuilder.Build(), false);
  4611. }
  4612. TRuntimeNode TProgramBuilder::PgToRecord(TRuntimeNode input, const TArrayRef<std::pair<std::string_view, std::string_view>>& members) {
  4613. if constexpr (RuntimeVersion < 48U) {
  4614. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4615. }
  4616. MKQL_ENSURE(input.GetStaticType()->IsStruct(), "Expected struct");
  4617. auto structType = AS_TYPE(TStructType, input.GetStaticType());
  4618. for (ui32 i = 0; i < structType->GetMembersCount(); ++i) {
  4619. auto itemType = structType->GetMemberType(i);
  4620. MKQL_ENSURE(itemType->IsNull() || itemType->IsPg(), "Expected null or pg");
  4621. }
  4622. auto returnType = NewPgType(NYql::NPg::LookupType("record").TypeId);
  4623. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4624. callableBuilder.Add(input);
  4625. TVector<TRuntimeNode> names;
  4626. for (const auto& x : members) {
  4627. names.push_back(NewDataLiteral<NUdf::EDataSlot::String>(x.first));
  4628. names.push_back(NewDataLiteral<NUdf::EDataSlot::String>(x.second));
  4629. }
  4630. callableBuilder.Add(NewTuple(names));
  4631. return TRuntimeNode(callableBuilder.Build(), false);
  4632. }
  4633. TRuntimeNode TProgramBuilder::PgCast(TRuntimeNode input, TType* returnType, TRuntimeNode typeMod) {
  4634. if constexpr (RuntimeVersion < 30U) {
  4635. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4636. }
  4637. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4638. callableBuilder.Add(input);
  4639. if (typeMod) {
  4640. callableBuilder.Add(typeMod);
  4641. }
  4642. return TRuntimeNode(callableBuilder.Build(), false);
  4643. }
  4644. TRuntimeNode TProgramBuilder::FromPg(TRuntimeNode input, TType* returnType) {
  4645. if constexpr (RuntimeVersion < 30U) {
  4646. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4647. }
  4648. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4649. callableBuilder.Add(input);
  4650. return TRuntimeNode(callableBuilder.Build(), false);
  4651. }
  4652. TRuntimeNode TProgramBuilder::ToPg(TRuntimeNode input, TType* returnType) {
  4653. if constexpr (RuntimeVersion < 30U) {
  4654. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4655. }
  4656. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4657. callableBuilder.Add(input);
  4658. return TRuntimeNode(callableBuilder.Build(), false);
  4659. }
  4660. TRuntimeNode TProgramBuilder::PgClone(TRuntimeNode input, const TArrayRef<const TRuntimeNode>& dependentNodes) {
  4661. if constexpr (RuntimeVersion < 38U) {
  4662. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4663. }
  4664. TCallableBuilder callableBuilder(Env, __func__, input.GetStaticType());
  4665. callableBuilder.Add(input);
  4666. for (const auto& node : dependentNodes) {
  4667. callableBuilder.Add(node);
  4668. }
  4669. return TRuntimeNode(callableBuilder.Build(), false);
  4670. }
  4671. TRuntimeNode TProgramBuilder::WithContext(TRuntimeNode input, const std::string_view& contextType) {
  4672. if constexpr (RuntimeVersion < 30U) {
  4673. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4674. }
  4675. TCallableBuilder callableBuilder(Env, __func__, input.GetStaticType());
  4676. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(contextType));
  4677. callableBuilder.Add(input);
  4678. return TRuntimeNode(callableBuilder.Build(), false);
  4679. }
  4680. TRuntimeNode TProgramBuilder::PgInternal0(TType* returnType) {
  4681. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4682. return TRuntimeNode(callableBuilder.Build(), false);
  4683. }
  4684. TRuntimeNode TProgramBuilder::BlockIf(TRuntimeNode condition, TRuntimeNode thenBranch, TRuntimeNode elseBranch) {
  4685. const auto conditionType = AS_TYPE(TBlockType, condition.GetStaticType());
  4686. MKQL_ENSURE(AS_TYPE(TDataType, conditionType->GetItemType())->GetSchemeType() == NUdf::TDataType<bool>::Id,
  4687. "Expected bool as first argument");
  4688. const auto thenType = AS_TYPE(TBlockType, thenBranch.GetStaticType());
  4689. const auto elseType = AS_TYPE(TBlockType, elseBranch.GetStaticType());
  4690. MKQL_ENSURE(thenType->GetItemType()->IsSameType(*elseType->GetItemType()), "Different return types in branches.");
  4691. auto returnType = NewBlockType(thenType->GetItemType(), GetResultShape({conditionType, thenType, elseType}));
  4692. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4693. callableBuilder.Add(condition);
  4694. callableBuilder.Add(thenBranch);
  4695. callableBuilder.Add(elseBranch);
  4696. return TRuntimeNode(callableBuilder.Build(), false);
  4697. }
  4698. TRuntimeNode TProgramBuilder::BlockJust(TRuntimeNode data) {
  4699. const auto initialType = AS_TYPE(TBlockType, data.GetStaticType());
  4700. auto returnType = NewBlockType(NewOptionalType(initialType->GetItemType()), initialType->GetShape());
  4701. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4702. callableBuilder.Add(data);
  4703. return TRuntimeNode(callableBuilder.Build(), false);
  4704. }
  4705. TRuntimeNode TProgramBuilder::BlockFunc(const std::string_view& funcName, TType* returnType, const TArrayRef<const TRuntimeNode>& args) {
  4706. for (const auto& arg : args) {
  4707. MKQL_ENSURE(arg.GetStaticType()->IsBlock(), "Expected Block type");
  4708. }
  4709. TCallableBuilder builder(Env, __func__, returnType);
  4710. builder.Add(NewDataLiteral<NUdf::EDataSlot::String>(funcName));
  4711. for (const auto& arg : args) {
  4712. builder.Add(arg);
  4713. }
  4714. return TRuntimeNode(builder.Build(), false);
  4715. }
  4716. TRuntimeNode TProgramBuilder::BuildBlockCombineAll(const std::string_view& callableName, TRuntimeNode input, std::optional<ui32> filterColumn,
  4717. const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4718. const auto inputType = input.GetStaticType();
  4719. MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type");
  4720. MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type");
  4721. TCallableBuilder builder(Env, callableName, returnType);
  4722. builder.Add(input);
  4723. if (!filterColumn) {
  4724. builder.Add(NewEmptyOptionalDataLiteral(NUdf::TDataType<ui32>::Id));
  4725. } else {
  4726. builder.Add(NewOptional(NewDataLiteral<ui32>(*filterColumn)));
  4727. }
  4728. TVector<TRuntimeNode> aggsNodes;
  4729. for (const auto& agg : aggs) {
  4730. TVector<TRuntimeNode> params;
  4731. params.push_back(NewDataLiteral<NUdf::EDataSlot::String>(agg.Name));
  4732. for (const auto& col : agg.ArgsColumns) {
  4733. params.push_back(NewDataLiteral<ui32>(col));
  4734. }
  4735. aggsNodes.push_back(NewTuple(params));
  4736. }
  4737. builder.Add(NewTuple(aggsNodes));
  4738. return TRuntimeNode(builder.Build(), false);
  4739. }
  4740. TRuntimeNode TProgramBuilder::BlockCombineAll(TRuntimeNode stream, std::optional<ui32> filterColumn,
  4741. const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4742. if constexpr (RuntimeVersion < 31U) {
  4743. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4744. }
  4745. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type");
  4746. MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type");
  4747. if constexpr (RuntimeVersion < 52U) {
  4748. const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType());
  4749. return FromFlow(BuildBlockCombineAll(__func__, ToFlow(stream), filterColumn, aggs, flowReturnType));
  4750. } else {
  4751. return BuildBlockCombineAll(__func__, stream, filterColumn, aggs, returnType);
  4752. }
  4753. }
  4754. TRuntimeNode TProgramBuilder::BuildBlockCombineHashed(const std::string_view& callableName, TRuntimeNode input, std::optional<ui32> filterColumn,
  4755. const TArrayRef<ui32>& keys, const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4756. const auto inputType = input.GetStaticType();
  4757. MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type");
  4758. MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type");
  4759. TCallableBuilder builder(Env, callableName, returnType);
  4760. builder.Add(input);
  4761. if (!filterColumn) {
  4762. builder.Add(NewEmptyOptionalDataLiteral(NUdf::TDataType<ui32>::Id));
  4763. } else {
  4764. builder.Add(NewOptional(NewDataLiteral<ui32>(*filterColumn)));
  4765. }
  4766. TVector<TRuntimeNode> keyNodes;
  4767. for (const auto& key : keys) {
  4768. keyNodes.push_back(NewDataLiteral<ui32>(key));
  4769. }
  4770. builder.Add(NewTuple(keyNodes));
  4771. TVector<TRuntimeNode> aggsNodes;
  4772. for (const auto& agg : aggs) {
  4773. TVector<TRuntimeNode> params;
  4774. params.push_back(NewDataLiteral<NUdf::EDataSlot::String>(agg.Name));
  4775. for (const auto& col : agg.ArgsColumns) {
  4776. params.push_back(NewDataLiteral<ui32>(col));
  4777. }
  4778. aggsNodes.push_back(NewTuple(params));
  4779. }
  4780. builder.Add(NewTuple(aggsNodes));
  4781. return TRuntimeNode(builder.Build(), false);
  4782. }
  4783. TRuntimeNode TProgramBuilder::BlockCombineHashed(TRuntimeNode stream, std::optional<ui32> filterColumn, const TArrayRef<ui32>& keys,
  4784. const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4785. if constexpr (RuntimeVersion < 31U) {
  4786. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4787. }
  4788. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type");
  4789. MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type");
  4790. if constexpr (RuntimeVersion < 52U) {
  4791. const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType());
  4792. return FromFlow(BuildBlockCombineHashed(__func__, ToFlow(stream), filterColumn, keys, aggs, flowReturnType));
  4793. } else {
  4794. return BuildBlockCombineHashed(__func__, stream, filterColumn, keys, aggs, returnType);
  4795. }
  4796. }
  4797. TRuntimeNode TProgramBuilder::BuildBlockMergeFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, const TArrayRef<ui32>& keys,
  4798. const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4799. const auto inputType = input.GetStaticType();
  4800. MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type");
  4801. MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type");
  4802. TCallableBuilder builder(Env, callableName, returnType);
  4803. builder.Add(input);
  4804. TVector<TRuntimeNode> keyNodes;
  4805. for (const auto& key : keys) {
  4806. keyNodes.push_back(NewDataLiteral<ui32>(key));
  4807. }
  4808. builder.Add(NewTuple(keyNodes));
  4809. TVector<TRuntimeNode> aggsNodes;
  4810. for (const auto& agg : aggs) {
  4811. TVector<TRuntimeNode> params;
  4812. params.push_back(NewDataLiteral<NUdf::EDataSlot::String>(agg.Name));
  4813. for (const auto& col : agg.ArgsColumns) {
  4814. params.push_back(NewDataLiteral<ui32>(col));
  4815. }
  4816. aggsNodes.push_back(NewTuple(params));
  4817. }
  4818. builder.Add(NewTuple(aggsNodes));
  4819. return TRuntimeNode(builder.Build(), false);
  4820. }
  4821. TRuntimeNode TProgramBuilder::BlockMergeFinalizeHashed(TRuntimeNode stream, const TArrayRef<ui32>& keys,
  4822. const TArrayRef<const TAggInfo>& aggs, TType* returnType) {
  4823. if constexpr (RuntimeVersion < 31U) {
  4824. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4825. }
  4826. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type");
  4827. MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type");
  4828. if constexpr (RuntimeVersion < 52U) {
  4829. const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType());
  4830. return FromFlow(BuildBlockMergeFinalizeHashed(__func__, ToFlow(stream), keys, aggs, flowReturnType));
  4831. } else {
  4832. return BuildBlockMergeFinalizeHashed(__func__, stream, keys, aggs, returnType);
  4833. }
  4834. }
  4835. TRuntimeNode TProgramBuilder::BuildBlockMergeManyFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, const TArrayRef<ui32>& keys,
  4836. const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType) {
  4837. const auto inputType = input.GetStaticType();
  4838. MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type");
  4839. MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type");
  4840. TCallableBuilder builder(Env, callableName, returnType);
  4841. builder.Add(input);
  4842. TVector<TRuntimeNode> keyNodes;
  4843. for (const auto& key : keys) {
  4844. keyNodes.push_back(NewDataLiteral<ui32>(key));
  4845. }
  4846. builder.Add(NewTuple(keyNodes));
  4847. TVector<TRuntimeNode> aggsNodes;
  4848. for (const auto& agg : aggs) {
  4849. TVector<TRuntimeNode> params;
  4850. params.push_back(NewDataLiteral<NUdf::EDataSlot::String>(agg.Name));
  4851. for (const auto& col : agg.ArgsColumns) {
  4852. params.push_back(NewDataLiteral<ui32>(col));
  4853. }
  4854. aggsNodes.push_back(NewTuple(params));
  4855. }
  4856. builder.Add(NewTuple(aggsNodes));
  4857. builder.Add(NewDataLiteral<ui32>(streamIndex));
  4858. TVector<TRuntimeNode> streamsNodes;
  4859. for (const auto& s : streams) {
  4860. TVector<TRuntimeNode> streamNodes;
  4861. for (const auto& i : s) {
  4862. streamNodes.push_back(NewDataLiteral<ui32>(i));
  4863. }
  4864. streamsNodes.push_back(NewTuple(streamNodes));
  4865. }
  4866. builder.Add(NewTuple(streamsNodes));
  4867. return TRuntimeNode(builder.Build(), false);
  4868. }
  4869. TRuntimeNode TProgramBuilder::BlockMergeManyFinalizeHashed(TRuntimeNode stream, const TArrayRef<ui32>& keys,
  4870. const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType) {
  4871. if constexpr (RuntimeVersion < 31U) {
  4872. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4873. }
  4874. MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type");
  4875. MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type");
  4876. if constexpr (RuntimeVersion < 52U) {
  4877. const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType());
  4878. return FromFlow(BuildBlockMergeManyFinalizeHashed(__func__, ToFlow(stream), keys, aggs, streamIndex, streams, flowReturnType));
  4879. } else {
  4880. return BuildBlockMergeManyFinalizeHashed(__func__, stream, keys, aggs, streamIndex, streams, returnType);
  4881. }
  4882. }
  4883. TRuntimeNode TProgramBuilder::ScalarApply(const TArrayRef<const TRuntimeNode>& args, const TArrayLambda& handler) {
  4884. if constexpr (RuntimeVersion < 39U) {
  4885. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4886. }
  4887. MKQL_ENSURE(!args.empty(), "Required at least one argument");
  4888. TVector<TRuntimeNode> lambdaArgs;
  4889. bool scalarOnly = true;
  4890. std::shared_ptr<arrow::DataType> arrowType;
  4891. for (const auto& arg : args) {
  4892. auto blockType = AS_TYPE(TBlockType, arg.GetStaticType());
  4893. scalarOnly = scalarOnly && blockType->GetShape() == TBlockType::EShape::Scalar;
  4894. MKQL_ENSURE(ConvertArrowType(blockType->GetItemType(), arrowType), "Unsupported arrow type");
  4895. lambdaArgs.emplace_back(Arg(blockType->GetItemType()));
  4896. }
  4897. auto ret = handler(lambdaArgs);
  4898. MKQL_ENSURE(ConvertArrowType(ret.GetStaticType(), arrowType), "Unsupported arrow type");
  4899. auto returnType = NewBlockType(ret.GetStaticType(), scalarOnly ? TBlockType::EShape::Scalar : TBlockType::EShape::Many);
  4900. TCallableBuilder builder(Env, __func__, returnType);
  4901. for (const auto& arg : args) {
  4902. builder.Add(arg);
  4903. }
  4904. for (const auto& arg : lambdaArgs) {
  4905. builder.Add(arg);
  4906. }
  4907. builder.Add(ret);
  4908. return TRuntimeNode(builder.Build(), false);
  4909. }
  4910. TRuntimeNode TProgramBuilder::BlockStorage(TRuntimeNode stream, TType* returnType) {
  4911. if constexpr (RuntimeVersion < 59U) {
  4912. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4913. }
  4914. ValidateBlockStreamType(stream.GetStaticType());
  4915. MKQL_ENSURE(returnType->IsResource(), "Expected Resource as a result type");
  4916. auto returnResourceType = AS_TYPE(TResourceType, returnType);
  4917. MKQL_ENSURE(returnResourceType->GetTag().StartsWith(BlockStorageResourcePrefix), "Expected block storage resource");
  4918. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4919. callableBuilder.Add(stream);
  4920. return TRuntimeNode(callableBuilder.Build(), false);
  4921. }
  4922. TRuntimeNode TProgramBuilder::BlockMapJoinIndex(TRuntimeNode blockStorage, TType* streamItemType, const TArrayRef<const ui32>& keyColumns, bool any, TType* returnType) {
  4923. if constexpr (RuntimeVersion < 59U) {
  4924. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4925. }
  4926. MKQL_ENSURE(blockStorage.GetStaticType()->IsResource(), "Expected Resource as an input type");
  4927. auto blockStorageType = AS_TYPE(TResourceType, blockStorage.GetStaticType());
  4928. MKQL_ENSURE(blockStorageType->GetTag().StartsWith(BlockStorageResourcePrefix), "Expected block storage resource");
  4929. ValidateBlockType(streamItemType);
  4930. MKQL_ENSURE(returnType->IsResource(), "Expected Resource as a result type");
  4931. auto returnResourceType = AS_TYPE(TResourceType, returnType);
  4932. MKQL_ENSURE(returnResourceType->GetTag().StartsWith(BlockMapJoinIndexResourcePrefix), "Expected block map join index resource");
  4933. TRuntimeNode::TList keyColumnsNodes;
  4934. keyColumnsNodes.reserve(keyColumns.size());
  4935. std::transform(keyColumns.cbegin(), keyColumns.cend(),
  4936. std::back_inserter(keyColumnsNodes), [this](const ui32 idx) {
  4937. return NewDataLiteral(idx);
  4938. });
  4939. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4940. callableBuilder.Add(blockStorage);
  4941. callableBuilder.Add(TRuntimeNode(streamItemType, true));
  4942. callableBuilder.Add(NewTuple(keyColumnsNodes));
  4943. callableBuilder.Add(NewDataLiteral((bool)any));
  4944. return TRuntimeNode(callableBuilder.Build(), false);
  4945. }
  4946. TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode leftStream, TRuntimeNode rightBlockStorage, TType* rightStreamItemType, EJoinKind joinKind,
  4947. const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& leftKeyDrops,
  4948. const TArrayRef<const ui32>& rightKeyColumns, const TArrayRef<const ui32>& rightKeyDrops, TType* returnType
  4949. ) {
  4950. if constexpr (RuntimeVersion < 59U) {
  4951. THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
  4952. }
  4953. MKQL_ENSURE(rightBlockStorage.GetStaticType()->IsResource(), "Expected Resource as an input type");
  4954. auto rightBlockStorageType = AS_TYPE(TResourceType, rightBlockStorage.GetStaticType());
  4955. if (joinKind != EJoinKind::Cross) {
  4956. MKQL_ENSURE(rightBlockStorageType->GetTag().StartsWith(BlockMapJoinIndexResourcePrefix), "Expected block map join index resource");
  4957. } else {
  4958. MKQL_ENSURE(rightBlockStorageType->GetTag().StartsWith(BlockStorageResourcePrefix), "Expected block storage resource");
  4959. }
  4960. MKQL_ENSURE(joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left ||
  4961. joinKind == EJoinKind::LeftSemi || joinKind == EJoinKind::LeftOnly || joinKind == EJoinKind::Cross,
  4962. "Unsupported join kind");
  4963. MKQL_ENSURE(leftKeyColumns.size() == rightKeyColumns.size(), "Key column count mismatch");
  4964. if (joinKind != EJoinKind::Cross) {
  4965. MKQL_ENSURE(!leftKeyColumns.empty(), "At least one key column must be specified");
  4966. } else {
  4967. MKQL_ENSURE(leftKeyColumns.empty(), "Specifying key columns is not allowed for cross join");
  4968. }
  4969. ValidateBlockStreamType(leftStream.GetStaticType());
  4970. ValidateBlockType(rightStreamItemType);
  4971. ValidateBlockStreamType(returnType);
  4972. TRuntimeNode::TList leftKeyColumnsNodes;
  4973. leftKeyColumnsNodes.reserve(leftKeyColumns.size());
  4974. std::transform(leftKeyColumns.cbegin(), leftKeyColumns.cend(),
  4975. std::back_inserter(leftKeyColumnsNodes), [this](const ui32 idx) {
  4976. return NewDataLiteral(idx);
  4977. });
  4978. TRuntimeNode::TList leftKeyDropsNodes;
  4979. leftKeyDropsNodes.reserve(leftKeyDrops.size());
  4980. std::transform(leftKeyDrops.cbegin(), leftKeyDrops.cend(),
  4981. std::back_inserter(leftKeyDropsNodes), [this](const ui32 idx) {
  4982. return NewDataLiteral(idx);
  4983. });
  4984. TRuntimeNode::TList rightKeyColumnsNodes;
  4985. rightKeyColumnsNodes.reserve(rightKeyColumns.size());
  4986. std::transform(rightKeyColumns.cbegin(), rightKeyColumns.cend(),
  4987. std::back_inserter(rightKeyColumnsNodes), [this](const ui32 idx) {
  4988. return NewDataLiteral(idx);
  4989. });
  4990. TRuntimeNode::TList rightKeyDropsNodes;
  4991. rightKeyDropsNodes.reserve(leftKeyDrops.size());
  4992. std::transform(rightKeyDrops.cbegin(), rightKeyDrops.cend(),
  4993. std::back_inserter(rightKeyDropsNodes), [this](const ui32 idx) {
  4994. return NewDataLiteral(idx);
  4995. });
  4996. TCallableBuilder callableBuilder(Env, __func__, returnType);
  4997. callableBuilder.Add(leftStream);
  4998. callableBuilder.Add(rightBlockStorage);
  4999. callableBuilder.Add(TRuntimeNode(rightStreamItemType, true));
  5000. callableBuilder.Add(NewDataLiteral((ui32)joinKind));
  5001. callableBuilder.Add(NewTuple(leftKeyColumnsNodes));
  5002. callableBuilder.Add(NewTuple(leftKeyDropsNodes));
  5003. callableBuilder.Add(NewTuple(rightKeyColumnsNodes));
  5004. callableBuilder.Add(NewTuple(rightKeyDropsNodes));
  5005. return TRuntimeNode(callableBuilder.Build(), false);
  5006. }
  5007. namespace {
  5008. using namespace NYql::NMatchRecognize;
  5009. TRuntimeNode PatternToRuntimeNode(const TRowPattern& pattern, const TProgramBuilder& programBuilder) {
  5010. const auto& env = programBuilder.GetTypeEnvironment();
  5011. TTupleLiteralBuilder patternBuilder(env);
  5012. for (const auto& term: pattern) {
  5013. TTupleLiteralBuilder termBuilder(env);
  5014. for (const auto& factor: term) {
  5015. TTupleLiteralBuilder factorBuilder(env);
  5016. factorBuilder.Add(std::visit(TOverloaded {
  5017. [&](const TString& s) {
  5018. return programBuilder.NewDataLiteral<NUdf::EDataSlot::String>(s);
  5019. },
  5020. [&](const TRowPattern& pattern) {
  5021. return PatternToRuntimeNode(pattern, programBuilder);
  5022. },
  5023. }, factor.Primary));
  5024. factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMin));
  5025. factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMax));
  5026. factorBuilder.Add(programBuilder.NewDataLiteral(factor.Greedy));
  5027. factorBuilder.Add(programBuilder.NewDataLiteral(factor.Output));
  5028. factorBuilder.Add(programBuilder.NewDataLiteral(factor.Unused));
  5029. termBuilder.Add({factorBuilder.Build(), true});
  5030. }
  5031. patternBuilder.Add({termBuilder.Build(), true});
  5032. }
  5033. return {patternBuilder.Build(), true};
  5034. };
  5035. } //namespace
  5036. TRuntimeNode TProgramBuilder::MatchRecognizeCore(
  5037. TRuntimeNode inputStream,
  5038. const TUnaryLambda& getPartitionKeySelectorNode,
  5039. const TArrayRef<TStringBuf>& partitionColumnNames,
  5040. const TVector<TStringBuf>& measureColumnNames,
  5041. const TVector<TBinaryLambda>& getMeasures,
  5042. const NYql::NMatchRecognize::TRowPattern& pattern,
  5043. const TVector<TStringBuf>& defineVarNames,
  5044. const TVector<TTernaryLambda>& getDefines,
  5045. bool streamingMode,
  5046. const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo,
  5047. NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch
  5048. ) {
  5049. MKQL_ENSURE(RuntimeVersion >= 42, "MatchRecognize is not supported in runtime version " << RuntimeVersion);
  5050. const auto inputRowType = AS_TYPE(TStructType, AS_TYPE(TFlowType, inputStream.GetStaticType())->GetItemType());
  5051. const auto inputRowArg = Arg(inputRowType);
  5052. const auto partitionKeySelectorNode = getPartitionKeySelectorNode(inputRowArg);
  5053. const auto partitionColumnTypes = AS_TYPE(TTupleType, partitionKeySelectorNode.GetStaticType())->GetElements();
  5054. const auto rangeList = NewListType(NewStructType({
  5055. {"From", NewDataType(NUdf::EDataSlot::Uint64)},
  5056. {"To", NewDataType(NUdf::EDataSlot::Uint64)}
  5057. }));
  5058. TStructTypeBuilder matchedVarsTypeBuilder(Env);
  5059. for (const auto& var: GetPatternVars(pattern)) {
  5060. matchedVarsTypeBuilder.Add(var, rangeList);
  5061. }
  5062. const auto matchedVarsType = matchedVarsTypeBuilder.Build();
  5063. TRuntimeNode matchedVarsArg = Arg(matchedVarsType);
  5064. //---These vars may be empty in case of no measures
  5065. TRuntimeNode measureInputDataArg;
  5066. std::vector<TRuntimeNode> specialColumnIndexesInMeasureInputDataRow;
  5067. TVector<TRuntimeNode> measures;
  5068. //---
  5069. if (getMeasures.empty()) {
  5070. measureInputDataArg = Arg(Env.GetTypeOfVoidLazy());
  5071. } else {
  5072. measures.reserve(getMeasures.size());
  5073. specialColumnIndexesInMeasureInputDataRow.resize(static_cast<size_t>(NYql::NMatchRecognize::EMeasureInputDataSpecialColumns::Last));
  5074. TStructTypeBuilder measureInputDataRowTypeBuilder(Env);
  5075. for (ui32 i = 0; i < inputRowType->GetMembersCount(); ++i) {
  5076. measureInputDataRowTypeBuilder.Add(inputRowType->GetMemberName(i), inputRowType->GetMemberType(i));
  5077. }
  5078. measureInputDataRowTypeBuilder.Add(
  5079. MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::Classifier),
  5080. NewDataType(NUdf::EDataSlot::Utf8)
  5081. );
  5082. measureInputDataRowTypeBuilder.Add(
  5083. MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::MatchNumber),
  5084. NewDataType(NUdf::EDataSlot::Uint64)
  5085. );
  5086. const auto measureInputDataRowType = measureInputDataRowTypeBuilder.Build();
  5087. for (ui32 i = 0; i < measureInputDataRowType->GetMembersCount(); ++i) {
  5088. //assume a few, if grows, it's better to use a lookup table here
  5089. static_assert(static_cast<size_t>(EMeasureInputDataSpecialColumns::Last) < 5);
  5090. for (size_t j = 0; j != static_cast<size_t>(EMeasureInputDataSpecialColumns::Last); ++j) {
  5091. if (measureInputDataRowType->GetMemberName(i) ==
  5092. NYql::NMatchRecognize::MeasureInputDataSpecialColumnName(static_cast<EMeasureInputDataSpecialColumns>(j)))
  5093. specialColumnIndexesInMeasureInputDataRow[j] = NewDataLiteral(i);
  5094. }
  5095. }
  5096. measureInputDataArg = Arg(NewListType(measureInputDataRowType));
  5097. for (size_t i = 0; i != getMeasures.size(); ++i) {
  5098. measures.push_back(getMeasures[i](measureInputDataArg, matchedVarsArg));
  5099. }
  5100. }
  5101. TStructTypeBuilder outputRowTypeBuilder(Env);
  5102. THashMap<TStringBuf, size_t> partitionColumnLookup;
  5103. THashMap<TStringBuf, size_t> measureColumnLookup;
  5104. THashMap<TStringBuf, size_t> otherColumnLookup;
  5105. for (size_t i = 0; i < measureColumnNames.size(); ++i) {
  5106. const auto name = measureColumnNames[i];
  5107. measureColumnLookup.emplace(name, i);
  5108. outputRowTypeBuilder.Add(name, measures[i].GetStaticType());
  5109. }
  5110. switch (rowsPerMatch) {
  5111. case NYql::NMatchRecognize::ERowsPerMatch::OneRow:
  5112. for (size_t i = 0; i < partitionColumnNames.size(); ++i) {
  5113. const auto name = partitionColumnNames[i];
  5114. partitionColumnLookup.emplace(name, i);
  5115. outputRowTypeBuilder.Add(name, partitionColumnTypes[i]);
  5116. }
  5117. break;
  5118. case NYql::NMatchRecognize::ERowsPerMatch::AllRows:
  5119. for (size_t i = 0; i < inputRowType->GetMembersCount(); ++i) {
  5120. const auto name = inputRowType->GetMemberName(i);
  5121. otherColumnLookup.emplace(name, i);
  5122. outputRowTypeBuilder.Add(name, inputRowType->GetMemberType(i));
  5123. }
  5124. break;
  5125. }
  5126. auto outputRowType = outputRowTypeBuilder.Build();
  5127. std::vector<TRuntimeNode> partitionColumnIndexes(partitionColumnLookup.size());
  5128. std::vector<TRuntimeNode> measureColumnIndexes(measureColumnLookup.size());
  5129. TVector<TRuntimeNode> outputColumnOrder(NDetail::TReserveTag{outputRowType->GetMembersCount()});
  5130. for (ui32 i = 0; i < outputRowType->GetMembersCount(); ++i) {
  5131. const auto name = outputRowType->GetMemberName(i);
  5132. if (auto iter = partitionColumnLookup.find(name);
  5133. iter != partitionColumnLookup.end()) {
  5134. partitionColumnIndexes[iter->second] = NewDataLiteral(i);
  5135. outputColumnOrder.push_back(NewStruct({
  5136. std::pair{"Index", NewDataLiteral(iter->second)},
  5137. std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::PartitionKey))},
  5138. }));
  5139. } else if (auto iter = measureColumnLookup.find(name);
  5140. iter != measureColumnLookup.end()) {
  5141. measureColumnIndexes[iter->second] = NewDataLiteral(i);
  5142. outputColumnOrder.push_back(NewStruct({
  5143. std::pair{"Index", NewDataLiteral(iter->second)},
  5144. std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Measure))},
  5145. }));
  5146. } else if (auto iter = otherColumnLookup.find(name);
  5147. iter != otherColumnLookup.end()) {
  5148. outputColumnOrder.push_back(NewStruct({
  5149. std::pair{"Index", NewDataLiteral(iter->second)},
  5150. std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Other))},
  5151. }));
  5152. }
  5153. }
  5154. const auto outputType = NewFlowType(outputRowType);
  5155. THashMap<TStringBuf, size_t> patternVarLookup;
  5156. for (ui32 i = 0; i < matchedVarsType->GetMembersCount(); ++i) {
  5157. patternVarLookup[matchedVarsType->GetMemberName(i)] = i;
  5158. }
  5159. THashMap<TStringBuf, size_t> defineLookup;
  5160. for (size_t i = 0; i < defineVarNames.size(); ++i) {
  5161. const auto name = defineVarNames[i];
  5162. defineLookup[name] = i;
  5163. }
  5164. TVector<TRuntimeNode> defineNames(patternVarLookup.size());
  5165. TVector<TRuntimeNode> defineNodes(patternVarLookup.size());
  5166. const auto inputDataArg = Arg(NewListType(inputRowType));
  5167. const auto currentRowIndexArg = Arg(NewDataType(NUdf::EDataSlot::Uint64));
  5168. for (const auto& [v, i]: patternVarLookup) {
  5169. defineNames[i] = NewDataLiteral<NUdf::EDataSlot::String>(v);
  5170. if (auto iter = defineLookup.find(v);
  5171. iter != defineLookup.end()) {
  5172. defineNodes[i] = getDefines[iter->second](inputDataArg, matchedVarsArg, currentRowIndexArg);
  5173. } else if ("$" == v || "^" == v) {
  5174. //DO nothing, //will be handled in a specific way
  5175. } else { // a var without a predicate matches any row
  5176. defineNodes[i] = NewDataLiteral(true);
  5177. }
  5178. }
  5179. TCallableBuilder callableBuilder(GetTypeEnvironment(), "MatchRecognizeCore", outputType);
  5180. const auto indexType = NewDataType(NUdf::EDataSlot::Uint32);
  5181. const auto outputColumnEntryType = NewStructType({
  5182. {"Index", NewDataType(NUdf::EDataSlot::Uint64)},
  5183. {"SourceType", NewDataType(NUdf::EDataSlot::Int32)},
  5184. });
  5185. callableBuilder.Add(inputStream);
  5186. callableBuilder.Add(inputRowArg);
  5187. callableBuilder.Add(partitionKeySelectorNode);
  5188. callableBuilder.Add(NewList(indexType, partitionColumnIndexes));
  5189. callableBuilder.Add(measureInputDataArg);
  5190. callableBuilder.Add(NewList(indexType, specialColumnIndexesInMeasureInputDataRow));
  5191. callableBuilder.Add(NewDataLiteral(inputRowType->GetMembersCount()));
  5192. callableBuilder.Add(matchedVarsArg);
  5193. callableBuilder.Add(NewList(indexType, measureColumnIndexes));
  5194. for (const auto& m: measures) {
  5195. callableBuilder.Add(m);
  5196. }
  5197. callableBuilder.Add(PatternToRuntimeNode(pattern, *this));
  5198. callableBuilder.Add(currentRowIndexArg);
  5199. callableBuilder.Add(inputDataArg);
  5200. callableBuilder.Add(NewList(NewDataType(NUdf::EDataSlot::String), defineNames));
  5201. for (const auto& d: defineNodes) {
  5202. callableBuilder.Add(d);
  5203. }
  5204. callableBuilder.Add(NewDataLiteral(streamingMode));
  5205. if constexpr (RuntimeVersion >= 52U) {
  5206. callableBuilder.Add(NewDataLiteral(static_cast<i32>(skipTo.To)));
  5207. callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(skipTo.Var));
  5208. }
  5209. if constexpr (RuntimeVersion >= 54U) {
  5210. callableBuilder.Add(NewDataLiteral(static_cast<i32>(rowsPerMatch)));
  5211. callableBuilder.Add(NewList(outputColumnEntryType, outputColumnOrder));
  5212. }
  5213. return TRuntimeNode(callableBuilder.Build(), false);
  5214. }
  5215. TRuntimeNode TProgramBuilder::TimeOrderRecover(
  5216. TRuntimeNode inputStream,
  5217. const TUnaryLambda& getTimeExtractor,
  5218. TRuntimeNode delay,
  5219. TRuntimeNode ahead,
  5220. TRuntimeNode rowLimit
  5221. )
  5222. {
  5223. MKQL_ENSURE(RuntimeVersion >= 44, "TimeOrderRecover is not supported in runtime version " << RuntimeVersion);
  5224. auto& inputRowType = *static_cast<TStructType*>(AS_TYPE(TStructType, AS_TYPE(TFlowType, inputStream.GetStaticType())->GetItemType()));
  5225. const auto inputRowArg = Arg(&inputRowType);
  5226. TStructTypeBuilder outputRowTypeBuilder(Env);
  5227. outputRowTypeBuilder.Reserve(inputRowType.GetMembersCount() + 1);
  5228. const ui32 inputRowColumnCount = inputRowType.GetMembersCount();
  5229. for (ui32 i = 0; i != inputRowColumnCount; ++i) {
  5230. outputRowTypeBuilder.Add(inputRowType.GetMemberName(i), inputRowType.GetMemberType(i));
  5231. }
  5232. using NYql::NTimeOrderRecover::OUT_OF_ORDER_MARKER;
  5233. outputRowTypeBuilder.Add(OUT_OF_ORDER_MARKER, TDataType::Create(NUdf::TDataType<bool>::Id, Env));
  5234. const auto outputRowType = outputRowTypeBuilder.Build();
  5235. const auto outOfOrderColumnIndex = outputRowType->GetMemberIndex(OUT_OF_ORDER_MARKER);
  5236. TCallableBuilder callableBuilder(GetTypeEnvironment(), "TimeOrderRecover", TFlowType::Create(outputRowType, Env));
  5237. callableBuilder.Add(inputStream);
  5238. callableBuilder.Add(inputRowArg);
  5239. callableBuilder.Add(getTimeExtractor(inputRowArg));
  5240. callableBuilder.Add(NewDataLiteral(inputRowColumnCount));
  5241. callableBuilder.Add(NewDataLiteral(outOfOrderColumnIndex));
  5242. callableBuilder.Add(delay),
  5243. callableBuilder.Add(ahead),
  5244. callableBuilder.Add(rowLimit);
  5245. return TRuntimeNode(callableBuilder.Build(), false);
  5246. }
  5247. bool CanExportType(TType* type, const TTypeEnvironment& env) {
  5248. if (type->GetKind() == TType::EKind::Type) {
  5249. return false; // Type of Type
  5250. }
  5251. TExploringNodeVisitor explorer;
  5252. explorer.Walk(type, env);
  5253. bool canExport = true;
  5254. for (auto& node : explorer.GetNodes()) {
  5255. switch (static_cast<TType*>(node)->GetKind()) {
  5256. case TType::EKind::Void:
  5257. node->SetCookie(1);
  5258. break;
  5259. case TType::EKind::Data:
  5260. node->SetCookie(1);
  5261. break;
  5262. case TType::EKind::Pg:
  5263. node->SetCookie(1);
  5264. break;
  5265. case TType::EKind::Optional: {
  5266. auto optionalType = static_cast<TOptionalType*>(node);
  5267. if (!optionalType->GetItemType()->GetCookie()) {
  5268. canExport = false;
  5269. } else {
  5270. node->SetCookie(1);
  5271. }
  5272. break;
  5273. }
  5274. case TType::EKind::List: {
  5275. auto listType = static_cast<TListType*>(node);
  5276. if (!listType->GetItemType()->GetCookie()) {
  5277. canExport = false;
  5278. } else {
  5279. node->SetCookie(1);
  5280. }
  5281. break;
  5282. }
  5283. case TType::EKind::Struct: {
  5284. auto structType = static_cast<TStructType*>(node);
  5285. for (ui32 index = 0; index < structType->GetMembersCount(); ++index) {
  5286. if (!structType->GetMemberType(index)->GetCookie()) {
  5287. canExport = false;
  5288. break;
  5289. }
  5290. }
  5291. if (canExport) {
  5292. node->SetCookie(1);
  5293. }
  5294. break;
  5295. }
  5296. case TType::EKind::Tuple: {
  5297. auto tupleType = static_cast<TTupleType*>(node);
  5298. for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) {
  5299. if (!tupleType->GetElementType(index)->GetCookie()) {
  5300. canExport = false;
  5301. break;
  5302. }
  5303. }
  5304. if (canExport) {
  5305. node->SetCookie(1);
  5306. }
  5307. break;
  5308. }
  5309. case TType::EKind::Dict: {
  5310. auto dictType = static_cast<TDictType*>(node);
  5311. if (!dictType->GetKeyType()->GetCookie() || !dictType->GetPayloadType()->GetCookie()) {
  5312. canExport = false;
  5313. } else {
  5314. node->SetCookie(1);
  5315. }
  5316. break;
  5317. }
  5318. case TType::EKind::Variant: {
  5319. auto variantType = static_cast<TVariantType*>(node);
  5320. TType* innerType = variantType->GetUnderlyingType();
  5321. if (innerType->IsStruct()) {
  5322. auto structType = static_cast<TStructType*>(innerType);
  5323. for (ui32 index = 0; index < structType->GetMembersCount(); ++index) {
  5324. if (!structType->GetMemberType(index)->GetCookie()) {
  5325. canExport = false;
  5326. break;
  5327. }
  5328. }
  5329. }
  5330. if (innerType->IsTuple()) {
  5331. auto tupleType = static_cast<TTupleType*>(innerType);
  5332. for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) {
  5333. if (!tupleType->GetElementType(index)->GetCookie()) {
  5334. canExport = false;
  5335. break;
  5336. }
  5337. }
  5338. }
  5339. if (canExport) {
  5340. node->SetCookie(1);
  5341. }
  5342. break;
  5343. }
  5344. case TType::EKind::Type:
  5345. break;
  5346. default:
  5347. canExport = false;
  5348. }
  5349. if (!canExport) {
  5350. break;
  5351. }
  5352. }
  5353. for (auto& node : explorer.GetNodes()) {
  5354. node->SetCookie(0);
  5355. }
  5356. return canExport;
  5357. }
  5358. }
  5359. }