http_proxy_sock_impl.h 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. #pragma once
  2. #include "http.h"
  3. #include "http_proxy.h"
  4. namespace NHttp {
  5. struct TPlainSocketImpl : virtual public THttpConfig {
  6. TIntrusivePtr<TSocketDescriptor> Socket;
  7. TPlainSocketImpl()
  8. : Socket(new TSocketDescriptor())
  9. {}
  10. TPlainSocketImpl(TIntrusivePtr<TSocketDescriptor> socket)
  11. : Socket(std::move(socket))
  12. {}
  13. SOCKET GetRawSocket() const {
  14. return static_cast<SOCKET>(Socket->Socket);
  15. }
  16. void SetNonBlock(bool nonBlock = true) noexcept {
  17. try {
  18. ::SetNonBlock(Socket->Socket, nonBlock);
  19. }
  20. catch (const yexception&) {
  21. }
  22. }
  23. void SetTimeout(TDuration timeout) noexcept {
  24. try {
  25. ::SetSocketTimeout(Socket->Socket, timeout.Seconds(), timeout.MilliSecondsOfSecond());
  26. }
  27. catch (const yexception&) {
  28. }
  29. }
  30. void Shutdown() {
  31. //Socket->Socket.ShutDown(SHUT_RDWR); // KIKIMR-3895
  32. ::shutdown(Socket->Socket, SHUT_RDWR);
  33. }
  34. int Connect(const SocketAddressType& address) {
  35. return Socket->Socket.Connect(&address);
  36. }
  37. static constexpr int OnConnect(bool&, bool&) {
  38. return 1;
  39. }
  40. static constexpr int OnAccept(const TEndpointInfo&, bool&, bool&) {
  41. return 1;
  42. }
  43. bool IsGood() {
  44. int res;
  45. GetSockOpt(Socket->Socket, SOL_SOCKET, SO_ERROR, res);
  46. return res == 0;
  47. }
  48. int GetError() {
  49. int res;
  50. GetSockOpt(Socket->Socket, SOL_SOCKET, SO_ERROR, res);
  51. return res;
  52. }
  53. ssize_t Send(const void* data, size_t size, bool&, bool&) {
  54. return Socket->Socket.Send(data, size);
  55. }
  56. ssize_t Recv(void* data, size_t size, bool&, bool&) {
  57. return Socket->Socket.Recv(data, size);
  58. }
  59. };
  60. struct TSecureSocketImpl : TPlainSocketImpl, TSslHelpers {
  61. static TSecureSocketImpl* IO(BIO* bio) noexcept {
  62. return static_cast<TSecureSocketImpl*>(BIO_get_data(bio));
  63. }
  64. static int IoWrite(BIO* bio, const char* data, int dlen) noexcept {
  65. BIO_clear_retry_flags(bio);
  66. int res = IO(bio)->Socket->Socket.Send(data, dlen);
  67. if (-res == EAGAIN) {
  68. BIO_set_retry_write(bio);
  69. }
  70. return res;
  71. }
  72. static int IoRead(BIO* bio, char* data, int dlen) noexcept {
  73. BIO_clear_retry_flags(bio);
  74. int res = IO(bio)->Socket->Socket.Recv(data, dlen);
  75. if (-res == EAGAIN) {
  76. BIO_set_retry_read(bio);
  77. }
  78. return res;
  79. }
  80. static int IoPuts(BIO* bio, const char* buf) noexcept {
  81. Y_UNUSED(bio);
  82. Y_UNUSED(buf);
  83. return -2;
  84. }
  85. static int IoGets(BIO* bio, char* buf, int size) noexcept {
  86. Y_UNUSED(bio);
  87. Y_UNUSED(buf);
  88. Y_UNUSED(size);
  89. return -2;
  90. }
  91. static long IoCtrl(BIO* bio, int cmd, long larg, void* parg) noexcept {
  92. Y_UNUSED(larg);
  93. Y_UNUSED(parg);
  94. if (cmd == BIO_CTRL_FLUSH) {
  95. IO(bio)->Flush();
  96. return 1;
  97. }
  98. return -2;
  99. }
  100. static int IoCreate(BIO* bio) noexcept {
  101. BIO_set_data(bio, nullptr);
  102. BIO_set_init(bio, 1);
  103. return 1;
  104. }
  105. static int IoDestroy(BIO* bio) noexcept {
  106. BIO_set_data(bio, nullptr);
  107. BIO_set_init(bio, 0);
  108. return 1;
  109. }
  110. static BIO_METHOD* CreateIoMethod() {
  111. BIO_METHOD* method = BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SecureSocketImpl");
  112. BIO_meth_set_write(method, IoWrite);
  113. BIO_meth_set_read(method, IoRead);
  114. BIO_meth_set_puts(method, IoPuts);
  115. BIO_meth_set_gets(method, IoGets);
  116. BIO_meth_set_ctrl(method, IoCtrl);
  117. BIO_meth_set_create(method, IoCreate);
  118. BIO_meth_set_destroy(method, IoDestroy);
  119. return method;
  120. }
  121. static BIO_METHOD* IoMethod() {
  122. static BIO_METHOD* method = CreateIoMethod();
  123. return method;
  124. }
  125. TSslHolder<BIO> Bio;
  126. TSslHolder<SSL_CTX> Ctx;
  127. TSslHolder<SSL> Ssl;
  128. TSecureSocketImpl() = default;
  129. TSecureSocketImpl(TIntrusivePtr<TSocketDescriptor> socket)
  130. : TPlainSocketImpl(std::move(socket))
  131. {}
  132. void InitClientSsl() {
  133. Bio.Reset(BIO_new(IoMethod()));
  134. BIO_set_data(Bio.Get(), this);
  135. BIO_set_nbio(Bio.Get(), 1);
  136. Ctx = CreateClientContext();
  137. Ssl = ConstructSsl(Ctx.Get(), Bio.Get());
  138. SSL_set_connect_state(Ssl.Get());
  139. }
  140. void InitServerSsl(SSL_CTX* ctx) {
  141. Bio.Reset(BIO_new(IoMethod()));
  142. BIO_set_data(Bio.Get(), this);
  143. BIO_set_nbio(Bio.Get(), 1);
  144. Ssl = ConstructSsl(ctx, Bio.Get());
  145. SSL_set_accept_state(Ssl.Get());
  146. }
  147. void Flush() {}
  148. ssize_t Send(const void* data, size_t size, bool& read, bool& write) {
  149. ssize_t res = SSL_write(Ssl.Get(), data, size);
  150. if (res < 0) {
  151. res = SSL_get_error(Ssl.Get(), res);
  152. switch(res) {
  153. case SSL_ERROR_WANT_READ:
  154. read = true;
  155. return -EAGAIN;
  156. case SSL_ERROR_WANT_WRITE:
  157. write = true;
  158. return -EAGAIN;
  159. default:
  160. return -EIO;
  161. }
  162. }
  163. return res;
  164. }
  165. ssize_t Recv(void* data, size_t size, bool& read, bool& write) {
  166. ssize_t res = SSL_read(Ssl.Get(), data, size);
  167. if (res < 0) {
  168. res = SSL_get_error(Ssl.Get(), res);
  169. switch(res) {
  170. case SSL_ERROR_WANT_READ:
  171. read = true;
  172. return -EAGAIN;
  173. case SSL_ERROR_WANT_WRITE:
  174. write = true;
  175. return -EAGAIN;
  176. default:
  177. return -EIO;
  178. }
  179. }
  180. return res;
  181. }
  182. int OnConnect(bool& read, bool& write) {
  183. if (!Ssl) {
  184. InitClientSsl();
  185. }
  186. int res = SSL_connect(Ssl.Get());
  187. if (res <= 0) {
  188. res = SSL_get_error(Ssl.Get(), res);
  189. switch(res) {
  190. case SSL_ERROR_WANT_READ:
  191. read = true;
  192. return -EAGAIN;
  193. case SSL_ERROR_WANT_WRITE:
  194. write = true;
  195. return -EAGAIN;
  196. default:
  197. return -EIO;
  198. }
  199. }
  200. return res;
  201. }
  202. int OnAccept(const TEndpointInfo& endpoint, bool& read, bool& write) {
  203. if (!Ssl) {
  204. InitServerSsl(endpoint.SecureContext.Get());
  205. }
  206. int res = SSL_accept(Ssl.Get());
  207. if (res <= 0) {
  208. res = SSL_get_error(Ssl.Get(), res);
  209. switch(res) {
  210. case SSL_ERROR_WANT_READ:
  211. read = true;
  212. return -EAGAIN;
  213. case SSL_ERROR_WANT_WRITE:
  214. write = true;
  215. return -EAGAIN;
  216. default:
  217. return -EIO;
  218. }
  219. }
  220. return res;
  221. }
  222. };
  223. }