socket_ut.cpp 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. #include "socket.h"
  2. #include "pair.h"
  3. #include <library/cpp/testing/unittest/registar.h>
  4. #include <util/string/builder.h>
  5. #include <util/generic/vector.h>
  6. #include <ctime>
  7. #ifdef _linux_
  8. #include <linux/version.h>
  9. #include <sys/utsname.h>
  10. #endif
  11. class TSockTest: public TTestBase {
  12. UNIT_TEST_SUITE(TSockTest);
  13. UNIT_TEST(TestSock);
  14. UNIT_TEST(TestTimeout);
  15. #ifndef _win_ // Test hangs on Windows
  16. UNIT_TEST_EXCEPTION(TestConnectionRefused, yexception);
  17. #endif
  18. UNIT_TEST(TestNetworkResolutionError);
  19. UNIT_TEST(TestNetworkResolutionErrorMessage);
  20. UNIT_TEST(TestBrokenPipe);
  21. UNIT_TEST(TestClose);
  22. UNIT_TEST_SUITE_END();
  23. public:
  24. void TestSock();
  25. void TestTimeout();
  26. void TestConnectionRefused();
  27. void TestNetworkResolutionError();
  28. void TestNetworkResolutionErrorMessage();
  29. void TestBrokenPipe();
  30. void TestClose();
  31. };
  32. UNIT_TEST_SUITE_REGISTRATION(TSockTest);
  33. void TSockTest::TestSock() {
  34. TNetworkAddress addr("yandex.ru", 80);
  35. TSocket s(addr);
  36. TSocketOutput so(s);
  37. TSocketInput si(s);
  38. const TStringBuf req = "GET / HTTP/1.1\r\nHost: yandex.ru\r\n\r\n";
  39. so.Write(req.data(), req.size());
  40. UNIT_ASSERT(!si.ReadLine().empty());
  41. }
  42. void TSockTest::TestTimeout() {
  43. static const int timeout = 1000;
  44. i64 startTime = millisec();
  45. try {
  46. TNetworkAddress addr("localhost", 1313);
  47. TSocket s(addr, TDuration::MilliSeconds(timeout));
  48. } catch (const yexception&) {
  49. }
  50. int realTimeout = (int)(millisec() - startTime);
  51. if (realTimeout > timeout + 2000) {
  52. TString err = TStringBuilder() << "Timeout exceeded: " << realTimeout << " ms (expected " << timeout << " ms)";
  53. UNIT_FAIL(err);
  54. }
  55. }
  56. void TSockTest::TestConnectionRefused() {
  57. TNetworkAddress addr("localhost", 1313);
  58. TSocket s(addr);
  59. }
  60. void TSockTest::TestNetworkResolutionError() {
  61. TString errMsg;
  62. try {
  63. TNetworkAddress addr("", 0);
  64. } catch (const TNetworkResolutionError& e) {
  65. errMsg = e.what();
  66. }
  67. if (errMsg.empty()) {
  68. return; // on Windows getaddrinfo("", 0, ...) returns "OK"
  69. }
  70. int expectedErr = EAI_NONAME;
  71. TString expectedErrMsg = gai_strerror(expectedErr);
  72. if (errMsg.find(expectedErrMsg) == TString::npos) {
  73. UNIT_FAIL("TNetworkResolutionError contains\nInvalid msg: " + errMsg + "\nExpected msg: " + expectedErrMsg + "\n");
  74. }
  75. }
  76. void TSockTest::TestNetworkResolutionErrorMessage() {
  77. #ifdef _unix_
  78. auto str = [](int code) -> TString {
  79. return TNetworkResolutionError(code).what();
  80. };
  81. auto expected = [](int code) -> TString {
  82. return gai_strerror(code);
  83. };
  84. struct TErrnoGuard {
  85. TErrnoGuard()
  86. : PrevValue_(errno)
  87. {
  88. }
  89. ~TErrnoGuard() {
  90. errno = PrevValue_;
  91. }
  92. private:
  93. int PrevValue_;
  94. } g;
  95. UNIT_ASSERT_VALUES_EQUAL(expected(0) + "(0): ", str(0));
  96. UNIT_ASSERT_VALUES_EQUAL(expected(-9) + "(-9): ", str(-9));
  97. errno = 0;
  98. UNIT_ASSERT_VALUES_EQUAL(expected(EAI_SYSTEM) + "(" + IntToString<10>(EAI_SYSTEM) + "; errno=0): ",
  99. str(EAI_SYSTEM));
  100. errno = 110;
  101. UNIT_ASSERT_VALUES_EQUAL(expected(EAI_SYSTEM) + "(" + IntToString<10>(EAI_SYSTEM) + "; errno=110): ",
  102. str(EAI_SYSTEM));
  103. #endif
  104. }
  105. class TTempEnableSigPipe {
  106. public:
  107. TTempEnableSigPipe() {
  108. OriginalSigHandler_ = signal(SIGPIPE, SIG_DFL);
  109. Y_ABORT_UNLESS(OriginalSigHandler_ != SIG_ERR);
  110. }
  111. ~TTempEnableSigPipe() {
  112. auto ret = signal(SIGPIPE, OriginalSigHandler_);
  113. Y_ABORT_UNLESS(ret != SIG_ERR);
  114. }
  115. private:
  116. void (*OriginalSigHandler_)(int);
  117. };
  118. void TSockTest::TestBrokenPipe() {
  119. TTempEnableSigPipe guard;
  120. SOCKET socks[2];
  121. int ret = SocketPair(socks);
  122. UNIT_ASSERT_VALUES_EQUAL(ret, 0);
  123. TSocket sender(socks[0]);
  124. TSocket receiver(socks[1]);
  125. receiver.ShutDown(SHUT_RDWR);
  126. int sent = sender.Send("FOO", 3);
  127. UNIT_ASSERT(sent < 0);
  128. IOutputStream::TPart parts[] = {
  129. {"foo", 3},
  130. {"bar", 3},
  131. };
  132. sent = sender.SendV(parts, 2);
  133. UNIT_ASSERT(sent < 0);
  134. }
  135. void TSockTest::TestClose() {
  136. SOCKET socks[2];
  137. UNIT_ASSERT_EQUAL(SocketPair(socks), 0);
  138. TSocket receiver(socks[1]);
  139. UNIT_ASSERT_EQUAL(static_cast<SOCKET>(receiver), socks[1]);
  140. #if defined _linux_
  141. UNIT_ASSERT_GE(fcntl(socks[1], F_GETFD), 0);
  142. receiver.Close();
  143. UNIT_ASSERT_EQUAL(fcntl(socks[1], F_GETFD), -1);
  144. #else
  145. receiver.Close();
  146. #endif
  147. UNIT_ASSERT_EQUAL(static_cast<SOCKET>(receiver), INVALID_SOCKET);
  148. }
  149. class TPollTest: public TTestBase {
  150. UNIT_TEST_SUITE(TPollTest);
  151. UNIT_TEST(TestPollInOut);
  152. UNIT_TEST_SUITE_END();
  153. public:
  154. inline TPollTest() {
  155. srand(static_cast<unsigned int>(time(nullptr)));
  156. }
  157. void TestPollInOut();
  158. private:
  159. sockaddr_in GetAddress(ui32 ip, ui16 port);
  160. SOCKET CreateSocket();
  161. SOCKET StartServerSocket(ui16 port, int backlog);
  162. SOCKET StartClientSocket(ui32 ip, ui16 port);
  163. SOCKET AcceptConnection(SOCKET serverSocket);
  164. };
  165. UNIT_TEST_SUITE_REGISTRATION(TPollTest);
  166. sockaddr_in TPollTest::GetAddress(ui32 ip, ui16 port) {
  167. struct sockaddr_in addr;
  168. memset(&addr, 0, sizeof(addr));
  169. addr.sin_family = AF_INET;
  170. addr.sin_port = htons(port);
  171. addr.sin_addr.s_addr = htonl(ip);
  172. return addr;
  173. }
  174. SOCKET TPollTest::CreateSocket() {
  175. SOCKET s = socket(AF_INET, SOCK_STREAM, 0);
  176. if (s == INVALID_SOCKET) {
  177. ythrow yexception() << "Can not create socket (" << LastSystemErrorText() << ")";
  178. }
  179. return s;
  180. }
  181. SOCKET TPollTest::StartServerSocket(ui16 port, int backlog) {
  182. TSocketHolder s(CreateSocket());
  183. sockaddr_in addr = GetAddress(ntohl(INADDR_ANY), port);
  184. if (bind(s, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) {
  185. ythrow yexception() << "Can not bind server socket (" << LastSystemErrorText() << ")";
  186. }
  187. if (listen(s, backlog) == SOCKET_ERROR) {
  188. ythrow yexception() << "Can not listen on server socket (" << LastSystemErrorText() << ")";
  189. }
  190. return s.Release();
  191. }
  192. SOCKET TPollTest::StartClientSocket(ui32 ip, ui16 port) {
  193. TSocketHolder s(CreateSocket());
  194. sockaddr_in addr = GetAddress(ip, port);
  195. if (connect(s, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) {
  196. ythrow yexception() << "Can not connect client socket (" << LastSystemErrorText() << ")";
  197. }
  198. return s.Release();
  199. }
  200. SOCKET TPollTest::AcceptConnection(SOCKET serverSocket) {
  201. SOCKET connectedSocket = accept(serverSocket, nullptr, nullptr);
  202. if (connectedSocket == INVALID_SOCKET) {
  203. ythrow yexception() << "Can not accept connection on server socket (" << LastSystemErrorText() << ")";
  204. }
  205. return connectedSocket;
  206. }
  207. void TPollTest::TestPollInOut() {
  208. #ifdef _win_
  209. const size_t socketCount = 1000;
  210. ui16 port = static_cast<ui16>(1300 + rand() % 97);
  211. TSocketHolder serverSocket = StartServerSocket(port, socketCount);
  212. ui32 localIp = ntohl(inet_addr("127.0.0.1"));
  213. TVector<TSimpleSharedPtr<TSocketHolder>> clientSockets;
  214. TVector<TSimpleSharedPtr<TSocketHolder>> connectedSockets;
  215. TVector<pollfd> fds;
  216. for (size_t i = 0; i < socketCount; ++i) {
  217. TSimpleSharedPtr<TSocketHolder> clientSocket(new TSocketHolder(StartClientSocket(localIp, port)));
  218. clientSockets.push_back(clientSocket);
  219. if (i % 5 == 0 || i % 5 == 2) {
  220. char buffer = 'c';
  221. if (send(*clientSocket, &buffer, 1, 0) == -1)
  222. ythrow yexception() << "Can not send (" << LastSystemErrorText() << ")";
  223. }
  224. TSimpleSharedPtr<TSocketHolder> connectedSocket(new TSocketHolder(AcceptConnection(serverSocket)));
  225. connectedSockets.push_back(connectedSocket);
  226. if (i % 5 == 2 || i % 5 == 3) {
  227. closesocket(*clientSocket);
  228. shutdown(*clientSocket, SD_BOTH);
  229. }
  230. }
  231. int expectedCount = 0;
  232. for (size_t i = 0; i < connectedSockets.size(); ++i) {
  233. pollfd fd = {(i % 5 == 4) ? INVALID_SOCKET : static_cast<SOCKET>(*connectedSockets[i]), POLLIN | POLLOUT, 0};
  234. fds.push_back(fd);
  235. if (i % 5 != 4)
  236. ++expectedCount;
  237. }
  238. int polledCount = poll(&fds[0], fds.size(), INFTIM);
  239. UNIT_ASSERT_EQUAL(expectedCount, polledCount);
  240. for (size_t i = 0; i < connectedSockets.size(); ++i) {
  241. short revents = fds[i].revents;
  242. if (i % 5 == 0) {
  243. UNIT_ASSERT_EQUAL(static_cast<short>(POLLRDNORM | POLLWRNORM), revents);
  244. } else if (i % 5 == 1) {
  245. UNIT_ASSERT_EQUAL(static_cast<short>(POLLOUT | POLLWRNORM), revents);
  246. } else if (i % 5 == 2) {
  247. UNIT_ASSERT_EQUAL(static_cast<short>(POLLHUP | POLLRDNORM | POLLWRNORM), revents);
  248. } else if (i % 5 == 3) {
  249. UNIT_ASSERT_EQUAL(static_cast<short>(POLLHUP | POLLWRNORM), revents);
  250. } else if (i % 5 == 4) {
  251. UNIT_ASSERT_EQUAL(static_cast<short>(POLLNVAL), revents);
  252. }
  253. }
  254. #endif
  255. }