wait_group-inl.h 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. #pragma once
  2. #if !defined(INCLUDE_FUTURE_INL_H)
  3. #error "you should never include wait_group-inl.h directly"
  4. #endif // INCLUDE_FUTURE_INL_H
  5. #include "wait_policy.h"
  6. #include <util/generic/maybe.h>
  7. #include <util/generic/ptr.h>
  8. #include <library/cpp/threading/future/core/future.h>
  9. #include <util/system/spinlock.h>
  10. #include <atomic>
  11. #include <exception>
  12. namespace NThreading {
  13. namespace NWaitGroup::NImpl {
  14. template <class WaitPolicy>
  15. struct TState final : TAtomicRefCount<TState<WaitPolicy>> {
  16. template <class T>
  17. void Add(const TFuture<T>& future);
  18. TFuture<void> Finish();
  19. void TryPublish();
  20. void Publish();
  21. bool ShouldPublishByCount() const noexcept;
  22. bool ShouldPublishByException() const noexcept;
  23. TStateRef<WaitPolicy> SharedFromThis() noexcept {
  24. return TStateRef<WaitPolicy>{this};
  25. }
  26. enum class EPhase {
  27. Initial,
  28. Publishing,
  29. };
  30. // initially we have one imaginary discovered future which we
  31. // use for synchronization with ::Finish
  32. std::atomic<ui64> Discovered{1};
  33. std::atomic<ui64> Finished{0};
  34. std::atomic<EPhase> Phase{EPhase::Initial};
  35. TPromise<void> Subscribers = NewPromise();
  36. mutable TAdaptiveLock Mut;
  37. std::exception_ptr ExceptionInFlight;
  38. void TrySetException(std::exception_ptr eptr) noexcept {
  39. TGuard lock{Mut};
  40. if (!ExceptionInFlight) {
  41. ExceptionInFlight = std::move(eptr);
  42. }
  43. }
  44. std::exception_ptr GetExceptionInFlight() const noexcept {
  45. TGuard lock{Mut};
  46. return ExceptionInFlight;
  47. }
  48. };
  49. template <class WaitPolicy>
  50. inline TFuture<void> TState<WaitPolicy>::Finish() {
  51. Finished.fetch_add(1); // complete the imaginary future
  52. // handle empty case explicitly:
  53. if (Discovered.load() == 1) {
  54. Y_ASSERT(Phase.load() == EPhase::Initial);
  55. Publish();
  56. } else {
  57. TryPublish();
  58. }
  59. return Subscribers;
  60. }
  61. template <class WaitPolicy>
  62. template <class T>
  63. inline void TState<WaitPolicy>::Add(const TFuture<T>& future) {
  64. future.EnsureInitialized();
  65. Discovered.fetch_add(1);
  66. // NoexceptSubscribe is needed to make ::Add exception-safe
  67. future.NoexceptSubscribe([self = SharedFromThis()](auto&& future) {
  68. try {
  69. future.TryRethrow();
  70. } catch (...) {
  71. self->TrySetException(std::current_exception());
  72. }
  73. self->Finished.fetch_add(1);
  74. self->TryPublish();
  75. });
  76. }
  77. //
  78. // ============================ PublishByCount ==================================
  79. //
  80. template <class WaitPolicy>
  81. inline bool TState<WaitPolicy>::ShouldPublishByCount() const noexcept {
  82. // - safety: a) If the future incremented ::Finished, and we observe the effect, then we will observe ::Discovered as incremented by its discovery later
  83. // b) Every discovery of a future observes discovery of the imaginary future
  84. // a, b => if finishedByNow == discoveredByNow, then every future discovered in [imaginary discovered, imaginary finished] is finished
  85. //
  86. // - liveness: a) TryPublish is called after each increment of ::Finished
  87. // b) There is some last increment of ::Finished which follows all other operations with ::Finished and ::Discovered (provided that every future is eventually set)
  88. // c) For each increment of ::Discovered there is an increment of ::Finished (provided that every future is eventually set)
  89. // a, b c => some call to ShouldPublishByCount will always return true
  90. //
  91. // order of the following two operations is significant for the proof.
  92. auto finishedByNow = Finished.load();
  93. auto discoveredByNow = Discovered.load();
  94. return finishedByNow == discoveredByNow;
  95. }
  96. template <>
  97. inline bool TState<TWaitPolicy::TAny>::ShouldPublishByCount() const noexcept {
  98. auto finishedByNow = Finished.load();
  99. // note that the empty case is not handled here
  100. return finishedByNow >= 2; // at least one non-imaginary
  101. }
  102. //
  103. // ============================ PublishByException ==================================
  104. //
  105. template <>
  106. inline bool TState<TWaitPolicy::TAny>::ShouldPublishByException() const noexcept {
  107. // for TAny exceptions are handled by ShouldPublishByCount
  108. return false;
  109. }
  110. template <>
  111. inline bool TState<TWaitPolicy::TAll>::ShouldPublishByException() const noexcept {
  112. return false;
  113. }
  114. template <>
  115. inline bool TState<TWaitPolicy::TExceptionOrAll>::ShouldPublishByException() const noexcept {
  116. return GetExceptionInFlight() != nullptr;
  117. }
  118. //
  119. //
  120. //
  121. template <class WaitPolicy>
  122. inline void TState<WaitPolicy>::TryPublish() {
  123. // the order is insignificant (without proof)
  124. bool shouldPublish = ShouldPublishByCount() || ShouldPublishByException();
  125. if (shouldPublish) {
  126. if (auto currentPhase = EPhase::Initial;
  127. Phase.compare_exchange_strong(currentPhase, EPhase::Publishing)) {
  128. Publish();
  129. }
  130. }
  131. }
  132. template <class WaitPolicy>
  133. inline void TState<WaitPolicy>::Publish() {
  134. auto eptr = GetExceptionInFlight();
  135. // can potentially throw
  136. if (eptr) {
  137. Subscribers.SetException(std::move(eptr));
  138. } else {
  139. Subscribers.SetValue();
  140. }
  141. }
  142. }
  143. template <class WaitPolicy>
  144. inline TWaitGroup<WaitPolicy>::TWaitGroup()
  145. : State_{MakeIntrusive<NWaitGroup::NImpl::TState<WaitPolicy>>()}
  146. {
  147. }
  148. template <class WaitPolicy>
  149. template <class T>
  150. inline TWaitGroup<WaitPolicy>& TWaitGroup<WaitPolicy>::Add(const TFuture<T>& future) {
  151. State_->Add(future);
  152. return *this;
  153. }
  154. template <class WaitPolicy>
  155. inline TFuture<void> TWaitGroup<WaitPolicy>::Finish() && {
  156. auto res = State_->Finish();
  157. // just to prevent nasty bugs from use-after-move
  158. State_.Reset();
  159. return res;
  160. }
  161. }