web_ws.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. import asyncio
  2. import base64
  3. import binascii
  4. import hashlib
  5. import json
  6. import sys
  7. from typing import Any, Final, Iterable, Optional, Tuple, cast
  8. import attr
  9. from multidict import CIMultiDict
  10. from . import hdrs
  11. from .abc import AbstractStreamWriter
  12. from .helpers import call_later, set_exception, set_result
  13. from .http import (
  14. WS_CLOSED_MESSAGE,
  15. WS_CLOSING_MESSAGE,
  16. WS_KEY,
  17. WebSocketError,
  18. WebSocketReader,
  19. WebSocketWriter,
  20. WSCloseCode,
  21. WSMessage,
  22. WSMsgType as WSMsgType,
  23. ws_ext_gen,
  24. ws_ext_parse,
  25. )
  26. from .log import ws_logger
  27. from .streams import EofStream, FlowControlDataQueue
  28. from .typedefs import JSONDecoder, JSONEncoder
  29. from .web_exceptions import HTTPBadRequest, HTTPException
  30. from .web_request import BaseRequest
  31. from .web_response import StreamResponse
  32. if sys.version_info >= (3, 11):
  33. import asyncio as async_timeout
  34. else:
  35. import async_timeout
  36. __all__ = (
  37. "WebSocketResponse",
  38. "WebSocketReady",
  39. "WSMsgType",
  40. )
  41. THRESHOLD_CONNLOST_ACCESS: Final[int] = 5
  42. @attr.s(auto_attribs=True, frozen=True, slots=True)
  43. class WebSocketReady:
  44. ok: bool
  45. protocol: Optional[str]
  46. def __bool__(self) -> bool:
  47. return self.ok
  48. class WebSocketResponse(StreamResponse):
  49. _length_check = False
  50. def __init__(
  51. self,
  52. *,
  53. timeout: float = 10.0,
  54. receive_timeout: Optional[float] = None,
  55. autoclose: bool = True,
  56. autoping: bool = True,
  57. heartbeat: Optional[float] = None,
  58. protocols: Iterable[str] = (),
  59. compress: bool = True,
  60. max_msg_size: int = 4 * 1024 * 1024,
  61. ) -> None:
  62. super().__init__(status=101)
  63. self._protocols = protocols
  64. self._ws_protocol: Optional[str] = None
  65. self._writer: Optional[WebSocketWriter] = None
  66. self._reader: Optional[FlowControlDataQueue[WSMessage]] = None
  67. self._closed = False
  68. self._closing = False
  69. self._conn_lost = 0
  70. self._close_code: Optional[int] = None
  71. self._loop: Optional[asyncio.AbstractEventLoop] = None
  72. self._waiting: Optional[asyncio.Future[bool]] = None
  73. self._exception: Optional[BaseException] = None
  74. self._timeout = timeout
  75. self._receive_timeout = receive_timeout
  76. self._autoclose = autoclose
  77. self._autoping = autoping
  78. self._heartbeat = heartbeat
  79. self._heartbeat_cb: Optional[asyncio.TimerHandle] = None
  80. if heartbeat is not None:
  81. self._pong_heartbeat = heartbeat / 2.0
  82. self._pong_response_cb: Optional[asyncio.TimerHandle] = None
  83. self._compress = compress
  84. self._max_msg_size = max_msg_size
  85. def _cancel_heartbeat(self) -> None:
  86. if self._pong_response_cb is not None:
  87. self._pong_response_cb.cancel()
  88. self._pong_response_cb = None
  89. if self._heartbeat_cb is not None:
  90. self._heartbeat_cb.cancel()
  91. self._heartbeat_cb = None
  92. def _reset_heartbeat(self) -> None:
  93. self._cancel_heartbeat()
  94. if self._heartbeat is not None:
  95. assert self._loop is not None
  96. self._heartbeat_cb = call_later(
  97. self._send_heartbeat,
  98. self._heartbeat,
  99. self._loop,
  100. timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold
  101. if self._req is not None
  102. else 5,
  103. )
  104. def _send_heartbeat(self) -> None:
  105. if self._heartbeat is not None and not self._closed:
  106. assert self._loop is not None
  107. # fire-and-forget a task is not perfect but maybe ok for
  108. # sending ping. Otherwise we need a long-living heartbeat
  109. # task in the class.
  110. self._loop.create_task(self._writer.ping()) # type: ignore[union-attr]
  111. if self._pong_response_cb is not None:
  112. self._pong_response_cb.cancel()
  113. self._pong_response_cb = call_later(
  114. self._pong_not_received,
  115. self._pong_heartbeat,
  116. self._loop,
  117. timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold
  118. if self._req is not None
  119. else 5,
  120. )
  121. def _pong_not_received(self) -> None:
  122. if self._req is not None and self._req.transport is not None:
  123. self._closed = True
  124. self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
  125. self._exception = asyncio.TimeoutError()
  126. async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
  127. # make pre-check to don't hide it by do_handshake() exceptions
  128. if self._payload_writer is not None:
  129. return self._payload_writer
  130. protocol, writer = self._pre_start(request)
  131. payload_writer = await super().prepare(request)
  132. assert payload_writer is not None
  133. self._post_start(request, protocol, writer)
  134. await payload_writer.drain()
  135. return payload_writer
  136. def _handshake(
  137. self, request: BaseRequest
  138. ) -> Tuple["CIMultiDict[str]", str, bool, bool]:
  139. headers = request.headers
  140. if "websocket" != headers.get(hdrs.UPGRADE, "").lower().strip():
  141. raise HTTPBadRequest(
  142. text=(
  143. "No WebSocket UPGRADE hdr: {}\n Can "
  144. '"Upgrade" only to "WebSocket".'
  145. ).format(headers.get(hdrs.UPGRADE))
  146. )
  147. if "upgrade" not in headers.get(hdrs.CONNECTION, "").lower():
  148. raise HTTPBadRequest(
  149. text="No CONNECTION upgrade hdr: {}".format(
  150. headers.get(hdrs.CONNECTION)
  151. )
  152. )
  153. # find common sub-protocol between client and server
  154. protocol = None
  155. if hdrs.SEC_WEBSOCKET_PROTOCOL in headers:
  156. req_protocols = [
  157. str(proto.strip())
  158. for proto in headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",")
  159. ]
  160. for proto in req_protocols:
  161. if proto in self._protocols:
  162. protocol = proto
  163. break
  164. else:
  165. # No overlap found: Return no protocol as per spec
  166. ws_logger.warning(
  167. "Client protocols %r don’t overlap server-known ones %r",
  168. req_protocols,
  169. self._protocols,
  170. )
  171. # check supported version
  172. version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, "")
  173. if version not in ("13", "8", "7"):
  174. raise HTTPBadRequest(text=f"Unsupported version: {version}")
  175. # check client handshake for validity
  176. key = headers.get(hdrs.SEC_WEBSOCKET_KEY)
  177. try:
  178. if not key or len(base64.b64decode(key)) != 16:
  179. raise HTTPBadRequest(text=f"Handshake error: {key!r}")
  180. except binascii.Error:
  181. raise HTTPBadRequest(text=f"Handshake error: {key!r}") from None
  182. accept_val = base64.b64encode(
  183. hashlib.sha1(key.encode() + WS_KEY).digest()
  184. ).decode()
  185. response_headers = CIMultiDict(
  186. {
  187. hdrs.UPGRADE: "websocket",
  188. hdrs.CONNECTION: "upgrade",
  189. hdrs.SEC_WEBSOCKET_ACCEPT: accept_val,
  190. }
  191. )
  192. notakeover = False
  193. compress = 0
  194. if self._compress:
  195. extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS)
  196. # Server side always get return with no exception.
  197. # If something happened, just drop compress extension
  198. compress, notakeover = ws_ext_parse(extensions, isserver=True)
  199. if compress:
  200. enabledext = ws_ext_gen(
  201. compress=compress, isserver=True, server_notakeover=notakeover
  202. )
  203. response_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = enabledext
  204. if protocol:
  205. response_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = protocol
  206. return (
  207. response_headers,
  208. protocol,
  209. compress,
  210. notakeover,
  211. ) # type: ignore[return-value]
  212. def _pre_start(self, request: BaseRequest) -> Tuple[str, WebSocketWriter]:
  213. self._loop = request._loop
  214. headers, protocol, compress, notakeover = self._handshake(request)
  215. self.set_status(101)
  216. self.headers.update(headers)
  217. self.force_close()
  218. self._compress = compress
  219. transport = request._protocol.transport
  220. assert transport is not None
  221. writer = WebSocketWriter(
  222. request._protocol, transport, compress=compress, notakeover=notakeover
  223. )
  224. return protocol, writer
  225. def _post_start(
  226. self, request: BaseRequest, protocol: str, writer: WebSocketWriter
  227. ) -> None:
  228. self._ws_protocol = protocol
  229. self._writer = writer
  230. self._reset_heartbeat()
  231. loop = self._loop
  232. assert loop is not None
  233. self._reader = FlowControlDataQueue(request._protocol, 2**16, loop=loop)
  234. request.protocol.set_parser(
  235. WebSocketReader(self._reader, self._max_msg_size, compress=self._compress)
  236. )
  237. # disable HTTP keepalive for WebSocket
  238. request.protocol.keep_alive(False)
  239. def can_prepare(self, request: BaseRequest) -> WebSocketReady:
  240. if self._writer is not None:
  241. raise RuntimeError("Already started")
  242. try:
  243. _, protocol, _, _ = self._handshake(request)
  244. except HTTPException:
  245. return WebSocketReady(False, None)
  246. else:
  247. return WebSocketReady(True, protocol)
  248. @property
  249. def closed(self) -> bool:
  250. return self._closed
  251. @property
  252. def close_code(self) -> Optional[int]:
  253. return self._close_code
  254. @property
  255. def ws_protocol(self) -> Optional[str]:
  256. return self._ws_protocol
  257. @property
  258. def compress(self) -> bool:
  259. return self._compress
  260. def get_extra_info(self, name: str, default: Any = None) -> Any:
  261. """Get optional transport information.
  262. If no value associated with ``name`` is found, ``default`` is returned.
  263. """
  264. writer = self._writer
  265. if writer is None:
  266. return default
  267. transport = writer.transport
  268. if transport is None:
  269. return default
  270. return transport.get_extra_info(name, default)
  271. def exception(self) -> Optional[BaseException]:
  272. return self._exception
  273. async def ping(self, message: bytes = b"") -> None:
  274. if self._writer is None:
  275. raise RuntimeError("Call .prepare() first")
  276. await self._writer.ping(message)
  277. async def pong(self, message: bytes = b"") -> None:
  278. # unsolicited pong
  279. if self._writer is None:
  280. raise RuntimeError("Call .prepare() first")
  281. await self._writer.pong(message)
  282. async def send_str(self, data: str, compress: Optional[bool] = None) -> None:
  283. if self._writer is None:
  284. raise RuntimeError("Call .prepare() first")
  285. if not isinstance(data, str):
  286. raise TypeError("data argument must be str (%r)" % type(data))
  287. await self._writer.send(data, binary=False, compress=compress)
  288. async def send_bytes(self, data: bytes, compress: Optional[bool] = None) -> None:
  289. if self._writer is None:
  290. raise RuntimeError("Call .prepare() first")
  291. if not isinstance(data, (bytes, bytearray, memoryview)):
  292. raise TypeError("data argument must be byte-ish (%r)" % type(data))
  293. await self._writer.send(data, binary=True, compress=compress)
  294. async def send_json(
  295. self,
  296. data: Any,
  297. compress: Optional[bool] = None,
  298. *,
  299. dumps: JSONEncoder = json.dumps,
  300. ) -> None:
  301. await self.send_str(dumps(data), compress=compress)
  302. async def write_eof(self) -> None: # type: ignore[override]
  303. if self._eof_sent:
  304. return
  305. if self._payload_writer is None:
  306. raise RuntimeError("Response has not been started")
  307. await self.close()
  308. self._eof_sent = True
  309. async def close(
  310. self, *, code: int = WSCloseCode.OK, message: bytes = b"", drain: bool = True
  311. ) -> bool:
  312. """Close websocket connection."""
  313. if self._writer is None:
  314. raise RuntimeError("Call .prepare() first")
  315. self._cancel_heartbeat()
  316. reader = self._reader
  317. assert reader is not None
  318. # we need to break `receive()` cycle first,
  319. # `close()` may be called from different task
  320. if self._waiting is not None and not self._closed:
  321. reader.feed_data(WS_CLOSING_MESSAGE, 0)
  322. await self._waiting
  323. if self._closed:
  324. return False
  325. self._closed = True
  326. try:
  327. await self._writer.close(code, message)
  328. writer = self._payload_writer
  329. assert writer is not None
  330. if drain:
  331. await writer.drain()
  332. except (asyncio.CancelledError, asyncio.TimeoutError):
  333. self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
  334. raise
  335. except Exception as exc:
  336. self._exception = exc
  337. self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
  338. return True
  339. if self._closing:
  340. self._close_transport()
  341. return True
  342. reader = self._reader
  343. assert reader is not None
  344. try:
  345. async with async_timeout.timeout(self._timeout):
  346. msg = await reader.read()
  347. except asyncio.CancelledError:
  348. self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
  349. raise
  350. except Exception as exc:
  351. self._exception = exc
  352. self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
  353. return True
  354. if msg.type == WSMsgType.CLOSE:
  355. self._set_code_close_transport(msg.data)
  356. return True
  357. self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
  358. self._exception = asyncio.TimeoutError()
  359. return True
  360. def _set_closing(self, code: WSCloseCode) -> None:
  361. """Set the close code and mark the connection as closing."""
  362. self._closing = True
  363. self._close_code = code
  364. def _set_code_close_transport(self, code: WSCloseCode) -> None:
  365. """Set the close code and close the transport."""
  366. self._close_code = code
  367. self._close_transport()
  368. def _close_transport(self) -> None:
  369. """Close the transport."""
  370. if self._req is not None and self._req.transport is not None:
  371. self._req.transport.close()
  372. async def receive(self, timeout: Optional[float] = None) -> WSMessage:
  373. if self._reader is None:
  374. raise RuntimeError("Call .prepare() first")
  375. loop = self._loop
  376. assert loop is not None
  377. while True:
  378. if self._waiting is not None:
  379. raise RuntimeError("Concurrent call to receive() is not allowed")
  380. if self._closed:
  381. self._conn_lost += 1
  382. if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS:
  383. raise RuntimeError("WebSocket connection is closed.")
  384. return WS_CLOSED_MESSAGE
  385. elif self._closing:
  386. return WS_CLOSING_MESSAGE
  387. try:
  388. self._waiting = loop.create_future()
  389. try:
  390. async with async_timeout.timeout(timeout or self._receive_timeout):
  391. msg = await self._reader.read()
  392. self._reset_heartbeat()
  393. finally:
  394. waiter = self._waiting
  395. set_result(waiter, True)
  396. self._waiting = None
  397. except asyncio.TimeoutError:
  398. raise
  399. except EofStream:
  400. self._close_code = WSCloseCode.OK
  401. await self.close()
  402. return WSMessage(WSMsgType.CLOSED, None, None)
  403. except WebSocketError as exc:
  404. self._close_code = exc.code
  405. await self.close(code=exc.code)
  406. return WSMessage(WSMsgType.ERROR, exc, None)
  407. except Exception as exc:
  408. self._exception = exc
  409. self._set_closing(WSCloseCode.ABNORMAL_CLOSURE)
  410. await self.close()
  411. return WSMessage(WSMsgType.ERROR, exc, None)
  412. if msg.type == WSMsgType.CLOSE:
  413. self._set_closing(msg.data)
  414. # Could be closed while awaiting reader.
  415. if not self._closed and self._autoclose:
  416. # The client is likely going to close the
  417. # connection out from under us so we do not
  418. # want to drain any pending writes as it will
  419. # likely result writing to a broken pipe.
  420. await self.close(drain=False)
  421. elif msg.type == WSMsgType.CLOSING:
  422. self._set_closing(WSCloseCode.OK)
  423. elif msg.type == WSMsgType.PING and self._autoping:
  424. await self.pong(msg.data)
  425. continue
  426. elif msg.type == WSMsgType.PONG and self._autoping:
  427. continue
  428. return msg
  429. async def receive_str(self, *, timeout: Optional[float] = None) -> str:
  430. msg = await self.receive(timeout)
  431. if msg.type != WSMsgType.TEXT:
  432. raise TypeError(
  433. "Received message {}:{!r} is not WSMsgType.TEXT".format(
  434. msg.type, msg.data
  435. )
  436. )
  437. return cast(str, msg.data)
  438. async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
  439. msg = await self.receive(timeout)
  440. if msg.type != WSMsgType.BINARY:
  441. raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes")
  442. return cast(bytes, msg.data)
  443. async def receive_json(
  444. self, *, loads: JSONDecoder = json.loads, timeout: Optional[float] = None
  445. ) -> Any:
  446. data = await self.receive_str(timeout=timeout)
  447. return loads(data)
  448. async def write(self, data: bytes) -> None:
  449. raise RuntimeError("Cannot call .write() for websocket")
  450. def __aiter__(self) -> "WebSocketResponse":
  451. return self
  452. async def __anext__(self) -> WSMessage:
  453. msg = await self.receive()
  454. if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
  455. raise StopAsyncIteration
  456. return msg
  457. def _cancel(self, exc: BaseException) -> None:
  458. # web_protocol calls this from connection_lost
  459. # or when the server is shutting down.
  460. self._closing = True
  461. if self._reader is not None:
  462. set_exception(self._reader, exc)