123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453 |
- import array
- import os
- import struct
- import sys
- from threading import Lock
- from typing import Callable, Optional, Union
- from ._exceptions import WebSocketPayloadException, WebSocketProtocolException
- from ._utils import validate_utf8
- """
- _abnf.py
- websocket - WebSocket client library for Python
- Copyright 2024 engn33r
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- """
- try:
- # If wsaccel is available, use compiled routines to mask data.
- # wsaccel only provides around a 10% speed boost compared
- # to the websocket-client _mask() implementation.
- # Note that wsaccel is unmaintained.
- from wsaccel.xormask import XorMaskerSimple
- def _mask(mask_value: array.array, data_value: array.array) -> bytes:
- mask_result: bytes = XorMaskerSimple(mask_value).process(data_value)
- return mask_result
- except ImportError:
- # wsaccel is not available, use websocket-client _mask()
- native_byteorder = sys.byteorder
- def _mask(mask_value: array.array, data_value: array.array) -> bytes:
- datalen = len(data_value)
- int_data_value = int.from_bytes(data_value, native_byteorder)
- int_mask_value = int.from_bytes(
- mask_value * (datalen // 4) + mask_value[: datalen % 4], native_byteorder
- )
- return (int_data_value ^ int_mask_value).to_bytes(datalen, native_byteorder)
- __all__ = [
- "ABNF",
- "continuous_frame",
- "frame_buffer",
- "STATUS_NORMAL",
- "STATUS_GOING_AWAY",
- "STATUS_PROTOCOL_ERROR",
- "STATUS_UNSUPPORTED_DATA_TYPE",
- "STATUS_STATUS_NOT_AVAILABLE",
- "STATUS_ABNORMAL_CLOSED",
- "STATUS_INVALID_PAYLOAD",
- "STATUS_POLICY_VIOLATION",
- "STATUS_MESSAGE_TOO_BIG",
- "STATUS_INVALID_EXTENSION",
- "STATUS_UNEXPECTED_CONDITION",
- "STATUS_BAD_GATEWAY",
- "STATUS_TLS_HANDSHAKE_ERROR",
- ]
- # closing frame status codes.
- STATUS_NORMAL = 1000
- STATUS_GOING_AWAY = 1001
- STATUS_PROTOCOL_ERROR = 1002
- STATUS_UNSUPPORTED_DATA_TYPE = 1003
- STATUS_STATUS_NOT_AVAILABLE = 1005
- STATUS_ABNORMAL_CLOSED = 1006
- STATUS_INVALID_PAYLOAD = 1007
- STATUS_POLICY_VIOLATION = 1008
- STATUS_MESSAGE_TOO_BIG = 1009
- STATUS_INVALID_EXTENSION = 1010
- STATUS_UNEXPECTED_CONDITION = 1011
- STATUS_SERVICE_RESTART = 1012
- STATUS_TRY_AGAIN_LATER = 1013
- STATUS_BAD_GATEWAY = 1014
- STATUS_TLS_HANDSHAKE_ERROR = 1015
- VALID_CLOSE_STATUS = (
- STATUS_NORMAL,
- STATUS_GOING_AWAY,
- STATUS_PROTOCOL_ERROR,
- STATUS_UNSUPPORTED_DATA_TYPE,
- STATUS_INVALID_PAYLOAD,
- STATUS_POLICY_VIOLATION,
- STATUS_MESSAGE_TOO_BIG,
- STATUS_INVALID_EXTENSION,
- STATUS_UNEXPECTED_CONDITION,
- STATUS_SERVICE_RESTART,
- STATUS_TRY_AGAIN_LATER,
- STATUS_BAD_GATEWAY,
- )
- class ABNF:
- """
- ABNF frame class.
- See http://tools.ietf.org/html/rfc5234
- and http://tools.ietf.org/html/rfc6455#section-5.2
- """
- # operation code values.
- OPCODE_CONT = 0x0
- OPCODE_TEXT = 0x1
- OPCODE_BINARY = 0x2
- OPCODE_CLOSE = 0x8
- OPCODE_PING = 0x9
- OPCODE_PONG = 0xA
- # available operation code value tuple
- OPCODES = (
- OPCODE_CONT,
- OPCODE_TEXT,
- OPCODE_BINARY,
- OPCODE_CLOSE,
- OPCODE_PING,
- OPCODE_PONG,
- )
- # opcode human readable string
- OPCODE_MAP = {
- OPCODE_CONT: "cont",
- OPCODE_TEXT: "text",
- OPCODE_BINARY: "binary",
- OPCODE_CLOSE: "close",
- OPCODE_PING: "ping",
- OPCODE_PONG: "pong",
- }
- # data length threshold.
- LENGTH_7 = 0x7E
- LENGTH_16 = 1 << 16
- LENGTH_63 = 1 << 63
- def __init__(
- self,
- fin: int = 0,
- rsv1: int = 0,
- rsv2: int = 0,
- rsv3: int = 0,
- opcode: int = OPCODE_TEXT,
- mask_value: int = 1,
- data: Union[str, bytes, None] = "",
- ) -> None:
- """
- Constructor for ABNF. Please check RFC for arguments.
- """
- self.fin = fin
- self.rsv1 = rsv1
- self.rsv2 = rsv2
- self.rsv3 = rsv3
- self.opcode = opcode
- self.mask_value = mask_value
- if data is None:
- data = ""
- self.data = data
- self.get_mask_key = os.urandom
- def validate(self, skip_utf8_validation: bool = False) -> None:
- """
- Validate the ABNF frame.
- Parameters
- ----------
- skip_utf8_validation: skip utf8 validation.
- """
- if self.rsv1 or self.rsv2 or self.rsv3:
- raise WebSocketProtocolException("rsv is not implemented, yet")
- if self.opcode not in ABNF.OPCODES:
- raise WebSocketProtocolException("Invalid opcode %r", self.opcode)
- if self.opcode == ABNF.OPCODE_PING and not self.fin:
- raise WebSocketProtocolException("Invalid ping frame.")
- if self.opcode == ABNF.OPCODE_CLOSE:
- l = len(self.data)
- if not l:
- return
- if l == 1 or l >= 126:
- raise WebSocketProtocolException("Invalid close frame.")
- if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
- raise WebSocketProtocolException("Invalid close frame.")
- code = 256 * int(self.data[0]) + int(self.data[1])
- if not self._is_valid_close_status(code):
- raise WebSocketProtocolException("Invalid close opcode %r", code)
- @staticmethod
- def _is_valid_close_status(code: int) -> bool:
- return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
- def __str__(self) -> str:
- return f"fin={self.fin} opcode={self.opcode} data={self.data}"
- @staticmethod
- def create_frame(data: Union[bytes, str], opcode: int, fin: int = 1) -> "ABNF":
- """
- Create frame to send text, binary and other data.
- Parameters
- ----------
- data: str
- data to send. This is string value(byte array).
- If opcode is OPCODE_TEXT and this value is unicode,
- data value is converted into unicode string, automatically.
- opcode: int
- operation code. please see OPCODE_MAP.
- fin: int
- fin flag. if set to 0, create continue fragmentation.
- """
- if opcode == ABNF.OPCODE_TEXT and isinstance(data, str):
- data = data.encode("utf-8")
- # mask must be set if send data from client
- return ABNF(fin, 0, 0, 0, opcode, 1, data)
- def format(self) -> bytes:
- """
- Format this object to string(byte array) to send data to server.
- """
- if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
- raise ValueError("not 0 or 1")
- if self.opcode not in ABNF.OPCODES:
- raise ValueError("Invalid OPCODE")
- length = len(self.data)
- if length >= ABNF.LENGTH_63:
- raise ValueError("data is too long")
- frame_header = chr(
- self.fin << 7
- | self.rsv1 << 6
- | self.rsv2 << 5
- | self.rsv3 << 4
- | self.opcode
- ).encode("latin-1")
- if length < ABNF.LENGTH_7:
- frame_header += chr(self.mask_value << 7 | length).encode("latin-1")
- elif length < ABNF.LENGTH_16:
- frame_header += chr(self.mask_value << 7 | 0x7E).encode("latin-1")
- frame_header += struct.pack("!H", length)
- else:
- frame_header += chr(self.mask_value << 7 | 0x7F).encode("latin-1")
- frame_header += struct.pack("!Q", length)
- if not self.mask_value:
- if isinstance(self.data, str):
- self.data = self.data.encode("utf-8")
- return frame_header + self.data
- mask_key = self.get_mask_key(4)
- return frame_header + self._get_masked(mask_key)
- def _get_masked(self, mask_key: Union[str, bytes]) -> bytes:
- s = ABNF.mask(mask_key, self.data)
- if isinstance(mask_key, str):
- mask_key = mask_key.encode("utf-8")
- return mask_key + s
- @staticmethod
- def mask(mask_key: Union[str, bytes], data: Union[str, bytes]) -> bytes:
- """
- Mask or unmask data. Just do xor for each byte
- Parameters
- ----------
- mask_key: bytes or str
- 4 byte mask.
- data: bytes or str
- data to mask/unmask.
- """
- if data is None:
- data = ""
- if isinstance(mask_key, str):
- mask_key = mask_key.encode("latin-1")
- if isinstance(data, str):
- data = data.encode("latin-1")
- return _mask(array.array("B", mask_key), array.array("B", data))
- class frame_buffer:
- _HEADER_MASK_INDEX = 5
- _HEADER_LENGTH_INDEX = 6
- def __init__(
- self, recv_fn: Callable[[int], int], skip_utf8_validation: bool
- ) -> None:
- self.recv = recv_fn
- self.skip_utf8_validation = skip_utf8_validation
- # Buffers over the packets from the layer beneath until desired amount
- # bytes of bytes are received.
- self.recv_buffer: list = []
- self.clear()
- self.lock = Lock()
- def clear(self) -> None:
- self.header: Optional[tuple] = None
- self.length: Optional[int] = None
- self.mask_value: Union[bytes, str, None] = None
- def has_received_header(self) -> bool:
- return self.header is None
- def recv_header(self) -> None:
- header = self.recv_strict(2)
- b1 = header[0]
- fin = b1 >> 7 & 1
- rsv1 = b1 >> 6 & 1
- rsv2 = b1 >> 5 & 1
- rsv3 = b1 >> 4 & 1
- opcode = b1 & 0xF
- b2 = header[1]
- has_mask = b2 >> 7 & 1
- length_bits = b2 & 0x7F
- self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits)
- def has_mask(self) -> Union[bool, int]:
- if not self.header:
- return False
- header_val: int = self.header[frame_buffer._HEADER_MASK_INDEX]
- return header_val
- def has_received_length(self) -> bool:
- return self.length is None
- def recv_length(self) -> None:
- bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
- length_bits = bits & 0x7F
- if length_bits == 0x7E:
- v = self.recv_strict(2)
- self.length = struct.unpack("!H", v)[0]
- elif length_bits == 0x7F:
- v = self.recv_strict(8)
- self.length = struct.unpack("!Q", v)[0]
- else:
- self.length = length_bits
- def has_received_mask(self) -> bool:
- return self.mask_value is None
- def recv_mask(self) -> None:
- self.mask_value = self.recv_strict(4) if self.has_mask() else ""
- def recv_frame(self) -> ABNF:
- with self.lock:
- # Header
- if self.has_received_header():
- self.recv_header()
- (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header
- # Frame length
- if self.has_received_length():
- self.recv_length()
- length = self.length
- # Mask
- if self.has_received_mask():
- self.recv_mask()
- mask_value = self.mask_value
- # Payload
- payload = self.recv_strict(length)
- if has_mask:
- payload = ABNF.mask(mask_value, payload)
- # Reset for next frame
- self.clear()
- frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
- frame.validate(self.skip_utf8_validation)
- return frame
- def recv_strict(self, bufsize: int) -> bytes:
- shortage = bufsize - sum(map(len, self.recv_buffer))
- while shortage > 0:
- # Limit buffer size that we pass to socket.recv() to avoid
- # fragmenting the heap -- the number of bytes recv() actually
- # reads is limited by socket buffer and is relatively small,
- # yet passing large numbers repeatedly causes lots of large
- # buffers allocated and then shrunk, which results in
- # fragmentation.
- bytes_ = self.recv(min(16384, shortage))
- self.recv_buffer.append(bytes_)
- shortage -= len(bytes_)
- unified = b"".join(self.recv_buffer)
- if shortage == 0:
- self.recv_buffer = []
- return unified
- else:
- self.recv_buffer = [unified[bufsize:]]
- return unified[:bufsize]
- class continuous_frame:
- def __init__(self, fire_cont_frame: bool, skip_utf8_validation: bool) -> None:
- self.fire_cont_frame = fire_cont_frame
- self.skip_utf8_validation = skip_utf8_validation
- self.cont_data: Optional[list] = None
- self.recving_frames: Optional[int] = None
- def validate(self, frame: ABNF) -> None:
- if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
- raise WebSocketProtocolException("Illegal frame")
- if self.recving_frames and frame.opcode in (
- ABNF.OPCODE_TEXT,
- ABNF.OPCODE_BINARY,
- ):
- raise WebSocketProtocolException("Illegal frame")
- def add(self, frame: ABNF) -> None:
- if self.cont_data:
- self.cont_data[1] += frame.data
- else:
- if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
- self.recving_frames = frame.opcode
- self.cont_data = [frame.opcode, frame.data]
- if frame.fin:
- self.recving_frames = None
- def is_fire(self, frame: ABNF) -> Union[bool, int]:
- return frame.fin or self.fire_cont_frame
- def extract(self, frame: ABNF) -> tuple:
- data = self.cont_data
- self.cont_data = None
- frame.data = data[1]
- if (
- not self.fire_cont_frame
- and data[0] == ABNF.OPCODE_TEXT
- and not self.skip_utf8_validation
- and not validate_utf8(frame.data)
- ):
- raise WebSocketPayloadException(f"cannot decode: {repr(frame.data)}")
- return data[0], frame
|