#pragma once #include "grpc_common.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include /* * This file contains low level logic for grpc * This file should not be used in high level code without special reason */ namespace NGrpc { const size_t DEFAULT_NUM_THREADS = 2; //////////////////////////////////////////////////////////////////////////////// void EnableGRpcTracing(); //////////////////////////////////////////////////////////////////////////////// struct TTcpKeepAliveSettings { bool Enabled; size_t Idle; size_t Count; size_t Interval; }; //////////////////////////////////////////////////////////////////////////////// // Common interface used to execute action from grpc cq routine class IQueueClientEvent { public: virtual ~IQueueClientEvent() = default; //! Execute an action defined by implementation virtual bool Execute(bool ok) = 0; //! Finish and destroy event virtual void Destroy() = 0; }; // Implementation of IQueueClientEvent that reduces allocations template class TQueueClientFixedEvent : private IQueueClientEvent { using TCallback = void (TSelf::*)(bool); public: TQueueClientFixedEvent(TSelf* self, TCallback callback) : Self(self) , Callback(callback) { } IQueueClientEvent* Prepare() { Self->Ref(); return this; } private: bool Execute(bool ok) override { ((*Self).*Callback)(ok); return false; } void Destroy() override { Self->UnRef(); } private: TSelf* const Self; TCallback const Callback; }; class IQueueClientContext; using IQueueClientContextPtr = std::shared_ptr; // Provider of IQueueClientContext instances class IQueueClientContextProvider { public: virtual ~IQueueClientContextProvider() = default; virtual IQueueClientContextPtr CreateContext() = 0; }; // Activity context for a low-level client class IQueueClientContext : public IQueueClientContextProvider { public: virtual ~IQueueClientContext() = default; //! Returns CompletionQueue associated with the client virtual grpc::CompletionQueue* CompletionQueue() = 0; //! Returns true if context has been cancelled virtual bool IsCancelled() const = 0; //! Tries to cancel context, calling all registered callbacks virtual bool Cancel() = 0; //! Subscribes callback to cancellation // // Note there's no way to unsubscribe, if subscription is temporary // make sure you create a new context with CreateContext and release // it as soon as it's no longer needed. virtual void SubscribeCancel(std::function callback) = 0; //! Subscribes callback to cancellation // // This alias is for compatibility with older code. void SubscribeStop(std::function callback) { SubscribeCancel(std::move(callback)); } }; // Represents grpc status and error message string struct TGrpcStatus { TString Msg; TString Details; int GRpcStatusCode; bool InternalError; TGrpcStatus() : GRpcStatusCode(grpc::StatusCode::OK) , InternalError(false) { } TGrpcStatus(TString msg, int statusCode, bool internalError) : Msg(std::move(msg)) , GRpcStatusCode(statusCode) , InternalError(internalError) { } TGrpcStatus(grpc::StatusCode status, TString msg, TString details = {}) : Msg(std::move(msg)) , Details(std::move(details)) , GRpcStatusCode(status) , InternalError(false) { } TGrpcStatus(const grpc::Status& status) : TGrpcStatus(status.error_code(), TString(status.error_message()), TString(status.error_details())) { } TGrpcStatus& operator=(const grpc::Status& status) { Msg = TString(status.error_message()); Details = TString(status.error_details()); GRpcStatusCode = status.error_code(); InternalError = false; return *this; } static TGrpcStatus Internal(TString msg) { return { std::move(msg), -1, true }; } bool Ok() const { return !InternalError && GRpcStatusCode == grpc::StatusCode::OK; } TStringBuilder ToDebugString() const { TStringBuilder ret; ret << "gRpcStatusCode: " << GRpcStatusCode; if(!Ok()) ret << ", Msg: " << Msg << ", Details: " << Details << ", InternalError: " << InternalError; return ret; } }; bool inline IsGRpcStatusGood(const TGrpcStatus& status) { return status.Ok(); } // Response callback type - this callback will be called when request is finished // (or after getting each chunk in case of streaming mode) template using TResponseCallback = std::function; template using TAdvancedResponseCallback = std::function; // Call associated metadata struct TCallMeta { std::shared_ptr CallCredentials; std::vector> Aux; std::variant Timeout; // timeout as duration from now or time point in future }; class TGRpcRequestProcessorCommon { protected: void ApplyMeta(const TCallMeta& meta) { for (const auto& rec : meta.Aux) { Context.AddMetadata(rec.first, rec.second); } if (meta.CallCredentials) { Context.set_credentials(meta.CallCredentials); } if (const TDuration* timeout = std::get_if(&meta.Timeout)) { if (*timeout) { auto deadline = gpr_time_add( gpr_now(GPR_CLOCK_MONOTONIC), gpr_time_from_micros(timeout->MicroSeconds(), GPR_TIMESPAN)); Context.set_deadline(deadline); } } else if (const TInstant* deadline = std::get_if(&meta.Timeout)) { if (*deadline) { Context.set_deadline(gpr_time_from_micros(deadline->MicroSeconds(), GPR_CLOCK_MONOTONIC)); } } } void GetInitialMetadata(std::unordered_multimap* metadata) { for (const auto& [key, value] : Context.GetServerInitialMetadata()) { metadata->emplace( TString(key.begin(), key.end()), TString(value.begin(), value.end()) ); } } grpc::Status Status; grpc::ClientContext Context; std::shared_ptr LocalContext; }; template class TSimpleRequestProcessor : public TThrRefBase , public IQueueClientEvent , public TGRpcRequestProcessorCommon { using TAsyncReaderPtr = std::unique_ptr>; template friend class TServiceConnection; public: using TPtr = TIntrusivePtr; using TAsyncRequest = TAsyncReaderPtr (TStub::*)(grpc::ClientContext*, const TRequest&, grpc::CompletionQueue*); explicit TSimpleRequestProcessor(TResponseCallback&& callback) : Callback_(std::move(callback)) { } ~TSimpleRequestProcessor() { if (!Replied_ && Callback_) { Callback_(TGrpcStatus::Internal("request left unhandled"), std::move(Reply_)); Callback_ = nullptr; // free resources as early as possible } } bool Execute(bool ok) override { { std::unique_lock guard(Mutex_); LocalContext.reset(); } TGrpcStatus status; if (ok) { status = Status; } else { status = TGrpcStatus::Internal("Unexpected error"); } Replied_ = true; Callback_(std::move(status), std::move(Reply_)); Callback_ = nullptr; // free resources as early as possible return false; } void Destroy() override { UnRef(); } private: IQueueClientEvent* FinishedEvent() { Ref(); return this; } void Start(TStub& stub, TAsyncRequest asyncRequest, const TRequest& request, IQueueClientContextProvider* provider) { auto context = provider->CreateContext(); if (!context) { Replied_ = true; Callback_(TGrpcStatus(grpc::StatusCode::CANCELLED, "Client is shutting down"), std::move(Reply_)); Callback_ = nullptr; return; } { std::unique_lock guard(Mutex_); LocalContext = context; Reader_ = (stub.*asyncRequest)(&Context, request, context->CompletionQueue()); Reader_->Finish(&Reply_, &Status, FinishedEvent()); } context->SubscribeStop([self = TPtr(this)] { self->Stop(); }); } void Stop() { Context.TryCancel(); } TResponseCallback Callback_; TResponse Reply_; std::mutex Mutex_; TAsyncReaderPtr Reader_; bool Replied_ = false; }; template class TAdvancedRequestProcessor : public TThrRefBase , public IQueueClientEvent , public TGRpcRequestProcessorCommon { using TAsyncReaderPtr = std::unique_ptr>; template friend class TServiceConnection; public: using TPtr = TIntrusivePtr; using TAsyncRequest = TAsyncReaderPtr (TStub::*)(grpc::ClientContext*, const TRequest&, grpc::CompletionQueue*); explicit TAdvancedRequestProcessor(TAdvancedResponseCallback&& callback) : Callback_(std::move(callback)) { } ~TAdvancedRequestProcessor() { if (!Replied_ && Callback_) { Callback_(Context, TGrpcStatus::Internal("request left unhandled"), std::move(Reply_)); Callback_ = nullptr; // free resources as early as possible } } bool Execute(bool ok) override { { std::unique_lock guard(Mutex_); LocalContext.reset(); } TGrpcStatus status; if (ok) { status = Status; } else { status = TGrpcStatus::Internal("Unexpected error"); } Replied_ = true; Callback_(Context, std::move(status), std::move(Reply_)); Callback_ = nullptr; // free resources as early as possible return false; } void Destroy() override { UnRef(); } private: IQueueClientEvent* FinishedEvent() { Ref(); return this; } void Start(TStub& stub, TAsyncRequest asyncRequest, const TRequest& request, IQueueClientContextProvider* provider) { auto context = provider->CreateContext(); if (!context) { Replied_ = true; Callback_(Context, TGrpcStatus(grpc::StatusCode::CANCELLED, "Client is shutting down"), std::move(Reply_)); Callback_ = nullptr; return; } { std::unique_lock guard(Mutex_); LocalContext = context; Reader_ = (stub.*asyncRequest)(&Context, request, context->CompletionQueue()); Reader_->Finish(&Reply_, &Status, FinishedEvent()); } context->SubscribeStop([self = TPtr(this)] { self->Stop(); }); } void Stop() { Context.TryCancel(); } TAdvancedResponseCallback Callback_; TResponse Reply_; std::mutex Mutex_; TAsyncReaderPtr Reader_; bool Replied_ = false; }; class IStreamRequestCtrl : public TThrRefBase { public: using TPtr = TIntrusivePtr; /** * Asynchronously cancel the request */ virtual void Cancel() = 0; }; template class IStreamRequestReadProcessor : public IStreamRequestCtrl { public: using TPtr = TIntrusivePtr; using TReadCallback = std::function; /** * Scheduled initial server metadata read from the stream */ virtual void ReadInitialMetadata(std::unordered_multimap* metadata, TReadCallback callback) = 0; /** * Scheduled response read from the stream * Callback will be called with the status if it failed * Only one Read or Finish call may be active at a time */ virtual void Read(TResponse* response, TReadCallback callback) = 0; /** * Stop reading and gracefully finish the stream * Only one Read or Finish call may be active at a time */ virtual void Finish(TReadCallback callback) = 0; /** * Additional callback to be called when stream has finished */ virtual void AddFinishedCallback(TReadCallback callback) = 0; }; template class IStreamRequestReadWriteProcessor : public IStreamRequestReadProcessor { public: using TPtr = TIntrusivePtr; using TWriteCallback = std::function; /** * Scheduled request write to the stream */ virtual void Write(TRequest&& request, TWriteCallback callback = { }) = 0; }; class TGRpcKeepAliveSocketMutator; // Class to hold stubs allocated on channel. // It is poor documented part of grpc. See KIKIMR-6109 and comment to this commit // Stub holds shared_ptr, so we can destroy this holder even if // request processor using stub class TStubsHolder : public TNonCopyable { using TypeInfoRef = std::reference_wrapper; struct THasher { std::size_t operator()(TypeInfoRef code) const { return code.get().hash_code(); } }; struct TEqualTo { bool operator()(TypeInfoRef lhs, TypeInfoRef rhs) const { return lhs.get() == rhs.get(); } }; public: TStubsHolder(std::shared_ptr channel) : ChannelInterface_(channel) {} // Returns true if channel can't be used to perform request now bool IsChannelBroken() const { auto state = ChannelInterface_->GetState(false); return state == GRPC_CHANNEL_SHUTDOWN || state == GRPC_CHANNEL_TRANSIENT_FAILURE; } template std::shared_ptr GetOrCreateStub() { const auto& stubId = typeid(TStub); { std::shared_lock readGuard(RWMutex_); const auto it = Stubs_.find(stubId); if (it != Stubs_.end()) { return std::static_pointer_cast(it->second); } } { std::unique_lock writeGuard(RWMutex_); auto it = Stubs_.emplace(stubId, nullptr); if (!it.second) { return std::static_pointer_cast(it.first->second); } else { it.first->second = std::make_shared(ChannelInterface_); return std::static_pointer_cast(it.first->second); } } } const TInstant& GetLastUseTime() const { return LastUsed_; } void SetLastUseTime(const TInstant& time) { LastUsed_ = time; } private: TInstant LastUsed_ = Now(); std::shared_mutex RWMutex_; std::unordered_map, THasher, TEqualTo> Stubs_; std::shared_ptr ChannelInterface_; }; class TChannelPool { public: TChannelPool(const TTcpKeepAliveSettings& tcpKeepAliveSettings, const TDuration& expireTime = TDuration::Minutes(6)); //Allows to CreateStub from TStubsHolder under lock //The callback will be called just during GetStubsHolderLocked call void GetStubsHolderLocked(const TString& channelId, const TGRpcClientConfig& config, std::function cb); void DeleteChannel(const TString& channelId); void DeleteExpiredStubsHolders(); private: std::shared_mutex RWMutex_; std::unordered_map Pool_; std::multimap LastUsedQueue_; TTcpKeepAliveSettings TcpKeepAliveSettings_; TDuration ExpireTime_; TDuration UpdateReUseTime_; void EraseFromQueueByTime(const TInstant& lastUseTime, const TString& channelId); }; template using TStreamReaderCallback = std::function::TPtr)>; template class TStreamRequestReadProcessor : public IStreamRequestReadProcessor , public TGRpcRequestProcessorCommon { template friend class TServiceConnection; public: using TSelf = TStreamRequestReadProcessor; using TAsyncReaderPtr = std::unique_ptr>; using TAsyncRequest = TAsyncReaderPtr (TStub::*)(grpc::ClientContext*, const TRequest&, grpc::CompletionQueue*, void*); using TReaderCallback = TStreamReaderCallback; using TPtr = TIntrusivePtr; using TBase = IStreamRequestReadProcessor; using TReadCallback = typename TBase::TReadCallback; explicit TStreamRequestReadProcessor(TReaderCallback&& callback) : Callback(std::move(callback)) { Y_VERIFY(Callback, "Missing connected callback"); } void Cancel() override { Context.TryCancel(); { std::unique_lock guard(Mutex); Cancelled = true; if (Started && !ReadFinished) { if (!ReadActive) { ReadFinished = true; } if (ReadFinished) { Stream->Finish(&Status, OnFinishedTag.Prepare()); } } } } void ReadInitialMetadata(std::unordered_multimap* metadata, TReadCallback callback) override { TGrpcStatus status; { std::unique_lock guard(Mutex); Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); if (!Finished && !HasInitialMetadata) { ReadActive = true; ReadCallback = std::move(callback); InitialMetadata = metadata; if (!ReadFinished) { Stream->ReadInitialMetadata(OnReadDoneTag.Prepare()); } return; } if (!HasInitialMetadata) { if (FinishedOk) { status = Status; } else { status = TGrpcStatus::Internal("Unexpected error"); } } else { GetInitialMetadata(metadata); } } callback(std::move(status)); } void Read(TResponse* message, TReadCallback callback) override { TGrpcStatus status; { std::unique_lock guard(Mutex); Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); if (!Finished) { ReadActive = true; ReadCallback = std::move(callback); if (!ReadFinished) { Stream->Read(message, OnReadDoneTag.Prepare()); } return; } if (FinishedOk) { status = Status; } else { status = TGrpcStatus::Internal("Unexpected error"); } } if (status.Ok()) { status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF"); } callback(std::move(status)); } void Finish(TReadCallback callback) override { TGrpcStatus status; { std::unique_lock guard(Mutex); Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); if (!Finished) { ReadActive = true; FinishCallback = std::move(callback); if (!ReadFinished) { ReadFinished = true; } Stream->Finish(&Status, OnFinishedTag.Prepare()); return; } if (FinishedOk) { status = Status; } else { status = TGrpcStatus::Internal("Unexpected error"); } } callback(std::move(status)); } void AddFinishedCallback(TReadCallback callback) override { Y_VERIFY(callback, "Unexpected empty callback"); TGrpcStatus status; { std::unique_lock guard(Mutex); if (!Finished) { FinishedCallbacks.emplace_back().swap(callback); return; } if (FinishedOk) { status = Status; } else if (Cancelled) { status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled"); } else { status = TGrpcStatus::Internal("Unexpected error"); } } callback(std::move(status)); } private: void Start(TStub& stub, const TRequest& request, TAsyncRequest asyncRequest, IQueueClientContextProvider* provider) { auto context = provider->CreateContext(); if (!context) { auto callback = std::move(Callback); TGrpcStatus status(grpc::StatusCode::CANCELLED, "Client is shutting down"); callback(std::move(status), nullptr); return; } { std::unique_lock guard(Mutex); LocalContext = context; Stream = (stub.*asyncRequest)(&Context, request, context->CompletionQueue(), OnStartDoneTag.Prepare()); } context->SubscribeStop([self = TPtr(this)] { self->Cancel(); }); } void OnReadDone(bool ok) { TGrpcStatus status; TReadCallback callback; std::unordered_multimap* initialMetadata = nullptr; { std::unique_lock guard(Mutex); Y_VERIFY(ReadActive, "Unexpected Read done callback"); Y_VERIFY(!ReadFinished, "Unexpected ReadFinished flag"); if (!ok || Cancelled) { ReadFinished = true; Stream->Finish(&Status, OnFinishedTag.Prepare()); if (!ok) { // Keep ReadActive=true, so callback is called // after the call is finished with an error return; } } callback = std::move(ReadCallback); ReadCallback = nullptr; ReadActive = false; initialMetadata = InitialMetadata; InitialMetadata = nullptr; HasInitialMetadata = true; } if (initialMetadata) { GetInitialMetadata(initialMetadata); } callback(std::move(status)); } void OnStartDone(bool ok) { TReaderCallback callback; { std::unique_lock guard(Mutex); Started = true; if (!ok || Cancelled) { ReadFinished = true; Stream->Finish(&Status, OnFinishedTag.Prepare()); return; } callback = std::move(Callback); Callback = nullptr; } callback({ }, typename TBase::TPtr(this)); } void OnFinished(bool ok) { TGrpcStatus status; std::vector finishedCallbacks; TReaderCallback startCallback; TReadCallback readCallback; TReadCallback finishCallback; { std::unique_lock guard(Mutex); Finished = true; FinishedOk = ok; LocalContext.reset(); if (ok) { status = Status; } else if (Cancelled) { status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled"); } else { status = TGrpcStatus::Internal("Unexpected error"); } finishedCallbacks.swap(FinishedCallbacks); if (Callback) { Y_VERIFY(!ReadActive); startCallback = std::move(Callback); Callback = nullptr; } else if (ReadActive) { if (ReadCallback) { readCallback = std::move(ReadCallback); ReadCallback = nullptr; } else { finishCallback = std::move(FinishCallback); FinishCallback = nullptr; } ReadActive = false; } } for (auto& finishedCallback : finishedCallbacks) { auto statusCopy = status; finishedCallback(std::move(statusCopy)); } if (startCallback) { if (status.Ok()) { status = TGrpcStatus(grpc::StatusCode::UNKNOWN, "Unknown stream failure"); } startCallback(std::move(status), nullptr); } else if (readCallback) { if (status.Ok()) { status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF"); } readCallback(std::move(status)); } else if (finishCallback) { finishCallback(std::move(status)); } } TReaderCallback Callback; TAsyncReaderPtr Stream; using TFixedEvent = TQueueClientFixedEvent; std::mutex Mutex; TFixedEvent OnReadDoneTag = { this, &TSelf::OnReadDone }; TFixedEvent OnStartDoneTag = { this, &TSelf::OnStartDone }; TFixedEvent OnFinishedTag = { this, &TSelf::OnFinished }; TReadCallback ReadCallback; TReadCallback FinishCallback; std::vector FinishedCallbacks; std::unordered_multimap* InitialMetadata = nullptr; bool Started = false; bool HasInitialMetadata = false; bool ReadActive = false; bool ReadFinished = false; bool Finished = false; bool Cancelled = false; bool FinishedOk = false; }; template using TStreamConnectedCallback = std::function::TPtr)>; template class TStreamRequestReadWriteProcessor : public IStreamRequestReadWriteProcessor , public TGRpcRequestProcessorCommon { public: using TSelf = TStreamRequestReadWriteProcessor; using TBase = IStreamRequestReadWriteProcessor; using TPtr = TIntrusivePtr; using TConnectedCallback = TStreamConnectedCallback; using TReadCallback = typename TBase::TReadCallback; using TWriteCallback = typename TBase::TWriteCallback; using TAsyncReaderWriterPtr = std::unique_ptr>; using TAsyncRequest = TAsyncReaderWriterPtr (TStub::*)(grpc::ClientContext*, grpc::CompletionQueue*, void*); explicit TStreamRequestReadWriteProcessor(TConnectedCallback&& callback) : ConnectedCallback(std::move(callback)) { Y_VERIFY(ConnectedCallback, "Missing connected callback"); } void Cancel() override { Context.TryCancel(); { std::unique_lock guard(Mutex); Cancelled = true; if (Started && !(ReadFinished && WriteFinished)) { if (!ReadActive) { ReadFinished = true; } if (!WriteActive) { WriteFinished = true; } if (ReadFinished && WriteFinished) { Stream->Finish(&Status, OnFinishedTag.Prepare()); } } } } void Write(TRequest&& request, TWriteCallback callback) override { TGrpcStatus status; { std::unique_lock guard(Mutex); if (Cancelled || ReadFinished || WriteFinished) { status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Write request dropped"); } else if (WriteActive) { auto& item = WriteQueue.emplace_back(); item.Callback.swap(callback); item.Request.Swap(&request); } else { WriteActive = true; WriteCallback.swap(callback); Stream->Write(request, OnWriteDoneTag.Prepare()); } } if (!status.Ok() && callback) { callback(std::move(status)); } } void ReadInitialMetadata(std::unordered_multimap* metadata, TReadCallback callback) override { TGrpcStatus status; { std::unique_lock guard(Mutex); Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); if (!Finished && !HasInitialMetadata) { ReadActive = true; ReadCallback = std::move(callback); InitialMetadata = metadata; if (!ReadFinished) { Stream->ReadInitialMetadata(OnReadDoneTag.Prepare()); } return; } if (!HasInitialMetadata) { if (FinishedOk) { status = Status; } else { status = TGrpcStatus::Internal("Unexpected error"); } } else { GetInitialMetadata(metadata); } } callback(std::move(status)); } void Read(TResponse* message, TReadCallback callback) override { TGrpcStatus status; { std::unique_lock guard(Mutex); Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); if (!Finished) { ReadActive = true; ReadCallback = std::move(callback); if (!ReadFinished) { Stream->Read(message, OnReadDoneTag.Prepare()); } return; } if (FinishedOk) { status = Status; } else { status = TGrpcStatus::Internal("Unexpected error"); } } if (status.Ok()) { status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF"); } callback(std::move(status)); } void Finish(TReadCallback callback) override { TGrpcStatus status; { std::unique_lock guard(Mutex); Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); if (!Finished) { ReadActive = true; FinishCallback = std::move(callback); if (!ReadFinished) { ReadFinished = true; if (!WriteActive) { WriteFinished = true; } if (WriteFinished) { Stream->Finish(&Status, OnFinishedTag.Prepare()); } } return; } if (FinishedOk) { status = Status; } else { status = TGrpcStatus::Internal("Unexpected error"); } } callback(std::move(status)); } void AddFinishedCallback(TReadCallback callback) override { Y_VERIFY(callback, "Unexpected empty callback"); TGrpcStatus status; { std::unique_lock guard(Mutex); if (!Finished) { FinishedCallbacks.emplace_back().swap(callback); return; } if (FinishedOk) { status = Status; } else if (Cancelled) { status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled"); } else { status = TGrpcStatus::Internal("Unexpected error"); } } callback(std::move(status)); } private: template friend class TServiceConnection; void Start(TStub& stub, TAsyncRequest asyncRequest, IQueueClientContextProvider* provider) { auto context = provider->CreateContext(); if (!context) { auto callback = std::move(ConnectedCallback); TGrpcStatus status(grpc::StatusCode::CANCELLED, "Client is shutting down"); callback(std::move(status), nullptr); return; } { std::unique_lock guard(Mutex); LocalContext = context; Stream = (stub.*asyncRequest)(&Context, context->CompletionQueue(), OnConnectedTag.Prepare()); } context->SubscribeStop([self = TPtr(this)] { self->Cancel(); }); } private: void OnConnected(bool ok) { TConnectedCallback callback; { std::unique_lock guard(Mutex); Started = true; if (!ok || Cancelled) { ReadFinished = true; WriteFinished = true; Stream->Finish(&Status, OnFinishedTag.Prepare()); return; } callback = std::move(ConnectedCallback); ConnectedCallback = nullptr; } callback({ }, typename TBase::TPtr(this)); } void OnReadDone(bool ok) { TGrpcStatus status; TReadCallback callback; std::unordered_multimap* initialMetadata = nullptr; { std::unique_lock guard(Mutex); Y_VERIFY(ReadActive, "Unexpected Read done callback"); Y_VERIFY(!ReadFinished, "Unexpected ReadFinished flag"); if (!ok || Cancelled || WriteFinished) { ReadFinished = true; if (!WriteActive) { WriteFinished = true; } if (WriteFinished) { Stream->Finish(&Status, OnFinishedTag.Prepare()); } if (!ok) { // Keep ReadActive=true, so callback is called // after the call is finished with an error return; } } callback = std::move(ReadCallback); ReadCallback = nullptr; ReadActive = false; initialMetadata = InitialMetadata; InitialMetadata = nullptr; HasInitialMetadata = true; } if (initialMetadata) { GetInitialMetadata(initialMetadata); } callback(std::move(status)); } void OnWriteDone(bool ok) { TWriteCallback okCallback; { std::unique_lock guard(Mutex); Y_VERIFY(WriteActive, "Unexpected Write done callback"); Y_VERIFY(!WriteFinished, "Unexpected WriteFinished flag"); if (ok) { okCallback.swap(WriteCallback); } else if (WriteCallback) { // Put callback back on the queue until OnFinished auto& item = WriteQueue.emplace_front(); item.Callback.swap(WriteCallback); } if (!ok || Cancelled) { WriteActive = false; WriteFinished = true; if (!ReadActive) { ReadFinished = true; } if (ReadFinished) { Stream->Finish(&Status, OnFinishedTag.Prepare()); } } else if (!WriteQueue.empty()) { WriteCallback.swap(WriteQueue.front().Callback); Stream->Write(WriteQueue.front().Request, OnWriteDoneTag.Prepare()); WriteQueue.pop_front(); } else { WriteActive = false; if (ReadFinished) { WriteFinished = true; Stream->Finish(&Status, OnFinishedTag.Prepare()); } } } if (okCallback) { okCallback(TGrpcStatus()); } } void OnFinished(bool ok) { TGrpcStatus status; std::deque writesDropped; std::vector finishedCallbacks; TConnectedCallback connectedCallback; TReadCallback readCallback; TReadCallback finishCallback; { std::unique_lock guard(Mutex); Finished = true; FinishedOk = ok; LocalContext.reset(); if (ok) { status = Status; } else if (Cancelled) { status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled"); } else { status = TGrpcStatus::Internal("Unexpected error"); } writesDropped.swap(WriteQueue); finishedCallbacks.swap(FinishedCallbacks); if (ConnectedCallback) { Y_VERIFY(!ReadActive); connectedCallback = std::move(ConnectedCallback); ConnectedCallback = nullptr; } else if (ReadActive) { if (ReadCallback) { readCallback = std::move(ReadCallback); ReadCallback = nullptr; } else { finishCallback = std::move(FinishCallback); FinishCallback = nullptr; } ReadActive = false; } } for (auto& item : writesDropped) { if (item.Callback) { TGrpcStatus writeStatus = status; if (writeStatus.Ok()) { writeStatus = TGrpcStatus(grpc::StatusCode::CANCELLED, "Write request dropped"); } item.Callback(std::move(writeStatus)); } } for (auto& finishedCallback : finishedCallbacks) { TGrpcStatus statusCopy = status; finishedCallback(std::move(statusCopy)); } if (connectedCallback) { if (status.Ok()) { status = TGrpcStatus(grpc::StatusCode::UNKNOWN, "Unknown stream failure"); } connectedCallback(std::move(status), nullptr); } else if (readCallback) { if (status.Ok()) { status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF"); } readCallback(std::move(status)); } else if (finishCallback) { finishCallback(std::move(status)); } } private: struct TWriteItem { TWriteCallback Callback; TRequest Request; }; private: using TFixedEvent = TQueueClientFixedEvent; TFixedEvent OnConnectedTag = { this, &TSelf::OnConnected }; TFixedEvent OnReadDoneTag = { this, &TSelf::OnReadDone }; TFixedEvent OnWriteDoneTag = { this, &TSelf::OnWriteDone }; TFixedEvent OnFinishedTag = { this, &TSelf::OnFinished }; private: std::mutex Mutex; TAsyncReaderWriterPtr Stream; TConnectedCallback ConnectedCallback; TReadCallback ReadCallback; TReadCallback FinishCallback; std::vector FinishedCallbacks; std::deque WriteQueue; TWriteCallback WriteCallback; std::unordered_multimap* InitialMetadata = nullptr; bool Started = false; bool HasInitialMetadata = false; bool ReadActive = false; bool ReadFinished = false; bool WriteActive = false; bool WriteFinished = false; bool Finished = false; bool Cancelled = false; bool FinishedOk = false; }; class TGRpcClientLow; template class TServiceConnection { using TStub = typename TGRpcService::Stub; friend class TGRpcClientLow; public: /* * Start simple request */ template void DoRequest(const TRequest& request, TResponseCallback callback, typename TSimpleRequestProcessor::TAsyncRequest asyncRequest, const TCallMeta& metas = { }, IQueueClientContextProvider* provider = nullptr) { auto processor = MakeIntrusive>(std::move(callback)); processor->ApplyMeta(metas); processor->Start(*Stub_, asyncRequest, request, provider ? provider : Provider_); } /* * Start simple request */ template void DoAdvancedRequest(const TRequest& request, TAdvancedResponseCallback callback, typename TAdvancedRequestProcessor::TAsyncRequest asyncRequest, const TCallMeta& metas = { }, IQueueClientContextProvider* provider = nullptr) { auto processor = MakeIntrusive>(std::move(callback)); processor->ApplyMeta(metas); processor->Start(*Stub_, asyncRequest, request, provider ? provider : Provider_); } /* * Start bidirectional streamming */ template void DoStreamRequest(TStreamConnectedCallback callback, typename TStreamRequestReadWriteProcessor::TAsyncRequest asyncRequest, const TCallMeta& metas = { }, IQueueClientContextProvider* provider = nullptr) { auto processor = MakeIntrusive>(std::move(callback)); processor->ApplyMeta(metas); processor->Start(*Stub_, std::move(asyncRequest), provider ? provider : Provider_); } /* * Start streaming response reading (one request, many responses) */ template void DoStreamRequest(const TRequest& request, TStreamReaderCallback callback, typename TStreamRequestReadProcessor::TAsyncRequest asyncRequest, const TCallMeta& metas = { }, IQueueClientContextProvider* provider = nullptr) { auto processor = MakeIntrusive>(std::move(callback)); processor->ApplyMeta(metas); processor->Start(*Stub_, request, std::move(asyncRequest), provider ? provider : Provider_); } private: TServiceConnection(std::shared_ptr ci, IQueueClientContextProvider* provider) : Stub_(TGRpcService::NewStub(ci)) , Provider_(provider) { Y_VERIFY(Provider_, "Connection does not have a queue provider"); } TServiceConnection(TStubsHolder& holder, IQueueClientContextProvider* provider) : Stub_(holder.GetOrCreateStub()) , Provider_(provider) { Y_VERIFY(Provider_, "Connection does not have a queue provider"); } std::shared_ptr Stub_; IQueueClientContextProvider* Provider_; }; class TGRpcClientLow : public IQueueClientContextProvider { class TContextImpl; friend class TContextImpl; enum ECqState : TAtomicBase { WORKING = 0, STOP_SILENT = 1, STOP_EXPLICIT = 2, }; public: explicit TGRpcClientLow(size_t numWorkerThread = DEFAULT_NUM_THREADS, bool useCompletionQueuePerThread = false); ~TGRpcClientLow(); // Tries to stop all currently running requests (via their stop callbacks) // Will shutdown CQ and drain events once all requests have finished // No new requests may be started after this call void Stop(bool wait = false); // Waits until all currently running requests finish execution void WaitIdle(); inline bool IsStopping() const { switch (GetCqState()) { case WORKING: return false; case STOP_SILENT: case STOP_EXPLICIT: return true; } Y_UNREACHABLE(); } IQueueClientContextPtr CreateContext() override; template std::unique_ptr> CreateGRpcServiceConnection(const TGRpcClientConfig& config) { return std::unique_ptr>(new TServiceConnection(CreateChannelInterface(config), this)); } template std::unique_ptr> CreateGRpcServiceConnection(TStubsHolder& holder) { return std::unique_ptr>(new TServiceConnection(holder, this)); } // Tests only, not thread-safe void AddWorkerThreadForTest(); private: using IThreadRef = std::unique_ptr; using CompletionQueueRef = std::unique_ptr; void Init(size_t numWorkerThread); inline ECqState GetCqState() const { return (ECqState) AtomicGet(CqState_); } inline void SetCqState(ECqState state) { AtomicSet(CqState_, state); } void StopInternal(bool silent); void WaitInternal(); void ForgetContext(TContextImpl* context); private: bool UseCompletionQueuePerThread_; std::vector CQS_; std::vector WorkerThreads_; TAtomic CqState_ = -1; std::mutex Mtx_; std::condition_variable ContextsEmpty_; std::unordered_set Contexts_; std::mutex JoinMutex_; }; } // namespace NGRpc