http_proxy_sock_impl.h 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. #pragma once
  2. #include "http.h"
  3. #include "http_proxy.h"
  4. namespace NHttp {
  5. struct TPlainSocketImpl : virtual public THttpConfig {
  6. TIntrusivePtr<TSocketDescriptor> Socket;
  7. TString Host;
  8. TPlainSocketImpl() = default;
  9. void Create(int af) {
  10. Socket = new TSocketDescriptor(af);
  11. }
  12. TPlainSocketImpl(TIntrusivePtr<TSocketDescriptor> socket)
  13. : Socket(std::move(socket))
  14. {}
  15. SOCKET GetRawSocket() const {
  16. return static_cast<SOCKET>(Socket->Socket);
  17. }
  18. void SetNonBlock(bool nonBlock = true) noexcept {
  19. try {
  20. ::SetNonBlock(Socket->Socket, nonBlock);
  21. }
  22. catch (const yexception&) {
  23. }
  24. }
  25. void SetTimeout(TDuration timeout) noexcept {
  26. try {
  27. ::SetSocketTimeout(Socket->Socket, timeout.Seconds(), timeout.MilliSecondsOfSecond());
  28. }
  29. catch (const yexception&) {
  30. }
  31. }
  32. void Shutdown() {
  33. //Socket->Socket.ShutDown(SHUT_RDWR); // KIKIMR-3895
  34. if (Socket) {
  35. ::shutdown(Socket->Socket, SHUT_RDWR);
  36. }
  37. }
  38. int Connect(SocketAddressType address) {
  39. return Socket->Socket.Connect(address.get());
  40. }
  41. static constexpr int OnConnect(bool&, bool&) {
  42. return 1;
  43. }
  44. static int OnAccept(std::shared_ptr<TPrivateEndpointInfo>, bool&, bool&) {
  45. return 1;
  46. }
  47. bool IsGood() {
  48. int res;
  49. GetSockOpt(Socket->Socket, SOL_SOCKET, SO_ERROR, res);
  50. return res == 0;
  51. }
  52. int GetError() {
  53. int res;
  54. GetSockOpt(Socket->Socket, SOL_SOCKET, SO_ERROR, res);
  55. return res;
  56. }
  57. ssize_t Send(const void* data, size_t size, bool&, bool&) {
  58. return Socket->Socket.Send(data, size);
  59. }
  60. ssize_t Recv(void* data, size_t size, bool&, bool&) {
  61. return Socket->Socket.Recv(data, size);
  62. }
  63. void SetHost(const TString& host) {
  64. Host = host;
  65. }
  66. };
  67. struct TSecureSocketImpl : TPlainSocketImpl, TSslHelpers {
  68. static TSecureSocketImpl* IO(BIO* bio) noexcept {
  69. return static_cast<TSecureSocketImpl*>(BIO_get_data(bio));
  70. }
  71. static int IoWrite(BIO* bio, const char* data, int dlen) noexcept {
  72. BIO_clear_retry_flags(bio);
  73. int res = IO(bio)->Socket->Socket.Send(data, dlen);
  74. if (-res == EAGAIN) {
  75. BIO_set_retry_write(bio);
  76. }
  77. return res;
  78. }
  79. static int IoRead(BIO* bio, char* data, int dlen) noexcept {
  80. BIO_clear_retry_flags(bio);
  81. int res = IO(bio)->Socket->Socket.Recv(data, dlen);
  82. if (-res == EAGAIN) {
  83. BIO_set_retry_read(bio);
  84. }
  85. return res;
  86. }
  87. static int IoPuts(BIO* bio, const char* buf) noexcept {
  88. Y_UNUSED(bio);
  89. Y_UNUSED(buf);
  90. return -2;
  91. }
  92. static int IoGets(BIO* bio, char* buf, int size) noexcept {
  93. Y_UNUSED(bio);
  94. Y_UNUSED(buf);
  95. Y_UNUSED(size);
  96. return -2;
  97. }
  98. static long IoCtrl(BIO* bio, int cmd, long larg, void* parg) noexcept {
  99. Y_UNUSED(larg);
  100. Y_UNUSED(parg);
  101. if (cmd == BIO_CTRL_FLUSH) {
  102. IO(bio)->Flush();
  103. return 1;
  104. }
  105. return -2;
  106. }
  107. static int IoCreate(BIO* bio) noexcept {
  108. BIO_set_data(bio, nullptr);
  109. BIO_set_init(bio, 1);
  110. return 1;
  111. }
  112. static int IoDestroy(BIO* bio) noexcept {
  113. BIO_set_data(bio, nullptr);
  114. BIO_set_init(bio, 0);
  115. return 1;
  116. }
  117. static BIO_METHOD* CreateIoMethod() {
  118. BIO_METHOD* method = BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SecureSocketImpl");
  119. BIO_meth_set_write(method, IoWrite);
  120. BIO_meth_set_read(method, IoRead);
  121. BIO_meth_set_puts(method, IoPuts);
  122. BIO_meth_set_gets(method, IoGets);
  123. BIO_meth_set_ctrl(method, IoCtrl);
  124. BIO_meth_set_create(method, IoCreate);
  125. BIO_meth_set_destroy(method, IoDestroy);
  126. return method;
  127. }
  128. static BIO_METHOD* IoMethod() {
  129. static BIO_METHOD* method = CreateIoMethod();
  130. return method;
  131. }
  132. TSslHolder<BIO> Bio;
  133. TSslHolder<SSL_CTX> Ctx;
  134. TSslHolder<SSL> Ssl;
  135. TSecureSocketImpl() = default;
  136. TSecureSocketImpl(TIntrusivePtr<TSocketDescriptor> socket)
  137. : TPlainSocketImpl(std::move(socket))
  138. {}
  139. void InitClientSsl() {
  140. Bio.Reset(BIO_new(IoMethod()));
  141. BIO_set_data(Bio.Get(), this);
  142. BIO_set_nbio(Bio.Get(), 1);
  143. Ctx = CreateClientContext();
  144. Ssl = ConstructSsl(Ctx.Get(), Bio.Get());
  145. if (!Host.Empty()) {
  146. SSL_set_tlsext_host_name(Ssl.Get(), Host.c_str());
  147. }
  148. SSL_set_connect_state(Ssl.Get());
  149. }
  150. void InitServerSsl(SSL_CTX* ctx) {
  151. Bio.Reset(BIO_new(IoMethod()));
  152. BIO_set_data(Bio.Get(), this);
  153. BIO_set_nbio(Bio.Get(), 1);
  154. Ssl = ConstructSsl(ctx, Bio.Get());
  155. SSL_set_accept_state(Ssl.Get());
  156. }
  157. void Flush() {}
  158. ssize_t Send(const void* data, size_t size, bool& read, bool& write) {
  159. ssize_t res = SSL_write(Ssl.Get(), data, size);
  160. if (res < 0) {
  161. res = SSL_get_error(Ssl.Get(), res);
  162. switch(res) {
  163. case SSL_ERROR_WANT_READ:
  164. read = true;
  165. return -EAGAIN;
  166. case SSL_ERROR_WANT_WRITE:
  167. write = true;
  168. return -EAGAIN;
  169. default:
  170. return -EIO;
  171. }
  172. }
  173. return res;
  174. }
  175. ssize_t Recv(void* data, size_t size, bool& read, bool& write) {
  176. ssize_t res = SSL_read(Ssl.Get(), data, size);
  177. if (res < 0) {
  178. res = SSL_get_error(Ssl.Get(), res);
  179. switch(res) {
  180. case SSL_ERROR_WANT_READ:
  181. read = true;
  182. return -EAGAIN;
  183. case SSL_ERROR_WANT_WRITE:
  184. write = true;
  185. return -EAGAIN;
  186. default:
  187. return -EIO;
  188. }
  189. }
  190. return res;
  191. }
  192. int OnConnect(bool& read, bool& write) {
  193. if (!Ssl) {
  194. InitClientSsl();
  195. }
  196. int res = SSL_connect(Ssl.Get());
  197. if (res <= 0) {
  198. res = SSL_get_error(Ssl.Get(), res);
  199. switch(res) {
  200. case SSL_ERROR_WANT_READ:
  201. read = true;
  202. return -EAGAIN;
  203. case SSL_ERROR_WANT_WRITE:
  204. write = true;
  205. return -EAGAIN;
  206. default:
  207. return -EIO;
  208. }
  209. }
  210. return res;
  211. }
  212. int OnAccept(std::shared_ptr<TPrivateEndpointInfo> endpoint, bool& read, bool& write) {
  213. if (!Ssl) {
  214. InitServerSsl(endpoint->SecureContext.Get());
  215. }
  216. int res = SSL_accept(Ssl.Get());
  217. if (res <= 0) {
  218. res = SSL_get_error(Ssl.Get(), res);
  219. switch(res) {
  220. case SSL_ERROR_WANT_READ:
  221. read = true;
  222. return -EAGAIN;
  223. case SSL_ERROR_WANT_WRITE:
  224. write = true;
  225. return -EAGAIN;
  226. default:
  227. return -EIO;
  228. }
  229. }
  230. return res;
  231. }
  232. };
  233. }