condvar_ut.cpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. #include "mutex.h"
  2. #include "guard.h"
  3. #include "condvar.h"
  4. #include <library/cpp/testing/unittest/registar.h>
  5. #include <util/thread/pool.h>
  6. #include <atomic>
  7. class TCondVarTest: public TTestBase {
  8. UNIT_TEST_SUITE(TCondVarTest);
  9. UNIT_TEST(TestBasics)
  10. UNIT_TEST(TestSyncronize)
  11. UNIT_TEST_SUITE_END();
  12. struct TSharedData {
  13. TMutex mutex;
  14. TCondVar condVar1;
  15. TCondVar condVar2;
  16. std::atomic<bool> stopWaiting = false;
  17. std::atomic<size_t> in = 0;
  18. std::atomic<size_t> out = 0;
  19. std::atomic<size_t> waited = 0;
  20. bool failed = false;
  21. };
  22. class TThreadTask: public IObjectInQueue {
  23. public:
  24. using PFunc = void (TThreadTask::*)(void);
  25. TThreadTask(PFunc func, size_t id, size_t totalIds, TSharedData& data)
  26. : Func_(func)
  27. , Id_(id)
  28. , TotalIds_(totalIds)
  29. , Data_(data)
  30. {
  31. }
  32. void Process(void*) override {
  33. THolder<TThreadTask> This(this);
  34. (this->*Func_)();
  35. }
  36. #define FAIL_ASSERT(cond) \
  37. if (!(cond)) { \
  38. Data_.failed = true; \
  39. }
  40. void RunBasics() {
  41. Y_ASSERT(TotalIds_ == 3);
  42. if (Id_ < 2) {
  43. TGuard<TMutex> guard(Data_.mutex);
  44. while (!Data_.stopWaiting.load()) {
  45. bool res = Data_.condVar1.WaitT(Data_.mutex, TDuration::Seconds(1));
  46. FAIL_ASSERT(res == true);
  47. }
  48. } else {
  49. usleep(100000);
  50. Data_.stopWaiting.store(true);
  51. TGuard<TMutex> guard(Data_.mutex);
  52. Data_.condVar1.Signal();
  53. Data_.condVar1.Signal();
  54. }
  55. }
  56. void RunBasicsWithPredicate() {
  57. Y_ASSERT(TotalIds_ == 3);
  58. if (Id_ < 2) {
  59. TGuard<TMutex> guard(Data_.mutex);
  60. const auto res = Data_.condVar1.WaitT(Data_.mutex, TDuration::Seconds(1), [&] {
  61. return Data_.stopWaiting.load();
  62. });
  63. FAIL_ASSERT(res == true);
  64. } else {
  65. usleep(100000);
  66. Data_.stopWaiting.store(true);
  67. TGuard<TMutex> guard(Data_.mutex);
  68. Data_.condVar1.Signal();
  69. Data_.condVar1.Signal();
  70. }
  71. }
  72. void RunSyncronize() {
  73. for (size_t i = 0; i < 10; ++i) {
  74. TGuard<TMutex> guard(Data_.mutex);
  75. ++Data_.in;
  76. if (Data_.in.load() == TotalIds_) {
  77. Data_.out.store(0);
  78. Data_.condVar1.BroadCast();
  79. } else {
  80. ++Data_.waited;
  81. while (Data_.in.load() < TotalIds_) {
  82. bool res = Data_.condVar1.WaitT(Data_.mutex, TDuration::Seconds(1));
  83. FAIL_ASSERT(res == true);
  84. }
  85. }
  86. ++Data_.out;
  87. if (Data_.out.load() == TotalIds_) {
  88. Data_.in.store(0);
  89. Data_.condVar2.BroadCast();
  90. } else {
  91. while (Data_.out.load() < TotalIds_) {
  92. bool res = Data_.condVar2.WaitT(Data_.mutex, TDuration::Seconds(1));
  93. FAIL_ASSERT(res == true);
  94. }
  95. }
  96. }
  97. FAIL_ASSERT(Data_.waited.load() == (TotalIds_ - 1) * 10);
  98. }
  99. void RunSyncronizeWithPredicate() {
  100. for (size_t i = 0; i < 10; ++i) {
  101. TGuard<TMutex> guard(Data_.mutex);
  102. ++Data_.in;
  103. if (Data_.in.load() == TotalIds_) {
  104. Data_.out.store(0);
  105. Data_.condVar1.BroadCast();
  106. } else {
  107. ++Data_.waited;
  108. const auto res = Data_.condVar1.WaitT(Data_.mutex, TDuration::Seconds(1), [&] {
  109. return Data_.in.load() >= TotalIds_;
  110. });
  111. FAIL_ASSERT(res == true);
  112. }
  113. ++Data_.out;
  114. if (Data_.out.load() == TotalIds_) {
  115. Data_.in.store(0);
  116. Data_.condVar2.BroadCast();
  117. } else {
  118. const auto res = Data_.condVar2.WaitT(Data_.mutex, TDuration::Seconds(1), [&] {
  119. return Data_.out.load() >= TotalIds_;
  120. });
  121. FAIL_ASSERT(res == true);
  122. }
  123. }
  124. FAIL_ASSERT(Data_.waited == (TotalIds_ - 1) * 10);
  125. }
  126. #undef FAIL_ASSERT
  127. private:
  128. PFunc Func_;
  129. size_t Id_;
  130. size_t TotalIds_;
  131. TSharedData& Data_;
  132. };
  133. private:
  134. #define RUN_CYCLE(what, count) \
  135. Q_.Start(count); \
  136. for (size_t i = 0; i < count; ++i) { \
  137. UNIT_ASSERT(Q_.Add(new TThreadTask(&TThreadTask::what, i, count, Data_))); \
  138. } \
  139. Q_.Stop(); \
  140. bool b = Data_.failed; \
  141. Data_.failed = false; \
  142. UNIT_ASSERT(!b);
  143. inline void TestBasics() {
  144. RUN_CYCLE(RunBasics, 3);
  145. }
  146. inline void TestBasicsWithPredicate() {
  147. RUN_CYCLE(RunBasicsWithPredicate, 3);
  148. }
  149. inline void TestSyncronize() {
  150. RUN_CYCLE(RunSyncronize, 6);
  151. }
  152. inline void TestSyncronizeWithPredicate() {
  153. RUN_CYCLE(RunSyncronizeWithPredicate, 6);
  154. }
  155. #undef RUN_CYCLE
  156. TSharedData Data_;
  157. TThreadPool Q_;
  158. };
  159. UNIT_TEST_SUITE_REGISTRATION(TCondVarTest);