#include "tcp2.h" #include "details.h" #include "factory.h" #include "http_common.h" #include "neh.h" #include "utils.h" #include #include #include #include #include #include #include #include #include #include #include #include //#define DEBUG_TCP2 1 #ifdef DEBUG_TCP2 TSpinLock OUT_LOCK; #define DBGOUT(args) \ { \ TGuard m(OUT_LOCK); \ Cout << TInstant::Now().GetValue() << " " << args << Endl; \ } #else #define DBGOUT(args) #endif using namespace std::placeholders; namespace NNeh { TDuration TTcp2Options::ConnectTimeout = TDuration::MilliSeconds(300); size_t TTcp2Options::InputBufferSize = 16000; size_t TTcp2Options::AsioClientThreads = 4; size_t TTcp2Options::AsioServerThreads = 4; int TTcp2Options::Backlog = 100; bool TTcp2Options::ClientUseDirectWrite = true; bool TTcp2Options::ServerUseDirectWrite = true; TDuration TTcp2Options::ServerInputDeadline = TDuration::Seconds(3600); TDuration TTcp2Options::ServerOutputDeadline = TDuration::Seconds(10); bool TTcp2Options::Set(TStringBuf name, TStringBuf value) { #define TCP2_TRY_SET(optType, optName) \ if (name == TStringBuf(#optName)) { \ optName = FromString(value); \ } TCP2_TRY_SET(TDuration, ConnectTimeout) else TCP2_TRY_SET(size_t, InputBufferSize) else TCP2_TRY_SET(size_t, AsioClientThreads) else TCP2_TRY_SET(size_t, AsioServerThreads) else TCP2_TRY_SET(int, Backlog) else TCP2_TRY_SET(bool, ClientUseDirectWrite) else TCP2_TRY_SET(bool, ServerUseDirectWrite) else TCP2_TRY_SET(TDuration, ServerInputDeadline) else TCP2_TRY_SET(TDuration, ServerOutputDeadline) else { return false; } return true; } } namespace { namespace NNehTcp2 { using namespace NAsio; using namespace NDns; using namespace NNeh; const TString canceled = "canceled"; const TString emptyReply = "empty reply"; inline void PrepareSocket(SOCKET s) { SetNoDelay(s, true); } typedef ui64 TRequestId; #pragma pack(push, 1) //disable align struct members (structs mapped to data transmitted other network) struct TBaseHeader { enum TMessageType { Request = 1, Response = 2, Cancel = 3, MaxMessageType }; TBaseHeader(TRequestId id, ui32 headerLength, ui8 version, ui8 mType) : Id(id) , HeaderLength(headerLength) , Version(version) , Type(mType) { } TRequestId Id; //message id, - monotonic inc. sequence (skip nil value) ui32 HeaderLength; ui8 Version; //current version: 1 ui8 Type; //<- TMessageType (+ in future possible ForceResponse,etc) }; struct TRequestHeader: public TBaseHeader { TRequestHeader(TRequestId reqId, size_t servicePathLength, size_t dataSize) : TBaseHeader(reqId, sizeof(TRequestHeader) + servicePathLength, 1, (ui8)Request) , ContentLength(dataSize) { } ui32 ContentLength; }; struct TResponseHeader: public TBaseHeader { enum TErrorCode { Success = 0, EmptyReply = 1 //not found such service or service not sent response , MaxErrorCode }; TResponseHeader(TRequestId reqId, TErrorCode code, size_t dataSize) : TBaseHeader(reqId, sizeof(TResponseHeader), 1, (ui8)Response) , ErrorCode((ui16)code) , ContentLength(dataSize) { } TString ErrorDescription() const { if (ErrorCode == (ui16)EmptyReply) { return emptyReply; } TStringStream ss; ss << TStringBuf("tcp2 err_code=") << ErrorCode; return ss.Str(); } ui16 ErrorCode; ui32 ContentLength; }; struct TCancelHeader: public TBaseHeader { TCancelHeader(TRequestId reqId) : TBaseHeader(reqId, sizeof(TCancelHeader), 1, (ui8)Cancel) { } }; #pragma pack(pop) static const size_t maxHeaderSize = sizeof(TResponseHeader); //buffer for read input data, - header + message data struct TTcp2Message { TTcp2Message() : Loader_(&TTcp2Message::LoadBaseHeader) , RequireBytesForComplete_(sizeof(TBaseHeader)) , Header_(sizeof(TBaseHeader)) { } void Clear() { Loader_ = &TTcp2Message::LoadBaseHeader; RequireBytesForComplete_ = sizeof(TBaseHeader); Header_.Clear(); Content_.clear(); } TBuffer& Header() noexcept { return Header_; } const TString& Content() const noexcept { return Content_; } bool IsComplete() const noexcept { return RequireBytesForComplete_ == 0; } size_t LoadFrom(const char* buf, size_t len) { return (this->*Loader_)(buf, len); } const TBaseHeader& BaseHeader() const { return *reinterpret_cast(Header_.Data()); } const TRequestHeader& RequestHeader() const { return *reinterpret_cast(Header_.Data()); } const TResponseHeader& ResponseHeader() const { return *reinterpret_cast(Header_.Data()); } private: size_t LoadBaseHeader(const char* buf, size_t len) { size_t useBytes = Min(sizeof(TBaseHeader) - Header_.Size(), len); Header_.Append(buf, useBytes); if (Y_UNLIKELY(sizeof(TBaseHeader) > Header_.Size())) { //base header yet not complete return useBytes; } { const TBaseHeader& hdr = BaseHeader(); if (BaseHeader().HeaderLength > 32000) { //some heuristic header size limit throw yexception() << TStringBuf("to large neh/tcp2 header size: ") << BaseHeader().HeaderLength; } //header completed Header_.Reserve(hdr.HeaderLength); } const TBaseHeader& hdr = BaseHeader(); //reallocation can move Header_ data to another place, so use fresh 'hdr' if (Y_UNLIKELY(hdr.Version != 1)) { throw yexception() << TStringBuf("unsupported protocol version: ") << static_cast(hdr.Version); } RequireBytesForComplete_ = hdr.HeaderLength - sizeof(TBaseHeader); return useBytes + LoadHeader(buf + useBytes, len - useBytes); } size_t LoadHeader(const char* buf, size_t len) { size_t useBytes = Min(RequireBytesForComplete_, len); Header_.Append(buf, useBytes); RequireBytesForComplete_ -= useBytes; if (RequireBytesForComplete_) { //continue load header Loader_ = &TTcp2Message::LoadHeader; return useBytes; } const TBaseHeader& hdr = *reinterpret_cast(Header_.Data()); if (hdr.Type == TBaseHeader::Request) { if (Header_.Size() < sizeof(TRequestHeader)) { throw yexception() << TStringBuf("invalid request header size"); } InitContentLoading(RequestHeader().ContentLength); } else if (hdr.Type == TBaseHeader::Response) { if (Header_.Size() < sizeof(TResponseHeader)) { throw yexception() << TStringBuf("invalid response header size"); } InitContentLoading(ResponseHeader().ContentLength); } else if (hdr.Type == TBaseHeader::Cancel) { if (Header_.Size() < sizeof(TCancelHeader)) { throw yexception() << TStringBuf("invalid cancel header size"); } return useBytes; } else { throw yexception() << TStringBuf("unsupported request type: ") << static_cast(hdr.Type); } return useBytes + (this->*Loader_)(buf + useBytes, len - useBytes); } void InitContentLoading(size_t contentLength) { RequireBytesForComplete_ = contentLength; Content_.ReserveAndResize(contentLength); Loader_ = &TTcp2Message::LoadContent; } size_t LoadContent(const char* buf, size_t len) { size_t curContentSize = Content_.size() - RequireBytesForComplete_; size_t useBytes = Min(RequireBytesForComplete_, len); memcpy(Content_.begin() + curContentSize, buf, useBytes); RequireBytesForComplete_ -= useBytes; return useBytes; } private: typedef size_t (TTcp2Message::*TLoader)(const char*, size_t); TLoader Loader_; //current loader (stages - base-header/header/content) size_t RequireBytesForComplete_; TBuffer Header_; TString Content_; }; //base storage for output data class TMultiBuffers { public: TMultiBuffers() : IOVec_(nullptr, 0) , DataSize_(0) , PoolBytes_(0) { } void Clear() noexcept { Parts_.clear(); DataSize_ = 0; PoolBytes_ = 0; } bool HasFreeSpace() const noexcept { return DataSize_ < 64000 && (PoolBytes_ < (MemPoolSize_ - maxHeaderSize)); } bool HasData() const noexcept { return Parts_.size(); } TContIOVector* GetIOvec() noexcept { return &IOVec_; } protected: void AddPart(const void* buf, size_t len) { Parts_.push_back(IOutputStream::TPart(buf, len)); DataSize_ += len; } //used for allocate header (MUST be POD type) template inline T* Allocate() noexcept { size_t poolBytes = PoolBytes_; PoolBytes_ += sizeof(T); return (T*)(MemPool_ + poolBytes); } //used for allocate header (MUST be POD type) + some tail template inline T* AllocatePlus(size_t tailSize) noexcept { Y_ASSERT(tailSize <= MemPoolReserve_); size_t poolBytes = PoolBytes_; PoolBytes_ += sizeof(T) + tailSize; return (T*)(MemPool_ + poolBytes); } protected: TContIOVector IOVec_; TVector Parts_; static const size_t MemPoolSize_ = maxHeaderSize * 100; static const size_t MemPoolReserve_ = 32; size_t DataSize_; size_t PoolBytes_; char MemPool_[MemPoolSize_ + MemPoolReserve_]; }; //protector for limit usage tcp connection output (and used data) only from one thread at same time class TOutputLock { public: TOutputLock() noexcept : Lock_(0) { } bool TryAquire() noexcept { do { if (AtomicTryLock(&Lock_)) { return true; } } while (!AtomicGet(Lock_)); //without magic loop atomic lock some unreliable return false; } void Release() noexcept { AtomicUnlock(&Lock_); } bool IsFree() const noexcept { return !AtomicGet(Lock_); } private: TAtomic Lock_; }; class TClient { class TRequest; class TConnection; typedef TIntrusivePtr TRequestRef; typedef TIntrusivePtr TConnectionRef; class TRequest: public TThrRefBase, public TNonCopyable { public: class THandle: public TSimpleHandle { public: THandle(IOnRecv* f, const TMessage& msg, TStatCollector* s) noexcept : TSimpleHandle(f, msg, s) { } bool MessageSendedCompletely() const noexcept override { if (TSimpleHandle::MessageSendedCompletely()) { return true; } TRequestRef req = GetRequest(); if (!!req && req->RequestSendedCompletely()) { const_cast(this)->SetSendComplete(); } return TSimpleHandle::MessageSendedCompletely(); } void Cancel() noexcept override { if (TSimpleHandle::Canceled()) { return; } TRequestRef req = GetRequest(); if (!!req) { req->Cancel(); TSimpleHandle::Cancel(); } } void NotifyResponse(const TString& resp) { TNotifyHandle::NotifyResponse(resp); ReleaseRequest(); } void NotifyError(const TString& error) { TNotifyHandle::NotifyError(error); ReleaseRequest(); } void NotifyError(TErrorRef error) { TNotifyHandle::NotifyError(error); ReleaseRequest(); } //not thread safe! void SetRequest(const TRequestRef& r) noexcept { Req_ = r; } void ReleaseRequest() noexcept { TRequestRef tmp; TGuard g(SP_); tmp.Swap(Req_); } private: TRequestRef GetRequest() const noexcept { TGuard g(SP_); return Req_; } mutable TSpinLock SP_; TRequestRef Req_; }; typedef TIntrusivePtr THandleRef; static void Run(THandleRef& h, const TMessage& msg, TClient& clnt) { TRequestRef req(new TRequest(h, msg, clnt)); h->SetRequest(req); req->Run(req); } ~TRequest() override { DBGOUT("TClient::~TRequest()"); } private: TRequest(THandleRef& h, TMessage msg, TClient& clnt) : Hndl_(h) , Clnt_(clnt) , Msg_(std::move(msg)) , Loc_(Msg_.Addr) , Addr_(CachedResolve(TResolveInfo(Loc_.Host, Loc_.GetPort()))) , Canceled_(false) , Id_(0) { DBGOUT("TClient::TRequest()"); } void Run(TRequestRef& req) { TDestination& dest = Clnt_.Dest_.Get(Addr_->Id); dest.Run(req); } public: void OnResponse(TTcp2Message& msg) { DBGOUT("TRequest::OnResponse: " << msg.ResponseHeader().Id); THandleRef h = ReleaseHandler(); if (!h) { return; } const TResponseHeader& respHdr = msg.ResponseHeader(); if (Y_LIKELY(!respHdr.ErrorCode)) { h->NotifyResponse(msg.Content()); } else { h->NotifyError(new TError(respHdr.ErrorDescription(), TError::ProtocolSpecific, respHdr.ErrorCode)); } ReleaseConn(); } void OnError(const TString& err, const i32 systemCode = 0) { DBGOUT("TRequest::OnError: " << Id_.load(std::memory_order_acquire)); THandleRef h = ReleaseHandler(); if (!h) { return; } h->NotifyError(new TError(err, TError::UnknownType, 0, systemCode)); ReleaseConn(); } void SetConnection(TConnection* conn) noexcept { auto g = Guard(AL_); Conn_ = conn; } bool Canceled() const noexcept { return Canceled_; } const TResolvedHost* Addr() const noexcept { return Addr_; } TStringBuf Service() const noexcept { return Loc_.Service; } const TString& Data() const noexcept { return Msg_.Data; } TClient& Client() noexcept { return Clnt_; } bool RequestSendedCompletely() const noexcept { if (Id_.load(std::memory_order_acquire) == 0) { return false; } TConnectionRef conn = GetConn(); if (!conn) { return false; } TRequestId lastSendedReqId = conn->LastSendedRequestId(); if (lastSendedReqId >= Id_.load(std::memory_order_acquire)) { return true; } else if (Y_UNLIKELY((Id_.load(std::memory_order_acquire) - lastSendedReqId) > (Max() - Max()))) { //overflow req-id value return true; } return false; } void Cancel() noexcept { Canceled_ = true; THandleRef h = ReleaseHandler(); if (!h) { return; } TConnectionRef conn = ReleaseConn(); if (!!conn && Id_.load(std::memory_order_acquire)) { conn->Cancel(Id_.load(std::memory_order_acquire)); } h->NotifyError(new TError(canceled, TError::Cancelled)); } void SetReqId(TRequestId reqId) noexcept { auto guard = Guard(IdLock_); Id_.store(reqId, std::memory_order_release); } TRequestId ReqId() const noexcept { return Id_.load(std::memory_order_acquire); } private: inline THandleRef ReleaseHandler() noexcept { THandleRef h; { auto g = Guard(AL_); h.Swap(Hndl_); } return h; } inline TConnectionRef GetConn() const noexcept { auto g = Guard(AL_); return Conn_; } inline TConnectionRef ReleaseConn() noexcept { TConnectionRef c; { auto g = Guard(AL_); c.Swap(Conn_); } return c; } mutable TAdaptiveLock AL_; //guaranted calling notify() only once (prevent race between asio thread and current) THandleRef Hndl_; TClient& Clnt_; const TMessage Msg_; const TParsedLocation Loc_; const TResolvedHost* Addr_; TConnectionRef Conn_; NAtomic::TBool Canceled_; TSpinLock IdLock_; std::atomic Id_; }; class TConnection: public TThrRefBase { enum TState { Init, Connecting, Connected, Closed, MaxState }; typedef THashMap TReqsInFly; public: class TOutputBuffers: public TMultiBuffers { public: void AddRequest(const TRequestRef& req) { Requests_.push_back(req); if (req->Service().size() > MemPoolReserve_) { TRequestHeader* hdr = new (Allocate()) TRequestHeader(req->ReqId(), req->Service().size(), req->Data().size()); AddPart(hdr, sizeof(TRequestHeader)); AddPart(req->Service().data(), req->Service().size()); } else { TRequestHeader* hdr = new (AllocatePlus(req->Service().size())) TRequestHeader(req->ReqId(), req->Service().size(), req->Data().size()); AddPart(hdr, sizeof(TRequestHeader) + req->Service().size()); memmove(++hdr, req->Service().data(), req->Service().size()); } AddPart(req->Data().data(), req->Data().size()); IOVec_ = TContIOVector(Parts_.data(), Parts_.size()); } void AddCancelRequest(TRequestId reqId) { TCancelHeader* hdr = new (Allocate()) TCancelHeader(reqId); AddPart(hdr, sizeof(TCancelHeader)); IOVec_ = TContIOVector(Parts_.data(), Parts_.size()); } void Clear() { TMultiBuffers::Clear(); Requests_.clear(); } private: TVector Requests_; }; TConnection(TIOService& srv) : AS_(srv) , State_(Init) , BuffSize_(TTcp2Options::InputBufferSize) , Buff_(new char[BuffSize_]) , NeedCheckReqsQueue_(0) , NeedCheckCancelsQueue_(0) , GenReqId_(0) , LastSendedReqId_(0) { } ~TConnection() override { try { DBGOUT("TClient::~TConnection()"); OnError("~"); } catch (...) { Cdbg << "tcp2::~cln_conn: " << CurrentExceptionMessage() << Endl; } } //called from client thread bool Run(TRequestRef& req) { if (Y_UNLIKELY(AtomicGet(State_) == Closed)) { return false; } req->Ref(); try { Reqs_.Enqueue(req.Get()); } catch (...) { req->UnRef(); throw; } AtomicSet(NeedCheckReqsQueue_, 1); req->SetConnection(this); TAtomicBase state = AtomicGet(State_); if (Y_LIKELY(state == Connected)) { ProcessOutputReqsQueue(); return true; } if (state == Init) { if (AtomicCas(&State_, Connecting, Init)) { try { TEndpoint addr(new NAddr::TAddrInfo(&*req->Addr()->Addr.Begin())); AS_.AsyncConnect(addr, std::bind(&TConnection::OnConnect, TConnectionRef(this), _1, _2), TTcp2Options::ConnectTimeout); } catch (...) { AS_.GetIOService().Post(std::bind(&TConnection::OnErrorCallback, TConnectionRef(this), CurrentExceptionMessage())); } return true; } } state = AtomicGet(State_); if (state == Connected) { ProcessOutputReqsQueue(); } else if (state == Closed) { SafeOnError(); } return true; } //called from client thread void Cancel(TRequestId id) { Cancels_.Enqueue(id); AtomicSet(NeedCheckCancelsQueue_, 1); if (Y_LIKELY(AtomicGet(State_) == Connected)) { ProcessOutputCancelsQueue(); } } void ProcessOutputReqsQueue() { if (OutputLock_.TryAquire()) { SendMessages(false); } } void ProcessOutputCancelsQueue() { if (OutputLock_.TryAquire()) { AS_.GetIOService().Post(std::bind(&TConnection::SendMessages, TConnectionRef(this), true)); return; } } //must be called only from asio thread void ProcessReqsInFlyQueue() { if (AtomicGet(State_) == Closed) { return; } TRequest* reqPtr; while (ReqsInFlyQueue_.Dequeue(&reqPtr)) { TRequestRef reqTmp(reqPtr); reqPtr->UnRef(); ReqsInFly_[reqPtr->ReqId()].Swap(reqTmp); } } //must be called only from asio thread void OnConnect(const TErrorCode& ec, IHandlingContext&) { DBGOUT("TConnect::OnConnect: " << ec.Value()); if (Y_UNLIKELY(ec)) { if (ec.Value() == EIO) { //try get more detail error info char buf[1]; TErrorCode errConnect; AS_.ReadSome(buf, 1, errConnect); OnErrorCode(errConnect.Value() ? errConnect : ec); } else { OnErrorCode(ec); } } else { try { PrepareSocket(AS_.Native()); AtomicSet(State_, Connected); AS_.AsyncPollRead(std::bind(&TConnection::OnCanRead, TConnectionRef(this), _1, _2)); if (OutputLock_.TryAquire()) { SendMessages(true); return; } } catch (...) { OnError(CurrentExceptionMessage()); } } } //must be called only after succes aquiring output void SendMessages(bool asioThread) { //DBGOUT("SendMessages"); if (Y_UNLIKELY(AtomicGet(State_) == Closed)) { if (asioThread) { OnError(Error_); } else { SafeOnError(); } return; } do { if (asioThread) { AtomicSet(NeedCheckCancelsQueue_, 0); TRequestId reqId; ProcessReqsInFlyQueue(); while (Cancels_.Dequeue(&reqId)) { TReqsInFly::iterator it = ReqsInFly_.find(reqId); if (it == ReqsInFly_.end()) { continue; } ReqsInFly_.erase(it); OutputBuffers_.AddCancelRequest(reqId); if (Y_UNLIKELY(!OutputBuffers_.HasFreeSpace())) { if (!FlushOutputBuffers(asioThread, 0)) { return; } } } } else if (AtomicGet(NeedCheckCancelsQueue_)) { AS_.GetIOService().Post(std::bind(&TConnection::SendMessages, TConnectionRef(this), true)); return; } TRequestId lastReqId = 0; { AtomicSet(NeedCheckReqsQueue_, 0); TRequest* reqPtr; while (Reqs_.Dequeue(&reqPtr)) { TRequestRef reqTmp(reqPtr); reqPtr->UnRef(); reqPtr->SetReqId(GenerateReqId()); if (reqPtr->Canceled()) { continue; } lastReqId = reqPtr->ReqId(); if (asioThread) { TRequestRef& req = ReqsInFly_[(TRequestId)reqPtr->ReqId()]; req.Swap(reqTmp); OutputBuffers_.AddRequest(req); } else { //can access to ReqsInFly_ only from asio thread, so enqueue req to update ReqsInFly_ queue try { reqTmp->Ref(); ReqsInFlyQueue_.Enqueue(reqPtr); } catch (...) { reqTmp->UnRef(); throw; } OutputBuffers_.AddRequest(reqTmp); } if (Y_UNLIKELY(!OutputBuffers_.HasFreeSpace())) { if (!FlushOutputBuffers(asioThread, lastReqId)) { return; } } } } if (OutputBuffers_.HasData()) { if (!FlushOutputBuffers(asioThread, lastReqId)) { return; } } OutputLock_.Release(); if (!AtomicGet(NeedCheckReqsQueue_) && !AtomicGet(NeedCheckCancelsQueue_)) { DBGOUT("TClient::SendMessages(exit2)"); return; } } while (OutputLock_.TryAquire()); DBGOUT("TClient::SendMessages(exit1)"); } TRequestId GenerateReqId() noexcept { TRequestId reqId; { auto guard = Guard(GenReqIdLock_); reqId = ++GenReqId_; } return Y_LIKELY(reqId) ? reqId : GenerateReqId(); } //called non thread-safe (from outside thread) bool FlushOutputBuffers(bool asioThread, TRequestId reqId) { if (asioThread || TTcp2Options::ClientUseDirectWrite) { TContIOVector& vec = *OutputBuffers_.GetIOvec(); TErrorCode err; vec.Proceed(AS_.WriteSome(vec, err)); if (Y_UNLIKELY(err)) { if (asioThread) { OnErrorCode(err); } else { AS_.GetIOService().Post(std::bind(&TConnection::OnErrorCode, TConnectionRef(this), err)); } return false; } if (vec.Complete()) { LastSendedReqId_.store(reqId, std::memory_order_release); DBGOUT("Client::FlushOutputBuffers(" << reqId << ")"); OutputBuffers_.Clear(); return true; } } DBGOUT("Client::AsyncWrite(" << reqId << ")"); AS_.AsyncWrite(OutputBuffers_.GetIOvec(), std::bind(&TConnection::OnSend, TConnectionRef(this), reqId, _1, _2, _3), TTcp2Options::ServerOutputDeadline); return false; } //must be called only from asio thread void OnSend(TRequestId reqId, const TErrorCode& ec, size_t amount, IHandlingContext&) { Y_UNUSED(amount); if (Y_UNLIKELY(ec)) { OnErrorCode(ec); } else { if (Y_LIKELY(reqId)) { DBGOUT("Client::OnSend(" << reqId << ")"); LastSendedReqId_.store(reqId, std::memory_order_release); } //output already aquired, used asio thread OutputBuffers_.Clear(); SendMessages(true); } } //must be called only from asio thread void OnCanRead(const TErrorCode& ec, IHandlingContext& ctx) { //DBGOUT("OnCanRead(" << ec.Value() << ")"); if (Y_UNLIKELY(ec)) { OnErrorCode(ec); } else { TErrorCode ec2; OnReadSome(ec2, AS_.ReadSome(Buff_.Get(), BuffSize_, ec2), ctx); } } //must be called only from asio thread void OnReadSome(const TErrorCode& ec, size_t amount, IHandlingContext& ctx) { //DBGOUT("OnReadSome(" << ec.Value() << ", " << amount << ")"); if (Y_UNLIKELY(ec)) { OnErrorCode(ec); return; } while (1) { if (Y_UNLIKELY(!amount)) { OnError("tcp conn. closed"); return; } try { const char* buff = Buff_.Get(); size_t leftBytes = amount; do { size_t useBytes = Msg_.LoadFrom(buff, leftBytes); leftBytes -= useBytes; buff += useBytes; if (Msg_.IsComplete()) { //DBGOUT("OnReceiveMessage(" << Msg_.BaseHeader().Id << "): " << leftBytes); OnReceiveMessage(); Msg_.Clear(); } } while (leftBytes); if (amount == BuffSize_) { //try decrease system calls, - re-run ReadSome if has full filled buffer TErrorCode ecR; amount = AS_.ReadSome(Buff_.Get(), BuffSize_, ecR); if (!ecR) { continue; //process next input data } if (ecR.Value() == EAGAIN || ecR.Value() == EWOULDBLOCK) { ctx.ContinueUseHandler(); } else { OnErrorCode(ec); } } else { ctx.ContinueUseHandler(); } } catch (...) { OnError(CurrentExceptionMessage()); } return; } } //must be called only from asio thread void OnErrorCode(TErrorCode ec) { OnError(ec.Text(), ec.Value()); } //must be called only from asio thread void OnErrorCallback(TString err) { OnError(err); } //must be called only from asio thread void OnError(const TString& err, const i32 systemCode = 0) { if (AtomicGet(State_) != Closed) { Error_ = err; SystemCode_ = systemCode; AtomicSet(State_, Closed); AS_.AsyncCancel(); } SafeOnError(); for (auto& it : ReqsInFly_) { it.second->OnError(err); } ReqsInFly_.clear(); } void SafeOnError() { TRequest* reqPtr; while (Reqs_.Dequeue(&reqPtr)) { TRequestRef req(reqPtr); reqPtr->UnRef(); //DBGOUT("err queue(" << AS_.Native() << "):" << size_t(reqPtr)); req->OnError(Error_, SystemCode_); } while (ReqsInFlyQueue_.Dequeue(&reqPtr)) { TRequestRef req(reqPtr); reqPtr->UnRef(); //DBGOUT("err fly queue(" << AS_.Native() << "):" << size_t(reqPtr)); req->OnError(Error_, SystemCode_); } } //must be called only from asio thread void OnReceiveMessage() { //DBGOUT("OnReceiveMessage"); const TBaseHeader& hdr = Msg_.BaseHeader(); if (hdr.Type == TBaseHeader::Response) { ProcessReqsInFlyQueue(); TReqsInFly::iterator it = ReqsInFly_.find(hdr.Id); if (it == ReqsInFly_.end()) { DBGOUT("ignore response: " << hdr.Id); return; } it->second->OnResponse(Msg_); ReqsInFly_.erase(it); } else { throw yexception() << TStringBuf("unsupported message type: ") << hdr.Type; } } TRequestId LastSendedRequestId() const noexcept { return LastSendedReqId_.load(std::memory_order_acquire); } private: NAsio::TTcpSocket AS_; TAtomic State_; //state machine status (TState) TString Error_; i32 SystemCode_ = 0; //input size_t BuffSize_; TArrayHolder Buff_; TTcp2Message Msg_; //output TOutputLock OutputLock_; TAtomic NeedCheckReqsQueue_; TLockFreeQueue Reqs_; TAtomic NeedCheckCancelsQueue_; TLockFreeQueue Cancels_; TAdaptiveLock GenReqIdLock_; std::atomic GenReqId_; std::atomic LastSendedReqId_; TLockFreeQueue ReqsInFlyQueue_; TReqsInFly ReqsInFly_; TOutputBuffers OutputBuffers_; }; class TDestination { public: void Run(TRequestRef& req) { while (1) { TConnectionRef conn = GetConnection(); if (!!conn && conn->Run(req)) { return; } DBGOUT("TDestination CreateConnection"); CreateConnection(conn, req->Client().ExecutorsPool().GetExecutor().GetIOService()); } } private: TConnectionRef GetConnection() { TGuard g(L_); return Conn_; } void CreateConnection(TConnectionRef& oldConn, TIOService& srv) { TConnectionRef conn(new TConnection(srv)); TGuard g(L_); if (Conn_ == oldConn) { Conn_.Swap(conn); } } TSpinLock L_; TConnectionRef Conn_; }; //////////// TClient ///////// public: TClient() : EP_(TTcp2Options::AsioClientThreads) { } ~TClient() { EP_.SyncShutdown(); } THandleRef Schedule(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) { //find exist connection or create new TRequest::THandleRef hndl(new TRequest::THandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss))); try { TRequest::Run(hndl, msg, *this); } catch (...) { hndl->ResetOnRecv(); hndl->ReleaseRequest(); throw; } return hndl.Get(); } TExecutorsPool& ExecutorsPool() { return EP_; } private: NNeh::NHttp::TLockFreeSequence Dest_; TExecutorsPool EP_; }; ////////// server side //////////////////////////////////////////////////////////////////////////////////////////// class TServer: public IRequester { typedef TAutoPtr TTcpAcceptorPtr; typedef TAtomicSharedPtr TTcpSocketRef; class TConnection; typedef TIntrusivePtr TConnectionRef; struct TRequest: public IRequest { struct TState: public TThrRefBase { TState() : Canceled(false) { } TAtomicBool Canceled; }; typedef TIntrusivePtr TStateRef; TRequest(const TConnectionRef& conn, TBuffer& buf, const TString& content); ~TRequest() override; TStringBuf Scheme() const override { return TStringBuf("tcp2"); } TString RemoteHost() const override; TStringBuf Service() const override { return TStringBuf(Buf.Data() + sizeof(TRequestHeader), Buf.End()); } TStringBuf Data() const override { return TStringBuf(Content_); } TStringBuf RequestId() const override { return TStringBuf(); } bool Canceled() const override { return State->Canceled; } void SendReply(TData& data) override; void SendError(TResponseError, const TString&) override { // TODO } const TRequestHeader& RequestHeader() const noexcept { return *reinterpret_cast(Buf.Data()); } private: TConnectionRef Conn; TBuffer Buf; //service-name + message-data TString Content_; TAtomic Replied_; public: TIntrusivePtr State; }; class TConnection: public TThrRefBase { private: TConnection(TServer& srv, const TTcpSocketRef& sock) : Srv_(srv) , AS_(sock) , Canceled_(false) , RemoteHost_(NNeh::PrintHostByRfc(*AS_->RemoteEndpoint().Addr())) , BuffSize_(TTcp2Options::InputBufferSize) , Buff_(new char[BuffSize_]) , NeedCheckOutputQueue_(0) { DBGOUT("TServer::TConnection()"); } public: class TOutputBuffers: public TMultiBuffers { public: void AddResponse(TRequestId reqId, TData& data) { TResponseHeader* hdr = new (Allocate()) TResponseHeader(reqId, TResponseHeader::Success, data.size()); ResponseData_.push_back(TAutoPtr(new TData())); TData& movedData = *ResponseData_.back(); movedData.swap(data); AddPart(hdr, sizeof(TResponseHeader)); AddPart(movedData.data(), movedData.size()); IOVec_ = TContIOVector(Parts_.data(), Parts_.size()); } void AddError(TRequestId reqId, TResponseHeader::TErrorCode errCode) { TResponseHeader* hdr = new (Allocate()) TResponseHeader(reqId, errCode, 0); AddPart(hdr, sizeof(TResponseHeader)); IOVec_ = TContIOVector(Parts_.data(), Parts_.size()); } void Clear() { TMultiBuffers::Clear(); ResponseData_.clear(); } private: TVector> ResponseData_; }; static void Create(TServer& srv, const TTcpSocketRef& sock) { TConnectionRef conn(new TConnection(srv, sock)); conn->AS_->AsyncPollRead(std::bind(&TConnection::OnCanRead, conn, _1, _2), TTcp2Options::ServerInputDeadline); } ~TConnection() override { DBGOUT("~TServer::TConnection(" << (!AS_ ? -666 : AS_->Native()) << ")"); } private: void OnCanRead(const TErrorCode& ec, IHandlingContext& ctx) { if (ec) { OnError(); } else { TErrorCode ec2; OnReadSome(ec2, AS_->ReadSome(Buff_.Get(), BuffSize_, ec2), ctx); } } void OnError() { DBGOUT("Srv OnError(" << (!AS_ ? -666 : AS_->Native()) << ")" << " c=" << (size_t)this); Canceled_ = true; AS_->AsyncCancel(); } void OnReadSome(const TErrorCode& ec, size_t amount, IHandlingContext& ctx) { while (1) { if (ec || !amount) { OnError(); return; } try { const char* buff = Buff_.Get(); size_t leftBytes = amount; do { size_t useBytes = Msg_.LoadFrom(buff, leftBytes); leftBytes -= useBytes; buff += useBytes; if (Msg_.IsComplete()) { OnReceiveMessage(); } } while (leftBytes); if (amount == BuffSize_) { //try decrease system calls, - re-run ReadSome if has full filled buffer TErrorCode ecR; amount = AS_->ReadSome(Buff_.Get(), BuffSize_, ecR); if (!ecR) { continue; } if (ecR.Value() == EAGAIN || ecR.Value() == EWOULDBLOCK) { ctx.ContinueUseHandler(); } else { OnError(); } } else { ctx.ContinueUseHandler(); } } catch (...) { DBGOUT("exc. " << CurrentExceptionMessage()); OnError(); } return; } } void OnReceiveMessage() { DBGOUT("OnReceiveMessage()"); const TBaseHeader& hdr = Msg_.BaseHeader(); if (hdr.Type == TBaseHeader::Request) { TRequest* reqPtr = new TRequest(TConnectionRef(this), Msg_.Header(), Msg_.Content()); IRequestRef req(reqPtr); ReqsState_[reqPtr->RequestHeader().Id] = reqPtr->State; OnRequest(req); } else if (hdr.Type == TBaseHeader::Cancel) { OnCancelRequest(hdr.Id); } else { throw yexception() << "unsupported message type: " << (ui32)hdr.Type; } Msg_.Clear(); { TRequestId reqId; while (FinReqs_.Dequeue(&reqId)) { ReqsState_.erase(reqId); } } } void OnRequest(IRequestRef& r) { DBGOUT("OnRequest()"); Srv_.OnRequest(r); } void OnCancelRequest(TRequestId reqId) { THashMap::iterator it = ReqsState_.find(reqId); if (it == ReqsState_.end()) { return; } it->second->Canceled = true; } public: class TOutputData { public: TOutputData(TRequestId reqId) : ReqId(reqId) { } virtual ~TOutputData() { } virtual void MoveTo(TOutputBuffers& bufs) = 0; TRequestId ReqId; }; class TResponseData: public TOutputData { public: TResponseData(TRequestId reqId, TData& data) : TOutputData(reqId) { Data.swap(data); } void MoveTo(TOutputBuffers& bufs) override { bufs.AddResponse(ReqId, Data); } TData Data; }; class TResponseErrorData: public TOutputData { public: TResponseErrorData(TRequestId reqId, TResponseHeader::TErrorCode errorCode) : TOutputData(reqId) , ErrorCode(errorCode) { } void MoveTo(TOutputBuffers& bufs) override { bufs.AddError(ReqId, ErrorCode); } TResponseHeader::TErrorCode ErrorCode; }; //called non thread-safe (from client thread) void SendResponse(TRequestId reqId, TData& data) { DBGOUT("SendResponse: " << reqId << " " << (size_t)~data << " c=" << (size_t)this); TAutoPtr od(new TResponseData(reqId, data)); OutputData_.Enqueue(od); ProcessOutputQueue(); } //called non thread-safe (from outside thread) void SendError(TRequestId reqId, TResponseHeader::TErrorCode err) { DBGOUT("SendResponseError: " << reqId << " c=" << (size_t)this); TAutoPtr od(new TResponseErrorData(reqId, err)); OutputData_.Enqueue(od); ProcessOutputQueue(); } void ProcessOutputQueue() { AtomicSet(NeedCheckOutputQueue_, 1); if (OutputLock_.TryAquire()) { SendMessages(false); return; } DBGOUT("ProcessOutputQueue: !AquireOutputOwnership: " << (int)OutputLock_.IsFree()); } //must be called only after success aquiring output void SendMessages(bool asioThread) { DBGOUT("TServer::SendMessages(enter)"); try { do { AtomicUnlock(&NeedCheckOutputQueue_); TAutoPtr d; while (OutputData_.Dequeue(&d)) { d->MoveTo(OutputBuffers_); if (!OutputBuffers_.HasFreeSpace()) { if (!FlushOutputBuffers(asioThread)) { return; } } } if (OutputBuffers_.HasData()) { if (!FlushOutputBuffers(asioThread)) { return; } } OutputLock_.Release(); if (!AtomicGet(NeedCheckOutputQueue_)) { DBGOUT("Server::SendMessages(exit2): " << (int)OutputLock_.IsFree()); return; } } while (OutputLock_.TryAquire()); DBGOUT("Server::SendMessages(exit1)"); } catch (...) { OnError(); } } bool FlushOutputBuffers(bool asioThread) { DBGOUT("FlushOutputBuffers: cnt=" << OutputBuffers_.GetIOvec()->Count() << " c=" << (size_t)this); //TODO:reseach direct write efficiency if (asioThread || TTcp2Options::ServerUseDirectWrite) { TContIOVector& vec = *OutputBuffers_.GetIOvec(); vec.Proceed(AS_->WriteSome(vec)); if (vec.Complete()) { OutputBuffers_.Clear(); //DBGOUT("WriteResponse: " << " c=" << (size_t)this); return true; } } //socket buffer filled - use async write for sending left data DBGOUT("AsyncWriteResponse: " << " [" << OutputBuffers_.GetIOvec()->Bytes() << "]" << " c=" << (size_t)this); AS_->AsyncWrite(OutputBuffers_.GetIOvec(), std::bind(&TConnection::OnSend, TConnectionRef(this), _1, _2, _3), TTcp2Options::ServerOutputDeadline); return false; } void OnFinishRequest(TRequestId reqId) { if (Y_LIKELY(!Canceled_)) { FinReqs_.Enqueue(reqId); } } private: void OnSend(const TErrorCode& ec, size_t amount, IHandlingContext&) { Y_UNUSED(amount); DBGOUT("TServer::OnSend(" << ec.Value() << ", " << amount << ")"); if (ec) { OnError(); } else { OutputBuffers_.Clear(); SendMessages(true); } } public: bool IsCanceled() const noexcept { return Canceled_; } const TString& RemoteHost() const noexcept { return RemoteHost_; } private: TServer& Srv_; TTcpSocketRef AS_; NAtomic::TBool Canceled_; TString RemoteHost_; //input size_t BuffSize_; TArrayHolder Buff_; TTcp2Message Msg_; THashMap ReqsState_; TLockFreeQueue FinReqs_; //output TOutputLock OutputLock_; //protect socket/buffers from simultaneous access from few threads TAtomic NeedCheckOutputQueue_; NNeh::TAutoLockFreeQueue OutputData_; TOutputBuffers OutputBuffers_; }; //////////// TServer ///////// public: TServer(IOnRequest* cb, ui16 port) : EP_(TTcp2Options::AsioServerThreads) , CB_(cb) { TNetworkAddress addr(port); for (TNetworkAddress::TIterator it = addr.Begin(); it != addr.End(); ++it) { TEndpoint ep(new NAddr::TAddrInfo(&*it)); TTcpAcceptorPtr a(new TTcpAcceptor(EA_.GetIOService())); //DBGOUT("bind:" << ep.IpToString() << ":" << ep.Port()); a->Bind(ep); a->Listen(TTcp2Options::Backlog); StartAccept(a.Get()); A_.push_back(a); } } ~TServer() override { EA_.SyncShutdown(); //cancel accepting connections A_.clear(); //stop listening EP_.SyncShutdown(); //close all exist connections } void StartAccept(TTcpAcceptor* a) { const auto s = MakeAtomicShared(EP_.Size() ? EP_.GetExecutor().GetIOService() : EA_.GetIOService()); a->AsyncAccept(*s, std::bind(&TServer::OnAccept, this, a, s, _1, _2)); } void OnAccept(TTcpAcceptor* a, TTcpSocketRef s, const TErrorCode& ec, IHandlingContext&) { if (Y_UNLIKELY(ec)) { if (ec.Value() == ECANCELED) { return; } else if (ec.Value() == EMFILE || ec.Value() == ENFILE || ec.Value() == ENOMEM || ec.Value() == ENOBUFS) { //reach some os limit, suspend accepting for preventing busyloop (100% cpu usage) TSimpleSharedPtr dt(new TDeadlineTimer(a->GetIOService())); dt->AsyncWaitExpireAt(TDuration::Seconds(30), std::bind(&TServer::OnTimeoutSuspendAccept, this, a, dt, _1, _2)); } else { Cdbg << "acc: " << ec.Text() << Endl; } } else { SetNonBlock(s->Native()); PrepareSocket(s->Native()); TConnection::Create(*this, s); } StartAccept(a); //continue accepting } void OnTimeoutSuspendAccept(TTcpAcceptor* a, TSimpleSharedPtr, const TErrorCode& ec, IHandlingContext&) { if (!ec) { DBGOUT("resume acceptor"); StartAccept(a); } } void OnRequest(IRequestRef& r) { try { CB_->OnRequest(r); } catch (...) { Cdbg << CurrentExceptionMessage() << Endl; } } private: TVector A_; TIOServiceExecutor EA_; //thread, where accepted incoming tcp connections TExecutorsPool EP_; //threads, for process write/read data to/from tcp connections (if empty, use EA_ for r/w) IOnRequest* CB_; }; TServer::TRequest::TRequest(const TConnectionRef& conn, TBuffer& buf, const TString& content) : Conn(conn) , Content_(content) , Replied_(0) , State(new TState()) { DBGOUT("TServer::TRequest()"); Buf.Swap(buf); } TServer::TRequest::~TRequest() { DBGOUT("TServer::~TRequest()"); if (!AtomicGet(Replied_)) { Conn->SendError(RequestHeader().Id, TResponseHeader::EmptyReply); } Conn->OnFinishRequest(RequestHeader().Id); } TString TServer::TRequest::RemoteHost() const { return Conn->RemoteHost(); } void TServer::TRequest::SendReply(TData& data) { do { if (AtomicCas(&Replied_, 1, 0)) { Conn->SendResponse(RequestHeader().Id, data); return; } } while (AtomicGet(Replied_) == 0); } class TProtocol: public IProtocol { public: inline TProtocol() { InitNetworkSubSystem(); } IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override { return new TServer(cb, loc.GetPort()); } THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override { return Singleton()->Schedule(msg, fallback, ss); } TStringBuf Scheme() const noexcept override { return TStringBuf("tcp2"); } bool SetOption(TStringBuf name, TStringBuf value) override { return TTcp2Options::Set(name, value); } }; } } NNeh::IProtocol* NNeh::Tcp2Protocol() { return Singleton(); }