#include "socket.h" #include "pair.h" #include #include #include #include #ifdef _linux_ #include #include #endif class TSockTest: public TTestBase { UNIT_TEST_SUITE(TSockTest); UNIT_TEST(TestSock); UNIT_TEST(TestTimeout); #ifndef _win_ // Test hangs on Windows UNIT_TEST_EXCEPTION(TestConnectionRefused, yexception); #endif UNIT_TEST(TestNetworkResolutionError); UNIT_TEST(TestNetworkResolutionErrorMessage); UNIT_TEST(TestBrokenPipe); UNIT_TEST(TestClose); UNIT_TEST_SUITE_END(); public: void TestSock(); void TestTimeout(); void TestConnectionRefused(); void TestNetworkResolutionError(); void TestNetworkResolutionErrorMessage(); void TestBrokenPipe(); void TestClose(); }; UNIT_TEST_SUITE_REGISTRATION(TSockTest); void TSockTest::TestSock() { TNetworkAddress addr("yandex.ru", 80); TSocket s(addr); TSocketOutput so(s); TSocketInput si(s); const TStringBuf req = "GET / HTTP/1.1\r\nHost: yandex.ru\r\n\r\n"; so.Write(req.data(), req.size()); UNIT_ASSERT(!si.ReadLine().empty()); } void TSockTest::TestTimeout() { static const int timeout = 1000; i64 startTime = millisec(); try { TNetworkAddress addr("localhost", 1313); TSocket s(addr, TDuration::MilliSeconds(timeout)); } catch (const yexception&) { } int realTimeout = (int)(millisec() - startTime); if (realTimeout > timeout + 2000) { TString err = TStringBuilder() << "Timeout exceeded: " << realTimeout << " ms (expected " << timeout << " ms)"; UNIT_FAIL(err); } } void TSockTest::TestConnectionRefused() { TNetworkAddress addr("localhost", 1313); TSocket s(addr); } void TSockTest::TestNetworkResolutionError() { TString errMsg; try { TNetworkAddress addr("", 0); } catch (const TNetworkResolutionError& e) { errMsg = e.what(); } if (errMsg.empty()) { return; // on Windows getaddrinfo("", 0, ...) returns "OK" } int expectedErr = EAI_NONAME; TString expectedErrMsg = gai_strerror(expectedErr); if (errMsg.find(expectedErrMsg) == TString::npos) { UNIT_FAIL("TNetworkResolutionError contains\nInvalid msg: " + errMsg + "\nExpected msg: " + expectedErrMsg + "\n"); } } void TSockTest::TestNetworkResolutionErrorMessage() { #ifdef _unix_ auto str = [](int code) -> TString { return TNetworkResolutionError(code).what(); }; auto expected = [](int code) -> TString { return gai_strerror(code); }; struct TErrnoGuard { TErrnoGuard() : PrevValue_(errno) { } ~TErrnoGuard() { errno = PrevValue_; } private: int PrevValue_; } g; UNIT_ASSERT_VALUES_EQUAL(expected(0) + "(0): ", str(0)); UNIT_ASSERT_VALUES_EQUAL(expected(-9) + "(-9): ", str(-9)); errno = 0; UNIT_ASSERT_VALUES_EQUAL(expected(EAI_SYSTEM) + "(" + IntToString<10>(EAI_SYSTEM) + "; errno=0): ", str(EAI_SYSTEM)); errno = 110; UNIT_ASSERT_VALUES_EQUAL(expected(EAI_SYSTEM) + "(" + IntToString<10>(EAI_SYSTEM) + "; errno=110): ", str(EAI_SYSTEM)); #endif } class TTempEnableSigPipe { public: TTempEnableSigPipe() { OriginalSigHandler_ = signal(SIGPIPE, SIG_DFL); Y_ABORT_UNLESS(OriginalSigHandler_ != SIG_ERR); } ~TTempEnableSigPipe() { auto ret = signal(SIGPIPE, OriginalSigHandler_); Y_ABORT_UNLESS(ret != SIG_ERR); } private: void (*OriginalSigHandler_)(int); }; void TSockTest::TestBrokenPipe() { TTempEnableSigPipe guard; SOCKET socks[2]; int ret = SocketPair(socks); UNIT_ASSERT_VALUES_EQUAL(ret, 0); TSocket sender(socks[0]); TSocket receiver(socks[1]); receiver.ShutDown(SHUT_RDWR); int sent = sender.Send("FOO", 3); UNIT_ASSERT(sent < 0); IOutputStream::TPart parts[] = { {"foo", 3}, {"bar", 3}, }; sent = sender.SendV(parts, 2); UNIT_ASSERT(sent < 0); } void TSockTest::TestClose() { SOCKET socks[2]; UNIT_ASSERT_EQUAL(SocketPair(socks), 0); TSocket receiver(socks[1]); UNIT_ASSERT_EQUAL(static_cast(receiver), socks[1]); #if defined _linux_ UNIT_ASSERT_GE(fcntl(socks[1], F_GETFD), 0); receiver.Close(); UNIT_ASSERT_EQUAL(fcntl(socks[1], F_GETFD), -1); #else receiver.Close(); #endif UNIT_ASSERT_EQUAL(static_cast(receiver), INVALID_SOCKET); } class TPollTest: public TTestBase { UNIT_TEST_SUITE(TPollTest); UNIT_TEST(TestPollInOut); UNIT_TEST_SUITE_END(); public: inline TPollTest() { srand(static_cast(time(nullptr))); } void TestPollInOut(); private: sockaddr_in GetAddress(ui32 ip, ui16 port); SOCKET CreateSocket(); SOCKET StartServerSocket(ui16 port, int backlog); SOCKET StartClientSocket(ui32 ip, ui16 port); SOCKET AcceptConnection(SOCKET serverSocket); }; UNIT_TEST_SUITE_REGISTRATION(TPollTest); sockaddr_in TPollTest::GetAddress(ui32 ip, ui16 port) { struct sockaddr_in addr; memset(&addr, 0, sizeof(addr)); addr.sin_family = AF_INET; addr.sin_port = htons(port); addr.sin_addr.s_addr = htonl(ip); return addr; } SOCKET TPollTest::CreateSocket() { SOCKET s = socket(AF_INET, SOCK_STREAM, 0); if (s == INVALID_SOCKET) { ythrow yexception() << "Can not create socket (" << LastSystemErrorText() << ")"; } return s; } SOCKET TPollTest::StartServerSocket(ui16 port, int backlog) { TSocketHolder s(CreateSocket()); sockaddr_in addr = GetAddress(ntohl(INADDR_ANY), port); if (bind(s, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) { ythrow yexception() << "Can not bind server socket (" << LastSystemErrorText() << ")"; } if (listen(s, backlog) == SOCKET_ERROR) { ythrow yexception() << "Can not listen on server socket (" << LastSystemErrorText() << ")"; } return s.Release(); } SOCKET TPollTest::StartClientSocket(ui32 ip, ui16 port) { TSocketHolder s(CreateSocket()); sockaddr_in addr = GetAddress(ip, port); if (connect(s, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) { ythrow yexception() << "Can not connect client socket (" << LastSystemErrorText() << ")"; } return s.Release(); } SOCKET TPollTest::AcceptConnection(SOCKET serverSocket) { SOCKET connectedSocket = accept(serverSocket, nullptr, nullptr); if (connectedSocket == INVALID_SOCKET) { ythrow yexception() << "Can not accept connection on server socket (" << LastSystemErrorText() << ")"; } return connectedSocket; } void TPollTest::TestPollInOut() { #ifdef _win_ const size_t socketCount = 1000; ui16 port = static_cast(1300 + rand() % 97); TSocketHolder serverSocket = StartServerSocket(port, socketCount); ui32 localIp = ntohl(inet_addr("127.0.0.1")); TVector> clientSockets; TVector> connectedSockets; TVector fds; for (size_t i = 0; i < socketCount; ++i) { TSimpleSharedPtr clientSocket(new TSocketHolder(StartClientSocket(localIp, port))); clientSockets.push_back(clientSocket); if (i % 5 == 0 || i % 5 == 2) { char buffer = 'c'; if (send(*clientSocket, &buffer, 1, 0) == -1) ythrow yexception() << "Can not send (" << LastSystemErrorText() << ")"; } TSimpleSharedPtr connectedSocket(new TSocketHolder(AcceptConnection(serverSocket))); connectedSockets.push_back(connectedSocket); if (i % 5 == 2 || i % 5 == 3) { closesocket(*clientSocket); shutdown(*clientSocket, SD_BOTH); } } int expectedCount = 0; for (size_t i = 0; i < connectedSockets.size(); ++i) { pollfd fd = {(i % 5 == 4) ? INVALID_SOCKET : static_cast(*connectedSockets[i]), POLLIN | POLLOUT, 0}; fds.push_back(fd); if (i % 5 != 4) ++expectedCount; } int polledCount = poll(&fds[0], fds.size(), INFTIM); UNIT_ASSERT_EQUAL(expectedCount, polledCount); for (size_t i = 0; i < connectedSockets.size(); ++i) { short revents = fds[i].revents; if (i % 5 == 0) { UNIT_ASSERT_EQUAL(static_cast(POLLRDNORM | POLLWRNORM), revents); } else if (i % 5 == 1) { UNIT_ASSERT_EQUAL(static_cast(POLLOUT | POLLWRNORM), revents); } else if (i % 5 == 2) { UNIT_ASSERT_EQUAL(static_cast(POLLHUP | POLLRDNORM | POLLWRNORM), revents); } else if (i % 5 == 3) { UNIT_ASSERT_EQUAL(static_cast(POLLHUP | POLLWRNORM), revents); } else if (i % 5 == 4) { UNIT_ASSERT_EQUAL(static_cast(POLLNVAL), revents); } } #endif }