_abnf.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. import array
  2. import os
  3. import struct
  4. import sys
  5. from threading import Lock
  6. from typing import Callable, Optional, Union
  7. from ._exceptions import WebSocketPayloadException, WebSocketProtocolException
  8. from ._utils import validate_utf8
  9. """
  10. _abnf.py
  11. websocket - WebSocket client library for Python
  12. Copyright 2024 engn33r
  13. Licensed under the Apache License, Version 2.0 (the "License");
  14. you may not use this file except in compliance with the License.
  15. You may obtain a copy of the License at
  16. http://www.apache.org/licenses/LICENSE-2.0
  17. Unless required by applicable law or agreed to in writing, software
  18. distributed under the License is distributed on an "AS IS" BASIS,
  19. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. See the License for the specific language governing permissions and
  21. limitations under the License.
  22. """
  23. try:
  24. # If wsaccel is available, use compiled routines to mask data.
  25. # wsaccel only provides around a 10% speed boost compared
  26. # to the websocket-client _mask() implementation.
  27. # Note that wsaccel is unmaintained.
  28. from wsaccel.xormask import XorMaskerSimple
  29. def _mask(mask_value: array.array, data_value: array.array) -> bytes:
  30. mask_result: bytes = XorMaskerSimple(mask_value).process(data_value)
  31. return mask_result
  32. except ImportError:
  33. # wsaccel is not available, use websocket-client _mask()
  34. native_byteorder = sys.byteorder
  35. def _mask(mask_value: array.array, data_value: array.array) -> bytes:
  36. datalen = len(data_value)
  37. int_data_value = int.from_bytes(data_value, native_byteorder)
  38. int_mask_value = int.from_bytes(
  39. mask_value * (datalen // 4) + mask_value[: datalen % 4], native_byteorder
  40. )
  41. return (int_data_value ^ int_mask_value).to_bytes(datalen, native_byteorder)
  42. __all__ = [
  43. "ABNF",
  44. "continuous_frame",
  45. "frame_buffer",
  46. "STATUS_NORMAL",
  47. "STATUS_GOING_AWAY",
  48. "STATUS_PROTOCOL_ERROR",
  49. "STATUS_UNSUPPORTED_DATA_TYPE",
  50. "STATUS_STATUS_NOT_AVAILABLE",
  51. "STATUS_ABNORMAL_CLOSED",
  52. "STATUS_INVALID_PAYLOAD",
  53. "STATUS_POLICY_VIOLATION",
  54. "STATUS_MESSAGE_TOO_BIG",
  55. "STATUS_INVALID_EXTENSION",
  56. "STATUS_UNEXPECTED_CONDITION",
  57. "STATUS_BAD_GATEWAY",
  58. "STATUS_TLS_HANDSHAKE_ERROR",
  59. ]
  60. # closing frame status codes.
  61. STATUS_NORMAL = 1000
  62. STATUS_GOING_AWAY = 1001
  63. STATUS_PROTOCOL_ERROR = 1002
  64. STATUS_UNSUPPORTED_DATA_TYPE = 1003
  65. STATUS_STATUS_NOT_AVAILABLE = 1005
  66. STATUS_ABNORMAL_CLOSED = 1006
  67. STATUS_INVALID_PAYLOAD = 1007
  68. STATUS_POLICY_VIOLATION = 1008
  69. STATUS_MESSAGE_TOO_BIG = 1009
  70. STATUS_INVALID_EXTENSION = 1010
  71. STATUS_UNEXPECTED_CONDITION = 1011
  72. STATUS_SERVICE_RESTART = 1012
  73. STATUS_TRY_AGAIN_LATER = 1013
  74. STATUS_BAD_GATEWAY = 1014
  75. STATUS_TLS_HANDSHAKE_ERROR = 1015
  76. VALID_CLOSE_STATUS = (
  77. STATUS_NORMAL,
  78. STATUS_GOING_AWAY,
  79. STATUS_PROTOCOL_ERROR,
  80. STATUS_UNSUPPORTED_DATA_TYPE,
  81. STATUS_INVALID_PAYLOAD,
  82. STATUS_POLICY_VIOLATION,
  83. STATUS_MESSAGE_TOO_BIG,
  84. STATUS_INVALID_EXTENSION,
  85. STATUS_UNEXPECTED_CONDITION,
  86. STATUS_SERVICE_RESTART,
  87. STATUS_TRY_AGAIN_LATER,
  88. STATUS_BAD_GATEWAY,
  89. )
  90. class ABNF:
  91. """
  92. ABNF frame class.
  93. See http://tools.ietf.org/html/rfc5234
  94. and http://tools.ietf.org/html/rfc6455#section-5.2
  95. """
  96. # operation code values.
  97. OPCODE_CONT = 0x0
  98. OPCODE_TEXT = 0x1
  99. OPCODE_BINARY = 0x2
  100. OPCODE_CLOSE = 0x8
  101. OPCODE_PING = 0x9
  102. OPCODE_PONG = 0xA
  103. # available operation code value tuple
  104. OPCODES = (
  105. OPCODE_CONT,
  106. OPCODE_TEXT,
  107. OPCODE_BINARY,
  108. OPCODE_CLOSE,
  109. OPCODE_PING,
  110. OPCODE_PONG,
  111. )
  112. # opcode human readable string
  113. OPCODE_MAP = {
  114. OPCODE_CONT: "cont",
  115. OPCODE_TEXT: "text",
  116. OPCODE_BINARY: "binary",
  117. OPCODE_CLOSE: "close",
  118. OPCODE_PING: "ping",
  119. OPCODE_PONG: "pong",
  120. }
  121. # data length threshold.
  122. LENGTH_7 = 0x7E
  123. LENGTH_16 = 1 << 16
  124. LENGTH_63 = 1 << 63
  125. def __init__(
  126. self,
  127. fin: int = 0,
  128. rsv1: int = 0,
  129. rsv2: int = 0,
  130. rsv3: int = 0,
  131. opcode: int = OPCODE_TEXT,
  132. mask_value: int = 1,
  133. data: Union[str, bytes, None] = "",
  134. ) -> None:
  135. """
  136. Constructor for ABNF. Please check RFC for arguments.
  137. """
  138. self.fin = fin
  139. self.rsv1 = rsv1
  140. self.rsv2 = rsv2
  141. self.rsv3 = rsv3
  142. self.opcode = opcode
  143. self.mask_value = mask_value
  144. if data is None:
  145. data = ""
  146. self.data = data
  147. self.get_mask_key = os.urandom
  148. def validate(self, skip_utf8_validation: bool = False) -> None:
  149. """
  150. Validate the ABNF frame.
  151. Parameters
  152. ----------
  153. skip_utf8_validation: skip utf8 validation.
  154. """
  155. if self.rsv1 or self.rsv2 or self.rsv3:
  156. raise WebSocketProtocolException("rsv is not implemented, yet")
  157. if self.opcode not in ABNF.OPCODES:
  158. raise WebSocketProtocolException("Invalid opcode %r", self.opcode)
  159. if self.opcode == ABNF.OPCODE_PING and not self.fin:
  160. raise WebSocketProtocolException("Invalid ping frame.")
  161. if self.opcode == ABNF.OPCODE_CLOSE:
  162. l = len(self.data)
  163. if not l:
  164. return
  165. if l == 1 or l >= 126:
  166. raise WebSocketProtocolException("Invalid close frame.")
  167. if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
  168. raise WebSocketProtocolException("Invalid close frame.")
  169. code = 256 * int(self.data[0]) + int(self.data[1])
  170. if not self._is_valid_close_status(code):
  171. raise WebSocketProtocolException("Invalid close opcode %r", code)
  172. @staticmethod
  173. def _is_valid_close_status(code: int) -> bool:
  174. return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
  175. def __str__(self) -> str:
  176. return f"fin={self.fin} opcode={self.opcode} data={self.data}"
  177. @staticmethod
  178. def create_frame(data: Union[bytes, str], opcode: int, fin: int = 1) -> "ABNF":
  179. """
  180. Create frame to send text, binary and other data.
  181. Parameters
  182. ----------
  183. data: str
  184. data to send. This is string value(byte array).
  185. If opcode is OPCODE_TEXT and this value is unicode,
  186. data value is converted into unicode string, automatically.
  187. opcode: int
  188. operation code. please see OPCODE_MAP.
  189. fin: int
  190. fin flag. if set to 0, create continue fragmentation.
  191. """
  192. if opcode == ABNF.OPCODE_TEXT and isinstance(data, str):
  193. data = data.encode("utf-8")
  194. # mask must be set if send data from client
  195. return ABNF(fin, 0, 0, 0, opcode, 1, data)
  196. def format(self) -> bytes:
  197. """
  198. Format this object to string(byte array) to send data to server.
  199. """
  200. if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
  201. raise ValueError("not 0 or 1")
  202. if self.opcode not in ABNF.OPCODES:
  203. raise ValueError("Invalid OPCODE")
  204. length = len(self.data)
  205. if length >= ABNF.LENGTH_63:
  206. raise ValueError("data is too long")
  207. frame_header = chr(
  208. self.fin << 7
  209. | self.rsv1 << 6
  210. | self.rsv2 << 5
  211. | self.rsv3 << 4
  212. | self.opcode
  213. ).encode("latin-1")
  214. if length < ABNF.LENGTH_7:
  215. frame_header += chr(self.mask_value << 7 | length).encode("latin-1")
  216. elif length < ABNF.LENGTH_16:
  217. frame_header += chr(self.mask_value << 7 | 0x7E).encode("latin-1")
  218. frame_header += struct.pack("!H", length)
  219. else:
  220. frame_header += chr(self.mask_value << 7 | 0x7F).encode("latin-1")
  221. frame_header += struct.pack("!Q", length)
  222. if not self.mask_value:
  223. if isinstance(self.data, str):
  224. self.data = self.data.encode("utf-8")
  225. return frame_header + self.data
  226. mask_key = self.get_mask_key(4)
  227. return frame_header + self._get_masked(mask_key)
  228. def _get_masked(self, mask_key: Union[str, bytes]) -> bytes:
  229. s = ABNF.mask(mask_key, self.data)
  230. if isinstance(mask_key, str):
  231. mask_key = mask_key.encode("utf-8")
  232. return mask_key + s
  233. @staticmethod
  234. def mask(mask_key: Union[str, bytes], data: Union[str, bytes]) -> bytes:
  235. """
  236. Mask or unmask data. Just do xor for each byte
  237. Parameters
  238. ----------
  239. mask_key: bytes or str
  240. 4 byte mask.
  241. data: bytes or str
  242. data to mask/unmask.
  243. """
  244. if data is None:
  245. data = ""
  246. if isinstance(mask_key, str):
  247. mask_key = mask_key.encode("latin-1")
  248. if isinstance(data, str):
  249. data = data.encode("latin-1")
  250. return _mask(array.array("B", mask_key), array.array("B", data))
  251. class frame_buffer:
  252. _HEADER_MASK_INDEX = 5
  253. _HEADER_LENGTH_INDEX = 6
  254. def __init__(
  255. self, recv_fn: Callable[[int], int], skip_utf8_validation: bool
  256. ) -> None:
  257. self.recv = recv_fn
  258. self.skip_utf8_validation = skip_utf8_validation
  259. # Buffers over the packets from the layer beneath until desired amount
  260. # bytes of bytes are received.
  261. self.recv_buffer: list = []
  262. self.clear()
  263. self.lock = Lock()
  264. def clear(self) -> None:
  265. self.header: Optional[tuple] = None
  266. self.length: Optional[int] = None
  267. self.mask_value: Union[bytes, str, None] = None
  268. def has_received_header(self) -> bool:
  269. return self.header is None
  270. def recv_header(self) -> None:
  271. header = self.recv_strict(2)
  272. b1 = header[0]
  273. fin = b1 >> 7 & 1
  274. rsv1 = b1 >> 6 & 1
  275. rsv2 = b1 >> 5 & 1
  276. rsv3 = b1 >> 4 & 1
  277. opcode = b1 & 0xF
  278. b2 = header[1]
  279. has_mask = b2 >> 7 & 1
  280. length_bits = b2 & 0x7F
  281. self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits)
  282. def has_mask(self) -> Union[bool, int]:
  283. if not self.header:
  284. return False
  285. header_val: int = self.header[frame_buffer._HEADER_MASK_INDEX]
  286. return header_val
  287. def has_received_length(self) -> bool:
  288. return self.length is None
  289. def recv_length(self) -> None:
  290. bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
  291. length_bits = bits & 0x7F
  292. if length_bits == 0x7E:
  293. v = self.recv_strict(2)
  294. self.length = struct.unpack("!H", v)[0]
  295. elif length_bits == 0x7F:
  296. v = self.recv_strict(8)
  297. self.length = struct.unpack("!Q", v)[0]
  298. else:
  299. self.length = length_bits
  300. def has_received_mask(self) -> bool:
  301. return self.mask_value is None
  302. def recv_mask(self) -> None:
  303. self.mask_value = self.recv_strict(4) if self.has_mask() else ""
  304. def recv_frame(self) -> ABNF:
  305. with self.lock:
  306. # Header
  307. if self.has_received_header():
  308. self.recv_header()
  309. (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header
  310. # Frame length
  311. if self.has_received_length():
  312. self.recv_length()
  313. length = self.length
  314. # Mask
  315. if self.has_received_mask():
  316. self.recv_mask()
  317. mask_value = self.mask_value
  318. # Payload
  319. payload = self.recv_strict(length)
  320. if has_mask:
  321. payload = ABNF.mask(mask_value, payload)
  322. # Reset for next frame
  323. self.clear()
  324. frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
  325. frame.validate(self.skip_utf8_validation)
  326. return frame
  327. def recv_strict(self, bufsize: int) -> bytes:
  328. shortage = bufsize - sum(map(len, self.recv_buffer))
  329. while shortage > 0:
  330. # Limit buffer size that we pass to socket.recv() to avoid
  331. # fragmenting the heap -- the number of bytes recv() actually
  332. # reads is limited by socket buffer and is relatively small,
  333. # yet passing large numbers repeatedly causes lots of large
  334. # buffers allocated and then shrunk, which results in
  335. # fragmentation.
  336. bytes_ = self.recv(min(16384, shortage))
  337. self.recv_buffer.append(bytes_)
  338. shortage -= len(bytes_)
  339. unified = b"".join(self.recv_buffer)
  340. if shortage == 0:
  341. self.recv_buffer = []
  342. return unified
  343. else:
  344. self.recv_buffer = [unified[bufsize:]]
  345. return unified[:bufsize]
  346. class continuous_frame:
  347. def __init__(self, fire_cont_frame: bool, skip_utf8_validation: bool) -> None:
  348. self.fire_cont_frame = fire_cont_frame
  349. self.skip_utf8_validation = skip_utf8_validation
  350. self.cont_data: Optional[list] = None
  351. self.recving_frames: Optional[int] = None
  352. def validate(self, frame: ABNF) -> None:
  353. if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
  354. raise WebSocketProtocolException("Illegal frame")
  355. if self.recving_frames and frame.opcode in (
  356. ABNF.OPCODE_TEXT,
  357. ABNF.OPCODE_BINARY,
  358. ):
  359. raise WebSocketProtocolException("Illegal frame")
  360. def add(self, frame: ABNF) -> None:
  361. if self.cont_data:
  362. self.cont_data[1] += frame.data
  363. else:
  364. if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
  365. self.recving_frames = frame.opcode
  366. self.cont_data = [frame.opcode, frame.data]
  367. if frame.fin:
  368. self.recving_frames = None
  369. def is_fire(self, frame: ABNF) -> Union[bool, int]:
  370. return frame.fin or self.fire_cont_frame
  371. def extract(self, frame: ABNF) -> tuple:
  372. data = self.cont_data
  373. self.cont_data = None
  374. frame.data = data[1]
  375. if (
  376. not self.fire_cont_frame
  377. and data[0] == ABNF.OPCODE_TEXT
  378. and not self.skip_utf8_validation
  379. and not validate_utf8(frame.data)
  380. ):
  381. raise WebSocketPayloadException(f"cannot decode: {repr(frame.data)}")
  382. return data[0], frame