Browse Source

YT block map join
commit_hash:fe68272adcfefced1da829c45212b436aadae5e1

ziganshinmr 1 month ago
parent
commit
4e5542dd0e

+ 15 - 0
yql/essentials/core/expr_nodes/yql_expr_nodes.json

@@ -1612,6 +1612,21 @@
                 {"Index": 8, "Name": "RightKeysColumnNames", "Type": "TCoAtomList"}
             ]
         },
+        {
+            "Name": "TCoBlockMapJoinCore",
+            "Base": "TCallable",
+            "Match": {"Type": "Callable", "Name": "BlockMapJoinCore"},
+            "Children": [
+                {"Index": 0, "Name": "LeftInput", "Type": "TExprBase"},
+                {"Index": 1, "Name": "RightInput", "Type": "TExprBase"},
+                {"Index": 2, "Name": "JoinKind", "Type": "TCoAtom"},
+                {"Index": 3, "Name": "LeftKeyColumns", "Type": "TCoAtomList"},
+                {"Index": 4, "Name": "LeftKeyDrops", "Type": "TCoAtomList"},
+                {"Index": 5, "Name": "RightKeyColumns", "Type": "TCoAtomList"},
+                {"Index": 6, "Name": "RightKeyDrops", "Type": "TCoAtomList"},
+                {"Index": 7, "Name": "Options", "Type": "TExprList"}
+            ]
+        },
         {
             "Name": "TCoGraceJoinCore",
             "Base": "TCallable",

+ 2 - 0
yql/essentials/core/type_ann/type_ann_core.cpp

@@ -12850,6 +12850,8 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
         ExtFunctions["BlockFunc"] = &BlockFuncWrapper;
         ExtFunctions["BlockBitCast"] = &BlockBitCastWrapper;
 
+        Functions["BlockMapJoinCore"] = &BlockMapJoinCoreWrapper;
+
         ExtFunctions["AsScalar"] = &AsScalarWrapper;
         ExtFunctions["WideToBlocks"] = &WideToBlocksWrapper;
         ExtFunctions["BlockCombineAll"] = &BlockCombineAllWrapper;

+ 1 - 0
yql/essentials/core/type_ann/type_ann_impl.h

@@ -35,6 +35,7 @@ namespace NTypeAnnImpl {
     IGraphTransformer::TStatus CombineCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
     IGraphTransformer::TStatus GroupingCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
     IGraphTransformer::TStatus DecimalBinaryWrapperBase(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx, bool blocks);
+    IGraphTransformer::TStatus BlockMapJoinCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
 
     TMaybe<ui32> FindOrReportMissingMember(TStringBuf memberName, TPositionHandle pos, const TStructExprType& structType, TExprContext& ctx);
 

+ 133 - 0
yql/essentials/core/type_ann/type_ann_join.cpp

@@ -980,5 +980,138 @@ namespace NTypeAnnImpl {
         }
     }
 
+    IGraphTransformer::TStatus BlockMapJoinCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+        Y_UNUSED(output);
+
+        if (!EnsureArgsCount(*input, 8, ctx.Expr)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        TTypeAnnotationNode::TListType leftItemTypes;
+        if (!EnsureWideStreamBlockType(input->Head(), leftItemTypes, ctx.Expr)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+        leftItemTypes.pop_back();
+        auto leftItemType = input->Head().GetTypeAnn()->Cast<TStreamExprType>()->GetItemType()->Cast<TMultiExprType>();
+
+        TTypeAnnotationNode::TListType rightItemTypes;
+        if (!EnsureWideStreamBlockType(*input->Child(1), rightItemTypes, ctx.Expr)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+        rightItemTypes.pop_back();
+        auto rightItemType = input->Child(1)->GetTypeAnn()->Cast<TStreamExprType>()->GetItemType()->Cast<TMultiExprType>();
+
+        if (!EnsureAtom(*input->Child(2), ctx.Expr)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        const auto joinKind = input->Child(2)->Content();
+        if (joinKind != "Inner" && joinKind != "Left" && joinKind != "LeftSemi" && joinKind != "LeftOnly") {
+            ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Child(2)->Pos()), TStringBuilder() << "Unknown join kind: " << joinKind
+                << ", supported: Inner, Left, LeftSemi, LeftOnly"));
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        if (input->Child(3)->ChildrenSize() != input->Child(5)->ChildrenSize()) {
+            ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Child(5)->Pos()), TStringBuilder() << "Mismatch of key column count"));
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        auto checkKeyColumns = [&](std::unordered_set<ui32>& keyColumns, bool isLeft, const TExprNode& keyColumnsNode, const TMultiExprType* itemType) {
+            for (const auto& keyColumnNode : keyColumnsNode.Children()) {
+                auto position = GetWideBlockFieldPosition(*itemType, keyColumnNode->Content());
+                if (!position) {
+                    ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(keyColumnNode->Pos()), TStringBuilder() << "Unknown " << (isLeft ? "left" : "right") << " key column: " << keyColumnNode->Content()));
+                    return false;
+                }
+                keyColumns.insert(*position);
+            }
+            return true;
+        };
+
+        auto checkKeyDrops = [&](std::unordered_set<ui32>& keyDrops, bool isLeft, const std::unordered_set<ui32>& keyColumns, const TExprNode& keyDropsNode, const TMultiExprType* itemType) {
+            for (const auto& keyDropNode : keyDropsNode.Children()) {
+                auto position = GetWideBlockFieldPosition(*itemType, keyDropNode->Content());
+                if (!position) {
+                    ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(keyDropNode->Pos()), TStringBuilder() << "Unknown " << (isLeft ? "left" : "right") << " key column: " << keyDropNode->Content()));
+                    return false;
+                }
+                if (!keyColumns.contains(*position)) {
+                    ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(keyDropNode->Pos()), TStringBuilder() << "Attempted to drop " << (isLeft ? "left" : "right") << " non-key column: " << keyDropNode->Content()));
+                    return false;
+                }
+                if (!keyDrops.insert(*position).second) {
+                    ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(keyDropNode->Pos()), TStringBuilder() << "Duplicated " << (isLeft ? "left" : "right") << " key drop: " << keyDropNode->Content()));
+                    return false;
+                }
+            }
+            return true;
+        };
+
+        for (size_t childIdx = 3; childIdx <= 6; childIdx++) {
+            if (!EnsureTupleOfAtoms(*input->Child(childIdx), ctx.Expr)) {
+                return IGraphTransformer::TStatus::Error;
+            }
+        }
+
+        std::unordered_set<ui32> leftKeyColumns;
+        if (!checkKeyColumns(leftKeyColumns, true, *input->Child(3), leftItemType)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        std::unordered_set<ui32> leftKeyDrops;
+        if (!checkKeyDrops(leftKeyDrops, true, leftKeyColumns, *input->Child(4), leftItemType)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        std::unordered_set<ui32> rightKeyColumns;
+        if (!checkKeyColumns(rightKeyColumns, false, *input->Child(5), rightItemType)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        std::unordered_set<ui32> rightKeyDrops;
+        if (!checkKeyDrops(rightKeyDrops, false, rightKeyColumns, *input->Child(6), rightItemType)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        auto settingsValidator = [&](TStringBuf, TExprNode& node, TExprContext&) { return node.ChildrenSize() == 1; };
+        if (!EnsureValidSettings(input->Tail(), {"rightAny"}, settingsValidator, ctx.Expr)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        std::vector<const TTypeAnnotationNode*> resultItems;
+        for (ui32 pos = 0; pos < leftItemTypes.size(); pos++) {
+            if (leftKeyDrops.contains(pos)) {
+                continue;
+            }
+
+            resultItems.push_back(ctx.Expr.MakeType<TBlockExprType>(leftItemTypes[pos]));
+        }
+
+        if (joinKind != "LeftSemi" && joinKind != "LeftOnly") {
+            for (ui32 pos = 0; pos < rightItemTypes.size(); pos++) {
+                if (rightKeyDrops.contains(pos)) {
+                    continue;
+                }
+
+                auto columnType = rightItemTypes[pos];
+                if (joinKind == "Left" && !rightItemTypes[pos]->IsOptionalOrNull()) {
+                    columnType = ctx.Expr.MakeType<TOptionalExprType>(columnType);
+                }
+
+                resultItems.push_back(ctx.Expr.MakeType<TBlockExprType>(columnType));
+            }
+        } else {
+            if (!rightKeyDrops.empty()) {
+                ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Child(6)->Pos()), TStringBuilder() << "Right key drops are not allowed for semi/only join"));
+                return IGraphTransformer::TStatus::Error;
+            }
+        }
+
+        resultItems.push_back(ctx.Expr.MakeType<TScalarExprType>(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64)));
+        input->SetTypeAnn(ctx.Expr.MakeType<TStreamExprType>(ctx.Expr.MakeType<TMultiExprType>(resultItems)));
+        return IGraphTransformer::TStatus::Ok;
+    }
+
 } // namespace NTypeAnnImpl
 } // namespace NYql

+ 7 - 0
yql/essentials/core/yql_expr_type_annotation.cpp

@@ -6035,6 +6035,13 @@ std::optional<ui32> GetFieldPosition(const TStructExprType& structType, const TS
     return std::nullopt;
 }
 
+std::optional<ui32> GetWideBlockFieldPosition(const TMultiExprType& multiType, const TStringBuf& field) {
+    YQL_ENSURE(multiType.GetSize() >= 1);
+    if (ui32 pos; TryFromString(field, pos) && pos < multiType.GetSize() - 1)
+        return {pos};
+    return std::nullopt;
+}
+
 bool ExtractPgType(const TTypeAnnotationNode* type, ui32& pgType, bool& convertToPg, TPositionHandle pos, TExprContext& ctx) {
     pgType = 0;
     convertToPg = false;

+ 1 - 0
yql/essentials/core/yql_expr_type_annotation.h

@@ -319,6 +319,7 @@ IGraphTransformer::TStatus NormalizeKeyValueTuples(const TExprNode::TPtr& input,
 std::optional<ui32> GetFieldPosition(const TMultiExprType& tupleType, const TStringBuf& field);
 std::optional<ui32> GetFieldPosition(const TTupleExprType& tupleType, const TStringBuf& field);
 std::optional<ui32> GetFieldPosition(const TStructExprType& structType, const TStringBuf& field);
+std::optional<ui32> GetWideBlockFieldPosition(const TMultiExprType& tupleType, const TStringBuf& field);
 
 bool ExtractPgType(const TTypeAnnotationNode* type, ui32& pgType, bool& convertToPg, TPositionHandle pos, TExprContext& ctx);
 bool HasContextFuncs(const TExprNode& input);

+ 4 - 2
yql/essentials/minikql/comp_nodes/mkql_block_map_join.cpp

@@ -7,6 +7,7 @@
 #include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
 #include <yql/essentials/minikql/comp_nodes/mkql_rh_hash.h>
 #include <yql/essentials/minikql/invoke_builtins/mkql_builtins.h>
+#include <yql/essentials/minikql/mkql_block_map_join_utils.h>
 #include <yql/essentials/minikql/mkql_node_cast.h>
 #include <yql/essentials/minikql/mkql_program_builder.h>
 
@@ -308,8 +309,6 @@ class TBlockIndex : public TComputationValue<TBlockIndex> {
         };
     };
 
-    static_assert(sizeof(TIndexMapValue) == 8);
-
     using TBase = TComputationValue<TBlockIndex>;
     using TIndexMap = TRobinHoodHashFixedMap<
         ui64,
@@ -319,6 +318,9 @@ class TBlockIndex : public TComputationValue<TBlockIndex> {
         TMKQLAllocator<char>
     >;
 
+    static_assert(sizeof(TIndexMapValue) == 8);
+    static_assert(std::max(TIndexMap::GetCellSize(), static_cast<ui32>(sizeof(TIndexNode))) == BlockMapJoinIndexEntrySize);
+
 public:
     class TIterator {
         enum class EIteratorType {

+ 6 - 14
yql/essentials/minikql/comp_nodes/mkql_rh_hash.h

@@ -6,6 +6,7 @@
 #include <vector>
 #include <span>
 
+#include <yql/essentials/minikql/mkql_rh_hash_utils.h>
 #include <yql/essentials/utils/prefetch.h>
 
 #include <util/digest/city.h>
@@ -109,7 +110,7 @@ public:
 
     // should be called after Insert if isNew is true
     Y_FORCE_INLINE void CheckGrow() {
-        if (Size * 2 >= Capacity) {
+        if (RHHashTableNeedsGrow(Size, Capacity)) {
             Grow();
         }
     }
@@ -124,7 +125,7 @@ public:
 
     template <typename TSink>
     Y_NO_INLINE void BatchInsert(std::span<TRobinHoodBatchRequestItem<TKey>> batchRequest, TSink&& sink) {
-        while (2 * (Size + batchRequest.size()) >= Capacity) {
+        while (RHHashTableNeedsGrow(Size + batchRequest.size(), Capacity)) {
             Grow();
         }
 
@@ -331,15 +332,7 @@ private:
     }
 
     Y_NO_INLINE void Grow() {
-        ui64 growFactor;
-        if (Capacity < 100'000) {
-            growFactor = 8;
-        } else if (Capacity < 1'000'000) {
-            growFactor = 4;
-        } else {
-            growFactor = 2;
-        }
-        auto newCapacity = Capacity * growFactor;
+        auto newCapacity = Capacity * CalculateRHHashTableGrowFactor(Capacity);
         auto newCapacityShift = 64 - MostSignificantBit(newCapacity);
         char *newData, *newDataEnd;
         Allocate(newCapacity, newData, newDataEnd);
@@ -522,8 +515,7 @@ public:
         TBase::Init();
     }
 
-
-    ui32 GetCellSize() const {
+    static constexpr ui32 GetCellSize() {
         return sizeof(typename TBase::TPSLStorage) + sizeof(TKey) + sizeof(TPayload);
     }
 
@@ -569,7 +561,7 @@ public:
         TBase::Init();
     }
 
-    ui32 GetCellSize() const {
+    static constexpr ui32 GetCellSize() {
         return sizeof(typename TBase::TPSLStorage) + sizeof(TKey);
     }
 

+ 12 - 0
yql/essentials/minikql/mkql_block_map_join_utils.cpp

@@ -0,0 +1,12 @@
+#include "mkql_block_map_join_utils.h"
+#include "mkql_rh_hash_utils.h"
+
+namespace NKikimr {
+namespace NMiniKQL {
+
+ui64 EstimateBlockMapJoinIndexSize(ui64 rowsCount)  {
+    return CalculateRHHashTableCapacity(rowsCount) * BlockMapJoinIndexEntrySize;
+}
+
+} // namespace NMiniKQL
+} // namespace NKikimr

+ 13 - 0
yql/essentials/minikql/mkql_block_map_join_utils.h

@@ -0,0 +1,13 @@
+#pragma once
+
+#include <util/system/types.h>
+
+namespace NKikimr {
+namespace NMiniKQL {
+
+constexpr ui64 BlockMapJoinIndexEntrySize = 20;
+
+ui64 EstimateBlockMapJoinIndexSize(ui64 rowsCount);
+
+} // namespace NMiniKQL
+} // namespace NKikimr

Some files were not shown because too many files changed in this diff