condvar_ut.cpp 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. #include "mutex.h"
  2. #include "guard.h"
  3. #include "condvar.h"
  4. #include <library/cpp/testing/unittest/registar.h>
  5. #include <util/system/atomic.h>
  6. #include <util/system/atomic_ops.h>
  7. #include <util/thread/pool.h>
  8. class TCondVarTest: public TTestBase {
  9. UNIT_TEST_SUITE(TCondVarTest);
  10. UNIT_TEST(TestBasics)
  11. UNIT_TEST(TestSyncronize)
  12. UNIT_TEST_SUITE_END();
  13. struct TSharedData {
  14. TSharedData()
  15. : stopWaiting(false)
  16. , in(0)
  17. , out(0)
  18. , waited(0)
  19. , failed(false)
  20. {
  21. }
  22. TMutex mutex;
  23. TCondVar condVar1;
  24. TCondVar condVar2;
  25. TAtomic stopWaiting;
  26. TAtomic in;
  27. TAtomic out;
  28. TAtomic waited;
  29. bool failed;
  30. };
  31. class TThreadTask: public IObjectInQueue {
  32. public:
  33. using PFunc = void (TThreadTask::*)(void);
  34. TThreadTask(PFunc func, size_t id, size_t totalIds, TSharedData& data)
  35. : Func_(func)
  36. , Id_(id)
  37. , TotalIds_(totalIds)
  38. , Data_(data)
  39. {
  40. }
  41. void Process(void*) override {
  42. THolder<TThreadTask> This(this);
  43. (this->*Func_)();
  44. }
  45. #define FAIL_ASSERT(cond) \
  46. if (!(cond)) { \
  47. Data_.failed = true; \
  48. }
  49. void RunBasics() {
  50. Y_ASSERT(TotalIds_ == 3);
  51. if (Id_ < 2) {
  52. TGuard<TMutex> guard(Data_.mutex);
  53. while (!AtomicGet(Data_.stopWaiting)) {
  54. bool res = Data_.condVar1.WaitT(Data_.mutex, TDuration::Seconds(1));
  55. FAIL_ASSERT(res == true);
  56. }
  57. } else {
  58. usleep(100000);
  59. AtomicSet(Data_.stopWaiting, true);
  60. TGuard<TMutex> guard(Data_.mutex);
  61. Data_.condVar1.Signal();
  62. Data_.condVar1.Signal();
  63. }
  64. }
  65. void RunBasicsWithPredicate() {
  66. Y_ASSERT(TotalIds_ == 3);
  67. if (Id_ < 2) {
  68. TGuard<TMutex> guard(Data_.mutex);
  69. const auto res = Data_.condVar1.WaitT(Data_.mutex, TDuration::Seconds(1), [&] {
  70. return AtomicGet(Data_.stopWaiting);
  71. });
  72. FAIL_ASSERT(res == true);
  73. } else {
  74. usleep(100000);
  75. AtomicSet(Data_.stopWaiting, true);
  76. TGuard<TMutex> guard(Data_.mutex);
  77. Data_.condVar1.Signal();
  78. Data_.condVar1.Signal();
  79. }
  80. }
  81. void RunSyncronize() {
  82. for (size_t i = 0; i < 10; ++i) {
  83. TGuard<TMutex> guard(Data_.mutex);
  84. AtomicIncrement(Data_.in);
  85. if (AtomicGet(Data_.in) == TotalIds_) {
  86. AtomicSet(Data_.out, 0);
  87. Data_.condVar1.BroadCast();
  88. } else {
  89. AtomicIncrement(Data_.waited);
  90. while (AtomicGet(Data_.in) < TotalIds_) {
  91. bool res = Data_.condVar1.WaitT(Data_.mutex, TDuration::Seconds(1));
  92. FAIL_ASSERT(res == true);
  93. }
  94. }
  95. AtomicIncrement(Data_.out);
  96. if (AtomicGet(Data_.out) == TotalIds_) {
  97. AtomicSet(Data_.in, 0);
  98. Data_.condVar2.BroadCast();
  99. } else {
  100. while (AtomicGet(Data_.out) < TotalIds_) {
  101. bool res = Data_.condVar2.WaitT(Data_.mutex, TDuration::Seconds(1));
  102. FAIL_ASSERT(res == true);
  103. }
  104. }
  105. }
  106. FAIL_ASSERT(AtomicGet(Data_.waited) == (TotalIds_ - 1) * 10);
  107. }
  108. void RunSyncronizeWithPredicate() {
  109. for (size_t i = 0; i < 10; ++i) {
  110. TGuard<TMutex> guard(Data_.mutex);
  111. AtomicIncrement(Data_.in);
  112. if (AtomicGet(Data_.in) == TotalIds_) {
  113. AtomicSet(Data_.out, 0);
  114. Data_.condVar1.BroadCast();
  115. } else {
  116. AtomicIncrement(Data_.waited);
  117. const auto res = Data_.condVar1.WaitT(Data_.mutex, TDuration::Seconds(1), [&] {
  118. return AtomicGet(Data_.in) >= TotalIds_;
  119. });
  120. FAIL_ASSERT(res == true);
  121. }
  122. AtomicIncrement(Data_.out);
  123. if (AtomicGet(Data_.out) == TotalIds_) {
  124. AtomicSet(Data_.in, 0);
  125. Data_.condVar2.BroadCast();
  126. } else {
  127. const auto res = Data_.condVar2.WaitT(Data_.mutex, TDuration::Seconds(1), [&] {
  128. return AtomicGet(Data_.out) >= TotalIds_;
  129. });
  130. FAIL_ASSERT(res == true);
  131. }
  132. }
  133. FAIL_ASSERT(Data_.waited == (TotalIds_ - 1) * 10);
  134. }
  135. #undef FAIL_ASSERT
  136. private:
  137. PFunc Func_;
  138. size_t Id_;
  139. TAtomicBase TotalIds_;
  140. TSharedData& Data_;
  141. };
  142. private:
  143. #define RUN_CYCLE(what, count) \
  144. Q_.Start(count); \
  145. for (size_t i = 0; i < count; ++i) { \
  146. UNIT_ASSERT(Q_.Add(new TThreadTask(&TThreadTask::what, i, count, Data_))); \
  147. } \
  148. Q_.Stop(); \
  149. bool b = Data_.failed; \
  150. Data_.failed = false; \
  151. UNIT_ASSERT(!b);
  152. inline void TestBasics() {
  153. RUN_CYCLE(RunBasics, 3);
  154. }
  155. inline void TestBasicsWithPredicate() {
  156. RUN_CYCLE(RunBasicsWithPredicate, 3);
  157. }
  158. inline void TestSyncronize() {
  159. RUN_CYCLE(RunSyncronize, 6);
  160. }
  161. inline void TestSyncronizeWithPredicate() {
  162. RUN_CYCLE(RunSyncronizeWithPredicate, 6);
  163. }
  164. #undef RUN_CYCLE
  165. TSharedData Data_;
  166. TThreadPool Q_;
  167. };
  168. UNIT_TEST_SUITE_REGISTRATION(TCondVarTest);