Browse Source

Add TCaseInsensitiveAsciiString

Followup for rXXXXXX - further optimize ascii-only case insensitive strings
1fca7889a074a191eadce12247bdd6dd18b75ab2
vadim-xd 9 months ago
parent
commit
22d59c45d8

+ 68 - 0
library/cpp/case_insensitive_string/benchmark/compare.cpp

@@ -0,0 +1,68 @@
+#include <library/cpp/case_insensitive_string/ut_gtest/util/locale_guard.h>
+#include <library/cpp/case_insensitive_string/case_insensitive_string.h>
+
+#include <benchmark/benchmark.h>
+
+#include <util/generic/string.h>
+#include <util/random/random.h>
+
+#include <array>
+
+#include <cerrno>
+#include <clocale>
+#include <cstring>
+
+namespace {
+    template <typename TStrBuf>
+    void BenchmarkCompare(benchmark::State& state, TStrBuf s1, TStrBuf s2) {
+        for (auto _ : state) {
+            benchmark::DoNotOptimize(s1);
+            benchmark::DoNotOptimize(s2);
+            auto cmp = s1.compare(s2);
+            benchmark::DoNotOptimize(cmp);
+        }
+    }
+
+    char RandomPrintableChar() {
+        while (true) {
+            unsigned char c = RandomNumber(127u);
+            if (std::isprint(c)) {
+                return c;
+            }
+        }
+    }
+
+    const std::array<const char*, 2> Locales = {
+        "C",
+        "ru_RU.CP1251",
+    };
+}
+
+template <typename TStrBuf>
+void CompareEqualStrings(benchmark::State& state) {
+    SetRandomSeed(123);
+
+    size_t n = state.range(0);
+    size_t locIndex = state.range(1);
+
+    TLocaleGuard loc(Locales[locIndex]);
+    if (loc.Error()) {
+        state.SkipWithMessage(TString::Join(Locales[locIndex], " locale is not available: ", loc.Error()));
+        return;
+    }
+    TString s1(Reserve(n)), s2(Reserve(n));
+    for (size_t i = 0; i < n; ++i) {
+        auto c = RandomPrintableChar();
+        s1.push_back(std::toupper(c));
+        s2.push_back(std::tolower(c));
+    }
+
+    BenchmarkCompare(state, TStrBuf(s1.data(), s1.size()), TStrBuf(s2.data(), s2.size()));
+}
+
+#define BENCH_ARGS ArgNames({"strlen", "locale"})->ArgsProduct({{2, 4, 8, 16, 32, 64}, {0, 1}})
+
+BENCHMARK(CompareEqualStrings<TCaseInsensitiveStringBuf>)->BENCH_ARGS;
+BENCHMARK(CompareEqualStrings<TCaseInsensitiveAsciiStringBuf>)->BENCH_ARGS;
+
+#undef BENCH_ARGS

+ 20 - 6
library/cpp/case_insensitive_string/benchmark/main.cpp → library/cpp/case_insensitive_string/benchmark/hash.cpp

@@ -1,5 +1,7 @@
-#include <benchmark/benchmark.h>
 #include <library/cpp/case_insensitive_string/case_insensitive_string.h>
+#include <library/cpp/case_insensitive_string/ut_gtest/util/locale_guard.h>
+
+#include <benchmark/benchmark.h>
 #include <library/cpp/digest/murmur/murmur.h>
 
 #include <util/generic/hash_table.h>
@@ -102,30 +104,38 @@ namespace {
     Y_FORCE_INLINE size_t DefaultHash(TCaseInsensitiveStringBuf str) {
         return ComputeHash(str);
     }
+
+    Y_FORCE_INLINE size_t DefaultHashAscii(TCaseInsensitiveAsciiStringBuf str) {
+        return ComputeHash(str);
+    }
 }
 
-template <auto Impl>
+template <auto Impl, typename TTraits = TCaseInsensitiveCharTraits>
 void CaseInsensitiveHash(benchmark::State& state) {
+    TLocaleGuard loc("C");
+    Y_ABORT_IF(loc.Error());
     SetRandomSeed(123 + state.range());
-    TCaseInsensitiveString str;
+    TBasicString<char, TTraits> str;
     for (int i = 0; i < state.range(); ++i) {
         str.push_back(RandomNumber<unsigned char>());
     }
-    Y_ENSURE(Impl(str) == NaiveHash(str));
+    Y_ENSURE(Impl(str) == NaiveHash(str), "Hashes differ: got " << Impl(str) << ", expected " <<  NaiveHash(str));
     for (auto _ : state) {
         size_t hash = Impl(str);
         benchmark::DoNotOptimize(hash);
     }
 }
 
-template <auto Impl>
+template <auto Impl, typename TTraits = TCaseInsensitiveCharTraits>
 void CaseInsensitiveHashRandomSizes(benchmark::State& state) {
+    TLocaleGuard loc("C");
+    Y_ABORT_IF(loc.Error());
     SetRandomSeed(123);
     size_t minStrLen = static_cast<size_t>(state.range(0));
     size_t maxStrLen = static_cast<size_t>(state.range(1));
     static constexpr size_t nStrings = 64;
     TVector<TString> stringStorage(Reserve(nStrings));
-    std::array<TCaseInsensitiveStringBuf, nStrings> strings;
+    std::array<TBasicStringBuf<char, TTraits>, nStrings> strings;
     for (size_t i = 0; i < nStrings; ++i) {
         auto& str = stringStorage.emplace_back();
         size_t strLen = minStrLen + RandomNumber(maxStrLen - minStrLen + 1);
@@ -153,6 +163,8 @@ BENCHMARK(CaseInsensitiveHash<OptimizedHashDuplicateTailLoopInFunc>)->BENCH_ARGS
 BENCHMARK(CaseInsensitiveHash<OptimizedHashTailLoopInFunc>)->BENCH_ARGS;
 #endif
 
+BENCHMARK(CaseInsensitiveHash<DefaultHashAscii, TCaseInsensitiveAsciiCharTraits>)->BENCH_ARGS;
+
 #undef BENCH_ARGS
 
 #define BENCH_ARGS \
@@ -164,4 +176,6 @@ BENCHMARK(CaseInsensitiveHash<OptimizedHashTailLoopInFunc>)->BENCH_ARGS;
 BENCHMARK(CaseInsensitiveHashRandomSizes<NaiveHash>)->BENCH_ARGS;
 BENCHMARK(CaseInsensitiveHashRandomSizes<DefaultHash>)->BENCH_ARGS;
 
+BENCHMARK(CaseInsensitiveHashRandomSizes<DefaultHashAscii, TCaseInsensitiveAsciiCharTraits>)->BENCH_ARGS;
+
 #undef BENCH_ARGS

+ 5 - 1
library/cpp/case_insensitive_string/benchmark/ya.make

@@ -1,15 +1,19 @@
 G_BENCHMARK()
 
+SIZE(MEDIUM)
+
 IF (NOT AUTOCHECK)
     CFLAGS(-DBENCHMARK_ALL_IMPLS)
 ENDIF()
 
 SRCS(
-    main.cpp
+    compare.cpp
+    hash.cpp
 )
 
 PEERDIR(
     library/cpp/case_insensitive_string
+    library/cpp/case_insensitive_string/ut_gtest/util
     library/cpp/digest/murmur
 )
 

+ 16 - 17
library/cpp/case_insensitive_string/case_insensitive_char_traits.cpp

@@ -1,14 +1,25 @@
 #include "case_insensitive_char_traits.h"
-#include "case_insensitive_string.h"
 
 #include <util/string/escape.h>
 
+template <typename TImpl>
+const char* ::NPrivate::TCommonCaseInsensitiveCharTraits<TImpl>::find(const char* s, std::size_t n, char a) {
+    const auto ca(TImpl::ToCommonCase(a));
+    while (n-- != 0) {
+        if (TImpl::ToCommonCase(*s) == ca)
+            return s;
+        s++;
+    }
+    return nullptr;
+}
+
 int TCaseInsensitiveCharTraits::compare(const char* s1, const char* s2, std::size_t n) {
     while (n-- != 0) {
-        if (to_upper(*s1) < to_upper(*s2)) {
+        auto c1 = ToCommonCase(*s1), c2 = ToCommonCase(*s2);
+        if (c1 < c2) {
             return -1;
         }
-        if (to_upper(*s1) > to_upper(*s2)) {
+        if (c1 > c2) {
             return 1;
         }
         ++s1;
@@ -17,17 +28,5 @@ int TCaseInsensitiveCharTraits::compare(const char* s1, const char* s2, std::siz
     return 0;
 }
 
-const char* TCaseInsensitiveCharTraits::find(const char* s, std::size_t n, char a) {
-    auto const ua(to_upper(a));
-    while (n-- != 0) {
-        if (to_upper(*s) == ua)
-            return s;
-        s++;
-    }
-    return nullptr;
-}
-
-TCaseInsensitiveString EscapeC(const TCaseInsensitiveString& str) {
-    const auto result = EscapeC(str.data(), str.size());
-    return {result.data(), result.size()};
-}
+template struct ::NPrivate::TCommonCaseInsensitiveCharTraits<TCaseInsensitiveCharTraits>;
+template struct ::NPrivate::TCommonCaseInsensitiveCharTraits<TCaseInsensitiveAsciiCharTraits>;

+ 40 - 15
library/cpp/case_insensitive_string/case_insensitive_char_traits.h

@@ -2,29 +2,54 @@
 
 #include <contrib/libs/libc_compat/string.h>
 
+#include <util/string/ascii.h>
+
 #include <string>
 
-struct TCaseInsensitiveCharTraits : private std::char_traits<char> {
-    static bool eq(char c1, char c2) {
-        return to_upper(c1) == to_upper(c2);
-    }
+namespace NPrivate {
+    template <typename TImpl>
+    struct TCommonCaseInsensitiveCharTraits : private std::char_traits<char> {
+        static bool eq(char c1, char c2) {
+            return TImpl::ToCommonCase(c1) == TImpl::ToCommonCase(c2);
+        }
 
-    static bool lt(char c1, char c2) {
-        return to_upper(c1) < to_upper(c2);
-    }
+        static bool lt(char c1, char c2) {
+            return TImpl::ToCommonCase(c1) < TImpl::ToCommonCase(c2);
+        }
 
-    static int compare(const char* s1, const char* s2, std::size_t n);
+        static const char* find(const char* s, std::size_t n, char a);
 
-    static const char* find(const char* s, std::size_t n, char a);
+        using std::char_traits<char>::assign;
+        using std::char_traits<char>::char_type;
+        using std::char_traits<char>::copy;
+        using std::char_traits<char>::length;
+        using std::char_traits<char>::move;
+    };
+} // namespace NPrivate
 
-    using std::char_traits<char>::assign;
-    using std::char_traits<char>::char_type;
-    using std::char_traits<char>::copy;
-    using std::char_traits<char>::length;
-    using std::char_traits<char>::move;
+struct TCaseInsensitiveCharTraits : public ::NPrivate::TCommonCaseInsensitiveCharTraits<TCaseInsensitiveCharTraits> {
+    static int compare(const char* s1, const char* s2, std::size_t n);
 
 private:
-    static char to_upper(char ch) {
+    friend ::NPrivate::TCommonCaseInsensitiveCharTraits<TCaseInsensitiveCharTraits>;
+
+    // XXX return unsigned char. Current impl depends on char signedness, and if char is signed,
+    // TCaseInsensitiveCharTraits::compare returns different result from std::char_traits<char>::compare for non-ascii strings.
+    static char ToCommonCase(char ch) {
         return std::toupper((unsigned char)ch);
     }
 };
+
+struct TCaseInsensitiveAsciiCharTraits : public ::NPrivate::TCommonCaseInsensitiveCharTraits<TCaseInsensitiveCharTraits> {
+    // WARN: does not work with null bytes (`compare("ab\0c", "ab\0d", 4)` returns 0).
+    static int compare(const char* s1, const char* s2, std::size_t n) {
+        return ::strncasecmp(s1, s2, n);
+    }
+
+private:
+    friend ::NPrivate::TCommonCaseInsensitiveCharTraits<TCaseInsensitiveAsciiCharTraits>;
+
+    static unsigned char ToCommonCase(char ch) {
+        return AsciiToLower(ch);
+    }
+};

+ 52 - 23
library/cpp/case_insensitive_string/case_insensitive_string.cpp

@@ -2,35 +2,64 @@
 
 #include <library/cpp/digest/murmur/murmur.h>
 
+#include <util/string/escape.h>
+
 #include <array>
 
-static size_t HashTail(TMurmurHash2A<size_t>& hash, const char* data, size_t size) {
-    for (size_t i = 0; i < size; ++i) {
-        char lower = std::tolower(data[i]);
-        hash.Update(&lower, 1);
-    }
-    return hash.Value();
-}
+namespace {
+    template <auto ToLower>
+    struct TCaseInsensitiveHash {
+        static size_t HashTail(TMurmurHash2A<size_t>& hash, const char* data, size_t size) noexcept {
+            for (size_t i = 0; i < size; ++i) {
+                char lower = ToLower(data[i]);
+                hash.Update(&lower, 1);
+            }
+            return hash.Value();
+        }
 
-size_t THash<TCaseInsensitiveStringBuf>::operator()(TCaseInsensitiveStringBuf str) const noexcept {
-    TMurmurHash2A<size_t> hash;
-    std::array<char, sizeof(size_t)> buf;
-    size_t headSize = str.size() - str.size() % buf.size();
-    for (size_t i = 0; i < headSize; i += buf.size()) {
-        for (size_t j = 0; j < buf.size(); ++j) {
-            buf[j] = std::tolower(str[i + j]);
+        static size_t ComputeHash(const char* s, size_t n) noexcept {
+            TMurmurHash2A<size_t> hash;
+            std::array<char, sizeof(size_t)> buf;
+            size_t headSize = n - n % buf.size();
+            for (size_t i = 0; i < headSize; i += buf.size()) {
+                for (size_t j = 0; j < buf.size(); ++j) {
+                    buf[j] = ToLower(s[i + j]);
+                }
+                hash.Update(buf.data(), buf.size());
+            }
+            return HashTail(hash, s + headSize, n - headSize);
         }
-        hash.Update(buf.data(), buf.size());
-    }
-    return HashTail(hash, str.data() + headSize, str.size() - headSize);
+    };
 }
 
-template <>
-void Out<TCaseInsensitiveString>(IOutputStream& o, const TCaseInsensitiveString& p) {
-    o.Write(p.data(), p.size());
+size_t CaseInsensitiveStringHash(const char* s, size_t n) noexcept {
+    return TCaseInsensitiveHash<static_cast<int(*)(int)>(std::tolower)>::ComputeHash(s, n);
 }
 
-template <>
-void Out<TCaseInsensitiveStringBuf>(IOutputStream& o, const TCaseInsensitiveStringBuf& p) {
-    o.Write(p.data(), p.size());
+size_t CaseInsensitiveAsciiStringHash(const char* s, size_t n) noexcept {
+    return TCaseInsensitiveHash<static_cast<char(*)(char)>(AsciiToLower)>::ComputeHash(s, n);
 }
+
+#define Y_DEFINE_STRING_OUT(type)                       \
+    template <>                                         \
+    void Out<type>(IOutputStream& o, const type& p) {   \
+        o.Write(p.data(), p.size());                    \
+    }
+
+Y_DEFINE_STRING_OUT(TCaseInsensitiveString);
+Y_DEFINE_STRING_OUT(TCaseInsensitiveStringBuf);
+Y_DEFINE_STRING_OUT(TCaseInsensitiveAsciiString);
+Y_DEFINE_STRING_OUT(TCaseInsensitiveAsciiStringBuf);
+
+#undef Y_DEFINE_STRING_OUT
+
+#define Y_DEFINE_STRING_ESCAPE(type)                            \
+    type EscapeC(const type& str) {                             \
+        const auto result = EscapeC(str.data(), str.size());    \
+        return {result.data(), result.size()};                  \
+    }
+
+Y_DEFINE_STRING_ESCAPE(TCaseInsensitiveString);
+Y_DEFINE_STRING_ESCAPE(TCaseInsensitiveAsciiString);
+
+#undef Y_DEFINE_STRING_ESCAPE

+ 34 - 1
library/cpp/case_insensitive_string/case_insensitive_string.h

@@ -10,14 +10,37 @@
 using TCaseInsensitiveString = TBasicString<char, TCaseInsensitiveCharTraits>;
 using TCaseInsensitiveStringBuf = TBasicStringBuf<char, TCaseInsensitiveCharTraits>;
 
+// WARN: comparison works incorrectly if strings contain null bytes (`TCaseInsensitiveAsciiString{"ab\0c", 4} == TCaseInsensitiveAsciiString{"ab\0d", 4}`).
+using TCaseInsensitiveAsciiString = TBasicString<char, TCaseInsensitiveAsciiCharTraits>;
+// WARN: comparison works incorrectly if strings contain null bytes.
+using TCaseInsensitiveAsciiStringBuf = TBasicStringBuf<char, TCaseInsensitiveAsciiCharTraits>;
+
+// Convert chars using std::tolower and hash the resulting string.
+// Locale may affect the result.
+size_t CaseInsensitiveStringHash(const char* s, size_t n) noexcept;
+// Convert chars using AsciiToLower and hash the resulting string.
+size_t CaseInsensitiveAsciiStringHash(const char* s, size_t n) noexcept;
+
 template <>
 struct THash<TCaseInsensitiveStringBuf> {
-    size_t operator()(TCaseInsensitiveStringBuf str) const noexcept;
+    size_t operator()(TCaseInsensitiveStringBuf str) const noexcept {
+        return CaseInsensitiveStringHash(str.data(), str.size());
+    }
+};
+
+template <>
+struct THash<TCaseInsensitiveAsciiStringBuf> {
+    size_t operator()(TCaseInsensitiveAsciiStringBuf str) const noexcept {
+        return CaseInsensitiveAsciiStringHash(str.data(), str.size());
+    }
 };
 
 template <>
 struct THash<TCaseInsensitiveString> : THash<TCaseInsensitiveStringBuf> {};
 
+template <>
+struct THash<TCaseInsensitiveAsciiString> : THash<TCaseInsensitiveAsciiStringBuf> {};
+
 namespace NStringSplitPrivate {
 
     template<>
@@ -35,3 +58,13 @@ namespace NStringSplitPrivate {
     };
 
 } // namespace NStringSplitPrivate
+
+template <>
+struct TEqualTo<TCaseInsensitiveString>: public TEqualTo<TCaseInsensitiveStringBuf> {
+    using is_transparent = void;
+};
+
+template <>
+struct TEqualTo<TCaseInsensitiveAsciiString>: public TEqualTo<TCaseInsensitiveAsciiStringBuf> {
+    using is_transparent = void;
+};

+ 79 - 0
library/cpp/case_insensitive_string/ut_gtest/case_insensitive_string_compare.cpp

@@ -0,0 +1,79 @@
+#include <library/cpp/case_insensitive_string/case_insensitive_string.h>
+#include <library/cpp/case_insensitive_string/ut_gtest/util/locale_guard.h>
+
+#include <library/cpp/testing/gtest/gtest.h>
+
+#include <util/generic/string.h>
+
+TEST(CaseInsensitiveString, CompareAscii) {
+    {
+        TCaseInsensitiveStringBuf s1 = "Some Text";
+        TCaseInsensitiveStringBuf s2 = "somE texT";
+        EXPECT_EQ(s1, s2);
+    }
+    {
+        TCaseInsensitiveStringBuf s1 = "aBCd";
+        TCaseInsensitiveStringBuf s2 = "AbcE";
+        EXPECT_LT(s1, s2);
+    }
+    {
+        // Also works with null bytes
+        TCaseInsensitiveStringBuf s1("aBC\0d", 5);
+        TCaseInsensitiveStringBuf s2("Abc\0E", 5);
+        EXPECT_LT(s1, s2);
+    }
+}
+
+TEST(CaseInsensitiveString, CompareLocaleDependent) {
+    TLocaleGuard loc("ru_RU.CP1251");
+    if (loc.Error()) {
+        GTEST_SKIP() << "ru_RU.CP1251 locale is not available: " << loc.Error();
+    }
+    {
+        TCaseInsensitiveStringBuf s1 = "\xc0\xc1\xc2";  // "АБВ"
+        TCaseInsensitiveStringBuf s2 = "\xe0\xe1\xe2";  // "абв"
+        EXPECT_EQ(s1, s2);
+    }
+    {
+        TCaseInsensitiveStringBuf s1 = "\xc0\xc1\xc3";  // "АБГ"
+        TCaseInsensitiveStringBuf s2 = "\xe0\xe1\xe2";  // "абв"
+        EXPECT_GT(s1, s2);
+    }
+}
+
+TEST(CaseInsensitiveAsciiString, CompareAsciiWithoutNullBytes) {
+    {
+        TCaseInsensitiveAsciiStringBuf s1 = "Some Text";
+        TCaseInsensitiveAsciiStringBuf s2 = "somE texT";
+        EXPECT_EQ(s1, s2);
+    }
+    {
+        TCaseInsensitiveAsciiStringBuf s1 = "aBCd";
+        TCaseInsensitiveAsciiStringBuf s2 = "AbcE";
+        EXPECT_LT(s1, s2);
+    }
+}
+
+TEST(CaseInsensitiveAsciiString, MaySupportLocales) {
+    TLocaleGuard loc("ru_RU.CP1251");
+    if (loc.Error()) {
+        GTEST_SKIP() << "ru_RU.CP1251 locale is not available: " << loc.Error();
+    }
+    // strncasecmp is locale-dependent, not sure about strnicmp
+    {
+        TCaseInsensitiveStringBuf s1 = "\xc0\xc1\xc2";  // "АБВ"
+        TCaseInsensitiveStringBuf s2 = "\xe0\xe1\xe2";  // "абв"
+        EXPECT_EQ(s1, s2);
+    }
+    {
+        TCaseInsensitiveStringBuf s1 = "\xc0\xc1\xc3";  // "АБГ"
+        TCaseInsensitiveStringBuf s2 = "\xe0\xe1\xe2";  // "абв"
+        EXPECT_GT(s1, s2);
+    }
+}
+
+TEST(CaseInsensitiveAsciiString, DoesNotWorkWithNullBytes) {
+    TCaseInsensitiveAsciiStringBuf s1("aBC\0d", 5);
+    TCaseInsensitiveAsciiStringBuf s2("Abc\0E", 5);
+    EXPECT_EQ(s1, s2);
+}

+ 1 - 0
library/cpp/case_insensitive_string/ut_gtest/case_insensitive_string_hash.cpp

@@ -1,4 +1,5 @@
 #include <library/cpp/case_insensitive_string/case_insensitive_string.h>
+
 #include <library/cpp/digest/murmur/murmur.h>
 #include <library/cpp/testing/gtest/gtest.h>
 

+ 1 - 0
library/cpp/case_insensitive_string/ut_gtest/util/locale_guard.cpp

@@ -0,0 +1 @@
+#include "locale_guard.h"

Some files were not shown because too many files changed in this diff