_abnf.py 14 KB


  1. """
  2. """
  3. """
  4. websocket - WebSocket client library for Python
  5. Copyright (C) 2010 Hiroki Ohtani(liris)
  6. This library is free software; you can redistribute it and/or
  7. modify it under the terms of the GNU Lesser General Public
  8. License as published by the Free Software Foundation; either
  9. version 2.1 of the License, or (at your option) any later version.
  10. This library is distributed in the hope that it will be useful,
  11. but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  13. Lesser General Public License for more details.
  14. You should have received a copy of the GNU Lesser General Public
  15. License along with this library; if not, write to the Free Software
  16. Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  17. """
  18. import array
  19. import os
  20. import struct
  21. import six
  22. from ._exceptions import *
  23. from ._utils import validate_utf8
  24. from threading import Lock
  25. try:
  26. if six.PY3:
  27. import numpy
  28. else:
  29. numpy = None
  30. except ImportError:
  31. numpy = None
  32. try:
  33. # If wsaccel is available we use compiled routines to mask data.
  34. if not numpy:
  35. from wsaccel.xormask import XorMaskerSimple
  36. def _mask(_m, _d):
  37. return XorMaskerSimple(_m).process(_d)
  38. except ImportError:
  39. # wsaccel is not available, we rely on python implementations.
  40. def _mask(_m, _d):
  41. for i in range(len(_d)):
  42. _d[i] ^= _m[i % 4]
  43. if six.PY3:
  44. return _d.tobytes()
  45. else:
  46. return _d.tostring()
  47. __all__ = [
  48. 'ABNF', 'continuous_frame', 'frame_buffer',
  49. 'STATUS_NORMAL',
  50. 'STATUS_GOING_AWAY',
  51. 'STATUS_PROTOCOL_ERROR',
  52. 'STATUS_UNSUPPORTED_DATA_TYPE',
  53. 'STATUS_STATUS_NOT_AVAILABLE',
  54. 'STATUS_ABNORMAL_CLOSED',
  55. 'STATUS_INVALID_PAYLOAD',
  56. 'STATUS_POLICY_VIOLATION',
  57. 'STATUS_MESSAGE_TOO_BIG',
  58. 'STATUS_INVALID_EXTENSION',
  59. 'STATUS_UNEXPECTED_CONDITION',
  60. 'STATUS_BAD_GATEWAY',
  61. 'STATUS_TLS_HANDSHAKE_ERROR',
  62. ]
  63. # closing frame status codes.
  64. STATUS_NORMAL = 1000
  65. STATUS_GOING_AWAY = 1001
  66. STATUS_PROTOCOL_ERROR = 1002
  67. STATUS_UNSUPPORTED_DATA_TYPE = 1003
  68. STATUS_STATUS_NOT_AVAILABLE = 1005
  69. STATUS_ABNORMAL_CLOSED = 1006
  70. STATUS_INVALID_PAYLOAD = 1007
  71. STATUS_POLICY_VIOLATION = 1008
  72. STATUS_MESSAGE_TOO_BIG = 1009
  73. STATUS_INVALID_EXTENSION = 1010
  74. STATUS_UNEXPECTED_CONDITION = 1011
  75. STATUS_BAD_GATEWAY = 1014
  76. STATUS_TLS_HANDSHAKE_ERROR = 1015
  77. VALID_CLOSE_STATUS = (
  78. STATUS_NORMAL,
  79. STATUS_GOING_AWAY,
  80. STATUS_PROTOCOL_ERROR,
  81. STATUS_UNSUPPORTED_DATA_TYPE,
  82. STATUS_INVALID_PAYLOAD,
  83. STATUS_POLICY_VIOLATION,
  84. STATUS_MESSAGE_TOO_BIG,
  85. STATUS_INVALID_EXTENSION,
  86. STATUS_UNEXPECTED_CONDITION,
  87. STATUS_BAD_GATEWAY,
  88. )
  89. class ABNF(object):
  90. """
  91. ABNF frame class.
  92. See http://tools.ietf.org/html/rfc5234
  93. and http://tools.ietf.org/html/rfc6455#section-5.2
  94. """
  95. # operation code values.
  96. OPCODE_CONT = 0x0
  97. OPCODE_TEXT = 0x1
  98. OPCODE_BINARY = 0x2
  99. OPCODE_CLOSE = 0x8
  100. OPCODE_PING = 0x9
  101. OPCODE_PONG = 0xa
  102. # available operation code value tuple
  103. OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE,
  104. OPCODE_PING, OPCODE_PONG)
  105. # opcode human readable string
  106. OPCODE_MAP = {
  107. OPCODE_CONT: "cont",
  108. OPCODE_TEXT: "text",
  109. OPCODE_BINARY: "binary",
  110. OPCODE_CLOSE: "close",
  111. OPCODE_PING: "ping",
  112. OPCODE_PONG: "pong"
  113. }
  114. # data length threshold.
  115. LENGTH_7 = 0x7e
  116. LENGTH_16 = 1 << 16
  117. LENGTH_63 = 1 << 63
  118. def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0,
  119. opcode=OPCODE_TEXT, mask=1, data=""):
  120. """
  121. Constructor for ABNF. Please check RFC for arguments.
  122. """
  123. self.fin = fin
  124. self.rsv1 = rsv1
  125. self.rsv2 = rsv2
  126. self.rsv3 = rsv3
  127. self.opcode = opcode
  128. self.mask = mask
  129. if data is None:
  130. data = ""
  131. self.data = data
  132. self.get_mask_key = os.urandom
  133. def validate(self, skip_utf8_validation=False):
  134. """
  135. Validate the ABNF frame.
  136. Parameters
  137. ----------
  138. skip_utf8_validation: skip utf8 validation.
  139. """
  140. if self.rsv1 or self.rsv2 or self.rsv3:
  141. raise WebSocketProtocolException("rsv is not implemented, yet")
  142. if self.opcode not in ABNF.OPCODES:
  143. raise WebSocketProtocolException("Invalid opcode %r", self.opcode)
  144. if self.opcode == ABNF.OPCODE_PING and not self.fin:
  145. raise WebSocketProtocolException("Invalid ping frame.")
  146. if self.opcode == ABNF.OPCODE_CLOSE:
  147. l = len(self.data)
  148. if not l:
  149. return
  150. if l == 1 or l >= 126:
  151. raise WebSocketProtocolException("Invalid close frame.")
  152. if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
  153. raise WebSocketProtocolException("Invalid close frame.")
  154. code = 256 * \
  155. six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2])
  156. if not self._is_valid_close_status(code):
  157. raise WebSocketProtocolException("Invalid close opcode.")
  158. @staticmethod
  159. def _is_valid_close_status(code):
  160. return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
  161. def __str__(self):
  162. return "fin=" + str(self.fin) \
  163. + " opcode=" + str(self.opcode) \
  164. + " data=" + str(self.data)
  165. @staticmethod
  166. def create_frame(data, opcode, fin=1):
  167. """
  168. Create frame to send text, binary and other data.
  169. Parameters
  170. ----------
  171. data: <type>
  172. data to send. This is string value(byte array).
  173. If opcode is OPCODE_TEXT and this value is unicode,
  174. data value is converted into unicode string, automatically.
  175. opcode: <type>
  176. operation code. please see OPCODE_XXX.
  177. fin: <type>
  178. fin flag. if set to 0, create continue fragmentation.
  179. """
  180. if opcode == ABNF.OPCODE_TEXT and isinstance(data, six.text_type):
  181. data = data.encode("utf-8")
  182. # mask must be set if send data from client
  183. return ABNF(fin, 0, 0, 0, opcode, 1, data)
  184. def format(self):
  185. """
  186. Format this object to string(byte array) to send data to server.
  187. """
  188. if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
  189. raise ValueError("not 0 or 1")
  190. if self.opcode not in ABNF.OPCODES:
  191. raise ValueError("Invalid OPCODE")
  192. length = len(self.data)
  193. if length >= ABNF.LENGTH_63:
  194. raise ValueError("data is too long")
  195. frame_header = chr(self.fin << 7 |
  196. self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 |
  197. self.opcode)
  198. if length < ABNF.LENGTH_7:
  199. frame_header += chr(self.mask << 7 | length)
  200. frame_header = six.b(frame_header)
  201. elif length < ABNF.LENGTH_16:
  202. frame_header += chr(self.mask << 7 | 0x7e)
  203. frame_header = six.b(frame_header)
  204. frame_header += struct.pack("!H", length)
  205. else:
  206. frame_header += chr(self.mask << 7 | 0x7f)
  207. frame_header = six.b(frame_header)
  208. frame_header += struct.pack("!Q", length)
  209. if not self.mask:
  210. return frame_header + self.data
  211. else:
  212. mask_key = self.get_mask_key(4)
  213. return frame_header + self._get_masked(mask_key)
  214. def _get_masked(self, mask_key):
  215. s = ABNF.mask(mask_key, self.data)
  216. if isinstance(mask_key, six.text_type):
  217. mask_key = mask_key.encode('utf-8')
  218. return mask_key + s
  219. @staticmethod
  220. def mask(mask_key, data):
  221. """
  222. Mask or unmask data. Just do xor for each byte
  223. Parameters
  224. ----------
  225. mask_key: <type>
  226. 4 byte string(byte).
  227. data: <type>
  228. data to mask/unmask.
  229. """
  230. if data is None:
  231. data = ""
  232. if isinstance(mask_key, six.text_type):
  233. mask_key = six.b(mask_key)
  234. if isinstance(data, six.text_type):
  235. data = six.b(data)
  236. if numpy:
  237. origlen = len(data)
  238. _mask_key = mask_key[3] << 24 | mask_key[2] << 16 | mask_key[1] << 8 | mask_key[0]
  239. # We need data to be a multiple of four...
  240. data += bytes(" " * (4 - (len(data) % 4)), "us-ascii")
  241. a = numpy.frombuffer(data, dtype="uint32")
  242. masked = numpy.bitwise_xor(a, [_mask_key]).astype("uint32")
  243. if len(data) > origlen:
  244. return masked.tobytes()[:origlen]
  245. return masked.tobytes()
  246. else:
  247. _m = array.array("B", mask_key)
  248. _d = array.array("B", data)
  249. return _mask(_m, _d)
  250. class frame_buffer(object):
  251. _HEADER_MASK_INDEX = 5
  252. _HEADER_LENGTH_INDEX = 6
  253. def __init__(self, recv_fn, skip_utf8_validation):
  254. self.recv = recv_fn
  255. self.skip_utf8_validation = skip_utf8_validation
  256. # Buffers over the packets from the layer beneath until desired amount
  257. # bytes of bytes are received.
  258. self.recv_buffer = []
  259. self.clear()
  260. self.lock = Lock()
  261. def clear(self):
  262. self.header = None
  263. self.length = None
  264. self.mask = None
  265. def has_received_header(self):
  266. return self.header is None
  267. def recv_header(self):
  268. header = self.recv_strict(2)
  269. b1 = header[0]
  270. if six.PY2:
  271. b1 = ord(b1)
  272. fin = b1 >> 7 & 1
  273. rsv1 = b1 >> 6 & 1
  274. rsv2 = b1 >> 5 & 1
  275. rsv3 = b1 >> 4 & 1
  276. opcode = b1 & 0xf
  277. b2 = header[1]
  278. if six.PY2:
  279. b2 = ord(b2)
  280. has_mask = b2 >> 7 & 1
  281. length_bits = b2 & 0x7f
  282. self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits)
  283. def has_mask(self):
  284. if not self.header:
  285. return False
  286. return self.header[frame_buffer._HEADER_MASK_INDEX]
  287. def has_received_length(self):
  288. return self.length is None
  289. def recv_length(self):
  290. bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
  291. length_bits = bits & 0x7f
  292. if length_bits == 0x7e:
  293. v = self.recv_strict(2)
  294. self.length = struct.unpack("!H", v)[0]
  295. elif length_bits == 0x7f:
  296. v = self.recv_strict(8)
  297. self.length = struct.unpack("!Q", v)[0]
  298. else:
  299. self.length = length_bits
  300. def has_received_mask(self):
  301. return self.mask is None
  302. def recv_mask(self):
  303. self.mask = self.recv_strict(4) if self.has_mask() else ""
  304. def recv_frame(self):
  305. with self.lock:
  306. # Header
  307. if self.has_received_header():
  308. self.recv_header()
  309. (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header
  310. # Frame length
  311. if self.has_received_length():
  312. self.recv_length()
  313. length = self.length
  314. # Mask
  315. if self.has_received_mask():
  316. self.recv_mask()
  317. mask = self.mask
  318. # Payload
  319. payload = self.recv_strict(length)
  320. if has_mask:
  321. payload = ABNF.mask(mask, payload)
  322. # Reset for next frame
  323. self.clear()
  324. frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
  325. frame.validate(self.skip_utf8_validation)
  326. return frame
  327. def recv_strict(self, bufsize):
  328. shortage = bufsize - sum(len(x) for x in self.recv_buffer)
  329. while shortage > 0:
  330. # Limit buffer size that we pass to socket.recv() to avoid
  331. # fragmenting the heap -- the number of bytes recv() actually
  332. # reads is limited by socket buffer and is relatively small,
  333. # yet passing large numbers repeatedly causes lots of large
  334. # buffers allocated and then shrunk, which results in
  335. # fragmentation.
  336. bytes_ = self.recv(min(16384, shortage))
  337. self.recv_buffer.append(bytes_)
  338. shortage -= len(bytes_)
  339. unified = six.b("").join(self.recv_buffer)
  340. if shortage == 0:
  341. self.recv_buffer = []
  342. return unified
  343. else:
  344. self.recv_buffer = [unified[bufsize:]]
  345. return unified[:bufsize]
  346. class continuous_frame(object):
  347. def __init__(self, fire_cont_frame, skip_utf8_validation):
  348. self.fire_cont_frame = fire_cont_frame
  349. self.skip_utf8_validation = skip_utf8_validation
  350. self.cont_data = None
  351. self.recving_frames = None
  352. def validate(self, frame):
  353. if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
  354. raise WebSocketProtocolException("Illegal frame")
  355. if self.recving_frames and \
  356. frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
  357. raise WebSocketProtocolException("Illegal frame")
  358. def add(self, frame):
  359. if self.cont_data:
  360. self.cont_data[1] += frame.data
  361. else:
  362. if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
  363. self.recving_frames = frame.opcode
  364. self.cont_data = [frame.opcode, frame.data]
  365. if frame.fin:
  366. self.recving_frames = None
  367. def is_fire(self, frame):
  368. return frame.fin or self.fire_cont_frame
  369. def extract(self, frame):
  370. data = self.cont_data
  371. self.cont_data = None
  372. frame.data = data[1]
  373. if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data):
  374. raise WebSocketPayloadException(
  375. "cannot decode: " + repr(frame.data))
  376. return [data[0], frame]