#include "udp.h" #include "details.h" #include "neh.h" #include "location.h" #include "utils.h" #include "factory.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace NNeh; using namespace NDns; using namespace NAddr; namespace { namespace NUdp { enum EPacketType { PT_REQUEST = 1, PT_RESPONSE = 2, PT_STOP = 3, PT_TIMEOUT = 4 }; struct TUdpHandle: public TNotifyHandle { inline TUdpHandle(IOnRecv* r, const TMessage& msg, TStatCollector* sc) noexcept : TNotifyHandle(r, msg, sc) { } void Cancel() noexcept override { THandle::Cancel(); //inform stat collector } bool MessageSendedCompletely() const noexcept override { //TODO return true; } }; static inline IRemoteAddrPtr GetSendAddr(SOCKET s) { IRemoteAddrPtr local = GetSockAddr(s); const sockaddr* addr = local->Addr(); switch (addr->sa_family) { case AF_INET: { const TIpAddress a = *(const sockaddr_in*)addr; return MakeHolder(TIpAddress(InetToHost(INADDR_LOOPBACK), a.Port())); } case AF_INET6: { sockaddr_in6 a = *(const sockaddr_in6*)addr; a.sin6_addr = in6addr_loopback; return MakeHolder(a); } } ythrow yexception() << "unsupported"; } typedef ui32 TCheckSum; static inline TString GenerateGuid() { const ui64 res[2] = { RandomNumber(), RandomNumber()}; return TString((const char*)res, sizeof(res)); } static inline TCheckSum Sum(const TStringBuf& s) noexcept { return HostToInet(MurmurHash(s.data(), s.size())); } struct TPacket; template static inline void Serialize(TPacket& p, const T& t); struct TPacket { inline TPacket(IRemoteAddrPtr addr) : Addr(std::move(addr)) { } template inline TPacket(const T& t, IRemoteAddrPtr addr) : Addr(std::move(addr)) { NUdp::Serialize(*this, t); } inline TPacket(TSocketHolder& s, TBuffer& tmp) { TAutoPtr addr(new TOpaqueAddr()); retry_on_intr : { const int rv = recvfrom(s, tmp.Data(), tmp.size(), MSG_WAITALL, addr->MutableAddr(), addr->LenPtr()); if (rv < 0) { int err = LastSystemError(); if (err == EAGAIN || err == EWOULDBLOCK) { Data.Resize(sizeof(TCheckSum) + 1); *(Data.data() + sizeof(TCheckSum)) = static_cast(PT_TIMEOUT); } else if (err == EINTR) { goto retry_on_intr; } else { ythrow TSystemError() << "recv failed"; } } else { Data.Append(tmp.Data(), (size_t)rv); Addr.Reset(addr.Release()); CheckSign(); } } } inline void SendTo(TSocketHolder& s) { Sign(); if (sendto(s, Data.data(), Data.size(), 0, Addr->Addr(), Addr->Len()) < 0) { Cdbg << LastSystemErrorText() << Endl; } } IRemoteAddrPtr Addr; TBuffer Data; inline void Sign() { const TCheckSum sum = CalcSign(); memcpy(Data.Data(), &sum, sizeof(sum)); } inline char Type() const noexcept { return *(Data.data() + sizeof(TCheckSum)); } inline void CheckSign() const { if (Data.size() < 16) { ythrow yexception() << "small packet"; } if (StoredSign() != CalcSign()) { ythrow yexception() << "bad checksum"; } } inline TCheckSum StoredSign() const noexcept { TCheckSum sum; memcpy(&sum, Data.Data(), sizeof(sum)); return sum; } inline TCheckSum CalcSign() const noexcept { return Sum(Body()); } inline TStringBuf Body() const noexcept { return TStringBuf(Data.data() + sizeof(TCheckSum), Data.End()); } }; typedef TAutoPtr TPacketRef; class TPacketInput: public TMemoryInput { public: inline TPacketInput(const TPacket& p) : TMemoryInput(p.Body().data(), p.Body().size()) { } }; class TPacketOutput: public TBufferOutput { public: inline TPacketOutput(TPacket& p) : TBufferOutput(p.Data) { p.Data.Proceed(sizeof(TCheckSum)); } }; template static inline void Serialize(TPacketOutput* out, const T& t) { Save(out, t.Type()); t.Serialize(out); } template static inline void Serialize(TPacket& p, const T& t) { TPacketOutput out(p); NUdp::Serialize(&out, t); } namespace NPrivate { template static inline void Deserialize(TPacketInput* in, T& t) { char type; Load(in, type); if (type != t.Type()) { ythrow yexception() << "unsupported packet"; } t.Deserialize(in); } template static inline void Deserialize(const TPacket& p, T& t) { TPacketInput in(p); Deserialize(&in, t); } } struct TRequestPacket { TString Guid; TString Service; TString Data; inline TRequestPacket(const TPacket& p) { NPrivate::Deserialize(p, *this); } inline TRequestPacket(const TString& srv, const TString& data) : Guid(GenerateGuid()) , Service(srv) , Data(data) { } inline char Type() const noexcept { return static_cast(PT_REQUEST); } inline void Serialize(TPacketOutput* out) const { Save(out, Guid); Save(out, Service); Save(out, Data); } inline void Deserialize(TPacketInput* in) { Load(in, Guid); Load(in, Service); Load(in, Data); } }; template struct TResponsePacket { TString Guid; TStore Data; inline TResponsePacket(const TString& guid, TStore& data) : Guid(guid) { Data.swap(data); } inline TResponsePacket(const TPacket& p) { NPrivate::Deserialize(p, *this); } inline char Type() const noexcept { return static_cast(PT_RESPONSE); } inline void Serialize(TPacketOutput* out) const { Save(out, Guid); Save(out, Data); } inline void Deserialize(TPacketInput* in) { Load(in, Guid); Load(in, Data); } }; struct TStopPacket { inline char Type() const noexcept { return static_cast(PT_STOP); } inline void Serialize(TPacketOutput* out) const { Save(out, TString("stop packet")); } }; struct TBindError: public TSystemError { }; struct TSocketDescr { inline TSocketDescr(TSocketHolder& s, int family) : S(s.Release()) , Family(family) { } TSocketHolder S; int Family; }; typedef TAutoPtr TSocketRef; typedef TVector TSockets; static inline void CreateSocket(TSocketHolder& s, const IRemoteAddr& addr) { TSocketHolder res(socket(addr.Addr()->sa_family, SOCK_DGRAM, IPPROTO_UDP)); if (!res) { ythrow TSystemError() << "can not create socket"; } FixIPv6ListenSocket(res); if (bind(res, addr.Addr(), addr.Len()) != 0) { ythrow TBindError() << "can not bind " << PrintHostAndPort(addr); } res.Swap(s); } static inline void CreateSockets(TSockets& s, ui16 port) { TNetworkAddress addr(port); for (TNetworkAddress::TIterator it = addr.Begin(); it != addr.End(); ++it) { TSocketHolder res; CreateSocket(res, TAddrInfo(&*it)); s.push_back(new TSocketDescr(res, it->ai_family)); } } static inline void CreateSocketsOnRandomPort(TSockets& s) { while (true) { try { TSockets tmp; CreateSockets(tmp, 5000 + (RandomNumber() % 1000)); tmp.swap(s); return; } catch (const TBindError&) { } } } typedef ui64 TTimeStamp; static inline TTimeStamp TimeStamp() noexcept { return GetCycleCount() >> 31; } struct TRequestDescr: public TIntrusiveListItem { inline TRequestDescr(const TString& guid, const TNotifyHandleRef& hndl, const TMessage& msg) : Guid(guid) , Hndl(hndl) , Msg(msg) , TS(TimeStamp()) { } TString Guid; TNotifyHandleRef Hndl; TMessage Msg; TTimeStamp TS; }; typedef TAutoPtr TRequestDescrRef; class TProto { class TRequest: public IRequest, public TRequestPacket { public: inline TRequest(TPacket& p, TProto* parent) : TRequestPacket(p) , Addr_(std::move(p.Addr)) , H_(PrintHostByRfc(*Addr_)) , P_(parent) { } TStringBuf Scheme() const override { return TStringBuf("udp"); } TString RemoteHost() const override { return H_; } TStringBuf Service() const override { return ((TRequestPacket&)(*this)).Service; } TStringBuf Data() const override { return ((TRequestPacket&)(*this)).Data; } TStringBuf RequestId() const override { return ((TRequestPacket&)(*this)).Guid; } bool Canceled() const override { //TODO ? return false; } void SendReply(TData& data) override { P_->Schedule(new TPacket(TResponsePacket(Guid, data), std::move(Addr_))); } void SendError(TResponseError, const TString&) override { // TODO } private: IRemoteAddrPtr Addr_; TString H_; TProto* P_; }; public: inline TProto(IOnRequest* cb, TSocketHolder& s) : CB_(cb) , ToSendEv_(TSystemEvent::rAuto) , S_(s.Release()) { SetSocketTimeout(S_, 10); Thrs_.push_back(Spawn(this)); Thrs_.push_back(Spawn(this)); } inline ~TProto() { Schedule(new TPacket(TStopPacket(), GetSendAddr(S_))); for (size_t i = 0; i < Thrs_.size(); ++i) { Thrs_[i]->Join(); } } inline TPacketRef Recv() { TBuffer tmp; tmp.Resize(128 * 1024); while (true) { try { return new TPacket(S_, tmp); } catch (...) { Cdbg << CurrentExceptionMessage() << Endl; continue; } } } typedef THashMap TInFlyBase; struct TInFly: public TInFlyBase, public TIntrusiveList { typedef TInFlyBase::iterator TIter; typedef TInFlyBase::const_iterator TContsIter; inline void Insert(TRequestDescrRef& d) { PushBack(d.Get()); (*this)[d->Guid] = d; } inline void EraseStale() noexcept { const TTimeStamp now = TimeStamp(); for (TIterator it = Begin(); (it != End()) && (it->TS < now) && ((now - it->TS) > 120);) { it->Hndl->NotifyError("request timeout"); TString safe_key = (it++)->Guid; erase(safe_key); } } }; inline void ExecuteRecv() { SetHighestThreadPriority(); TInFly infly; while (true) { TPacketRef p = Recv(); switch (static_cast(p->Type())) { case PT_REQUEST: if (CB_) { CB_->OnRequest(new TRequest(*p, this)); } else { //skip request in case of client } break; case PT_RESPONSE: { CancelStaleRequests(infly); TResponsePacket rp(*p); TInFly::TIter it = static_cast(infly).find(rp.Guid); if (it == static_cast(infly).end()) { break; } const TRequestDescrRef& d = it->second; d->Hndl->NotifyResponse(rp.Data); infly.erase(it); break; } case PT_STOP: Schedule(nullptr); return; case PT_TIMEOUT: CancelStaleRequests(infly); break; } } } inline void ExecuteSend() { SetHighestThreadPriority(); while (true) { TPacketRef p; while (!ToSend_.Dequeue(&p)) { ToSendEv_.Wait(); } //shutdown if (!p) { return; } p->SendTo(S_); } } inline void Schedule(TPacketRef p) { ToSend_.Enqueue(p); ToSendEv_.Signal(); } inline void Schedule(TRequestDescrRef dsc, TPacketRef p) { ScheduledReqs_.Enqueue(dsc); Schedule(p); } protected: void CancelStaleRequests(TInFly& infly) { TRequestDescrRef d; while (ScheduledReqs_.Dequeue(&d)) { infly.Insert(d); } infly.EraseStale(); } IOnRequest* CB_; NNeh::TAutoLockFreeQueue ToSend_; NNeh::TAutoLockFreeQueue ScheduledReqs_; TSystemEvent ToSendEv_; TSocketHolder S_; TVector Thrs_; }; class TProtos { public: inline TProtos() { TSockets s; CreateSocketsOnRandomPort(s); Init(nullptr, s); } inline TProtos(IOnRequest* cb, ui16 port) { TSockets s; CreateSockets(s, port); Init(cb, s); } static inline TProtos* Instance() { return Singleton(); } inline void Schedule(const TMessage& msg, const TNotifyHandleRef& hndl) { TParsedLocation loc(msg.Addr); const TNetworkAddress* addr = &CachedThrResolve(TResolveInfo(loc.Host, loc.GetPort()))->Addr; for (TNetworkAddress::TIterator ai = addr->Begin(); ai != addr->End(); ++ai) { TProto* proto = Find(ai->ai_family); if (proto) { TRequestPacket rp(ToString(loc.Service), msg.Data); TRequestDescrRef rd(new TRequestDescr(rp.Guid, hndl, msg)); IRemoteAddrPtr raddr(new TAddrInfo(&*ai)); TPacketRef p(new TPacket(rp, std::move(raddr))); proto->Schedule(rd, p); return; } } ythrow yexception() << "unsupported protocol family"; } private: inline void Init(IOnRequest* cb, TSockets& s) { for (auto& it : s) { P_[it->Family] = new TProto(cb, it->S); } } inline TProto* Find(int family) const { TProtoStorage::const_iterator it = P_.find(family); if (it == P_.end()) { return nullptr; } return it->second.Get(); } private: typedef TAutoPtr TProtoRef; typedef THashMap TProtoStorage; TProtoStorage P_; }; class TRequester: public IRequester, public TProtos { public: inline TRequester(IOnRequest* cb, ui16 port) : TProtos(cb, port) { } }; class TProtocol: public IProtocol { public: IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override { return new TRequester(cb, loc.GetPort()); } THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override { TNotifyHandleRef ret(new TUdpHandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss))); TProtos::Instance()->Schedule(msg, ret); return ret.Get(); } TStringBuf Scheme() const noexcept override { return TStringBuf("udp"); } }; } } IProtocol* NNeh::UdpProtocol() { return Singleton(); }