WebSocket.cpp 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. //
  2. // WebSocket.cpp
  3. //
  4. // Library: Net
  5. // Package: WebSocket
  6. // Module: WebSocket
  7. //
  8. // Copyright (c) 2012, Applied Informatics Software Engineering GmbH.
  9. // and Contributors.
  10. //
  11. // SPDX-License-Identifier: BSL-1.0
  12. //
  13. #include "Poco/Net/WebSocket.h"
  14. #include "Poco/Net/WebSocketImpl.h"
  15. #include "Poco/Net/HTTPServerRequestImpl.h"
  16. #include "Poco/Net/HTTPServerResponse.h"
  17. #include "Poco/Net/HTTPClientSession.h"
  18. #include "Poco/Net/HTTPServerSession.h"
  19. #include "Poco/Net/NetException.h"
  20. #include "Poco/MemoryStream.h"
  21. #include "Poco/NullStream.h"
  22. #include "Poco/BinaryWriter.h"
  23. #include "Poco/SHA1Engine.h"
  24. #include "Poco/Base64Encoder.h"
  25. #include "Poco/String.h"
  26. #include "Poco/Random.h"
  27. #include "Poco/StreamCopier.h"
  28. #include <sstream>
  29. namespace Poco {
  30. namespace Net {
  31. const std::string WebSocket::WEBSOCKET_GUID("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
  32. const std::string WebSocket::WEBSOCKET_VERSION("13");
  33. HTTPCredentials WebSocket::_defaultCreds;
  34. WebSocket::WebSocket(HTTPServerRequest& request, HTTPServerResponse& response):
  35. StreamSocket(accept(request, response))
  36. {
  37. }
  38. WebSocket::WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response):
  39. StreamSocket(connect(cs, request, response, _defaultCreds))
  40. {
  41. }
  42. WebSocket::WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response, HTTPCredentials& credentials):
  43. StreamSocket(connect(cs, request, response, credentials))
  44. {
  45. }
  46. WebSocket::WebSocket(const Socket& socket):
  47. StreamSocket(socket)
  48. {
  49. if (!dynamic_cast<WebSocketImpl*>(impl()))
  50. throw InvalidArgumentException("Cannot assign incompatible socket");
  51. }
  52. WebSocket::~WebSocket()
  53. {
  54. }
  55. WebSocket& WebSocket::operator = (const Socket& socket)
  56. {
  57. if (dynamic_cast<WebSocketImpl*>(socket.impl()))
  58. Socket::operator = (socket);
  59. else
  60. throw InvalidArgumentException("Cannot assign incompatible socket");
  61. return *this;
  62. }
  63. void WebSocket::shutdown()
  64. {
  65. shutdown(WS_NORMAL_CLOSE);
  66. }
  67. void WebSocket::shutdown(Poco::UInt16 statusCode, const std::string& statusMessage)
  68. {
  69. Poco::Buffer<char> buffer(statusMessage.size() + 2);
  70. Poco::MemoryOutputStream ostr(buffer.begin(), buffer.size());
  71. Poco::BinaryWriter writer(ostr, Poco::BinaryWriter::NETWORK_BYTE_ORDER);
  72. writer << statusCode;
  73. writer.writeRaw(statusMessage);
  74. sendFrame(buffer.begin(), static_cast<int>(ostr.charsWritten()), FRAME_FLAG_FIN | FRAME_OP_CLOSE);
  75. }
  76. int WebSocket::sendFrame(const void* buffer, int length, int flags)
  77. {
  78. flags |= FRAME_OP_SETRAW;
  79. return static_cast<WebSocketImpl*>(impl())->sendBytes(buffer, length, flags);
  80. }
  81. int WebSocket::receiveFrame(void* buffer, int length, int& flags)
  82. {
  83. int n = static_cast<WebSocketImpl*>(impl())->receiveBytes(buffer, length, 0);
  84. flags = static_cast<WebSocketImpl*>(impl())->frameFlags();
  85. return n;
  86. }
  87. int WebSocket::receiveFrame(Poco::Buffer<char>& buffer, int& flags)
  88. {
  89. int n = static_cast<WebSocketImpl*>(impl())->receiveBytes(buffer, 0);
  90. flags = static_cast<WebSocketImpl*>(impl())->frameFlags();
  91. return n;
  92. }
  93. WebSocket::Mode WebSocket::mode() const
  94. {
  95. return static_cast<WebSocketImpl*>(impl())->mustMaskPayload() ? WS_CLIENT : WS_SERVER;
  96. }
  97. void WebSocket::setMaxPayloadSize(int maxPayloadSize)
  98. {
  99. static_cast<WebSocketImpl*>(impl())->setMaxPayloadSize(maxPayloadSize);
  100. }
  101. int WebSocket::getMaxPayloadSize() const
  102. {
  103. return static_cast<WebSocketImpl*>(impl())->getMaxPayloadSize();
  104. }
  105. WebSocketImpl* WebSocket::accept(HTTPServerRequest& request, HTTPServerResponse& response)
  106. {
  107. if (request.hasToken("Connection", "upgrade") && icompare(request.get("Upgrade", ""), "websocket") == 0)
  108. {
  109. std::string version = request.get("Sec-WebSocket-Version", "");
  110. if (version.empty()) throw WebSocketException("Missing Sec-WebSocket-Version in handshake request", WS_ERR_HANDSHAKE_NO_VERSION);
  111. if (version != WEBSOCKET_VERSION) throw WebSocketException("Unsupported WebSocket version requested", version, WS_ERR_HANDSHAKE_UNSUPPORTED_VERSION);
  112. std::string key = request.get("Sec-WebSocket-Key", "");
  113. Poco::trimInPlace(key);
  114. if (key.empty()) throw WebSocketException("Missing Sec-WebSocket-Key in handshake request", WS_ERR_HANDSHAKE_NO_KEY);
  115. response.setStatusAndReason(HTTPResponse::HTTP_SWITCHING_PROTOCOLS);
  116. response.set("Upgrade", "websocket");
  117. response.set("Connection", "Upgrade");
  118. response.set("Sec-WebSocket-Accept", computeAccept(key));
  119. response.setContentLength(HTTPResponse::UNKNOWN_CONTENT_LENGTH);
  120. response.send().flush();
  121. HTTPServerRequestImpl& requestImpl = static_cast<HTTPServerRequestImpl&>(request);
  122. return new WebSocketImpl(static_cast<StreamSocketImpl*>(requestImpl.detachSocket().impl()), requestImpl.session(), false);
  123. }
  124. else throw WebSocketException("No WebSocket handshake", WS_ERR_NO_HANDSHAKE);
  125. }
  126. WebSocketImpl* WebSocket::connect(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response, HTTPCredentials& credentials)
  127. {
  128. if (!cs.getProxyHost().empty() && !cs.secure())
  129. {
  130. cs.proxyTunnel();
  131. }
  132. std::string key = createKey();
  133. request.set("Connection", "Upgrade");
  134. request.set("Upgrade", "websocket");
  135. request.set("Sec-WebSocket-Version", WEBSOCKET_VERSION);
  136. request.set("Sec-WebSocket-Key", key);
  137. request.setChunkedTransferEncoding(false);
  138. cs.setKeepAlive(true);
  139. cs.sendRequest(request);
  140. std::istream& istr = cs.receiveResponse(response);
  141. if (response.getStatus() == HTTPResponse::HTTP_SWITCHING_PROTOCOLS)
  142. {
  143. return completeHandshake(cs, response, key);
  144. }
  145. else if (response.getStatus() == HTTPResponse::HTTP_UNAUTHORIZED)
  146. {
  147. if (!credentials.empty())
  148. {
  149. Poco::NullOutputStream null;
  150. Poco::StreamCopier::copyStream(istr, null);
  151. credentials.authenticate(request, response);
  152. if (!cs.getProxyHost().empty() && !cs.secure())
  153. {
  154. cs.reset();
  155. cs.proxyTunnel();
  156. }
  157. cs.sendRequest(request);
  158. cs.receiveResponse(response);
  159. if (response.getStatus() == HTTPResponse::HTTP_SWITCHING_PROTOCOLS)
  160. {
  161. return completeHandshake(cs, response, key);
  162. }
  163. else if (response.getStatus() == HTTPResponse::HTTP_UNAUTHORIZED)
  164. {
  165. throw WebSocketException("Not authorized", WS_ERR_UNAUTHORIZED);
  166. }
  167. }
  168. else throw WebSocketException("Not authorized", WS_ERR_UNAUTHORIZED);
  169. }
  170. if (response.getStatus() == HTTPResponse::HTTP_OK)
  171. {
  172. throw WebSocketException("The server does not understand the WebSocket protocol", WS_ERR_NO_HANDSHAKE);
  173. }
  174. else
  175. {
  176. throw WebSocketException("Cannot upgrade to WebSocket connection", response.getReason(), WS_ERR_NO_HANDSHAKE);
  177. }
  178. }
  179. WebSocketImpl* WebSocket::completeHandshake(HTTPClientSession& cs, HTTPResponse& response, const std::string& key)
  180. {
  181. std::string connection = response.get("Connection", "");
  182. if (Poco::icompare(connection, "Upgrade") != 0)
  183. throw WebSocketException("No Connection: Upgrade header in handshake response", WS_ERR_NO_HANDSHAKE);
  184. std::string upgrade = response.get("Upgrade", "");
  185. if (Poco::icompare(upgrade, "websocket") != 0)
  186. throw WebSocketException("No Upgrade: websocket header in handshake response", WS_ERR_NO_HANDSHAKE);
  187. std::string accept = response.get("Sec-WebSocket-Accept", "");
  188. if (accept != computeAccept(key))
  189. throw WebSocketException("Invalid or missing Sec-WebSocket-Accept header in handshake response", WS_ERR_HANDSHAKE_ACCEPT);
  190. return new WebSocketImpl(static_cast<StreamSocketImpl*>(cs.detachSocket().impl()), cs, true);
  191. }
  192. std::string WebSocket::createKey()
  193. {
  194. Poco::Random rnd;
  195. std::ostringstream ostr;
  196. Poco::Base64Encoder base64(ostr);
  197. Poco::BinaryWriter writer(base64);
  198. writer << rnd.next() << rnd.next() << rnd.next() << rnd.next();
  199. base64.close();
  200. return ostr.str();
  201. }
  202. std::string WebSocket::computeAccept(const std::string& key)
  203. {
  204. std::string accept(key);
  205. accept += WEBSOCKET_GUID;
  206. Poco::SHA1Engine sha1;
  207. sha1.update(accept);
  208. Poco::DigestEngine::Digest d = sha1.digest();
  209. std::ostringstream ostr;
  210. Poco::Base64Encoder base64(ostr);
  211. base64.write(reinterpret_cast<const char*>(&d[0]), d.size());
  212. base64.close();
  213. return ostr.str();
  214. }
  215. } } // namespace Poco::Net