Browse Source

ListSample/ListSampleN/ListShuffle implementation
commit_hash:987b10b398caa89eee8b94b33f9ea1dc74197223

ziganshinmr 4 days ago
parent
commit
c320ff3884

+ 81 - 0
yql/essentials/core/common_opt/yql_co_simple1.cpp

@@ -3675,6 +3675,28 @@ bool IsEarlyExpandOfSkipNullAllowed(const TOptimizeContext& optCtx) {
     return optCtx.Types->OptimizerFlags.contains(skipNullFlags);
 }
 
+TExprNode::TPtr ReplaceFuncWithImpl(const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx) {
+    auto exportsPtr = optCtx.Types->Modules->GetModule("/lib/yql/core.yql");
+    YQL_ENSURE(exportsPtr);
+    const auto& exports = exportsPtr->Symbols();
+    const auto ex = exports.find(TString(node->Content()) + "Impl");
+    YQL_ENSURE(exports.cend() != ex);
+    TNodeOnNodeOwnedMap deepClones;
+    auto lambda = ctx.DeepCopy(*ex->second, exportsPtr->ExprCtx(), deepClones, true, false);
+
+    YQL_CLOG(DEBUG, Core) << "Replace " << node->Content() << " with implementation";
+    return ctx.Builder(node->Pos())
+        .Apply(lambda)
+            .Do([&node](TExprNodeReplaceBuilder& builder) -> TExprNodeReplaceBuilder& {
+                for (size_t i = 0; i < node->ChildrenSize(); i++) {
+                    builder.With(i, node->ChildPtr(i));
+                }
+                return builder;
+            })
+        .Seal()
+        .Build();
+}
+
 } // namespace
 
 void RegisterCoSimpleCallables1(TCallableOptimizerMap& map) {
@@ -4897,6 +4919,65 @@ void RegisterCoSimpleCallables1(TCallableOptimizerMap& map) {
         return node;
     };
 
+    map["ListSample"] = map["ListSampleN"] = [](const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx) {
+        if (node->Child(0)->GetTypeAnn()->GetKind() == ETypeAnnotationKind::Optional) {
+            YQL_CLOG(DEBUG, Core) << "Handle optional list in " << node->Content();
+            return ctx.Builder(node->Pos())
+                .Callable("Map")
+                    .Add(0, node->Child(0))
+                    .Lambda(1)
+                        .Param("list")
+                        .Callable(node->Content())
+                            .Arg(0, "list")
+                            .Add(1, node->Child(1))
+                            .Add(2, node->Child(2))
+                        .Seal()
+                    .Seal()
+                .Seal()
+                .Build();
+        }
+
+        if (node->Child(1)->GetTypeAnn()->GetKind() == ETypeAnnotationKind::Optional) {
+            YQL_CLOG(DEBUG, Core) << "Handle optional prob arg in " << node->Content();
+            return ctx.Builder(node->Pos())
+                .Callable("IfPresent")
+                    .Add(0, node->Child(1))
+                    .Lambda(1)
+                        .Param("probArg")
+                        .Callable(node->Content())
+                            .Add(0, node->Child(0))
+                            .Arg(1, "probArg")
+                            .Add(2, node->Child(2))
+                        .Seal()
+                    .Seal()
+                    .Add(2, node->Child(0))
+                .Seal()
+                .Build();
+        }
+
+        return ReplaceFuncWithImpl(node, ctx, optCtx);
+    };
+
+    map["ListShuffle"] = [](const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx) {
+        if (node->Head().GetTypeAnn()->GetKind() == ETypeAnnotationKind::Optional) {
+            YQL_CLOG(DEBUG, Core) << "Handle optionals args in " << node->Content();
+            return ctx.Builder(node->Pos())
+                .Callable("Map")
+                    .Add(0, node->Child(0))
+                    .Lambda(1)
+                        .Param("list")
+                        .Callable(node->Content())
+                            .Arg(0, "list")
+                            .Add(1, node->Child(1))
+                        .Seal()
+                    .Seal()
+                .Seal()
+                .Build();
+        }
+
+        return ReplaceFuncWithImpl(node, ctx, optCtx);
+    };
+
     map["OptionalReduce"] = std::bind(&RemoveOptionalReduceOverData, _1, _2);
 
     map["Fold"] = [](const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& /*optCtx*/) {

+ 3 - 0
yql/essentials/core/type_ann/type_ann_core.cpp

@@ -12594,6 +12594,9 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
         Functions["ListTopSort"] = &ListTopSortWrapper;
         Functions["ListTopSortAsc"] = &ListTopSortWrapper;
         Functions["ListTopSortDesc"] = &ListTopSortWrapper;
+        Functions["ListSample"] = &ListSampleWrapper;
+        Functions["ListSampleN"] = &ListSampleNWrapper;
+        Functions["ListShuffle"] = &ListShuffleWrapper;
 
         Functions["ExpandMap"] = &ExpandMapWrapper;
         Functions["WideMap"] = &WideMapWrapper;

+ 109 - 0
yql/essentials/core/type_ann/type_ann_list.cpp

@@ -1524,6 +1524,115 @@ namespace {
         return OptListWrapperImpl<1U>(input, output, ctx, "Collect");
     }
 
+    IGraphTransformer::TStatus ListSampleWrapperCommon(const TExprNode::TPtr& input, TExprNode::TPtr& output, NUdf::EDataSlot probArgDataType, TContext& ctx) {
+        if (!EnsureMinMaxArgsCount(*input, 2, 3, ctx.Expr)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        if (IsNull(input->Head())) {
+            output = input->HeadPtr();
+            return IGraphTransformer::TStatus::Repeat;
+        }
+
+        if (!EnsureComputable(input->Head(), ctx.Expr)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        auto type = input->Head().GetTypeAnn();
+        if (type->GetKind() == ETypeAnnotationKind::Optional) {
+            type = type->Cast<TOptionalExprType>()->GetItemType();
+        }
+
+        if (type->GetKind() != ETypeAnnotationKind::List && type->GetKind() != ETypeAnnotationKind::EmptyList) {
+            ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Head().Pos()), TStringBuilder()
+                << "Expected (empty) list or optional of (empty) list, but got: " << *input->Head().GetTypeAnn()));
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        if (type->GetKind() == ETypeAnnotationKind::EmptyList) {
+            output = input->HeadPtr();
+            return IGraphTransformer::TStatus::Repeat;
+        }
+
+        if (IsNull(*input->Child(1))) {
+            output = input->HeadPtr();
+            return IGraphTransformer::TStatus::Repeat;
+        }
+
+        if (!EnsureSpecificDataType(*input->Child(1), probArgDataType, ctx.Expr, true)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        if (input->ChildrenSize() == 2) {
+            auto children = input->ChildrenList();
+            children.push_back(ctx.Expr.NewCallable(input->Pos(), "Null", {}));
+            output = ctx.Expr.ChangeChildren(*input, std::move(children));
+            return IGraphTransformer::TStatus::Repeat;
+        }
+        YQL_ENSURE(input->ChildrenSize() == 3);
+
+        if (!EnsureComputable(*input->Child(2), ctx.Expr)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        input->SetTypeAnn(input->Head().GetTypeAnn());
+        return IGraphTransformer::TStatus::Ok;
+    }
+
+    IGraphTransformer::TStatus ListSampleWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+        return ListSampleWrapperCommon(input, output, NUdf::EDataSlot::Double, ctx);
+    }
+
+    IGraphTransformer::TStatus ListSampleNWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+        return ListSampleWrapperCommon(input, output, NUdf::EDataSlot::Uint64, ctx);
+    }
+
+    IGraphTransformer::TStatus ListShuffleWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+        if (!EnsureMinMaxArgsCount(*input, 1, 2, ctx.Expr)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        if (IsNull(input->Head())) {
+            output = input->HeadPtr();
+            return IGraphTransformer::TStatus::Repeat;
+        }
+
+        if (!EnsureComputable(input->Head(), ctx.Expr)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        auto type = input->Head().GetTypeAnn();
+        if (type->GetKind() == ETypeAnnotationKind::Optional) {
+            type = type->Cast<TOptionalExprType>()->GetItemType();
+        }
+
+        if (type->GetKind() != ETypeAnnotationKind::List && type->GetKind() != ETypeAnnotationKind::EmptyList) {
+            ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Head().Pos()), TStringBuilder()
+                << "Expected (empty) list or optional of (empty) list, but got: " << *input->Head().GetTypeAnn()));
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        if (type->GetKind() == ETypeAnnotationKind::EmptyList) {
+            output = input->HeadPtr();
+            return IGraphTransformer::TStatus::Repeat;
+        }
+
+        if (input->ChildrenSize() == 1) {
+            auto children = input->ChildrenList();
+            children.push_back(ctx.Expr.NewCallable(input->Pos(), "Null", {}));
+            output = ctx.Expr.ChangeChildren(*input, std::move(children));
+            return IGraphTransformer::TStatus::Repeat;
+        }
+        YQL_ENSURE(input->ChildrenSize() == 2);
+
+        if (!EnsureComputable(*input->Child(1), ctx.Expr)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        input->SetTypeAnn(input->Head().GetTypeAnn());
+        return IGraphTransformer::TStatus::Ok;
+    }
+
     IGraphTransformer::TStatus OptListFold1WrapperImpl(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx, TExprNode::TPtr&& updateLambda) {
         if (IsNull(input->Head())) {
             output = input->HeadPtr();

+ 3 - 0
yql/essentials/core/type_ann/type_ann_list.h

@@ -41,6 +41,9 @@ namespace NTypeAnnImpl {
     IGraphTransformer::TStatus ListTopSortWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
     IGraphTransformer::TStatus ListExtractWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
     IGraphTransformer::TStatus ListCollectWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
+    IGraphTransformer::TStatus ListSampleWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
+    IGraphTransformer::TStatus ListSampleNWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
+    IGraphTransformer::TStatus ListShuffleWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
     IGraphTransformer::TStatus FoldMapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
     IGraphTransformer::TStatus Fold1MapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
     IGraphTransformer::TStatus Chain1MapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);

+ 53 - 0
yql/essentials/mount/lib/yql/core.yql

@@ -479,6 +479,56 @@ def signature(script, name):
         (lambda '() (Apply ListToTupleImpl list n)))
 ))
 
+(let ListSampleImpl (lambda '(list probability dependsOn)
+    (Filter list (lambda '(x) (< (Random (DependsOn '(x probability dependsOn))) probability)))
+))
+
+(let ListSampleNImpl (lambda '(list count dependsOn) (block '(
+    (let value_type (ListItemType (TypeOf list)))
+
+    (let UdfVectorCreate (Udf 'Vector.Create (Void) (TupleType (TupleType value_type (DataType 'Uint64)) (StructType) value_type)))
+
+    (let resource_type (TypeOf (Apply UdfVectorCreate (Uint32 '0))))
+
+    (let UdfVectorEmplace (Udf 'Vector.Emplace (Void) (TupleType (TupleType resource_type (DataType 'Uint64) value_type) (StructType) value_type)))
+    (let UdfVectorSwap (Udf 'Vector.Swap (Void) (TupleType (TupleType resource_type (DataType 'Uint64) (DataType 'Uint64)) (StructType) value_type)))
+    (let UdfVectorGetResult (Udf 'Vector.GetResult (Void) (TupleType (TupleType resource_type) (StructType) value_type)))
+
+    (return (Apply UdfVectorGetResult (Fold
+        (Skip (Enumerate list) count)
+        (Fold
+            (Take list count)
+            (NamedApply UdfVectorCreate '(count) (AsStruct) (DependsOn '(list dependsOn)))
+            (lambda '(x y) (Apply UdfVectorEmplace y count x))
+        )
+        (lambda '(x y) (block '(
+            (let pos (Coalesce (% (RandomNumber (DependsOn '(x count dependsOn))) (+ (Nth x '0) (Uint64 '1))) (Uint64 '0)))
+            (return (If (< pos count) (Apply UdfVectorEmplace y pos (Nth x '1)) y))
+        )))
+    )))
+))))
+
+(let ListShuffleImpl (lambda '(list dependsOn) (block '(
+    (let value_type (ListItemType (TypeOf list)))
+
+    (let UdfVectorCreate (Udf 'Vector.Create (Void) (TupleType (TupleType value_type (DataType 'Uint64)) (StructType) value_type)))
+
+    (let resource_type (TypeOf (Apply UdfVectorCreate (Uint32 '0))))
+
+    (let UdfVectorEmplace (Udf 'Vector.Emplace (Void) (TupleType (TupleType resource_type (DataType 'Uint64) value_type) (StructType) value_type)))
+    (let UdfVectorSwap (Udf 'Vector.Swap (Void) (TupleType (TupleType resource_type (DataType 'Uint64) (DataType 'Uint64)) (StructType) value_type)))
+    (let UdfVectorGetResult (Udf 'Vector.GetResult (Void) (TupleType (TupleType resource_type) (StructType) value_type)))
+
+    (return (Apply UdfVectorGetResult (Fold
+        (Enumerate list)
+        (NamedApply UdfVectorCreate '((Uint32 '1)) (AsStruct) (DependsOn '(list dependsOn)))
+        (lambda '(x y) (block '(
+            (let pos (Coalesce (% (RandomNumber (DependsOn '(x dependsOn))) (+ (Nth x '0) (Uint64 '1))) (Uint64 '0)))
+            (return (Apply UdfVectorSwap (Apply UdfVectorEmplace y (Nth x '0) (Nth x '1)) pos (Nth x '0)))
+        )))
+    )))
+))))
+
 (export Equals)
 (export Unequals)
 (export FindIndex)
@@ -516,4 +566,7 @@ def signature(script, name):
 (export ForceSpreadMembers)
 (export ListFromTuple)
 (export ListToTuple)
+(export ListSampleImpl)
+(export ListSampleNImpl)
+(export ListShuffleImpl)
 )

+ 3 - 0
yql/essentials/sql/v1/builtin.cpp

@@ -2916,6 +2916,9 @@ struct TBuiltinFuncData {
             {"listtopsort", BuildNamedArgcBuiltinFactoryCallback<TCallNodeImpl>("ListTopSort", 2, 3)},
             {"listtopsortasc", BuildNamedArgcBuiltinFactoryCallback<TCallNodeImpl>("ListTopSortAsc", 2, 3)},
             {"listtopsortdesc", BuildNamedArgcBuiltinFactoryCallback<TCallNodeImpl>("ListTopSortDesc", 2, 3)},
+            {"listsample", BuildNamedArgcBuiltinFactoryCallback<TCallNodeImpl>("ListSample", 2, 3)},
+            {"listsamplen", BuildNamedArgcBuiltinFactoryCallback<TCallNodeImpl>("ListSampleN", 2, 3)},
+            {"listshuffle", BuildNamedArgcBuiltinFactoryCallback<TCallNodeImpl>("ListShuffle", 1, 2)},
 
             // Dict builtins
             {"dictlength", BuildNamedArgcBuiltinFactoryCallback<TCallNodeImpl>("Length", 1, 1)},

+ 1 - 0
yql/essentials/tests/common/test_framework/udfs_deps/ya.make

@@ -19,6 +19,7 @@ SET(
     yql/essentials/udfs/common/url_base
     yql/essentials/udfs/common/unicode_base
     yql/essentials/udfs/common/streaming
+    yql/essentials/udfs/common/vector
     yql/essentials/udfs/examples/callables
     yql/essentials/udfs/examples/dicts
     yql/essentials/udfs/examples/dummylog

+ 42 - 0
yql/essentials/tests/sql/sql2yql/canondata/result.json

@@ -6453,6 +6453,27 @@
             "uri": "https://{canondata_backend}/1784117/d56ae82ad9d30397a41490647be1bd2124718f98/resource.tar.gz#test_sql2yql.test_expr-list_replicate_fail_/sql.yql"
         }
     ],
+    "test_sql2yql.test[expr-list_sample]": [
+        {
+            "checksum": "922f4c9c5a2fe848f40272dd15cfde42",
+            "size": 10843,
+            "uri": "https://{canondata_backend}/1924537/278b77accb7596bd976e3e218425469d4b97dcf9/resource.tar.gz#test_sql2yql.test_expr-list_sample_/sql.yql"
+        }
+    ],
+    "test_sql2yql.test[expr-list_sample_n]": [
+        {
+            "checksum": "5ce08b8b61ef8b2863f931bc1b986679",
+            "size": 7573,
+            "uri": "https://{canondata_backend}/1924537/278b77accb7596bd976e3e218425469d4b97dcf9/resource.tar.gz#test_sql2yql.test_expr-list_sample_n_/sql.yql"
+        }
+    ],
+    "test_sql2yql.test[expr-list_shuffle]": [
+        {
+            "checksum": "3cd4f632706daf9ac8962369e7d0eac3",
+            "size": 4413,
+            "uri": "https://{canondata_backend}/1777230/f0ec95d2b2a3a38fc99b00afc1f2d60d2b3e8548/resource.tar.gz#test_sql2yql.test_expr-list_shuffle_/sql.yql"
+        }
+    ],
     "test_sql2yql.test[expr-list_takeskipwhile]": [
         {
             "checksum": "827d6c45ccb33ccc641531600fa839ce",
@@ -26319,6 +26340,27 @@
             "uri": "https://{canondata_backend}/1880306/64654158d6bfb1289c66c626a8162239289559d0/resource.tar.gz#test_sql_format.test_expr-list_replicate_fail_/formatted.sql"
         }
     ],
+    "test_sql_format.test[expr-list_sample]": [
+        {
+            "checksum": "a642f47aa5488ecfa6450c114a85903d",
+            "size": 1235,
+            "uri": "https://{canondata_backend}/1942525/0302d8428323e9211161c4db74348074ea0aab49/resource.tar.gz#test_sql_format.test_expr-list_sample_/formatted.sql"
+        }
+    ],
+    "test_sql_format.test[expr-list_sample_n]": [
+        {
+            "checksum": "4b04a240db2a66eab919da4fbbf3cdea",
+            "size": 1128,
+            "uri": "https://{canondata_backend}/1942525/0302d8428323e9211161c4db74348074ea0aab49/resource.tar.gz#test_sql_format.test_expr-list_sample_n_/formatted.sql"
+        }
+    ],
+    "test_sql_format.test[expr-list_shuffle]": [
+        {
+            "checksum": "73822288846e1fc180736baa4a9548c7",
+            "size": 612,
+            "uri": "https://{canondata_backend}/1942525/0302d8428323e9211161c4db74348074ea0aab49/resource.tar.gz#test_sql_format.test_expr-list_shuffle_/formatted.sql"
+        }
+    ],
     "test_sql_format.test[expr-list_takeskipwhile]": [
         {
             "checksum": "fe413941b62655034d49cd2674f2c947",

+ 1 - 0
yql/essentials/tests/sql/suites/expr/list_sample.cfg

@@ -0,0 +1 @@
+providers yt

+ 36 - 0
yql/essentials/tests/sql/suites/expr/list_sample.sql

@@ -0,0 +1,36 @@
+/* yt can not */
+$list = ListFromRange(1, 101);
+$test = ($probability, $dependsOn) -> { 
+    $sample = ListCollect(ListSample($list, $probability, $dependsOn));
+    RETURN 
+    (
+        ListSort(DictKeys(ToSet($sample))) == ListSort($sample),
+        (ListLength($sample), $probability * 100),
+        SetIncludes(ToSet($list), $sample)
+    );
+};
+
+SELECT
+    ListSample(NULL                                               , 1.0) IS NULL AS mustBeTrue1,
+    ListSample(Nothing(OptionalType(ListType(DataType("Uint64")))), 1.0) IS NULL AS mustBeTrue2,
+    ListSample([]                                                 , 1.0) == []   AS mustBeTrue3,
+    
+    ListSample($list, NULL                                     ) == $list AS mustBeTrue4,
+    ListSample($list, Nothing(OptionalType(DataType("Double")))) == $list AS mustBeTrue5,
+
+    ListSample($list, 0.5, 123) == ListSample($list, 0.5, 123) AS mustBeTrue6,
+
+    $test(0.2, 1) AS result1,
+    $test(0.2, 2) AS result2,
+    $test(0.2, 3) AS result3,
+    $test(0.2, 4) AS result4,
+    $test(0.2, 5) AS result5,
+    $test(0.5, 6) AS result6,
+    $test(0.8, 7) AS result7,
+    $test(1.0,   8) AS result8,
+    $test(0.0,   9) AS result9,
+
+    ListSample($list      , 0.1      , 10) AS result10,
+    ListSample(Just($list), 0.1      , 11) AS result11,
+    ListSample($list      , Just(0.1), 12) AS result12,
+    ListSample(Just($list), Just(0.1), 13) AS result13;

Some files were not shown because too many files changed in this diff