#include "rpc.h" #include "rq.h" #include "multi.h" #include "location.h" #include #include #include #include using namespace NNeh; namespace { typedef std::pair TServiceDescr; typedef TVector TServicesBase; class TServices: public TServicesBase, public TThrRefBase, public IOnRequest { typedef THashMap TSrvs; struct TVersionedServiceMap { TSrvs Srvs; i64 Version = 0; }; struct TFunc: public IThreadFactory::IThreadAble { inline TFunc(TServices* parent) : Parent(parent) { } void DoExecute() override { TThread::SetCurrentThreadName("NehTFunc"); TVersionedServiceMap mp; while (true) { IRequestRef req = Parent->RQ_->Next(); if (!req) { break; } Parent->ServeRequest(mp, req); } Parent->RQ_->Schedule(nullptr); } TServices* Parent; }; public: inline TServices() : RQ_(CreateRequestQueue()) { } inline TServices(TCheck check) : RQ_(CreateRequestQueue()) , C_(check) { } inline ~TServices() override { LF_.Destroy(); } inline void Add(const TString& service, IServiceRef srv) { TGuard guard(L_); push_back(std::make_pair(service, srv)); AtomicIncrement(SelfVersion_); } inline void Listen() { Y_ENSURE(!HasLoop_ || !*HasLoop_); HasLoop_ = false; RR_ = MultiRequester(ListenAddrs(), this); } inline void Loop(size_t threads) { Y_ENSURE(!HasLoop_ || *HasLoop_); HasLoop_ = true; TIntrusivePtr self(this); IRequesterRef rr = MultiRequester(ListenAddrs(), this); TFunc func(this); typedef TAutoPtr IThreadRef; TVector thrs; for (size_t i = 1; i < threads; ++i) { thrs.push_back(SystemThreadFactory()->Run(&func)); } func.Execute(); for (size_t i = 0; i < thrs.size(); ++i) { thrs[i]->Join(); } RQ_->Clear(); } inline void ForkLoop(size_t threads) { Y_ENSURE(!HasLoop_ || *HasLoop_); HasLoop_ = true; //here we can have trouble with binding port(s), so expect exceptions IRequesterRef rr = MultiRequester(ListenAddrs(), this); LF_.Reset(new TLoopFunc(this, threads, rr)); } inline void Stop() { RQ_->Schedule(nullptr); } inline void SyncStopFork() { Stop(); if (LF_) { LF_->SyncStop(); } RQ_->Clear(); LF_.Destroy(); } void OnRequest(IRequestRef req) override { if (C_) { if (auto error = C_(req)) { req->SendError(*error); return; } } if (!*HasLoop_) { ServeRequest(LocalMap_.GetRef(), req); } else { RQ_->Schedule(req); } } private: class TLoopFunc: public TFunc { public: TLoopFunc(TServices* parent, size_t threads, IRequesterRef& rr) : TFunc(parent) , RR_(rr) { T_.reserve(threads); try { for (size_t i = 0; i < threads; ++i) { T_.push_back(SystemThreadFactory()->Run(this)); } } catch (...) { //paranoid mode on SyncStop(); throw; } } ~TLoopFunc() override { try { SyncStop(); } catch (...) { Cdbg << TStringBuf("neh rpc ~loop_func: ") << CurrentExceptionMessage() << Endl; } } void SyncStop() { if (!T_) { return; } Parent->Stop(); for (size_t i = 0; i < T_.size(); ++i) { T_[i]->Join(); } T_.clear(); } private: typedef TAutoPtr IThreadRef; TVector T_; IRequesterRef RR_; }; inline void ServeRequest(TVersionedServiceMap& mp, IRequestRef req) { if (!req) { return; } const TStringBuf name = req->Service(); TSrvs::const_iterator it = mp.Srvs.find(name); if (Y_UNLIKELY(it == mp.Srvs.end())) { if (UpdateServices(mp.Srvs, mp.Version)) { it = mp.Srvs.find(name); } } if (Y_UNLIKELY(it == mp.Srvs.end())) { it = mp.Srvs.find(TStringBuf("*")); } if (Y_UNLIKELY(it == mp.Srvs.end())) { req->SendError(IRequest::NotExistService); } else { try { it->second->ServeRequest(req); } catch (...) { Cdbg << CurrentExceptionMessage() << Endl; } } } inline bool UpdateServices(TSrvs& srvs, i64& version) const { if (AtomicGet(SelfVersion_) == version) { return false; } srvs.clear(); TGuard guard(L_); for (const auto& it : *this) { srvs[TParsedLocation(it.first).Service] = it.second; } version = AtomicGet(SelfVersion_); return true; } inline TListenAddrs ListenAddrs() const { TListenAddrs addrs; { TGuard guard(L_); for (const auto& it : *this) { addrs.push_back(it.first); } } return addrs; } TSpinLock L_; IRequestQueueRef RQ_; THolder LF_; TAtomic SelfVersion_ = 1; TCheck C_; NThreading::TThreadLocalValue LocalMap_; IRequesterRef RR_; TMaybe HasLoop_; }; class TServicesFace: public IServices { public: inline TServicesFace() : S_(new TServices()) { } inline TServicesFace(TCheck check) : S_(new TServices(check)) { } void DoAdd(const TString& service, IServiceRef srv) override { S_->Add(service, srv); } void Loop(size_t threads) override { S_->Loop(threads); } void ForkLoop(size_t threads) override { S_->ForkLoop(threads); } void SyncStopFork() override { S_->SyncStopFork(); } void Stop() override { S_->Stop(); } void Listen() override { S_->Listen(); } private: TIntrusivePtr S_; }; } IServiceRef NNeh::Wrap(const TServiceFunction& func) { struct TWrapper: public IService { inline TWrapper(const TServiceFunction& f) : F(f) { } void ServeRequest(const IRequestRef& request) override { F(request); } TServiceFunction F; }; return new TWrapper(func); } IServicesRef NNeh::CreateLoop() { return new TServicesFace(); } IServicesRef NNeh::CreateLoop(TCheck check) { return new TServicesFace(check); }