pool_ut.cpp 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. #include "pool.h"
  2. #include <library/cpp/testing/unittest/registar.h>
  3. #include <util/random/fast.h>
  4. #include <util/system/spinlock.h>
  5. #include <util/system/thread.h>
  6. #include <util/system/mutex.h>
  7. #include <util/system/condvar.h>
  8. struct TThreadPoolTest {
  9. TSpinLock Lock;
  10. long R = -1;
  11. struct TTask: public IObjectInQueue {
  12. TThreadPoolTest* Test = nullptr;
  13. long Value = 0;
  14. TTask(TThreadPoolTest* test, int value)
  15. : Test(test)
  16. , Value(value)
  17. {
  18. }
  19. void Process(void*) override {
  20. THolder<TTask> This(this);
  21. TGuard<TSpinLock> guard(Test->Lock);
  22. Test->R ^= Value;
  23. }
  24. };
  25. struct TOwnedTask: public IObjectInQueue {
  26. bool& Processed;
  27. bool& Destructed;
  28. TOwnedTask(bool& processed, bool& destructed)
  29. : Processed(processed)
  30. , Destructed(destructed)
  31. {
  32. }
  33. ~TOwnedTask() override {
  34. Destructed = true;
  35. }
  36. void Process(void*) override {
  37. Processed = true;
  38. }
  39. };
  40. inline void TestAnyQueue(IThreadPool* queue, size_t queueSize = 1000) {
  41. TReallyFastRng32 rand(17);
  42. const size_t cnt = 1000;
  43. R = 0;
  44. for (size_t i = 0; i < cnt; ++i) {
  45. R ^= (long)rand.GenRand();
  46. }
  47. queue->Start(10, queueSize);
  48. rand = TReallyFastRng32(17);
  49. for (size_t i = 0; i < cnt; ++i) {
  50. UNIT_ASSERT(queue->Add(new TTask(this, (long)rand.GenRand())));
  51. }
  52. queue->Stop();
  53. UNIT_ASSERT_EQUAL(0, R);
  54. }
  55. };
  56. class TFailAddQueue: public IThreadPool {
  57. public:
  58. bool Add(IObjectInQueue* /*obj*/) override Y_WARN_UNUSED_RESULT {
  59. return false;
  60. }
  61. void Start(size_t, size_t) override {
  62. }
  63. void Stop() noexcept override {
  64. }
  65. size_t Size() const noexcept override {
  66. return 0;
  67. }
  68. };
  69. Y_UNIT_TEST_SUITE(TThreadPoolTest) {
  70. Y_UNIT_TEST(TestTThreadPool) {
  71. TThreadPoolTest t;
  72. TThreadPool q;
  73. t.TestAnyQueue(&q);
  74. }
  75. Y_UNIT_TEST(TestTThreadPoolBlocking) {
  76. TThreadPoolTest t;
  77. TThreadPool q(TThreadPool::TParams().SetBlocking(true));
  78. t.TestAnyQueue(&q, 100);
  79. }
  80. // disabled by pg@ long time ago due to test flaps
  81. // Tried to enable: REVIEW:78772
  82. Y_UNIT_TEST(TestTAdaptiveThreadPool) {
  83. if (false) {
  84. TThreadPoolTest t;
  85. TAdaptiveThreadPool q;
  86. t.TestAnyQueue(&q);
  87. }
  88. }
  89. Y_UNIT_TEST(TestAddAndOwn) {
  90. TThreadPool q;
  91. q.Start(2);
  92. bool processed = false;
  93. bool destructed = false;
  94. q.SafeAddAndOwn(MakeHolder<TThreadPoolTest::TOwnedTask>(processed, destructed));
  95. q.Stop();
  96. UNIT_ASSERT_C(processed, "Not processed");
  97. UNIT_ASSERT_C(destructed, "Not destructed");
  98. }
  99. Y_UNIT_TEST(TestAddFunc) {
  100. TFailAddQueue queue;
  101. bool added = queue.AddFunc(
  102. []() {} // Lambda, I call him 'Lambda'!
  103. );
  104. UNIT_ASSERT_VALUES_EQUAL(added, false);
  105. }
  106. Y_UNIT_TEST(TestSafeAddFuncThrows) {
  107. TFailAddQueue queue;
  108. UNIT_CHECK_GENERATED_EXCEPTION(queue.SafeAddFunc([] {}), TThreadPoolException);
  109. }
  110. Y_UNIT_TEST(TestFunctionNotCopied) {
  111. struct TFailOnCopy {
  112. TFailOnCopy() {
  113. }
  114. TFailOnCopy(TFailOnCopy&&) {
  115. }
  116. TFailOnCopy(const TFailOnCopy&) {
  117. UNIT_FAIL("Don't copy std::function inside TThreadPool");
  118. }
  119. };
  120. TThreadPool queue(TThreadPool::TParams().SetBlocking(false).SetCatching(true));
  121. queue.Start(2);
  122. queue.SafeAddFunc([data = TFailOnCopy()]() {});
  123. queue.Stop();
  124. }
  125. Y_UNIT_TEST(TestInfoGetters) {
  126. TThreadPool queue;
  127. queue.Start(2, 7);
  128. UNIT_ASSERT_EQUAL(queue.GetThreadCountExpected(), 2);
  129. UNIT_ASSERT_EQUAL(queue.GetThreadCountReal(), 2);
  130. UNIT_ASSERT_EQUAL(queue.GetMaxQueueSize(), 7);
  131. queue.Stop();
  132. queue.Start(4, 1);
  133. UNIT_ASSERT_EQUAL(queue.GetThreadCountExpected(), 4);
  134. UNIT_ASSERT_EQUAL(queue.GetThreadCountReal(), 4);
  135. UNIT_ASSERT_EQUAL(queue.GetMaxQueueSize(), 1);
  136. queue.Stop();
  137. }
  138. void TestFixedThreadName(IThreadPool& pool, const TString& expectedName) {
  139. pool.Start(1);
  140. TString name;
  141. pool.SafeAddFunc([&name]() {
  142. name = TThread::CurrentThreadName();
  143. });
  144. pool.Stop();
  145. if (TThread::CanGetCurrentThreadName()) {
  146. UNIT_ASSERT_EQUAL(name, expectedName);
  147. UNIT_ASSERT_UNEQUAL(TThread::CurrentThreadName(), expectedName);
  148. }
  149. }
  150. Y_UNIT_TEST(TestFixedThreadName) {
  151. const TString expectedName = "HelloWorld";
  152. {
  153. TThreadPool pool(TThreadPool::TParams().SetBlocking(true).SetCatching(false).SetThreadName(expectedName));
  154. TestFixedThreadName(pool, expectedName);
  155. }
  156. {
  157. TAdaptiveThreadPool pool(TThreadPool::TParams().SetThreadName(expectedName));
  158. TestFixedThreadName(pool, expectedName);
  159. }
  160. }
  161. void TestEnumeratedThreadName(IThreadPool& pool, const THashSet<TString>& expectedNames) {
  162. pool.Start(expectedNames.size());
  163. TMutex lock;
  164. TCondVar allReady;
  165. size_t readyCount = 0;
  166. THashSet<TString> names;
  167. for (size_t i = 0; i < expectedNames.size(); ++i) {
  168. pool.SafeAddFunc([&]() {
  169. with_lock (lock) {
  170. if (++readyCount == expectedNames.size()) {
  171. allReady.BroadCast();
  172. } else {
  173. while (readyCount != expectedNames.size()) {
  174. allReady.WaitI(lock);
  175. }
  176. }
  177. names.insert(TThread::CurrentThreadName());
  178. }
  179. });
  180. }
  181. pool.Stop();
  182. if (TThread::CanGetCurrentThreadName()) {
  183. UNIT_ASSERT_EQUAL(names, expectedNames);
  184. }
  185. }
  186. Y_UNIT_TEST(TestEnumeratedThreadName) {
  187. const TString namePrefix = "HelloWorld";
  188. const THashSet<TString> expectedNames = {
  189. "HelloWorld0",
  190. "HelloWorld1",
  191. "HelloWorld2",
  192. "HelloWorld3",
  193. "HelloWorld4",
  194. "HelloWorld5",
  195. "HelloWorld6",
  196. "HelloWorld7",
  197. "HelloWorld8",
  198. "HelloWorld9",
  199. "HelloWorld10",
  200. };
  201. {
  202. TThreadPool pool(TThreadPool::TParams().SetBlocking(true).SetCatching(false).SetThreadNamePrefix(namePrefix));
  203. TestEnumeratedThreadName(pool, expectedNames);
  204. }
  205. {
  206. TAdaptiveThreadPool pool(TThreadPool::TParams().SetThreadNamePrefix(namePrefix));
  207. TestEnumeratedThreadName(pool, expectedNames);
  208. }
  209. }
  210. }