client_proto.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. import asyncio
  2. from contextlib import suppress
  3. from typing import Any, Optional, Tuple
  4. from .base_protocol import BaseProtocol
  5. from .client_exceptions import (
  6. ClientOSError,
  7. ClientPayloadError,
  8. ServerDisconnectedError,
  9. ServerTimeoutError,
  10. )
  11. from .helpers import (
  12. _EXC_SENTINEL,
  13. BaseTimerContext,
  14. set_exception,
  15. status_code_must_be_empty_body,
  16. )
  17. from .http import HttpResponseParser, RawResponseMessage
  18. from .http_exceptions import HttpProcessingError
  19. from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader
  20. class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamReader]]):
  21. """Helper class to adapt between Protocol and StreamReader."""
  22. def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
  23. BaseProtocol.__init__(self, loop=loop)
  24. DataQueue.__init__(self, loop)
  25. self._should_close = False
  26. self._payload: Optional[StreamReader] = None
  27. self._skip_payload = False
  28. self._payload_parser = None
  29. self._timer = None
  30. self._tail = b""
  31. self._upgraded = False
  32. self._parser: Optional[HttpResponseParser] = None
  33. self._read_timeout: Optional[float] = None
  34. self._read_timeout_handle: Optional[asyncio.TimerHandle] = None
  35. self._timeout_ceil_threshold: Optional[float] = 5
  36. @property
  37. def upgraded(self) -> bool:
  38. return self._upgraded
  39. @property
  40. def should_close(self) -> bool:
  41. if self._payload is not None and not self._payload.is_eof() or self._upgraded:
  42. return True
  43. return (
  44. self._should_close
  45. or self._upgraded
  46. or self.exception() is not None
  47. or self._payload_parser is not None
  48. or len(self) > 0
  49. or bool(self._tail)
  50. )
  51. def force_close(self) -> None:
  52. self._should_close = True
  53. def close(self) -> None:
  54. transport = self.transport
  55. if transport is not None:
  56. transport.close()
  57. self.transport = None
  58. self._payload = None
  59. self._drop_timeout()
  60. def is_connected(self) -> bool:
  61. return self.transport is not None and not self.transport.is_closing()
  62. def connection_lost(self, exc: Optional[BaseException]) -> None:
  63. self._drop_timeout()
  64. original_connection_error = exc
  65. reraised_exc = original_connection_error
  66. connection_closed_cleanly = original_connection_error is None
  67. if self._payload_parser is not None:
  68. with suppress(Exception): # FIXME: log this somehow?
  69. self._payload_parser.feed_eof()
  70. uncompleted = None
  71. if self._parser is not None:
  72. try:
  73. uncompleted = self._parser.feed_eof()
  74. except Exception as underlying_exc:
  75. if self._payload is not None:
  76. client_payload_exc_msg = (
  77. f"Response payload is not completed: {underlying_exc !r}"
  78. )
  79. if not connection_closed_cleanly:
  80. client_payload_exc_msg = (
  81. f"{client_payload_exc_msg !s}. "
  82. f"{original_connection_error !r}"
  83. )
  84. set_exception(
  85. self._payload,
  86. ClientPayloadError(client_payload_exc_msg),
  87. underlying_exc,
  88. )
  89. if not self.is_eof():
  90. if isinstance(original_connection_error, OSError):
  91. reraised_exc = ClientOSError(*original_connection_error.args)
  92. if connection_closed_cleanly:
  93. reraised_exc = ServerDisconnectedError(uncompleted)
  94. # assigns self._should_close to True as side effect,
  95. # we do it anyway below
  96. underlying_non_eof_exc = (
  97. _EXC_SENTINEL
  98. if connection_closed_cleanly
  99. else original_connection_error
  100. )
  101. assert underlying_non_eof_exc is not None
  102. assert reraised_exc is not None
  103. self.set_exception(reraised_exc, underlying_non_eof_exc)
  104. self._should_close = True
  105. self._parser = None
  106. self._payload = None
  107. self._payload_parser = None
  108. self._reading_paused = False
  109. super().connection_lost(reraised_exc)
  110. def eof_received(self) -> None:
  111. # should call parser.feed_eof() most likely
  112. self._drop_timeout()
  113. def pause_reading(self) -> None:
  114. super().pause_reading()
  115. self._drop_timeout()
  116. def resume_reading(self) -> None:
  117. super().resume_reading()
  118. self._reschedule_timeout()
  119. def set_exception(
  120. self,
  121. exc: BaseException,
  122. exc_cause: BaseException = _EXC_SENTINEL,
  123. ) -> None:
  124. self._should_close = True
  125. self._drop_timeout()
  126. super().set_exception(exc, exc_cause)
  127. def set_parser(self, parser: Any, payload: Any) -> None:
  128. # TODO: actual types are:
  129. # parser: WebSocketReader
  130. # payload: FlowControlDataQueue
  131. # but they are not generi enough
  132. # Need an ABC for both types
  133. self._payload = payload
  134. self._payload_parser = parser
  135. self._drop_timeout()
  136. if self._tail:
  137. data, self._tail = self._tail, b""
  138. self.data_received(data)
  139. def set_response_params(
  140. self,
  141. *,
  142. timer: Optional[BaseTimerContext] = None,
  143. skip_payload: bool = False,
  144. read_until_eof: bool = False,
  145. auto_decompress: bool = True,
  146. read_timeout: Optional[float] = None,
  147. read_bufsize: int = 2**16,
  148. timeout_ceil_threshold: float = 5,
  149. max_line_size: int = 8190,
  150. max_field_size: int = 8190,
  151. ) -> None:
  152. self._skip_payload = skip_payload
  153. self._read_timeout = read_timeout
  154. self._timeout_ceil_threshold = timeout_ceil_threshold
  155. self._parser = HttpResponseParser(
  156. self,
  157. self._loop,
  158. read_bufsize,
  159. timer=timer,
  160. payload_exception=ClientPayloadError,
  161. response_with_body=not skip_payload,
  162. read_until_eof=read_until_eof,
  163. auto_decompress=auto_decompress,
  164. max_line_size=max_line_size,
  165. max_field_size=max_field_size,
  166. )
  167. if self._tail:
  168. data, self._tail = self._tail, b""
  169. self.data_received(data)
  170. def _drop_timeout(self) -> None:
  171. if self._read_timeout_handle is not None:
  172. self._read_timeout_handle.cancel()
  173. self._read_timeout_handle = None
  174. def _reschedule_timeout(self) -> None:
  175. timeout = self._read_timeout
  176. if self._read_timeout_handle is not None:
  177. self._read_timeout_handle.cancel()
  178. if timeout:
  179. self._read_timeout_handle = self._loop.call_later(
  180. timeout, self._on_read_timeout
  181. )
  182. else:
  183. self._read_timeout_handle = None
  184. def start_timeout(self) -> None:
  185. self._reschedule_timeout()
  186. def _on_read_timeout(self) -> None:
  187. exc = ServerTimeoutError("Timeout on reading data from socket")
  188. self.set_exception(exc)
  189. if self._payload is not None:
  190. set_exception(self._payload, exc)
  191. def data_received(self, data: bytes) -> None:
  192. self._reschedule_timeout()
  193. if not data:
  194. return
  195. # custom payload parser
  196. if self._payload_parser is not None:
  197. eof, tail = self._payload_parser.feed_data(data)
  198. if eof:
  199. self._payload = None
  200. self._payload_parser = None
  201. if tail:
  202. self.data_received(tail)
  203. return
  204. else:
  205. if self._upgraded or self._parser is None:
  206. # i.e. websocket connection, websocket parser is not set yet
  207. self._tail += data
  208. else:
  209. # parse http messages
  210. try:
  211. messages, upgraded, tail = self._parser.feed_data(data)
  212. except BaseException as underlying_exc:
  213. if self.transport is not None:
  214. # connection.release() could be called BEFORE
  215. # data_received(), the transport is already
  216. # closed in this case
  217. self.transport.close()
  218. # should_close is True after the call
  219. self.set_exception(HttpProcessingError(), underlying_exc)
  220. return
  221. self._upgraded = upgraded
  222. payload: Optional[StreamReader] = None
  223. for message, payload in messages:
  224. if message.should_close:
  225. self._should_close = True
  226. self._payload = payload
  227. if self._skip_payload or status_code_must_be_empty_body(
  228. message.code
  229. ):
  230. self.feed_data((message, EMPTY_PAYLOAD), 0)
  231. else:
  232. self.feed_data((message, payload), 0)
  233. if payload is not None:
  234. # new message(s) was processed
  235. # register timeout handler unsubscribing
  236. # either on end-of-stream or immediately for
  237. # EMPTY_PAYLOAD
  238. if payload is not EMPTY_PAYLOAD:
  239. payload.on_eof(self._drop_timeout)
  240. else:
  241. self._drop_timeout()
  242. if tail:
  243. if upgraded:
  244. self.data_received(tail)
  245. else:
  246. self._tail = tail