#include "compact_hash.h" #include #include #include #include #include #include #include using namespace NKikimr; using namespace NKikimr::NCHash; Y_UNIT_TEST_SUITE(TCompactHashTest) { template void TestListPoolPagesImpl(size_t listSize, ui16 countOfLists, size_t expectedListCapacity, ui32 expectedMark) { using TPool = TListPool; TAlignedPagePool pagePool(__LOCATION__); TPool pool(pagePool); UNIT_ASSERT(countOfLists > 1); THashSet lists; void* pageAddr = nullptr; for (size_t i = 0; i < countOfLists / 2; ++i) { TItem* l = pool.template GetList(listSize); UNIT_ASSERT_VALUES_EQUAL(expectedMark, TPool::GetMark(l)); UNIT_ASSERT_VALUES_EQUAL(listSize, TListPoolBase::GetPartListSize(l)); UNIT_ASSERT_VALUES_EQUAL(expectedListCapacity, TListPoolBase::GetListCapacity(l)); if (0 == i) { pageAddr = TAlignedPagePool::GetPageStart(l); } else { // All lists are from the same page UNIT_ASSERT_VALUES_EQUAL(pageAddr, TAlignedPagePool::GetPageStart(l)); } UNIT_ASSERT(lists.insert(l).second); } // Return all lists except one while (lists.size() > 1) { auto it = lists.begin(); pool.ReturnList(*it); lists.erase(it); } for (size_t i = 1; i < countOfLists; ++i) { TItem* l = pool.template GetList(listSize); UNIT_ASSERT_VALUES_EQUAL(expectedMark, TPool::GetMark(l)); UNIT_ASSERT_VALUES_EQUAL(listSize, TListPoolBase::GetPartListSize(l)); UNIT_ASSERT_VALUES_EQUAL(expectedListCapacity, TListPoolBase::GetListCapacity(l)); UNIT_ASSERT_VALUES_EQUAL(pageAddr, TAlignedPagePool::GetPageStart(l)); UNIT_ASSERT(lists.insert(l).second); } for (auto l: lists) { pool.ReturnList(l); } THashSet lists2; for (size_t i = 0; i < countOfLists; ++i) { TItem* l = pool.template GetList(listSize); // All lists are from the same page UNIT_ASSERT_VALUES_EQUAL(pageAddr, TAlignedPagePool::GetPageStart(l)); UNIT_ASSERT(1 == lists.erase(l)); UNIT_ASSERT(lists2.insert(l).second); } TItem* l = pool.template GetList(listSize); // New page UNIT_ASSERT_VALUES_UNEQUAL(pageAddr, TAlignedPagePool::GetPageStart(l)); UNIT_ASSERT_VALUES_EQUAL(listSize, TListPoolBase::GetPartListSize(l)); UNIT_ASSERT_VALUES_EQUAL(expectedListCapacity, TListPoolBase::GetListCapacity(l)); pool.ReturnList(l); for (auto l: lists2) { pool.ReturnList(l); } } template void TestListPoolLargeImpl() { using TPool = TListPool; TAlignedPagePool pagePool(__LOCATION__); TPool pool(pagePool); const size_t listSize = TListPoolBase::GetMaxListSize(); TItem* l = pool.template GetList(listSize); pool.template IncrementList(l); UNIT_ASSERT_VALUES_EQUAL((ui32)TListPoolBase::LARGE_MARK, TListPoolBase::GetMark(l)); UNIT_ASSERT_VALUES_EQUAL(listSize + 1, TListPoolBase::GetPartListSize(l)); TListPoolBase::SetPartListSize(l, TListPoolBase::GetListCapacity(l)); pool.template IncrementList(l); UNIT_ASSERT_VALUES_EQUAL(1, TListPoolBase::GetPartListSize(l)); UNIT_ASSERT_VALUES_EQUAL(1 + TListPoolBase::GetListCapacity(l), TListPoolBase::GetFullListSize(l)); pool.ReturnList(l); } Y_UNIT_TEST(TestListPoolSmallPagesByte) { ui16 count = TListPoolBase::GetSmallPageCapacity(2); TestListPoolPagesImpl(2, count, 2, TListPoolBase::SMALL_MARK); } Y_UNIT_TEST(TestListPoolMediumPagesByte) { size_t listSize = TListPoolBase::MAX_SMALL_LIST_SIZE + 1; ui16 count = TListPoolBase::GetMediumPageCapacity(listSize); TestListPoolPagesImpl(listSize, count, FastClp2(listSize), TListPoolBase::MEDIUM_MARK); } Y_UNIT_TEST(TestListPoolLargPagesByte) { TestListPoolLargeImpl(); } Y_UNIT_TEST(TestListPoolSmallPagesUi64) { ui16 count = TListPoolBase::GetSmallPageCapacity(2); TestListPoolPagesImpl(2, count, 2, TListPoolBase::SMALL_MARK); } Y_UNIT_TEST(TestListPoolMediumPagesUi64) { size_t listSize = TListPoolBase::MAX_SMALL_LIST_SIZE + 1; ui16 count = TListPoolBase::GetMediumPageCapacity(listSize); TestListPoolPagesImpl(listSize, count, FastClp2(listSize), TListPoolBase::MEDIUM_MARK); } Y_UNIT_TEST(TestListPoolLargPagesUi64) { TestListPoolLargeImpl(); } struct TItem { ui8 A[256]; }; Y_UNIT_TEST(TestListPoolSmallPagesObj) { ui16 count = TListPoolBase::GetSmallPageCapacity(2); TestListPoolPagesImpl(2, count, 2, TListPoolBase::SMALL_MARK); } Y_UNIT_TEST(TestListPoolMediumPagesObj) { size_t listSize = TListPoolBase::MAX_SMALL_LIST_SIZE + 1; ui16 count = TListPoolBase::GetMediumPageCapacity(listSize); TestListPoolPagesImpl(listSize, count, FastClp2(listSize), TListPoolBase::MEDIUM_MARK); } Y_UNIT_TEST(TestListPoolLargPagesObj) { TestListPoolLargeImpl(); } struct TItemHash { template size_t operator() (T num) const { return num % 13; } }; template void TestHashImpl() { const ui32 elementsCount = 32; const ui64 sumKeysTarget = elementsCount * (elementsCount - 1) / 2; const ui32 addition = 20; const ui64 sumValuesTarget = sumKeysTarget + addition * elementsCount; TAlignedPagePool pagePool(__LOCATION__); TCompactHash hash(pagePool); TVector elements(elementsCount); std::iota(elements.begin(), elements.end(), 0); Shuffle(elements.begin(), elements.end()); for (TItem i: elements) { hash.Insert(i, i + addition); } { decltype(hash) hash2(std::move(hash)); decltype(hash) hash3(pagePool); hash3.Swap(hash2); hash = hash3; } for (TItem i: elements) { UNIT_ASSERT(hash.Has(i)); UNIT_ASSERT(hash.Find(i).Ok()); UNIT_ASSERT_VALUES_EQUAL(i + addition, hash.Find(i).Get().second); } UNIT_ASSERT(!hash.Has(elementsCount + 1)); UNIT_ASSERT(!hash.Has(elementsCount + 10)); UNIT_ASSERT_VALUES_EQUAL(elementsCount, hash.Size()); UNIT_ASSERT_VALUES_EQUAL(elementsCount, hash.UniqSize()); ui64 sumKeys = 0; ui64 sumValues = 0; for (auto it = hash.Iterate(); it.Ok(); ++it) { UNIT_ASSERT_VALUES_EQUAL(it.Get().first + addition, it.Get().second); sumKeys += it.Get().first; sumValues += it.Get().second; } UNIT_ASSERT_VALUES_EQUAL(sumKeys, sumKeysTarget); UNIT_ASSERT_VALUES_EQUAL(sumValues, sumValuesTarget); } template void TestMultiHashImpl() { const ui32 keysCount = 10; const ui64 elementsCount = keysCount * (keysCount + 1) / 2; TAlignedPagePool pagePool(__LOCATION__); TCompactMultiHash hash(pagePool); TVector keys(keysCount); std::iota(keys.begin(), keys.end(), 0); Shuffle(keys.begin(), keys.end()); ui64 sumKeysTarget = 0; ui64 sumValuesTarget = 0; for (TItem k: keys) { sumKeysTarget += k; for (TItem i = 0; i < k + 1; ++i) { hash.Insert(k, i); sumValuesTarget += i; } } { decltype(hash) hash2(std::move(hash)); decltype(hash) hash3(pagePool); hash3.Swap(hash2); hash = hash3; } for (TItem k: keys) { UNIT_ASSERT(hash.Has(k)); UNIT_ASSERT_VALUES_EQUAL(k + 1, hash.Count(k)); auto it = hash.Find(k); UNIT_ASSERT(it.Ok()); TDynBitMap set; for (; it.Ok(); ++it) { UNIT_ASSERT_VALUES_EQUAL(k, it.GetKey()); UNIT_ASSERT(!set.Test(it.GetValue())); set.Set(it.GetValue()); } UNIT_ASSERT_VALUES_EQUAL(set.Count(), k + 1); } UNIT_ASSERT(!hash.Has(keysCount + 1)); UNIT_ASSERT(!hash.Has(keysCount + 10)); UNIT_ASSERT(!hash.Find(keysCount + 1).Ok()); UNIT_ASSERT_VALUES_EQUAL(elementsCount, hash.Size()); UNIT_ASSERT_VALUES_EQUAL(keysCount, hash.UniqSize()); ui64 sumKeys = 0; ui64 sumValues = 0; TMaybe prevKey; for (auto it = hash.Iterate(); it.Ok(); ++it) { const auto key = it.GetKey(); if (prevKey != key) { sumKeys += key; for (auto valIt = it.MakeCurrentKeyIter(); valIt.Ok(); ++valIt) { UNIT_ASSERT_VALUES_EQUAL(key, valIt.GetKey()); sumValues += valIt.GetValue(); } prevKey = key; } } UNIT_ASSERT_VALUES_EQUAL(sumKeys, sumKeysTarget); UNIT_ASSERT_VALUES_EQUAL(sumValues, sumValuesTarget); // Test large lists TItem val = 0; for (size_t i = 0; i < TListPoolBase::GetLargePageCapacity(); ++i) { hash.Insert(keysCount, val++); } TItem check = 0; for (auto i = hash.Find(keysCount); i.Ok(); ++i) { UNIT_ASSERT_VALUES_EQUAL(keysCount, i.GetKey()); UNIT_ASSERT_VALUES_EQUAL(check++, i.GetValue()); } UNIT_ASSERT_VALUES_EQUAL(check, val); UNIT_ASSERT_VALUES_EQUAL(TListPoolBase::GetLargePageCapacity(), hash.Count(keysCount)); for (size_t i = 0; i < TListPoolBase::GetLargePageCapacity() + 1; ++i) { hash.Insert(keysCount, val++); } { decltype(hash) hash2(std::move(hash)); decltype(hash) hash3(pagePool); hash3.Swap(hash2); hash = hash3; } check = 0; for (auto i = hash.Find(keysCount); i.Ok(); ++i) { UNIT_ASSERT_VALUES_EQUAL(keysCount, i.GetKey()); UNIT_ASSERT_VALUES_EQUAL(check++, i.GetValue()); } UNIT_ASSERT_VALUES_EQUAL(check, val); UNIT_ASSERT_VALUES_EQUAL(2 * TListPoolBase::GetLargePageCapacity() + 1, hash.Count(keysCount)); } template void TestSetImpl() { const ui32 elementsCount = 32; const ui64 sumKeysTarget = elementsCount * (elementsCount - 1) / 2; TAlignedPagePool pagePool(__LOCATION__); TCompactHashSet hash(pagePool); TVector elements(elementsCount); std::iota(elements.begin(), elements.end(), 0); Shuffle(elements.begin(), elements.end()); for (TItem i: elements) { hash.Insert(i); } { decltype(hash) hash2(std::move(hash)); decltype(hash) hash3(pagePool); hash3.Swap(hash2); hash = hash3; } for (TItem i: elements) { UNIT_ASSERT(hash.Has(i)); } UNIT_ASSERT(!hash.Has(elementsCount + 1)); UNIT_ASSERT(!hash.Has(elementsCount + 10)); UNIT_ASSERT_VALUES_EQUAL(elementsCount, hash.Size()); UNIT_ASSERT_VALUES_EQUAL(elementsCount, hash.UniqSize()); ui64 sumKeys = 0; for (auto i = hash.Iterate(); i.Ok(); ++i) { sumKeys += *i; } UNIT_ASSERT_VALUES_EQUAL(sumKeys, sumKeysTarget); } Y_UNIT_TEST(TestHashByte) { TestHashImpl(); } Y_UNIT_TEST(TestMultiHashByte) { TestMultiHashImpl(); } Y_UNIT_TEST(TestSetByte) { TestSetImpl(); } Y_UNIT_TEST(TestHashUi16) { TestHashImpl(); } Y_UNIT_TEST(TestMultiHashUi16) { TestMultiHashImpl(); } Y_UNIT_TEST(TestSetUi16) { TestSetImpl(); } Y_UNIT_TEST(TestHashUi64) { TestHashImpl(); } Y_UNIT_TEST(TestMultiHashUi64) { TestMultiHashImpl(); } Y_UNIT_TEST(TestSetUi64) { TestSetImpl(); } struct TStressHash { TStressHash(size_t param) : Param(param) { } template size_t operator() (T num) const { return num % Param; } const size_t Param; }; Y_UNIT_TEST(TestStressSmallLists) { TAlignedPagePool pagePool(__LOCATION__); for (size_t listSize: xrange(2, 17, 1)) { const size_t backets = TListPoolBase::GetSmallPageCapacity(listSize); const size_t elementsCount = backets * listSize; for (size_t count: xrange(1, elementsCount + 1, elementsCount / 16)) { TCompactHashSet hash(pagePool, elementsCount, TStressHash(backets)); for (auto i: xrange(count)) { hash.Insert(i); } UNIT_ASSERT_VALUES_EQUAL(count, hash.Size()); UNIT_ASSERT_VALUES_EQUAL(count, hash.UniqSize()); //hash.PrintStat(Cerr); } } } }