future_mt_ut.cpp 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. #include "future.h"
  2. #include <library/cpp/testing/unittest/registar.h>
  3. #include <util/generic/noncopyable.h>
  4. #include <util/generic/xrange.h>
  5. #include <util/thread/pool.h>
  6. #include <atomic>
  7. #include <exception>
  8. using NThreading::NewPromise;
  9. using NThreading::TFuture;
  10. using NThreading::TPromise;
  11. using NThreading::TWaitPolicy;
  12. namespace {
  13. // Wait* implementation without optimizations, to test TWaitGroup better
  14. template <class WaitPolicy, class TContainer>
  15. TFuture<void> WaitNoOpt(const TContainer& futures) {
  16. NThreading::TWaitGroup<WaitPolicy> wg;
  17. for (const auto& fut : futures) {
  18. wg.Add(fut);
  19. }
  20. return std::move(wg).Finish();
  21. }
  22. class TRelaxedBarrier {
  23. public:
  24. explicit TRelaxedBarrier(i64 size)
  25. : Waiting_{size} {
  26. }
  27. void Arrive() {
  28. // barrier is not for synchronization, just to ensure good timings, so
  29. // std::memory_order_relaxed is enough
  30. Waiting_.fetch_add(-1, std::memory_order_relaxed);
  31. while (Waiting_.load(std::memory_order_relaxed)) {
  32. }
  33. Y_ASSERT(Waiting_.load(std::memory_order_relaxed) >= 0);
  34. }
  35. private:
  36. std::atomic<i64> Waiting_;
  37. };
  38. THolder<TThreadPool> MakePool() {
  39. auto pool = MakeHolder<TThreadPool>(TThreadPool::TParams{}.SetBlocking(false).SetCatching(false));
  40. pool->Start(8);
  41. return pool;
  42. }
  43. template <class T>
  44. TVector<TFuture<T>> ToFutures(const TVector<TPromise<T>>& promises) {
  45. TVector<TFuture<void>> futures;
  46. for (auto&& p : promises) {
  47. futures.emplace_back(p);
  48. }
  49. return futures;
  50. }
  51. struct TStateSnapshot {
  52. i64 Started = -1;
  53. i64 StartedException = -1;
  54. const TVector<TFuture<void>>* Futures = nullptr;
  55. };
  56. // note: std::memory_order_relaxed should be enough everywhere, because TFuture::SetValue must provide the
  57. // needed synchronization
  58. template <class TFactory>
  59. void RunWaitTest(TFactory global) {
  60. auto pool = MakePool();
  61. const auto exception = std::make_exception_ptr(42);
  62. for (auto numPromises : xrange(1, 5)) {
  63. for (auto loopIter : xrange(1024 * 64)) {
  64. const auto numParticipants = numPromises + 1;
  65. TRelaxedBarrier barrier{numParticipants};
  66. std::atomic<i64> started = 0;
  67. std::atomic<i64> startedException = 0;
  68. std::atomic<i64> completed = 0;
  69. TVector<TPromise<void>> promises;
  70. for (auto i : xrange(numPromises)) {
  71. Y_UNUSED(i);
  72. promises.push_back(NewPromise());
  73. }
  74. const auto futures = ToFutures(promises);
  75. auto snapshotter = [&] {
  76. return TStateSnapshot{
  77. .Started = started.load(std::memory_order_relaxed),
  78. .StartedException = startedException.load(std::memory_order_relaxed),
  79. .Futures = &futures,
  80. };
  81. };
  82. for (auto i : xrange(numPromises)) {
  83. pool->SafeAddFunc([&, i] {
  84. barrier.Arrive();
  85. // subscribers must observe effects of this operation
  86. // after .Set*
  87. started.fetch_add(1, std::memory_order_relaxed);
  88. if ((loopIter % 4 == 0) && i == 0) {
  89. startedException.fetch_add(1, std::memory_order_relaxed);
  90. promises[i].SetException(exception);
  91. } else {
  92. promises[i].SetValue();
  93. }
  94. completed.fetch_add(1, std::memory_order_release);
  95. });
  96. }
  97. pool->SafeAddFunc([&] {
  98. auto local = global(snapshotter);
  99. barrier.Arrive();
  100. local();
  101. completed.fetch_add(1, std::memory_order_release);
  102. });
  103. while (completed.load() != numParticipants) {
  104. }
  105. }
  106. }
  107. }
  108. }
  109. Y_UNIT_TEST_SUITE(TFutureMultiThreadedTest) {
  110. Y_UNIT_TEST(WaitAll) {
  111. RunWaitTest(
  112. [](auto snapshotter) {
  113. return [=]() {
  114. auto* futures = snapshotter().Futures;
  115. auto all = WaitNoOpt<TWaitPolicy::TAll>(*futures);
  116. // tests safety part
  117. all.Subscribe([=] (auto&& all) {
  118. TStateSnapshot snap = snapshotter();
  119. // value safety: all is set => every future is set
  120. UNIT_ASSERT(all.HasValue() <= ((snap.Started == (i64)snap.Futures->size()) && !snap.StartedException));
  121. // safety for hasException: all is set => every future is set and some has exception
  122. UNIT_ASSERT(all.HasException() <= ((snap.Started == (i64)snap.Futures->size()) && snap.StartedException > 0));
  123. });
  124. // test liveness
  125. all.Wait();
  126. };
  127. });
  128. }
  129. Y_UNIT_TEST(WaitAny) {
  130. RunWaitTest(
  131. [](auto snapshotter) {
  132. return [=]() {
  133. auto* futures = snapshotter().Futures;
  134. auto any = WaitNoOpt<TWaitPolicy::TAny>(*futures);
  135. // safety: any is ready => some f is ready
  136. any.Subscribe([=](auto&&) {
  137. UNIT_ASSERT(snapshotter().Started > 0);
  138. });
  139. // do we need better multithreaded liveness tests?
  140. any.Wait();
  141. };
  142. });
  143. }
  144. Y_UNIT_TEST(WaitExceptionOrAll) {
  145. RunWaitTest(
  146. [](auto snapshotter) {
  147. return [=]() {
  148. NThreading::WaitExceptionOrAll(*snapshotter().Futures)
  149. .Subscribe([=](auto&&) {
  150. auto* futures = snapshotter().Futures;
  151. auto exceptionOrAll = WaitNoOpt<TWaitPolicy::TExceptionOrAll>(*futures);
  152. exceptionOrAll.Subscribe([snapshotter](auto&& exceptionOrAll) {
  153. TStateSnapshot snap = snapshotter();
  154. // safety for hasException: exceptionOrAll has exception => some has exception
  155. UNIT_ASSERT(exceptionOrAll.HasException() ? snap.StartedException > 0 : true);
  156. // value safety: exceptionOrAll has value => all have value
  157. UNIT_ASSERT(exceptionOrAll.HasValue() == ((snap.Started == (i64)snap.Futures->size()) && !snap.StartedException));
  158. });
  159. // do we need better multithreaded liveness tests?
  160. exceptionOrAll.Wait();
  161. });
  162. };
  163. });
  164. }
  165. }