sslproto.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926
  1. # Contains code from https://github.com/MagicStack/uvloop/tree/v0.16.0
  2. # SPDX-License-Identifier: PSF-2.0 AND (MIT OR Apache-2.0)
  3. # SPDX-FileCopyrightText: Copyright (c) 2015-2021 MagicStack Inc. http://magic.io
  4. import collections
  5. import enum
  6. import warnings
  7. try:
  8. import ssl
  9. except ImportError: # pragma: no cover
  10. ssl = None
  11. from . import constants
  12. from . import exceptions
  13. from . import protocols
  14. from . import transports
  15. from .log import logger
  16. if ssl is not None:
  17. SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError)
  18. class SSLProtocolState(enum.Enum):
  19. UNWRAPPED = "UNWRAPPED"
  20. DO_HANDSHAKE = "DO_HANDSHAKE"
  21. WRAPPED = "WRAPPED"
  22. FLUSHING = "FLUSHING"
  23. SHUTDOWN = "SHUTDOWN"
  24. class AppProtocolState(enum.Enum):
  25. # This tracks the state of app protocol (https://git.io/fj59P):
  26. #
  27. # INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST
  28. #
  29. # * cm: connection_made()
  30. # * dr: data_received()
  31. # * er: eof_received()
  32. # * cl: connection_lost()
  33. STATE_INIT = "STATE_INIT"
  34. STATE_CON_MADE = "STATE_CON_MADE"
  35. STATE_EOF = "STATE_EOF"
  36. STATE_CON_LOST = "STATE_CON_LOST"
  37. def _create_transport_context(server_side, server_hostname):
  38. if server_side:
  39. raise ValueError('Server side SSL needs a valid SSLContext')
  40. # Client side may pass ssl=True to use a default
  41. # context; in that case the sslcontext passed is None.
  42. # The default is secure for client connections.
  43. # Python 3.4+: use up-to-date strong settings.
  44. sslcontext = ssl.create_default_context()
  45. if not server_hostname:
  46. sslcontext.check_hostname = False
  47. return sslcontext
  48. def add_flowcontrol_defaults(high, low, kb):
  49. if high is None:
  50. if low is None:
  51. hi = kb * 1024
  52. else:
  53. lo = low
  54. hi = 4 * lo
  55. else:
  56. hi = high
  57. if low is None:
  58. lo = hi // 4
  59. else:
  60. lo = low
  61. if not hi >= lo >= 0:
  62. raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
  63. (hi, lo))
  64. return hi, lo
  65. class _SSLProtocolTransport(transports._FlowControlMixin,
  66. transports.Transport):
  67. _start_tls_compatible = True
  68. _sendfile_compatible = constants._SendfileMode.FALLBACK
  69. def __init__(self, loop, ssl_protocol):
  70. self._loop = loop
  71. self._ssl_protocol = ssl_protocol
  72. self._closed = False
  73. def get_extra_info(self, name, default=None):
  74. """Get optional transport information."""
  75. return self._ssl_protocol._get_extra_info(name, default)
  76. def set_protocol(self, protocol):
  77. self._ssl_protocol._set_app_protocol(protocol)
  78. def get_protocol(self):
  79. return self._ssl_protocol._app_protocol
  80. def is_closing(self):
  81. return self._closed
  82. def close(self):
  83. """Close the transport.
  84. Buffered data will be flushed asynchronously. No more data
  85. will be received. After all buffered data is flushed, the
  86. protocol's connection_lost() method will (eventually) called
  87. with None as its argument.
  88. """
  89. if not self._closed:
  90. self._closed = True
  91. self._ssl_protocol._start_shutdown()
  92. else:
  93. self._ssl_protocol = None
  94. def __del__(self, _warnings=warnings):
  95. if not self._closed:
  96. self._closed = True
  97. _warnings.warn(
  98. "unclosed transport <asyncio._SSLProtocolTransport "
  99. "object>", ResourceWarning)
  100. def is_reading(self):
  101. return not self._ssl_protocol._app_reading_paused
  102. def pause_reading(self):
  103. """Pause the receiving end.
  104. No data will be passed to the protocol's data_received()
  105. method until resume_reading() is called.
  106. """
  107. self._ssl_protocol._pause_reading()
  108. def resume_reading(self):
  109. """Resume the receiving end.
  110. Data received will once again be passed to the protocol's
  111. data_received() method.
  112. """
  113. self._ssl_protocol._resume_reading()
  114. def set_write_buffer_limits(self, high=None, low=None):
  115. """Set the high- and low-water limits for write flow control.
  116. These two values control when to call the protocol's
  117. pause_writing() and resume_writing() methods. If specified,
  118. the low-water limit must be less than or equal to the
  119. high-water limit. Neither value can be negative.
  120. The defaults are implementation-specific. If only the
  121. high-water limit is given, the low-water limit defaults to an
  122. implementation-specific value less than or equal to the
  123. high-water limit. Setting high to zero forces low to zero as
  124. well, and causes pause_writing() to be called whenever the
  125. buffer becomes non-empty. Setting low to zero causes
  126. resume_writing() to be called only once the buffer is empty.
  127. Use of zero for either limit is generally sub-optimal as it
  128. reduces opportunities for doing I/O and computation
  129. concurrently.
  130. """
  131. self._ssl_protocol._set_write_buffer_limits(high, low)
  132. self._ssl_protocol._control_app_writing()
  133. def get_write_buffer_limits(self):
  134. return (self._ssl_protocol._outgoing_low_water,
  135. self._ssl_protocol._outgoing_high_water)
  136. def get_write_buffer_size(self):
  137. """Return the current size of the write buffers."""
  138. return self._ssl_protocol._get_write_buffer_size()
  139. def set_read_buffer_limits(self, high=None, low=None):
  140. """Set the high- and low-water limits for read flow control.
  141. These two values control when to call the upstream transport's
  142. pause_reading() and resume_reading() methods. If specified,
  143. the low-water limit must be less than or equal to the
  144. high-water limit. Neither value can be negative.
  145. The defaults are implementation-specific. If only the
  146. high-water limit is given, the low-water limit defaults to an
  147. implementation-specific value less than or equal to the
  148. high-water limit. Setting high to zero forces low to zero as
  149. well, and causes pause_reading() to be called whenever the
  150. buffer becomes non-empty. Setting low to zero causes
  151. resume_reading() to be called only once the buffer is empty.
  152. Use of zero for either limit is generally sub-optimal as it
  153. reduces opportunities for doing I/O and computation
  154. concurrently.
  155. """
  156. self._ssl_protocol._set_read_buffer_limits(high, low)
  157. self._ssl_protocol._control_ssl_reading()
  158. def get_read_buffer_limits(self):
  159. return (self._ssl_protocol._incoming_low_water,
  160. self._ssl_protocol._incoming_high_water)
  161. def get_read_buffer_size(self):
  162. """Return the current size of the read buffer."""
  163. return self._ssl_protocol._get_read_buffer_size()
  164. @property
  165. def _protocol_paused(self):
  166. # Required for sendfile fallback pause_writing/resume_writing logic
  167. return self._ssl_protocol._app_writing_paused
  168. def write(self, data):
  169. """Write some data bytes to the transport.
  170. This does not block; it buffers the data and arranges for it
  171. to be sent out asynchronously.
  172. """
  173. if not isinstance(data, (bytes, bytearray, memoryview)):
  174. raise TypeError(f"data: expecting a bytes-like instance, "
  175. f"got {type(data).__name__}")
  176. if not data:
  177. return
  178. self._ssl_protocol._write_appdata((data,))
  179. def writelines(self, list_of_data):
  180. """Write a list (or any iterable) of data bytes to the transport.
  181. The default implementation concatenates the arguments and
  182. calls write() on the result.
  183. """
  184. self._ssl_protocol._write_appdata(list_of_data)
  185. def write_eof(self):
  186. """Close the write end after flushing buffered data.
  187. This raises :exc:`NotImplementedError` right now.
  188. """
  189. raise NotImplementedError
  190. def can_write_eof(self):
  191. """Return True if this transport supports write_eof(), False if not."""
  192. return False
  193. def abort(self):
  194. """Close the transport immediately.
  195. Buffered data will be lost. No more data will be received.
  196. The protocol's connection_lost() method will (eventually) be
  197. called with None as its argument.
  198. """
  199. self._force_close(None)
  200. def _force_close(self, exc):
  201. self._closed = True
  202. if self._ssl_protocol is not None:
  203. self._ssl_protocol._abort(exc)
  204. def _test__append_write_backlog(self, data):
  205. # for test only
  206. self._ssl_protocol._write_backlog.append(data)
  207. self._ssl_protocol._write_buffer_size += len(data)
  208. class SSLProtocol(protocols.BufferedProtocol):
  209. max_size = 256 * 1024 # Buffer size passed to read()
  210. _handshake_start_time = None
  211. _handshake_timeout_handle = None
  212. _shutdown_timeout_handle = None
  213. def __init__(self, loop, app_protocol, sslcontext, waiter,
  214. server_side=False, server_hostname=None,
  215. call_connection_made=True,
  216. ssl_handshake_timeout=None,
  217. ssl_shutdown_timeout=None):
  218. if ssl is None:
  219. raise RuntimeError("stdlib ssl module not available")
  220. self._ssl_buffer = bytearray(self.max_size)
  221. self._ssl_buffer_view = memoryview(self._ssl_buffer)
  222. if ssl_handshake_timeout is None:
  223. ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT
  224. elif ssl_handshake_timeout <= 0:
  225. raise ValueError(
  226. f"ssl_handshake_timeout should be a positive number, "
  227. f"got {ssl_handshake_timeout}")
  228. if ssl_shutdown_timeout is None:
  229. ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT
  230. elif ssl_shutdown_timeout <= 0:
  231. raise ValueError(
  232. f"ssl_shutdown_timeout should be a positive number, "
  233. f"got {ssl_shutdown_timeout}")
  234. if not sslcontext:
  235. sslcontext = _create_transport_context(
  236. server_side, server_hostname)
  237. self._server_side = server_side
  238. if server_hostname and not server_side:
  239. self._server_hostname = server_hostname
  240. else:
  241. self._server_hostname = None
  242. self._sslcontext = sslcontext
  243. # SSL-specific extra info. More info are set when the handshake
  244. # completes.
  245. self._extra = dict(sslcontext=sslcontext)
  246. # App data write buffering
  247. self._write_backlog = collections.deque()
  248. self._write_buffer_size = 0
  249. self._waiter = waiter
  250. self._loop = loop
  251. self._set_app_protocol(app_protocol)
  252. self._app_transport = None
  253. self._app_transport_created = False
  254. # transport, ex: SelectorSocketTransport
  255. self._transport = None
  256. self._ssl_handshake_timeout = ssl_handshake_timeout
  257. self._ssl_shutdown_timeout = ssl_shutdown_timeout
  258. # SSL and state machine
  259. self._incoming = ssl.MemoryBIO()
  260. self._outgoing = ssl.MemoryBIO()
  261. self._state = SSLProtocolState.UNWRAPPED
  262. self._conn_lost = 0 # Set when connection_lost called
  263. if call_connection_made:
  264. self._app_state = AppProtocolState.STATE_INIT
  265. else:
  266. self._app_state = AppProtocolState.STATE_CON_MADE
  267. self._sslobj = self._sslcontext.wrap_bio(
  268. self._incoming, self._outgoing,
  269. server_side=self._server_side,
  270. server_hostname=self._server_hostname)
  271. # Flow Control
  272. self._ssl_writing_paused = False
  273. self._app_reading_paused = False
  274. self._ssl_reading_paused = False
  275. self._incoming_high_water = 0
  276. self._incoming_low_water = 0
  277. self._set_read_buffer_limits()
  278. self._eof_received = False
  279. self._app_writing_paused = False
  280. self._outgoing_high_water = 0
  281. self._outgoing_low_water = 0
  282. self._set_write_buffer_limits()
  283. self._get_app_transport()
  284. def _set_app_protocol(self, app_protocol):
  285. self._app_protocol = app_protocol
  286. # Make fast hasattr check first
  287. if (hasattr(app_protocol, 'get_buffer') and
  288. isinstance(app_protocol, protocols.BufferedProtocol)):
  289. self._app_protocol_get_buffer = app_protocol.get_buffer
  290. self._app_protocol_buffer_updated = app_protocol.buffer_updated
  291. self._app_protocol_is_buffer = True
  292. else:
  293. self._app_protocol_is_buffer = False
  294. def _wakeup_waiter(self, exc=None):
  295. if self._waiter is None:
  296. return
  297. if not self._waiter.cancelled():
  298. if exc is not None:
  299. self._waiter.set_exception(exc)
  300. else:
  301. self._waiter.set_result(None)
  302. self._waiter = None
  303. def _get_app_transport(self):
  304. if self._app_transport is None:
  305. if self._app_transport_created:
  306. raise RuntimeError('Creating _SSLProtocolTransport twice')
  307. self._app_transport = _SSLProtocolTransport(self._loop, self)
  308. self._app_transport_created = True
  309. return self._app_transport
  310. def connection_made(self, transport):
  311. """Called when the low-level connection is made.
  312. Start the SSL handshake.
  313. """
  314. self._transport = transport
  315. self._start_handshake()
  316. def connection_lost(self, exc):
  317. """Called when the low-level connection is lost or closed.
  318. The argument is an exception object or None (the latter
  319. meaning a regular EOF is received or the connection was
  320. aborted or closed).
  321. """
  322. self._write_backlog.clear()
  323. self._outgoing.read()
  324. self._conn_lost += 1
  325. # Just mark the app transport as closed so that its __dealloc__
  326. # doesn't complain.
  327. if self._app_transport is not None:
  328. self._app_transport._closed = True
  329. if self._state != SSLProtocolState.DO_HANDSHAKE:
  330. if (
  331. self._app_state == AppProtocolState.STATE_CON_MADE or
  332. self._app_state == AppProtocolState.STATE_EOF
  333. ):
  334. self._app_state = AppProtocolState.STATE_CON_LOST
  335. self._loop.call_soon(self._app_protocol.connection_lost, exc)
  336. self._set_state(SSLProtocolState.UNWRAPPED)
  337. self._transport = None
  338. self._app_transport = None
  339. self._app_protocol = None
  340. self._wakeup_waiter(exc)
  341. if self._shutdown_timeout_handle:
  342. self._shutdown_timeout_handle.cancel()
  343. self._shutdown_timeout_handle = None
  344. if self._handshake_timeout_handle:
  345. self._handshake_timeout_handle.cancel()
  346. self._handshake_timeout_handle = None
  347. def get_buffer(self, n):
  348. want = n
  349. if want <= 0 or want > self.max_size:
  350. want = self.max_size
  351. if len(self._ssl_buffer) < want:
  352. self._ssl_buffer = bytearray(want)
  353. self._ssl_buffer_view = memoryview(self._ssl_buffer)
  354. return self._ssl_buffer_view
  355. def buffer_updated(self, nbytes):
  356. self._incoming.write(self._ssl_buffer_view[:nbytes])
  357. if self._state == SSLProtocolState.DO_HANDSHAKE:
  358. self._do_handshake()
  359. elif self._state == SSLProtocolState.WRAPPED:
  360. self._do_read()
  361. elif self._state == SSLProtocolState.FLUSHING:
  362. self._do_flush()
  363. elif self._state == SSLProtocolState.SHUTDOWN:
  364. self._do_shutdown()
  365. def eof_received(self):
  366. """Called when the other end of the low-level stream
  367. is half-closed.
  368. If this returns a false value (including None), the transport
  369. will close itself. If it returns a true value, closing the
  370. transport is up to the protocol.
  371. """
  372. self._eof_received = True
  373. try:
  374. if self._loop.get_debug():
  375. logger.debug("%r received EOF", self)
  376. if self._state == SSLProtocolState.DO_HANDSHAKE:
  377. self._on_handshake_complete(ConnectionResetError)
  378. elif self._state == SSLProtocolState.WRAPPED:
  379. self._set_state(SSLProtocolState.FLUSHING)
  380. if self._app_reading_paused:
  381. return True
  382. else:
  383. self._do_flush()
  384. elif self._state == SSLProtocolState.FLUSHING:
  385. self._do_write()
  386. self._set_state(SSLProtocolState.SHUTDOWN)
  387. self._do_shutdown()
  388. elif self._state == SSLProtocolState.SHUTDOWN:
  389. self._do_shutdown()
  390. except Exception:
  391. self._transport.close()
  392. raise
  393. def _get_extra_info(self, name, default=None):
  394. if name in self._extra:
  395. return self._extra[name]
  396. elif self._transport is not None:
  397. return self._transport.get_extra_info(name, default)
  398. else:
  399. return default
  400. def _set_state(self, new_state):
  401. allowed = False
  402. if new_state == SSLProtocolState.UNWRAPPED:
  403. allowed = True
  404. elif (
  405. self._state == SSLProtocolState.UNWRAPPED and
  406. new_state == SSLProtocolState.DO_HANDSHAKE
  407. ):
  408. allowed = True
  409. elif (
  410. self._state == SSLProtocolState.DO_HANDSHAKE and
  411. new_state == SSLProtocolState.WRAPPED
  412. ):
  413. allowed = True
  414. elif (
  415. self._state == SSLProtocolState.WRAPPED and
  416. new_state == SSLProtocolState.FLUSHING
  417. ):
  418. allowed = True
  419. elif (
  420. self._state == SSLProtocolState.FLUSHING and
  421. new_state == SSLProtocolState.SHUTDOWN
  422. ):
  423. allowed = True
  424. if allowed:
  425. self._state = new_state
  426. else:
  427. raise RuntimeError(
  428. 'cannot switch state from {} to {}'.format(
  429. self._state, new_state))
  430. # Handshake flow
  431. def _start_handshake(self):
  432. if self._loop.get_debug():
  433. logger.debug("%r starts SSL handshake", self)
  434. self._handshake_start_time = self._loop.time()
  435. else:
  436. self._handshake_start_time = None
  437. self._set_state(SSLProtocolState.DO_HANDSHAKE)
  438. # start handshake timeout count down
  439. self._handshake_timeout_handle = \
  440. self._loop.call_later(self._ssl_handshake_timeout,
  441. lambda: self._check_handshake_timeout())
  442. self._do_handshake()
  443. def _check_handshake_timeout(self):
  444. if self._state == SSLProtocolState.DO_HANDSHAKE:
  445. msg = (
  446. f"SSL handshake is taking longer than "
  447. f"{self._ssl_handshake_timeout} seconds: "
  448. f"aborting the connection"
  449. )
  450. self._fatal_error(ConnectionAbortedError(msg))
  451. def _do_handshake(self):
  452. try:
  453. self._sslobj.do_handshake()
  454. except SSLAgainErrors:
  455. self._process_outgoing()
  456. except ssl.SSLError as exc:
  457. self._on_handshake_complete(exc)
  458. else:
  459. self._on_handshake_complete(None)
  460. def _on_handshake_complete(self, handshake_exc):
  461. if self._handshake_timeout_handle is not None:
  462. self._handshake_timeout_handle.cancel()
  463. self._handshake_timeout_handle = None
  464. sslobj = self._sslobj
  465. try:
  466. if handshake_exc is None:
  467. self._set_state(SSLProtocolState.WRAPPED)
  468. else:
  469. raise handshake_exc
  470. peercert = sslobj.getpeercert()
  471. except Exception as exc:
  472. handshake_exc = None
  473. self._set_state(SSLProtocolState.UNWRAPPED)
  474. if isinstance(exc, ssl.CertificateError):
  475. msg = 'SSL handshake failed on verifying the certificate'
  476. else:
  477. msg = 'SSL handshake failed'
  478. self._fatal_error(exc, msg)
  479. self._wakeup_waiter(exc)
  480. return
  481. if self._loop.get_debug():
  482. dt = self._loop.time() - self._handshake_start_time
  483. logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3)
  484. # Add extra info that becomes available after handshake.
  485. self._extra.update(peercert=peercert,
  486. cipher=sslobj.cipher(),
  487. compression=sslobj.compression(),
  488. ssl_object=sslobj)
  489. if self._app_state == AppProtocolState.STATE_INIT:
  490. self._app_state = AppProtocolState.STATE_CON_MADE
  491. self._app_protocol.connection_made(self._get_app_transport())
  492. self._wakeup_waiter()
  493. self._do_read()
  494. # Shutdown flow
  495. def _start_shutdown(self):
  496. if (
  497. self._state in (
  498. SSLProtocolState.FLUSHING,
  499. SSLProtocolState.SHUTDOWN,
  500. SSLProtocolState.UNWRAPPED
  501. )
  502. ):
  503. return
  504. if self._app_transport is not None:
  505. self._app_transport._closed = True
  506. if self._state == SSLProtocolState.DO_HANDSHAKE:
  507. self._abort(None)
  508. else:
  509. self._set_state(SSLProtocolState.FLUSHING)
  510. self._shutdown_timeout_handle = self._loop.call_later(
  511. self._ssl_shutdown_timeout,
  512. lambda: self._check_shutdown_timeout()
  513. )
  514. self._do_flush()
  515. def _check_shutdown_timeout(self):
  516. if (
  517. self._state in (
  518. SSLProtocolState.FLUSHING,
  519. SSLProtocolState.SHUTDOWN
  520. )
  521. ):
  522. self._transport._force_close(
  523. exceptions.TimeoutError('SSL shutdown timed out'))
  524. def _do_flush(self):
  525. self._do_read()
  526. self._set_state(SSLProtocolState.SHUTDOWN)
  527. self._do_shutdown()
  528. def _do_shutdown(self):
  529. try:
  530. if not self._eof_received:
  531. self._sslobj.unwrap()
  532. except SSLAgainErrors:
  533. self._process_outgoing()
  534. except ssl.SSLError as exc:
  535. self._on_shutdown_complete(exc)
  536. else:
  537. self._process_outgoing()
  538. self._call_eof_received()
  539. self._on_shutdown_complete(None)
  540. def _on_shutdown_complete(self, shutdown_exc):
  541. if self._shutdown_timeout_handle is not None:
  542. self._shutdown_timeout_handle.cancel()
  543. self._shutdown_timeout_handle = None
  544. if shutdown_exc:
  545. self._fatal_error(shutdown_exc)
  546. else:
  547. self._loop.call_soon(self._transport.close)
  548. def _abort(self, exc):
  549. self._set_state(SSLProtocolState.UNWRAPPED)
  550. if self._transport is not None:
  551. self._transport._force_close(exc)
  552. # Outgoing flow
  553. def _write_appdata(self, list_of_data):
  554. if (
  555. self._state in (
  556. SSLProtocolState.FLUSHING,
  557. SSLProtocolState.SHUTDOWN,
  558. SSLProtocolState.UNWRAPPED
  559. )
  560. ):
  561. if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
  562. logger.warning('SSL connection is closed')
  563. self._conn_lost += 1
  564. return
  565. for data in list_of_data:
  566. self._write_backlog.append(data)
  567. self._write_buffer_size += len(data)
  568. try:
  569. if self._state == SSLProtocolState.WRAPPED:
  570. self._do_write()
  571. except Exception as ex:
  572. self._fatal_error(ex, 'Fatal error on SSL protocol')
  573. def _do_write(self):
  574. try:
  575. while self._write_backlog:
  576. data = self._write_backlog[0]
  577. count = self._sslobj.write(data)
  578. data_len = len(data)
  579. if count < data_len:
  580. self._write_backlog[0] = data[count:]
  581. self._write_buffer_size -= count
  582. else:
  583. del self._write_backlog[0]
  584. self._write_buffer_size -= data_len
  585. except SSLAgainErrors:
  586. pass
  587. self._process_outgoing()
  588. def _process_outgoing(self):
  589. if not self._ssl_writing_paused:
  590. data = self._outgoing.read()
  591. if len(data):
  592. self._transport.write(data)
  593. self._control_app_writing()
  594. # Incoming flow
  595. def _do_read(self):
  596. if (
  597. self._state not in (
  598. SSLProtocolState.WRAPPED,
  599. SSLProtocolState.FLUSHING,
  600. )
  601. ):
  602. return
  603. try:
  604. if not self._app_reading_paused:
  605. if self._app_protocol_is_buffer:
  606. self._do_read__buffered()
  607. else:
  608. self._do_read__copied()
  609. if self._write_backlog:
  610. self._do_write()
  611. else:
  612. self._process_outgoing()
  613. self._control_ssl_reading()
  614. except Exception as ex:
  615. self._fatal_error(ex, 'Fatal error on SSL protocol')
  616. def _do_read__buffered(self):
  617. offset = 0
  618. count = 1
  619. buf = self._app_protocol_get_buffer(self._get_read_buffer_size())
  620. wants = len(buf)
  621. try:
  622. count = self._sslobj.read(wants, buf)
  623. if count > 0:
  624. offset = count
  625. while offset < wants:
  626. count = self._sslobj.read(wants - offset, buf[offset:])
  627. if count > 0:
  628. offset += count
  629. else:
  630. break
  631. else:
  632. self._loop.call_soon(lambda: self._do_read())
  633. except SSLAgainErrors:
  634. pass
  635. if offset > 0:
  636. self._app_protocol_buffer_updated(offset)
  637. if not count:
  638. # close_notify
  639. self._call_eof_received()
  640. self._start_shutdown()
  641. def _do_read__copied(self):
  642. chunk = b'1'
  643. zero = True
  644. one = False
  645. try:
  646. while True:
  647. chunk = self._sslobj.read(self.max_size)
  648. if not chunk:
  649. break
  650. if zero:
  651. zero = False
  652. one = True
  653. first = chunk
  654. elif one:
  655. one = False
  656. data = [first, chunk]
  657. else:
  658. data.append(chunk)
  659. except SSLAgainErrors:
  660. pass
  661. if one:
  662. self._app_protocol.data_received(first)
  663. elif not zero:
  664. self._app_protocol.data_received(b''.join(data))
  665. if not chunk:
  666. # close_notify
  667. self._call_eof_received()
  668. self._start_shutdown()
  669. def _call_eof_received(self):
  670. try:
  671. if self._app_state == AppProtocolState.STATE_CON_MADE:
  672. self._app_state = AppProtocolState.STATE_EOF
  673. keep_open = self._app_protocol.eof_received()
  674. if keep_open:
  675. logger.warning('returning true from eof_received() '
  676. 'has no effect when using ssl')
  677. except (KeyboardInterrupt, SystemExit):
  678. raise
  679. except BaseException as ex:
  680. self._fatal_error(ex, 'Error calling eof_received()')
  681. # Flow control for writes from APP socket
  682. def _control_app_writing(self):
  683. size = self._get_write_buffer_size()
  684. if size >= self._outgoing_high_water and not self._app_writing_paused:
  685. self._app_writing_paused = True
  686. try:
  687. self._app_protocol.pause_writing()
  688. except (KeyboardInterrupt, SystemExit):
  689. raise
  690. except BaseException as exc:
  691. self._loop.call_exception_handler({
  692. 'message': 'protocol.pause_writing() failed',
  693. 'exception': exc,
  694. 'transport': self._app_transport,
  695. 'protocol': self,
  696. })
  697. elif size <= self._outgoing_low_water and self._app_writing_paused:
  698. self._app_writing_paused = False
  699. try:
  700. self._app_protocol.resume_writing()
  701. except (KeyboardInterrupt, SystemExit):
  702. raise
  703. except BaseException as exc:
  704. self._loop.call_exception_handler({
  705. 'message': 'protocol.resume_writing() failed',
  706. 'exception': exc,
  707. 'transport': self._app_transport,
  708. 'protocol': self,
  709. })
  710. def _get_write_buffer_size(self):
  711. return self._outgoing.pending + self._write_buffer_size
  712. def _set_write_buffer_limits(self, high=None, low=None):
  713. high, low = add_flowcontrol_defaults(
  714. high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE)
  715. self._outgoing_high_water = high
  716. self._outgoing_low_water = low
  717. # Flow control for reads to APP socket
  718. def _pause_reading(self):
  719. self._app_reading_paused = True
  720. def _resume_reading(self):
  721. if self._app_reading_paused:
  722. self._app_reading_paused = False
  723. def resume():
  724. if self._state == SSLProtocolState.WRAPPED:
  725. self._do_read()
  726. elif self._state == SSLProtocolState.FLUSHING:
  727. self._do_flush()
  728. elif self._state == SSLProtocolState.SHUTDOWN:
  729. self._do_shutdown()
  730. self._loop.call_soon(resume)
  731. # Flow control for reads from SSL socket
  732. def _control_ssl_reading(self):
  733. size = self._get_read_buffer_size()
  734. if size >= self._incoming_high_water and not self._ssl_reading_paused:
  735. self._ssl_reading_paused = True
  736. self._transport.pause_reading()
  737. elif size <= self._incoming_low_water and self._ssl_reading_paused:
  738. self._ssl_reading_paused = False
  739. self._transport.resume_reading()
  740. def _set_read_buffer_limits(self, high=None, low=None):
  741. high, low = add_flowcontrol_defaults(
  742. high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ)
  743. self._incoming_high_water = high
  744. self._incoming_low_water = low
  745. def _get_read_buffer_size(self):
  746. return self._incoming.pending
  747. # Flow control for writes to SSL socket
  748. def pause_writing(self):
  749. """Called when the low-level transport's buffer goes over
  750. the high-water mark.
  751. """
  752. assert not self._ssl_writing_paused
  753. self._ssl_writing_paused = True
  754. def resume_writing(self):
  755. """Called when the low-level transport's buffer drains below
  756. the low-water mark.
  757. """
  758. assert self._ssl_writing_paused
  759. self._ssl_writing_paused = False
  760. self._process_outgoing()
  761. def _fatal_error(self, exc, message='Fatal error on transport'):
  762. if self._transport:
  763. self._transport._force_close(exc)
  764. if isinstance(exc, OSError):
  765. if self._loop.get_debug():
  766. logger.debug("%r: %s", self, message, exc_info=True)
  767. elif not isinstance(exc, exceptions.CancelledError):
  768. self._loop.call_exception_handler({
  769. 'message': message,
  770. 'exception': exc,
  771. 'transport': self._transport,
  772. 'protocol': self,
  773. })