#pragma once #include #include #include #include #include #include "poller.h" #include "interconnect_address.h" #include #include namespace NInterconnect { class TSocket: public NActors::TSharedDescriptor, public TNonCopyable { protected: TSocket(SOCKET fd); virtual ~TSocket() override; SOCKET Descriptor; virtual int GetDescriptor() override; private: friend class TSecureSocket; SOCKET ReleaseDescriptor() { return std::exchange(Descriptor, INVALID_SOCKET); } public: operator SOCKET() const { return Descriptor; } int Bind(const TAddress& addr) const; int Shutdown(int how) const; int GetConnectStatus() const; }; class TStreamSocket: public TSocket { public: TStreamSocket(SOCKET fd); static TIntrusivePtr Make(int domain); virtual ssize_t Send(const void* msg, size_t len, TString *err = nullptr) const; virtual ssize_t Recv(void* buf, size_t len, TString *err = nullptr) const; virtual ssize_t WriteV(const struct iovec* iov, int iovcnt) const; virtual ssize_t ReadV(const struct iovec* iov, int iovcnt) const; int Connect(const TAddress& addr) const; int Connect(const NAddr::IRemoteAddr* addr) const; int Listen(int backlog) const; int Accept(TAddress& acceptedAddr) const; ssize_t GetUnsentQueueSize() const; void SetSendBufferSize(i32 len) const; ui32 GetSendBufferSize() const; }; class TSecureSocketContext { class TImpl; THolder Impl; friend class TSecureSocket; public: TSecureSocketContext(const TString& certificate, const TString& privateKey, const TString& caFilePath, const TString& ciphers); ~TSecureSocketContext(); public: using TPtr = std::shared_ptr; }; class TSecureSocket : public TStreamSocket { TSecureSocketContext::TPtr Context; class TImpl; THolder Impl; public: enum class EStatus { SUCCESS, ERROR, WANT_READ, WANT_WRITE, }; public: TSecureSocket(TStreamSocket& socket, TSecureSocketContext::TPtr context); ~TSecureSocket(); EStatus Establish(bool server, bool authOnly, TString& err) const; TIntrusivePtr Detach(); ssize_t Send(const void* msg, size_t len, TString *err) const override; ssize_t Recv(void* msg, size_t len, TString *err) const override; ssize_t WriteV(const struct iovec* iov, int iovcnt) const override; ssize_t ReadV(const struct iovec* iov, int iovcnt) const override; TString GetCipherName() const; int GetCipherBits() const; TString GetProtocolName() const; TString GetPeerCommonName() const; bool WantRead() const; bool WantWrite() const; }; class TDatagramSocket: public TSocket { public: typedef std::shared_ptr TPtr; TDatagramSocket(SOCKET fd); static TPtr Make(int domain); ssize_t SendTo(const void* msg, size_t len, const TAddress& toAddr) const; ssize_t RecvFrom(void* buf, size_t len, TAddress& fromAddr) const; }; }