#pragma once #if !defined(INCLUDE_FUTURE_INL_H) #error "you should never include wait_group-inl.h directly" #endif // INCLUDE_FUTURE_INL_H #include "wait_policy.h" #include #include #include #include #include #include namespace NThreading { namespace NWaitGroup::NImpl { template struct TState final : TAtomicRefCount> { template void Add(const TFuture& future); TFuture Finish(); void TryPublish(); void Publish(); bool ShouldPublishByCount() const noexcept; bool ShouldPublishByException() const noexcept; TStateRef SharedFromThis() noexcept { return TStateRef{this}; } enum class EPhase { Initial, Publishing, }; // initially we have one imaginary discovered future which we // use for synchronization with ::Finish std::atomic Discovered{1}; std::atomic Finished{0}; std::atomic Phase{EPhase::Initial}; TPromise Subscribers = NewPromise(); mutable TAdaptiveLock Mut; std::exception_ptr ExceptionInFlight; void TrySetException(std::exception_ptr eptr) noexcept { TGuard lock{Mut}; if (!ExceptionInFlight) { ExceptionInFlight = std::move(eptr); } } std::exception_ptr GetExceptionInFlight() const noexcept { TGuard lock{Mut}; return ExceptionInFlight; } }; template inline TFuture TState::Finish() { Finished.fetch_add(1); // complete the imaginary future // handle empty case explicitly: if (Discovered.load() == 1) { Y_ASSERT(Phase.load() == EPhase::Initial); Publish(); } else { TryPublish(); } return Subscribers; } template template inline void TState::Add(const TFuture& future) { future.EnsureInitialized(); Discovered.fetch_add(1); // NoexceptSubscribe is needed to make ::Add exception-safe future.NoexceptSubscribe([self = SharedFromThis()](auto&& future) { try { future.TryRethrow(); } catch (...) { self->TrySetException(std::current_exception()); } self->Finished.fetch_add(1); self->TryPublish(); }); } // // ============================ PublishByCount ================================== // template inline bool TState::ShouldPublishByCount() const noexcept { // - safety: a) If the future incremented ::Finished, and we observe the effect, then we will observe ::Discovered as incremented by its discovery later // b) Every discovery of a future observes discovery of the imaginary future // a, b => if finishedByNow == discoveredByNow, then every future discovered in [imaginary discovered, imaginary finished] is finished // // - liveness: a) TryPublish is called after each increment of ::Finished // b) There is some last increment of ::Finished which follows all other operations with ::Finished and ::Discovered (provided that every future is eventually set) // c) For each increment of ::Discovered there is an increment of ::Finished (provided that every future is eventually set) // a, b c => some call to ShouldPublishByCount will always return true // // order of the following two operations is significant for the proof. auto finishedByNow = Finished.load(); auto discoveredByNow = Discovered.load(); return finishedByNow == discoveredByNow; } template <> inline bool TState::ShouldPublishByCount() const noexcept { auto finishedByNow = Finished.load(); // note that the empty case is not handled here return finishedByNow >= 2; // at least one non-imaginary } // // ============================ PublishByException ================================== // template <> inline bool TState::ShouldPublishByException() const noexcept { // for TAny exceptions are handled by ShouldPublishByCount return false; } template <> inline bool TState::ShouldPublishByException() const noexcept { return false; } template <> inline bool TState::ShouldPublishByException() const noexcept { return GetExceptionInFlight() != nullptr; } // // // template inline void TState::TryPublish() { // the order is insignificant (without proof) bool shouldPublish = ShouldPublishByCount() || ShouldPublishByException(); if (shouldPublish) { if (auto currentPhase = EPhase::Initial; Phase.compare_exchange_strong(currentPhase, EPhase::Publishing)) { Publish(); } } } template inline void TState::Publish() { auto eptr = GetExceptionInFlight(); // can potentially throw if (eptr) { Subscribers.SetException(std::move(eptr)); } else { Subscribers.SetValue(); } } } template inline TWaitGroup::TWaitGroup() : State_{MakeIntrusive>()} { } template template inline TWaitGroup& TWaitGroup::Add(const TFuture& future) { State_->Add(future); return *this; } template inline TFuture TWaitGroup::Finish() && { auto res = State_->Finish(); // just to prevent nasty bugs from use-after-move State_.Reset(); return res; } }