_handshake.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. """
  2. _handshake.py
  3. websocket - WebSocket client library for Python
  4. Copyright 2024 engn33r
  5. Licensed under the Apache License, Version 2.0 (the "License");
  6. you may not use this file except in compliance with the License.
  7. You may obtain a copy of the License at
  8. http://www.apache.org/licenses/LICENSE-2.0
  9. Unless required by applicable law or agreed to in writing, software
  10. distributed under the License is distributed on an "AS IS" BASIS,
  11. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. See the License for the specific language governing permissions and
  13. limitations under the License.
  14. """
  15. import hashlib
  16. import hmac
  17. import os
  18. from base64 import encodebytes as base64encode
  19. from http import HTTPStatus
  20. from ._cookiejar import SimpleCookieJar
  21. from ._exceptions import WebSocketException, WebSocketBadStatusException
  22. from ._http import read_headers
  23. from ._logging import dump, error
  24. from ._socket import send
  25. __all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
  26. # websocket supported version.
  27. VERSION = 13
  28. SUPPORTED_REDIRECT_STATUSES = (
  29. HTTPStatus.MOVED_PERMANENTLY,
  30. HTTPStatus.FOUND,
  31. HTTPStatus.SEE_OTHER,
  32. HTTPStatus.TEMPORARY_REDIRECT,
  33. HTTPStatus.PERMANENT_REDIRECT,
  34. )
  35. SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,)
  36. CookieJar = SimpleCookieJar()
  37. class handshake_response:
  38. def __init__(self, status: int, headers: dict, subprotocol):
  39. self.status = status
  40. self.headers = headers
  41. self.subprotocol = subprotocol
  42. CookieJar.add(headers.get("set-cookie"))
  43. def handshake(
  44. sock, url: str, hostname: str, port: int, resource: str, **options
  45. ) -> handshake_response:
  46. headers, key = _get_handshake_headers(resource, url, hostname, port, options)
  47. header_str = "\r\n".join(headers)
  48. send(sock, header_str)
  49. dump("request header", header_str)
  50. status, resp = _get_resp_headers(sock)
  51. if status in SUPPORTED_REDIRECT_STATUSES:
  52. return handshake_response(status, resp, None)
  53. success, subproto = _validate(resp, key, options.get("subprotocols"))
  54. if not success:
  55. raise WebSocketException("Invalid WebSocket Header")
  56. return handshake_response(status, resp, subproto)
  57. def _pack_hostname(hostname: str) -> str:
  58. # IPv6 address
  59. if ":" in hostname:
  60. return f"[{hostname}]"
  61. return hostname
  62. def _get_handshake_headers(
  63. resource: str, url: str, host: str, port: int, options: dict
  64. ) -> tuple:
  65. headers = [f"GET {resource} HTTP/1.1", "Upgrade: websocket"]
  66. if port in [80, 443]:
  67. hostport = _pack_hostname(host)
  68. else:
  69. hostport = f"{_pack_hostname(host)}:{port}"
  70. if options.get("host"):
  71. headers.append(f'Host: {options["host"]}')
  72. else:
  73. headers.append(f"Host: {hostport}")
  74. # scheme indicates whether http or https is used in Origin
  75. # The same approach is used in parse_url of _url.py to set default port
  76. scheme, url = url.split(":", 1)
  77. if not options.get("suppress_origin"):
  78. if "origin" in options and options["origin"] is not None:
  79. headers.append(f'Origin: {options["origin"]}')
  80. elif scheme == "wss":
  81. headers.append(f"Origin: https://{hostport}")
  82. else:
  83. headers.append(f"Origin: http://{hostport}")
  84. key = _create_sec_websocket_key()
  85. # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
  86. if not options.get("header") or "Sec-WebSocket-Key" not in options["header"]:
  87. headers.append(f"Sec-WebSocket-Key: {key}")
  88. else:
  89. key = options["header"]["Sec-WebSocket-Key"]
  90. if not options.get("header") or "Sec-WebSocket-Version" not in options["header"]:
  91. headers.append(f"Sec-WebSocket-Version: {VERSION}")
  92. if not options.get("connection"):
  93. headers.append("Connection: Upgrade")
  94. else:
  95. headers.append(options["connection"])
  96. if subprotocols := options.get("subprotocols"):
  97. headers.append(f'Sec-WebSocket-Protocol: {",".join(subprotocols)}')
  98. if header := options.get("header"):
  99. if isinstance(header, dict):
  100. header = [": ".join([k, v]) for k, v in header.items() if v is not None]
  101. headers.extend(header)
  102. server_cookie = CookieJar.get(host)
  103. client_cookie = options.get("cookie", None)
  104. if cookie := "; ".join(filter(None, [server_cookie, client_cookie])):
  105. headers.append(f"Cookie: {cookie}")
  106. headers.extend(("", ""))
  107. return headers, key
  108. def _get_resp_headers(sock, success_statuses: tuple = SUCCESS_STATUSES) -> tuple:
  109. status, resp_headers, status_message = read_headers(sock)
  110. if status not in success_statuses:
  111. content_len = resp_headers.get("content-length")
  112. if content_len:
  113. response_body = sock.recv(
  114. int(content_len)
  115. ) # read the body of the HTTP error message response and include it in the exception
  116. else:
  117. response_body = None
  118. raise WebSocketBadStatusException(
  119. f"Handshake status {status} {status_message} -+-+- {resp_headers} -+-+- {response_body}",
  120. status,
  121. status_message,
  122. resp_headers,
  123. response_body,
  124. )
  125. return status, resp_headers
  126. _HEADERS_TO_CHECK = {
  127. "upgrade": "websocket",
  128. "connection": "upgrade",
  129. }
  130. def _validate(headers, key: str, subprotocols) -> tuple:
  131. subproto = None
  132. for k, v in _HEADERS_TO_CHECK.items():
  133. r = headers.get(k, None)
  134. if not r:
  135. return False, None
  136. r = [x.strip().lower() for x in r.split(",")]
  137. if v not in r:
  138. return False, None
  139. if subprotocols:
  140. subproto = headers.get("sec-websocket-protocol", None)
  141. if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]:
  142. error(f"Invalid subprotocol: {subprotocols}")
  143. return False, None
  144. subproto = subproto.lower()
  145. result = headers.get("sec-websocket-accept", None)
  146. if not result:
  147. return False, None
  148. result = result.lower()
  149. if isinstance(result, str):
  150. result = result.encode("utf-8")
  151. value = f"{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11".encode("utf-8")
  152. hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
  153. if hmac.compare_digest(hashed, result):
  154. return True, subproto
  155. else:
  156. return False, None
  157. def _create_sec_websocket_key() -> str:
  158. randomness = os.urandom(16)
  159. return base64encode(randomness).decode("utf-8").strip()