_handshake.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. """
  2. websocket - WebSocket client library for Python
  3. Copyright (C) 2010 Hiroki Ohtani(liris)
  4. This library is free software; you can redistribute it and/or
  5. modify it under the terms of the GNU Lesser General Public
  6. License as published by the Free Software Foundation; either
  7. version 2.1 of the License, or (at your option) any later version.
  8. This library is distributed in the hope that it will be useful,
  9. but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  11. Lesser General Public License for more details.
  12. You should have received a copy of the GNU Lesser General Public
  13. License along with this library; if not, write to the Free Software
  14. Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  15. """
  16. import hashlib
  17. import hmac
  18. import os
  19. import six
  20. from ._cookiejar import SimpleCookieJar
  21. from ._exceptions import *
  22. from ._http import *
  23. from ._logging import *
  24. from ._socket import *
  25. if hasattr(six, 'PY3') and six.PY3:
  26. from base64 import encodebytes as base64encode
  27. else:
  28. from base64 import encodestring as base64encode
  29. if hasattr(six, 'PY3') and six.PY3:
  30. if hasattr(six, 'PY34') and six.PY34:
  31. from http import client as HTTPStatus
  32. else:
  33. from http import HTTPStatus
  34. else:
  35. import httplib as HTTPStatus
  36. __all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
  37. if hasattr(hmac, "compare_digest"):
  38. compare_digest = hmac.compare_digest
  39. else:
  40. def compare_digest(s1, s2):
  41. return s1 == s2
  42. # websocket supported version.
  43. VERSION = 13
  44. SUPPORTED_REDIRECT_STATUSES = (HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER,)
  45. SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,)
  46. CookieJar = SimpleCookieJar()
  47. class handshake_response(object):
  48. def __init__(self, status, headers, subprotocol):
  49. self.status = status
  50. self.headers = headers
  51. self.subprotocol = subprotocol
  52. CookieJar.add(headers.get("set-cookie"))
  53. def handshake(sock, hostname, port, resource, **options):
  54. headers, key = _get_handshake_headers(resource, hostname, port, options)
  55. header_str = "\r\n".join(headers)
  56. send(sock, header_str)
  57. dump("request header", header_str)
  58. status, resp = _get_resp_headers(sock)
  59. if status in SUPPORTED_REDIRECT_STATUSES:
  60. return handshake_response(status, resp, None)
  61. success, subproto = _validate(resp, key, options.get("subprotocols"))
  62. if not success:
  63. raise WebSocketException("Invalid WebSocket Header")
  64. return handshake_response(status, resp, subproto)
  65. def _pack_hostname(hostname):
  66. # IPv6 address
  67. if ':' in hostname:
  68. return '[' + hostname + ']'
  69. return hostname
  70. def _get_handshake_headers(resource, host, port, options):
  71. headers = [
  72. "GET %s HTTP/1.1" % resource,
  73. "Upgrade: websocket"
  74. ]
  75. if port == 80 or port == 443:
  76. hostport = _pack_hostname(host)
  77. else:
  78. hostport = "%s:%d" % (_pack_hostname(host), port)
  79. if "host" in options and options["host"] is not None:
  80. headers.append("Host: %s" % options["host"])
  81. else:
  82. headers.append("Host: %s" % hostport)
  83. if "suppress_origin" not in options or not options["suppress_origin"]:
  84. if "origin" in options and options["origin"] is not None:
  85. headers.append("Origin: %s" % options["origin"])
  86. else:
  87. headers.append("Origin: http://%s" % hostport)
  88. key = _create_sec_websocket_key()
  89. # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
  90. if 'header' not in options or 'Sec-WebSocket-Key' not in options['header']:
  91. key = _create_sec_websocket_key()
  92. headers.append("Sec-WebSocket-Key: %s" % key)
  93. else:
  94. key = options['header']['Sec-WebSocket-Key']
  95. if 'header' not in options or 'Sec-WebSocket-Version' not in options['header']:
  96. headers.append("Sec-WebSocket-Version: %s" % VERSION)
  97. if 'connection' not in options or options['connection'] is None:
  98. headers.append('Connection: Upgrade')
  99. else:
  100. headers.append(options['connection'])
  101. subprotocols = options.get("subprotocols")
  102. if subprotocols:
  103. headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))
  104. if "header" in options:
  105. header = options["header"]
  106. if isinstance(header, dict):
  107. header = [
  108. ": ".join([k, v])
  109. for k, v in header.items()
  110. if v is not None
  111. ]
  112. headers.extend(header)
  113. server_cookie = CookieJar.get(host)
  114. client_cookie = options.get("cookie", None)
  115. cookie = "; ".join(filter(None, [server_cookie, client_cookie]))
  116. if cookie:
  117. headers.append("Cookie: %s" % cookie)
  118. headers.append("")
  119. headers.append("")
  120. return headers, key
  121. def _get_resp_headers(sock, success_statuses=SUCCESS_STATUSES):
  122. status, resp_headers, status_message = read_headers(sock)
  123. if status not in success_statuses:
  124. raise WebSocketBadStatusException("Handshake status %d %s", status, status_message, resp_headers)
  125. return status, resp_headers
  126. _HEADERS_TO_CHECK = {
  127. "upgrade": "websocket",
  128. "connection": "upgrade",
  129. }
  130. def _validate(headers, key, subprotocols):
  131. subproto = None
  132. for k, v in _HEADERS_TO_CHECK.items():
  133. r = headers.get(k, None)
  134. if not r:
  135. return False, None
  136. r = [x.strip().lower() for x in r.split(',')]
  137. if v not in r:
  138. return False, None
  139. if subprotocols:
  140. subproto = headers.get("sec-websocket-protocol", None)
  141. if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]:
  142. error("Invalid subprotocol: " + str(subprotocols))
  143. return False, None
  144. subproto = subproto.lower()
  145. result = headers.get("sec-websocket-accept", None)
  146. if not result:
  147. return False, None
  148. result = result.lower()
  149. if isinstance(result, six.text_type):
  150. result = result.encode('utf-8')
  151. value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8')
  152. hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
  153. success = compare_digest(hashed, result)
  154. if success:
  155. return True, subproto
  156. else:
  157. return False, None
  158. def _create_sec_websocket_key():
  159. randomness = os.urandom(16)
  160. return base64encode(randomness).decode('utf-8').strip()