Просмотр исходного кода

NaN-safe comparison and hashing
commit_hash:46d59ab3acbd313753d3e46f3a6f10a8ebc424d8

babenko 4 месяцев назад
Родитель
Сommit
b60a78031c

+ 71 - 0
library/cpp/yt/misc/compare-inl.h

@@ -0,0 +1,71 @@
+#ifndef COMPARE_INL_H_
+#error "Direct inclusion of this file is not allowed, include compare.h"
+// For the sake of sane code completion.
+#include "compare.h"
+#endif
+
+#include "numeric_helpers.h"
+
+#include <util/generic/string.h>
+
+#include <string>
+#include <string_view>
+
+namespace NYT {
+
+////////////////////////////////////////////////////////////////////////////////
+
+template <class T>
+Y_FORCE_INLINE int TernaryCompare(const T& lhs, const T& rhs)
+{
+    if (lhs == rhs) {
+        return 0;
+    } else if (lhs < rhs) {
+        return -1;
+    } else {
+        return +1;
+    }
+}
+
+//! An optimized version for string types.
+template <class T>
+    requires
+        std::is_same_v<T, TString> ||
+        std::is_same_v<T, TStringBuf> ||
+        std::is_same_v<T, std::string> ||
+        std::is_same_v<T, std::string_view>
+Y_FORCE_INLINE int TernaryCompare(const T& lhs, const T& rhs)
+{
+    return GetSign(std::string_view(lhs).compare(std::string_view(rhs)));
+}
+
+template <class T>
+Y_FORCE_INLINE int NaNSafeTernaryCompare(const T& lhs, const T& rhs)
+{
+    return TernaryCompare(lhs, rhs);
+}
+
+template <class T>
+    requires std::is_floating_point_v<T>
+Y_FORCE_INLINE int NaNSafeTernaryCompare(const T& lhs, const T& rhs)
+{
+    if (lhs < rhs) {
+        return -1;
+    } else if (lhs > rhs) {
+        return +1;
+    } else if (std::isnan(lhs)) {
+        if (std::isnan(rhs)) {
+            return 0;
+        } else {
+            return +1;
+        }
+    } else if (std::isnan(rhs)) {
+        return -1;
+    } else {
+        return 0;
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+} // namespace NYT

+ 28 - 0
library/cpp/yt/misc/compare.h

@@ -0,0 +1,28 @@
+#pragma once
+
+#include <util/generic/string.h>
+
+#include <string_view>
+
+namespace NYT {
+
+////////////////////////////////////////////////////////////////////////////////
+
+//! Compares #lhs with #rhs;
+//! returns -1 (if |lhs < rhs|), 0 (if |lhs == rhs|), or 1 (|lhs > rhs|).
+template <class T>
+int TernaryCompare(const T& lhs, const T& rhs);
+
+//! Same as #TernaryCompare but handles NaN values gracefully
+//! (assuming NaNs are larger than any regular number and all NaNs are equal).
+//! If |T| is not a floating-point type, #NaNSafeTernaryCompare is equivalent to #TernaryCompare.
+template <class T>
+int NaNSafeTernaryCompare(const T& lhs, const T& rhs);
+
+////////////////////////////////////////////////////////////////////////////////
+
+} // namespace NYT
+
+#define COMPARE_INL_H_
+#include "compare-inl.h"
+#undef COMPARE_INL_H_

+ 15 - 0
library/cpp/yt/misc/hash-inl.h

@@ -4,6 +4,8 @@
 #include "hash.h"
 #endif
 
+#include <cmath>
+
 namespace NYT {
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -29,6 +31,19 @@ void HashCombine(size_t& h, const T& k)
     HashCombine(h, THash<T>()(k));
 }
 
+template <class T>
+Y_FORCE_INLINE size_t NaNSafeHash(const T& value)
+{
+    return ::THash<T>()(value);
+}
+
+template <class T>
+    requires std::is_floating_point_v<T>
+Y_FORCE_INLINE size_t NaNSafeHash(const T& value)
+{
+    return std::isnan(value) ? 0 : ::THash<T>()(value);
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 template <class TElement, class TUnderlying>

+ 7 - 2
library/cpp/yt/misc/hash.h

@@ -17,6 +17,12 @@ void HashCombine(size_t& h, size_t k);
 template <class T>
 void HashCombine(size_t& h, const T& k);
 
+//! Computes the hash of #value handling NaN values gracefully
+//! (returning the same constant for all NaNs).
+//! If |T| is not a floating-point type, #NaNSafeHash is equivalent to #THash.
+template <class T>
+size_t NaNSafeHash(const T& value);
+
 ////////////////////////////////////////////////////////////////////////////////
 
 //! Provides a hasher that randomizes the results of another one.
@@ -25,12 +31,11 @@ class TRandomizedHash
 {
 public:
     TRandomizedHash();
-    size_t operator () (const TElement& element) const;
+    size_t operator()(const TElement& element) const;
 
 private:
     size_t Seed_;
     TUnderlying Underlying_;
-
 };
 
 ////////////////////////////////////////////////////////////////////////////////

+ 15 - 0
yt/yt/core/misc/numeric_helpers-inl.h → library/cpp/yt/misc/numeric_helpers-inl.h

@@ -6,8 +6,11 @@
 
 #include <cstdlib>
 #include <cinttypes>
+#include <cmath>
 #include <algorithm>
 
+#include <util/system/compiler.h>
+
 namespace NYT {
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -39,6 +42,18 @@ T RoundDown(const T& numerator, const T& denominator)
     return (numerator / denominator) * denominator;
 }
 
+template <class T>
+Y_FORCE_INLINE int GetSign(const T& value)
+{
+    if (value < 0) {
+        return -1;
+    } else if (value > 0) {
+        return +1;
+    } else {
+        return 0;
+    }
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 } // namespace NYT

+ 3 - 0
yt/yt/core/misc/numeric_helpers.h → library/cpp/yt/misc/numeric_helpers.h

@@ -19,6 +19,9 @@ T RoundUp(const T& numerator, const T& denominator);
 template <class T>
 T RoundDown(const T& numerator, const T& denominator);
 
+template <class T>
+int GetSign(const T& value);
+
 ////////////////////////////////////////////////////////////////////////////////
 
 } // namespace NYT

+ 62 - 0
library/cpp/yt/misc/unittests/compare_ut.cpp

@@ -0,0 +1,62 @@
+#include <library/cpp/testing/gtest/gtest.h>
+
+#include <library/cpp/yt/misc/compare.h>
+
+namespace NYT {
+namespace {
+
+////////////////////////////////////////////////////////////////////////////////
+
+TEST(TCompareTest, TernaryCompare)
+{
+    EXPECT_EQ(TernaryCompare(123, 123), 0);
+    EXPECT_EQ(TernaryCompare(10, 20), -1);
+    EXPECT_EQ(TernaryCompare(20, 10), +1);
+    EXPECT_EQ(TernaryCompare(std::nan("1"), std::nan("1")), +1);
+    EXPECT_EQ(TernaryCompare(std::nan("1"), std::nan("2")), +1);
+    EXPECT_EQ(TernaryCompare(std::nan("1"), 123.0), +1);
+    EXPECT_EQ(TernaryCompare(123.0, std::nan("1")), +1);
+}
+
+TEST(TCompareTest, NaNSafeTernaryCompare)
+{
+    EXPECT_EQ(NaNSafeTernaryCompare(std::nan("1"), std::nan("1")), 0);
+    EXPECT_EQ(NaNSafeTernaryCompare(std::nan("1"), std::nan("2")), 0);
+    EXPECT_EQ(NaNSafeTernaryCompare(123.0, std::nan("1")), -1);
+    EXPECT_EQ(NaNSafeTernaryCompare(std::nan("1"), 123.0), +1);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+template <class T>
+class TTernaryCompareStringTest
+    : public ::testing::Test
+{ };
+
+TYPED_TEST_SUITE_P(TTernaryCompareStringTest);
+
+TYPED_TEST_P(TTernaryCompareStringTest, Compare)
+{
+    EXPECT_EQ(TernaryCompare(TypeParam("abc"), TypeParam("abc")), 0);
+    EXPECT_EQ(TernaryCompare(TypeParam("x"), TypeParam("y")), -1);
+    EXPECT_EQ(TernaryCompare(TypeParam("y"), TypeParam("x")), +1);
+}
+
+REGISTER_TYPED_TEST_SUITE_P(TTernaryCompareStringTest, Compare);
+
+using TTernaryCompareStringTestTypes = ::testing::Types<
+    TString,
+    TStringBuf,
+    std::string,
+    std::string_view
+>;
+
+INSTANTIATE_TYPED_TEST_SUITE_P(
+    TypeParametrized,
+    TTernaryCompareStringTest,
+    TTernaryCompareStringTestTypes);
+
+////////////////////////////////////////////////////////////////////////////////
+
+} // namespace
+} // namespace NYT

+ 19 - 0
library/cpp/yt/misc/unittests/hash_ut.cpp

@@ -0,0 +1,19 @@
+#include <library/cpp/testing/gtest/gtest.h>
+
+#include <library/cpp/yt/misc/hash.h>
+
+namespace NYT {
+namespace {
+
+////////////////////////////////////////////////////////////////////////////////
+
+TEST(THashTest, NaNSafeHash)
+{
+    EXPECT_EQ(NaNSafeHash(123), THash<int>()(123));
+    EXPECT_EQ(NaNSafeHash(std::nan("1")), NaNSafeHash(std::nan("2")));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+} // namespace
+} // namespace NYT

+ 2 - 0
library/cpp/yt/misc/unittests/ya.make

@@ -4,8 +4,10 @@ INCLUDE(${ARCADIA_ROOT}/library/cpp/yt/ya_cpp.make.inc)
 
 SRCS(
     cast_ut.cpp
+    compare_ut.cpp
     enum_ut.cpp
     guid_ut.cpp
+    hash_ut.cpp
     non_null_ptr_ut.cpp
     preprocessor_ut.cpp
     strong_typedef_ut.cpp

+ 10 - 62
yt/yt/client/table_client/composite_compare.cpp

@@ -10,8 +10,11 @@
 #include <yt/yt/library/numeric/util.h>
 
 #include <library/cpp/yt/farmhash/farm_hash.h>
+
 #include <library/cpp/yt/logging/logger.h>
 
+#include <library/cpp/yt/misc/compare.h>
+
 #include <util/stream/mem.h>
 
 #include <cmath>
@@ -73,48 +76,12 @@ static constexpr ECompareClass GetCompareClass(EYsonItemType type)
     return static_cast<ECompareClass>(0x3u & (compareClassMask >> (static_cast<ui32>(type) * 2)));
 }
 
-template <typename T>
-Y_FORCE_INLINE int ComparePrimitive(T lhs, T rhs)
-{
-    if (lhs == rhs) {
-        return 0;
-    } else if (lhs < rhs) {
-        return -1;
-    } else {
-        return 1;
-    }
-}
-
-template <>
-Y_FORCE_INLINE int ComparePrimitive<double>(double lhs, double rhs)
-{
-    if (lhs < rhs) {
-        return -1;
-    } else if (lhs > rhs) {
-        return 1;
-    } else if (std::isnan(lhs)) {
-        if (std::isnan(rhs)) {
-            return 0;
-        }
-        return 1;
-    } else if (std::isnan(rhs)) {
-        return -1;
-    } else {
-        return 0;
-    }
-}
-
 [[noreturn]] static void ThrowIncomparableYsonToken(EYsonItemType tokenType)
 {
     THROW_ERROR_EXCEPTION("Incomparable YSON token %Qlv",
         tokenType);
 }
 
-Y_FORCE_INLINE static int GetSign(int x)
-{
-    return static_cast<int>(0 < x) - static_cast<int>(0 > x);
-}
-
 Y_FORCE_INLINE static EValueType MapItemTypeToValueType(EYsonItemType itemType)
 {
     static const TEnumIndexedArray<EYsonItemType, EValueType> mapping = {
@@ -149,15 +116,15 @@ Y_FORCE_INLINE static int CompareYsonItems(const TYsonItem& lhs, const TYsonItem
             case EYsonItemType::EntityValue:
                 return 0;
             case EYsonItemType::Int64Value:
-                return ComparePrimitive(lhs.UncheckedAsInt64(), rhs.UncheckedAsInt64());
+                return TernaryCompare(lhs.UncheckedAsInt64(), rhs.UncheckedAsInt64());
             case EYsonItemType::Uint64Value:
-                return ComparePrimitive(lhs.UncheckedAsUint64(), rhs.UncheckedAsUint64());
+                return TernaryCompare(lhs.UncheckedAsUint64(), rhs.UncheckedAsUint64());
             case EYsonItemType::DoubleValue:
-                return ComparePrimitive(lhs.UncheckedAsDouble(), rhs.UncheckedAsDouble());
+                return NaNSafeTernaryCompare(lhs.UncheckedAsDouble(), rhs.UncheckedAsDouble());
             case EYsonItemType::BooleanValue:
-                return ComparePrimitive(lhs.UncheckedAsBoolean(), rhs.UncheckedAsBoolean());
+                return TernaryCompare(lhs.UncheckedAsBoolean(), rhs.UncheckedAsBoolean());
             case EYsonItemType::StringValue:
-                return GetSign(TString::compare(lhs.UncheckedAsString(), rhs.UncheckedAsString()));
+                return TernaryCompare(lhs.UncheckedAsString(), rhs.UncheckedAsString());
 
             case EYsonItemType::BeginMap:
             case EYsonItemType::EndMap:
@@ -179,9 +146,9 @@ Y_FORCE_INLINE static int CompareYsonItems(const TYsonItem& lhs, const TYsonItem
     }
 
     if (lhsClass == ECompareClass::BeginValue && rhsClass == ECompareClass::BeginValue) {
-        return static_cast<int>(MapItemTypeToValueType(lhs.GetType())) - static_cast<int>(MapItemTypeToValueType(rhs.GetType()));
+        return TernaryCompare(MapItemTypeToValueType(lhs.GetType()), MapItemTypeToValueType(rhs.GetType()));
     }
-    return ComparePrimitive(static_cast<ui32>(lhsClass), static_cast<ui32>(rhsClass));
+    return TernaryCompare(lhsClass, rhsClass);
 }
 
 // Returns the minimum binary size needed to represent a potentially truncated version of the this item.
@@ -217,25 +184,6 @@ i64 GetMinResultingSize(const TYsonItem& item, bool isInsideList)
 
 } // namespace
 
-int CompareDoubleValues(double lhs, double rhs)
-{
-    if (lhs < rhs) {
-        return -1;
-    } else if (lhs > rhs) {
-        return +1;
-    } else if (std::isnan(lhs)) {
-        if (std::isnan(rhs)) {
-            return 0;
-        } else {
-            return 1;
-        }
-    } else if (std::isnan(rhs)) {
-        return -1;
-    }
-
-    return 0;
-}
-
 int CompareYsonValues(TYsonStringBuf lhs, TYsonStringBuf rhs)
 {
     YT_ASSERT(lhs.GetType() == EYsonType::Node);

Некоторые файлы не были показаны из-за большого количества измененных файлов