_socket.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import errno
  2. import selectors
  3. import socket
  4. from typing import Union
  5. from ._exceptions import (
  6. WebSocketConnectionClosedException,
  7. WebSocketTimeoutException,
  8. )
  9. from ._ssl_compat import SSLError, SSLWantReadError, SSLWantWriteError
  10. from ._utils import extract_error_code, extract_err_message
  11. """
  12. _socket.py
  13. websocket - WebSocket client library for Python
  14. Copyright 2024 engn33r
  15. Licensed under the Apache License, Version 2.0 (the "License");
  16. you may not use this file except in compliance with the License.
  17. You may obtain a copy of the License at
  18. http://www.apache.org/licenses/LICENSE-2.0
  19. Unless required by applicable law or agreed to in writing, software
  20. distributed under the License is distributed on an "AS IS" BASIS,
  21. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. See the License for the specific language governing permissions and
  23. limitations under the License.
  24. """
  25. DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)]
  26. if hasattr(socket, "SO_KEEPALIVE"):
  27. DEFAULT_SOCKET_OPTION.append((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1))
  28. if hasattr(socket, "TCP_KEEPIDLE"):
  29. DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPIDLE, 30))
  30. if hasattr(socket, "TCP_KEEPINTVL"):
  31. DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPINTVL, 10))
  32. if hasattr(socket, "TCP_KEEPCNT"):
  33. DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPCNT, 3))
  34. _default_timeout = None
  35. __all__ = [
  36. "DEFAULT_SOCKET_OPTION",
  37. "sock_opt",
  38. "setdefaulttimeout",
  39. "getdefaulttimeout",
  40. "recv",
  41. "recv_line",
  42. "send",
  43. ]
  44. class sock_opt:
  45. def __init__(self, sockopt: list, sslopt: dict) -> None:
  46. if sockopt is None:
  47. sockopt = []
  48. if sslopt is None:
  49. sslopt = {}
  50. self.sockopt = sockopt
  51. self.sslopt = sslopt
  52. self.timeout = None
  53. def setdefaulttimeout(timeout: Union[int, float, None]) -> None:
  54. """
  55. Set the global timeout setting to connect.
  56. Parameters
  57. ----------
  58. timeout: int or float
  59. default socket timeout time (in seconds)
  60. """
  61. global _default_timeout
  62. _default_timeout = timeout
  63. def getdefaulttimeout() -> Union[int, float, None]:
  64. """
  65. Get default timeout
  66. Returns
  67. ----------
  68. _default_timeout: int or float
  69. Return the global timeout setting (in seconds) to connect.
  70. """
  71. return _default_timeout
  72. def recv(sock: socket.socket, bufsize: int) -> bytes:
  73. if not sock:
  74. raise WebSocketConnectionClosedException("socket is already closed.")
  75. def _recv():
  76. try:
  77. return sock.recv(bufsize)
  78. except SSLWantReadError:
  79. pass
  80. except socket.error as exc:
  81. error_code = extract_error_code(exc)
  82. if error_code not in [errno.EAGAIN, errno.EWOULDBLOCK]:
  83. raise
  84. sel = selectors.DefaultSelector()
  85. sel.register(sock, selectors.EVENT_READ)
  86. r = sel.select(sock.gettimeout())
  87. sel.close()
  88. if r:
  89. return sock.recv(bufsize)
  90. try:
  91. if sock.gettimeout() == 0:
  92. bytes_ = sock.recv(bufsize)
  93. else:
  94. bytes_ = _recv()
  95. except TimeoutError:
  96. raise WebSocketTimeoutException("Connection timed out")
  97. except socket.timeout as e:
  98. message = extract_err_message(e)
  99. raise WebSocketTimeoutException(message)
  100. except SSLError as e:
  101. message = extract_err_message(e)
  102. if isinstance(message, str) and "timed out" in message:
  103. raise WebSocketTimeoutException(message)
  104. else:
  105. raise
  106. if not bytes_:
  107. raise WebSocketConnectionClosedException("Connection to remote host was lost.")
  108. return bytes_
  109. def recv_line(sock: socket.socket) -> bytes:
  110. line = []
  111. while True:
  112. c = recv(sock, 1)
  113. line.append(c)
  114. if c == b"\n":
  115. break
  116. return b"".join(line)
  117. def send(sock: socket.socket, data: Union[bytes, str]) -> int:
  118. if isinstance(data, str):
  119. data = data.encode("utf-8")
  120. if not sock:
  121. raise WebSocketConnectionClosedException("socket is already closed.")
  122. def _send():
  123. try:
  124. return sock.send(data)
  125. except SSLWantWriteError:
  126. pass
  127. except socket.error as exc:
  128. error_code = extract_error_code(exc)
  129. if error_code is None:
  130. raise
  131. if error_code not in [errno.EAGAIN, errno.EWOULDBLOCK]:
  132. raise
  133. sel = selectors.DefaultSelector()
  134. sel.register(sock, selectors.EVENT_WRITE)
  135. w = sel.select(sock.gettimeout())
  136. sel.close()
  137. if w:
  138. return sock.send(data)
  139. try:
  140. if sock.gettimeout() == 0:
  141. return sock.send(data)
  142. else:
  143. return _send()
  144. except socket.timeout as e:
  145. message = extract_err_message(e)
  146. raise WebSocketTimeoutException(message)
  147. except Exception as e:
  148. message = extract_err_message(e)
  149. if isinstance(message, str) and "timed out" in message:
  150. raise WebSocketTimeoutException(message)
  151. else:
  152. raise