sslproto.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929
  1. # Contains code from https://github.com/MagicStack/uvloop/tree/v0.16.0
  2. # SPDX-License-Identifier: PSF-2.0 AND (MIT OR Apache-2.0)
  3. # SPDX-FileCopyrightText: Copyright (c) 2015-2021 MagicStack Inc. http://magic.io
  4. import collections
  5. import enum
  6. import warnings
  7. try:
  8. import ssl
  9. except ImportError: # pragma: no cover
  10. ssl = None
  11. from . import constants
  12. from . import exceptions
  13. from . import protocols
  14. from . import transports
  15. from .log import logger
  16. if ssl is not None:
  17. SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError)
  18. class SSLProtocolState(enum.Enum):
  19. UNWRAPPED = "UNWRAPPED"
  20. DO_HANDSHAKE = "DO_HANDSHAKE"
  21. WRAPPED = "WRAPPED"
  22. FLUSHING = "FLUSHING"
  23. SHUTDOWN = "SHUTDOWN"
  24. class AppProtocolState(enum.Enum):
  25. # This tracks the state of app protocol (https://git.io/fj59P):
  26. #
  27. # INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST
  28. #
  29. # * cm: connection_made()
  30. # * dr: data_received()
  31. # * er: eof_received()
  32. # * cl: connection_lost()
  33. STATE_INIT = "STATE_INIT"
  34. STATE_CON_MADE = "STATE_CON_MADE"
  35. STATE_EOF = "STATE_EOF"
  36. STATE_CON_LOST = "STATE_CON_LOST"
  37. def _create_transport_context(server_side, server_hostname):
  38. if server_side:
  39. raise ValueError('Server side SSL needs a valid SSLContext')
  40. # Client side may pass ssl=True to use a default
  41. # context; in that case the sslcontext passed is None.
  42. # The default is secure for client connections.
  43. # Python 3.4+: use up-to-date strong settings.
  44. sslcontext = ssl.create_default_context()
  45. if not server_hostname:
  46. sslcontext.check_hostname = False
  47. return sslcontext
  48. def add_flowcontrol_defaults(high, low, kb):
  49. if high is None:
  50. if low is None:
  51. hi = kb * 1024
  52. else:
  53. lo = low
  54. hi = 4 * lo
  55. else:
  56. hi = high
  57. if low is None:
  58. lo = hi // 4
  59. else:
  60. lo = low
  61. if not hi >= lo >= 0:
  62. raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
  63. (hi, lo))
  64. return hi, lo
  65. class _SSLProtocolTransport(transports._FlowControlMixin,
  66. transports.Transport):
  67. _start_tls_compatible = True
  68. _sendfile_compatible = constants._SendfileMode.FALLBACK
  69. def __init__(self, loop, ssl_protocol):
  70. self._loop = loop
  71. self._ssl_protocol = ssl_protocol
  72. self._closed = False
  73. def get_extra_info(self, name, default=None):
  74. """Get optional transport information."""
  75. return self._ssl_protocol._get_extra_info(name, default)
  76. def set_protocol(self, protocol):
  77. self._ssl_protocol._set_app_protocol(protocol)
  78. def get_protocol(self):
  79. return self._ssl_protocol._app_protocol
  80. def is_closing(self):
  81. return self._closed or self._ssl_protocol._is_transport_closing()
  82. def close(self):
  83. """Close the transport.
  84. Buffered data will be flushed asynchronously. No more data
  85. will be received. After all buffered data is flushed, the
  86. protocol's connection_lost() method will (eventually) called
  87. with None as its argument.
  88. """
  89. if not self._closed:
  90. self._closed = True
  91. self._ssl_protocol._start_shutdown()
  92. else:
  93. self._ssl_protocol = None
  94. def __del__(self, _warnings=warnings):
  95. if not self._closed:
  96. self._closed = True
  97. _warnings.warn(
  98. "unclosed transport <asyncio._SSLProtocolTransport "
  99. "object>", ResourceWarning)
  100. def is_reading(self):
  101. return not self._ssl_protocol._app_reading_paused
  102. def pause_reading(self):
  103. """Pause the receiving end.
  104. No data will be passed to the protocol's data_received()
  105. method until resume_reading() is called.
  106. """
  107. self._ssl_protocol._pause_reading()
  108. def resume_reading(self):
  109. """Resume the receiving end.
  110. Data received will once again be passed to the protocol's
  111. data_received() method.
  112. """
  113. self._ssl_protocol._resume_reading()
  114. def set_write_buffer_limits(self, high=None, low=None):
  115. """Set the high- and low-water limits for write flow control.
  116. These two values control when to call the protocol's
  117. pause_writing() and resume_writing() methods. If specified,
  118. the low-water limit must be less than or equal to the
  119. high-water limit. Neither value can be negative.
  120. The defaults are implementation-specific. If only the
  121. high-water limit is given, the low-water limit defaults to an
  122. implementation-specific value less than or equal to the
  123. high-water limit. Setting high to zero forces low to zero as
  124. well, and causes pause_writing() to be called whenever the
  125. buffer becomes non-empty. Setting low to zero causes
  126. resume_writing() to be called only once the buffer is empty.
  127. Use of zero for either limit is generally sub-optimal as it
  128. reduces opportunities for doing I/O and computation
  129. concurrently.
  130. """
  131. self._ssl_protocol._set_write_buffer_limits(high, low)
  132. self._ssl_protocol._control_app_writing()
  133. def get_write_buffer_limits(self):
  134. return (self._ssl_protocol._outgoing_low_water,
  135. self._ssl_protocol._outgoing_high_water)
  136. def get_write_buffer_size(self):
  137. """Return the current size of the write buffers."""
  138. return self._ssl_protocol._get_write_buffer_size()
  139. def set_read_buffer_limits(self, high=None, low=None):
  140. """Set the high- and low-water limits for read flow control.
  141. These two values control when to call the upstream transport's
  142. pause_reading() and resume_reading() methods. If specified,
  143. the low-water limit must be less than or equal to the
  144. high-water limit. Neither value can be negative.
  145. The defaults are implementation-specific. If only the
  146. high-water limit is given, the low-water limit defaults to an
  147. implementation-specific value less than or equal to the
  148. high-water limit. Setting high to zero forces low to zero as
  149. well, and causes pause_reading() to be called whenever the
  150. buffer becomes non-empty. Setting low to zero causes
  151. resume_reading() to be called only once the buffer is empty.
  152. Use of zero for either limit is generally sub-optimal as it
  153. reduces opportunities for doing I/O and computation
  154. concurrently.
  155. """
  156. self._ssl_protocol._set_read_buffer_limits(high, low)
  157. self._ssl_protocol._control_ssl_reading()
  158. def get_read_buffer_limits(self):
  159. return (self._ssl_protocol._incoming_low_water,
  160. self._ssl_protocol._incoming_high_water)
  161. def get_read_buffer_size(self):
  162. """Return the current size of the read buffer."""
  163. return self._ssl_protocol._get_read_buffer_size()
  164. @property
  165. def _protocol_paused(self):
  166. # Required for sendfile fallback pause_writing/resume_writing logic
  167. return self._ssl_protocol._app_writing_paused
  168. def write(self, data):
  169. """Write some data bytes to the transport.
  170. This does not block; it buffers the data and arranges for it
  171. to be sent out asynchronously.
  172. """
  173. if not isinstance(data, (bytes, bytearray, memoryview)):
  174. raise TypeError(f"data: expecting a bytes-like instance, "
  175. f"got {type(data).__name__}")
  176. if not data:
  177. return
  178. self._ssl_protocol._write_appdata((data,))
  179. def writelines(self, list_of_data):
  180. """Write a list (or any iterable) of data bytes to the transport.
  181. The default implementation concatenates the arguments and
  182. calls write() on the result.
  183. """
  184. self._ssl_protocol._write_appdata(list_of_data)
  185. def write_eof(self):
  186. """Close the write end after flushing buffered data.
  187. This raises :exc:`NotImplementedError` right now.
  188. """
  189. raise NotImplementedError
  190. def can_write_eof(self):
  191. """Return True if this transport supports write_eof(), False if not."""
  192. return False
  193. def abort(self):
  194. """Close the transport immediately.
  195. Buffered data will be lost. No more data will be received.
  196. The protocol's connection_lost() method will (eventually) be
  197. called with None as its argument.
  198. """
  199. self._force_close(None)
  200. def _force_close(self, exc):
  201. self._closed = True
  202. if self._ssl_protocol is not None:
  203. self._ssl_protocol._abort(exc)
  204. def _test__append_write_backlog(self, data):
  205. # for test only
  206. self._ssl_protocol._write_backlog.append(data)
  207. self._ssl_protocol._write_buffer_size += len(data)
  208. class SSLProtocol(protocols.BufferedProtocol):
  209. max_size = 256 * 1024 # Buffer size passed to read()
  210. _handshake_start_time = None
  211. _handshake_timeout_handle = None
  212. _shutdown_timeout_handle = None
  213. def __init__(self, loop, app_protocol, sslcontext, waiter,
  214. server_side=False, server_hostname=None,
  215. call_connection_made=True,
  216. ssl_handshake_timeout=None,
  217. ssl_shutdown_timeout=None):
  218. if ssl is None:
  219. raise RuntimeError("stdlib ssl module not available")
  220. self._ssl_buffer = bytearray(self.max_size)
  221. self._ssl_buffer_view = memoryview(self._ssl_buffer)
  222. if ssl_handshake_timeout is None:
  223. ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT
  224. elif ssl_handshake_timeout <= 0:
  225. raise ValueError(
  226. f"ssl_handshake_timeout should be a positive number, "
  227. f"got {ssl_handshake_timeout}")
  228. if ssl_shutdown_timeout is None:
  229. ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT
  230. elif ssl_shutdown_timeout <= 0:
  231. raise ValueError(
  232. f"ssl_shutdown_timeout should be a positive number, "
  233. f"got {ssl_shutdown_timeout}")
  234. if not sslcontext:
  235. sslcontext = _create_transport_context(
  236. server_side, server_hostname)
  237. self._server_side = server_side
  238. if server_hostname and not server_side:
  239. self._server_hostname = server_hostname
  240. else:
  241. self._server_hostname = None
  242. self._sslcontext = sslcontext
  243. # SSL-specific extra info. More info are set when the handshake
  244. # completes.
  245. self._extra = dict(sslcontext=sslcontext)
  246. # App data write buffering
  247. self._write_backlog = collections.deque()
  248. self._write_buffer_size = 0
  249. self._waiter = waiter
  250. self._loop = loop
  251. self._set_app_protocol(app_protocol)
  252. self._app_transport = None
  253. self._app_transport_created = False
  254. # transport, ex: SelectorSocketTransport
  255. self._transport = None
  256. self._ssl_handshake_timeout = ssl_handshake_timeout
  257. self._ssl_shutdown_timeout = ssl_shutdown_timeout
  258. # SSL and state machine
  259. self._incoming = ssl.MemoryBIO()
  260. self._outgoing = ssl.MemoryBIO()
  261. self._state = SSLProtocolState.UNWRAPPED
  262. self._conn_lost = 0 # Set when connection_lost called
  263. if call_connection_made:
  264. self._app_state = AppProtocolState.STATE_INIT
  265. else:
  266. self._app_state = AppProtocolState.STATE_CON_MADE
  267. self._sslobj = self._sslcontext.wrap_bio(
  268. self._incoming, self._outgoing,
  269. server_side=self._server_side,
  270. server_hostname=self._server_hostname)
  271. # Flow Control
  272. self._ssl_writing_paused = False
  273. self._app_reading_paused = False
  274. self._ssl_reading_paused = False
  275. self._incoming_high_water = 0
  276. self._incoming_low_water = 0
  277. self._set_read_buffer_limits()
  278. self._eof_received = False
  279. self._app_writing_paused = False
  280. self._outgoing_high_water = 0
  281. self._outgoing_low_water = 0
  282. self._set_write_buffer_limits()
  283. self._get_app_transport()
  284. def _set_app_protocol(self, app_protocol):
  285. self._app_protocol = app_protocol
  286. # Make fast hasattr check first
  287. if (hasattr(app_protocol, 'get_buffer') and
  288. isinstance(app_protocol, protocols.BufferedProtocol)):
  289. self._app_protocol_get_buffer = app_protocol.get_buffer
  290. self._app_protocol_buffer_updated = app_protocol.buffer_updated
  291. self._app_protocol_is_buffer = True
  292. else:
  293. self._app_protocol_is_buffer = False
  294. def _wakeup_waiter(self, exc=None):
  295. if self._waiter is None:
  296. return
  297. if not self._waiter.cancelled():
  298. if exc is not None:
  299. self._waiter.set_exception(exc)
  300. else:
  301. self._waiter.set_result(None)
  302. self._waiter = None
  303. def _get_app_transport(self):
  304. if self._app_transport is None:
  305. if self._app_transport_created:
  306. raise RuntimeError('Creating _SSLProtocolTransport twice')
  307. self._app_transport = _SSLProtocolTransport(self._loop, self)
  308. self._app_transport_created = True
  309. return self._app_transport
  310. def _is_transport_closing(self):
  311. return self._transport is not None and self._transport.is_closing()
  312. def connection_made(self, transport):
  313. """Called when the low-level connection is made.
  314. Start the SSL handshake.
  315. """
  316. self._transport = transport
  317. self._start_handshake()
  318. def connection_lost(self, exc):
  319. """Called when the low-level connection is lost or closed.
  320. The argument is an exception object or None (the latter
  321. meaning a regular EOF is received or the connection was
  322. aborted or closed).
  323. """
  324. self._write_backlog.clear()
  325. self._outgoing.read()
  326. self._conn_lost += 1
  327. # Just mark the app transport as closed so that its __dealloc__
  328. # doesn't complain.
  329. if self._app_transport is not None:
  330. self._app_transport._closed = True
  331. if self._state != SSLProtocolState.DO_HANDSHAKE:
  332. if (
  333. self._app_state == AppProtocolState.STATE_CON_MADE or
  334. self._app_state == AppProtocolState.STATE_EOF
  335. ):
  336. self._app_state = AppProtocolState.STATE_CON_LOST
  337. self._loop.call_soon(self._app_protocol.connection_lost, exc)
  338. self._set_state(SSLProtocolState.UNWRAPPED)
  339. self._transport = None
  340. self._app_transport = None
  341. self._app_protocol = None
  342. self._wakeup_waiter(exc)
  343. if self._shutdown_timeout_handle:
  344. self._shutdown_timeout_handle.cancel()
  345. self._shutdown_timeout_handle = None
  346. if self._handshake_timeout_handle:
  347. self._handshake_timeout_handle.cancel()
  348. self._handshake_timeout_handle = None
  349. def get_buffer(self, n):
  350. want = n
  351. if want <= 0 or want > self.max_size:
  352. want = self.max_size
  353. if len(self._ssl_buffer) < want:
  354. self._ssl_buffer = bytearray(want)
  355. self._ssl_buffer_view = memoryview(self._ssl_buffer)
  356. return self._ssl_buffer_view
  357. def buffer_updated(self, nbytes):
  358. self._incoming.write(self._ssl_buffer_view[:nbytes])
  359. if self._state == SSLProtocolState.DO_HANDSHAKE:
  360. self._do_handshake()
  361. elif self._state == SSLProtocolState.WRAPPED:
  362. self._do_read()
  363. elif self._state == SSLProtocolState.FLUSHING:
  364. self._do_flush()
  365. elif self._state == SSLProtocolState.SHUTDOWN:
  366. self._do_shutdown()
  367. def eof_received(self):
  368. """Called when the other end of the low-level stream
  369. is half-closed.
  370. If this returns a false value (including None), the transport
  371. will close itself. If it returns a true value, closing the
  372. transport is up to the protocol.
  373. """
  374. self._eof_received = True
  375. try:
  376. if self._loop.get_debug():
  377. logger.debug("%r received EOF", self)
  378. if self._state == SSLProtocolState.DO_HANDSHAKE:
  379. self._on_handshake_complete(ConnectionResetError)
  380. elif self._state == SSLProtocolState.WRAPPED:
  381. self._set_state(SSLProtocolState.FLUSHING)
  382. if self._app_reading_paused:
  383. return True
  384. else:
  385. self._do_flush()
  386. elif self._state == SSLProtocolState.FLUSHING:
  387. self._do_write()
  388. self._set_state(SSLProtocolState.SHUTDOWN)
  389. self._do_shutdown()
  390. elif self._state == SSLProtocolState.SHUTDOWN:
  391. self._do_shutdown()
  392. except Exception:
  393. self._transport.close()
  394. raise
  395. def _get_extra_info(self, name, default=None):
  396. if name in self._extra:
  397. return self._extra[name]
  398. elif self._transport is not None:
  399. return self._transport.get_extra_info(name, default)
  400. else:
  401. return default
  402. def _set_state(self, new_state):
  403. allowed = False
  404. if new_state == SSLProtocolState.UNWRAPPED:
  405. allowed = True
  406. elif (
  407. self._state == SSLProtocolState.UNWRAPPED and
  408. new_state == SSLProtocolState.DO_HANDSHAKE
  409. ):
  410. allowed = True
  411. elif (
  412. self._state == SSLProtocolState.DO_HANDSHAKE and
  413. new_state == SSLProtocolState.WRAPPED
  414. ):
  415. allowed = True
  416. elif (
  417. self._state == SSLProtocolState.WRAPPED and
  418. new_state == SSLProtocolState.FLUSHING
  419. ):
  420. allowed = True
  421. elif (
  422. self._state == SSLProtocolState.FLUSHING and
  423. new_state == SSLProtocolState.SHUTDOWN
  424. ):
  425. allowed = True
  426. if allowed:
  427. self._state = new_state
  428. else:
  429. raise RuntimeError(
  430. 'cannot switch state from {} to {}'.format(
  431. self._state, new_state))
  432. # Handshake flow
  433. def _start_handshake(self):
  434. if self._loop.get_debug():
  435. logger.debug("%r starts SSL handshake", self)
  436. self._handshake_start_time = self._loop.time()
  437. else:
  438. self._handshake_start_time = None
  439. self._set_state(SSLProtocolState.DO_HANDSHAKE)
  440. # start handshake timeout count down
  441. self._handshake_timeout_handle = \
  442. self._loop.call_later(self._ssl_handshake_timeout,
  443. lambda: self._check_handshake_timeout())
  444. self._do_handshake()
  445. def _check_handshake_timeout(self):
  446. if self._state == SSLProtocolState.DO_HANDSHAKE:
  447. msg = (
  448. f"SSL handshake is taking longer than "
  449. f"{self._ssl_handshake_timeout} seconds: "
  450. f"aborting the connection"
  451. )
  452. self._fatal_error(ConnectionAbortedError(msg))
  453. def _do_handshake(self):
  454. try:
  455. self._sslobj.do_handshake()
  456. except SSLAgainErrors:
  457. self._process_outgoing()
  458. except ssl.SSLError as exc:
  459. self._on_handshake_complete(exc)
  460. else:
  461. self._on_handshake_complete(None)
  462. def _on_handshake_complete(self, handshake_exc):
  463. if self._handshake_timeout_handle is not None:
  464. self._handshake_timeout_handle.cancel()
  465. self._handshake_timeout_handle = None
  466. sslobj = self._sslobj
  467. try:
  468. if handshake_exc is None:
  469. self._set_state(SSLProtocolState.WRAPPED)
  470. else:
  471. raise handshake_exc
  472. peercert = sslobj.getpeercert()
  473. except Exception as exc:
  474. handshake_exc = None
  475. self._set_state(SSLProtocolState.UNWRAPPED)
  476. if isinstance(exc, ssl.CertificateError):
  477. msg = 'SSL handshake failed on verifying the certificate'
  478. else:
  479. msg = 'SSL handshake failed'
  480. self._fatal_error(exc, msg)
  481. self._wakeup_waiter(exc)
  482. return
  483. if self._loop.get_debug():
  484. dt = self._loop.time() - self._handshake_start_time
  485. logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3)
  486. # Add extra info that becomes available after handshake.
  487. self._extra.update(peercert=peercert,
  488. cipher=sslobj.cipher(),
  489. compression=sslobj.compression(),
  490. ssl_object=sslobj)
  491. if self._app_state == AppProtocolState.STATE_INIT:
  492. self._app_state = AppProtocolState.STATE_CON_MADE
  493. self._app_protocol.connection_made(self._get_app_transport())
  494. self._wakeup_waiter()
  495. self._do_read()
  496. # Shutdown flow
  497. def _start_shutdown(self):
  498. if (
  499. self._state in (
  500. SSLProtocolState.FLUSHING,
  501. SSLProtocolState.SHUTDOWN,
  502. SSLProtocolState.UNWRAPPED
  503. )
  504. ):
  505. return
  506. if self._app_transport is not None:
  507. self._app_transport._closed = True
  508. if self._state == SSLProtocolState.DO_HANDSHAKE:
  509. self._abort(None)
  510. else:
  511. self._set_state(SSLProtocolState.FLUSHING)
  512. self._shutdown_timeout_handle = self._loop.call_later(
  513. self._ssl_shutdown_timeout,
  514. lambda: self._check_shutdown_timeout()
  515. )
  516. self._do_flush()
  517. def _check_shutdown_timeout(self):
  518. if (
  519. self._state in (
  520. SSLProtocolState.FLUSHING,
  521. SSLProtocolState.SHUTDOWN
  522. )
  523. ):
  524. self._transport._force_close(
  525. exceptions.TimeoutError('SSL shutdown timed out'))
  526. def _do_flush(self):
  527. self._do_read()
  528. self._set_state(SSLProtocolState.SHUTDOWN)
  529. self._do_shutdown()
  530. def _do_shutdown(self):
  531. try:
  532. if not self._eof_received:
  533. self._sslobj.unwrap()
  534. except SSLAgainErrors:
  535. self._process_outgoing()
  536. except ssl.SSLError as exc:
  537. self._on_shutdown_complete(exc)
  538. else:
  539. self._process_outgoing()
  540. self._call_eof_received()
  541. self._on_shutdown_complete(None)
  542. def _on_shutdown_complete(self, shutdown_exc):
  543. if self._shutdown_timeout_handle is not None:
  544. self._shutdown_timeout_handle.cancel()
  545. self._shutdown_timeout_handle = None
  546. if shutdown_exc:
  547. self._fatal_error(shutdown_exc)
  548. else:
  549. self._loop.call_soon(self._transport.close)
  550. def _abort(self, exc):
  551. self._set_state(SSLProtocolState.UNWRAPPED)
  552. if self._transport is not None:
  553. self._transport._force_close(exc)
  554. # Outgoing flow
  555. def _write_appdata(self, list_of_data):
  556. if (
  557. self._state in (
  558. SSLProtocolState.FLUSHING,
  559. SSLProtocolState.SHUTDOWN,
  560. SSLProtocolState.UNWRAPPED
  561. )
  562. ):
  563. if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
  564. logger.warning('SSL connection is closed')
  565. self._conn_lost += 1
  566. return
  567. for data in list_of_data:
  568. self._write_backlog.append(data)
  569. self._write_buffer_size += len(data)
  570. try:
  571. if self._state == SSLProtocolState.WRAPPED:
  572. self._do_write()
  573. except Exception as ex:
  574. self._fatal_error(ex, 'Fatal error on SSL protocol')
  575. def _do_write(self):
  576. try:
  577. while self._write_backlog:
  578. data = self._write_backlog[0]
  579. count = self._sslobj.write(data)
  580. data_len = len(data)
  581. if count < data_len:
  582. self._write_backlog[0] = data[count:]
  583. self._write_buffer_size -= count
  584. else:
  585. del self._write_backlog[0]
  586. self._write_buffer_size -= data_len
  587. except SSLAgainErrors:
  588. pass
  589. self._process_outgoing()
  590. def _process_outgoing(self):
  591. if not self._ssl_writing_paused:
  592. data = self._outgoing.read()
  593. if len(data):
  594. self._transport.write(data)
  595. self._control_app_writing()
  596. # Incoming flow
  597. def _do_read(self):
  598. if (
  599. self._state not in (
  600. SSLProtocolState.WRAPPED,
  601. SSLProtocolState.FLUSHING,
  602. )
  603. ):
  604. return
  605. try:
  606. if not self._app_reading_paused:
  607. if self._app_protocol_is_buffer:
  608. self._do_read__buffered()
  609. else:
  610. self._do_read__copied()
  611. if self._write_backlog:
  612. self._do_write()
  613. else:
  614. self._process_outgoing()
  615. self._control_ssl_reading()
  616. except Exception as ex:
  617. self._fatal_error(ex, 'Fatal error on SSL protocol')
  618. def _do_read__buffered(self):
  619. offset = 0
  620. count = 1
  621. buf = self._app_protocol_get_buffer(self._get_read_buffer_size())
  622. wants = len(buf)
  623. try:
  624. count = self._sslobj.read(wants, buf)
  625. if count > 0:
  626. offset = count
  627. while offset < wants:
  628. count = self._sslobj.read(wants - offset, buf[offset:])
  629. if count > 0:
  630. offset += count
  631. else:
  632. break
  633. else:
  634. self._loop.call_soon(lambda: self._do_read())
  635. except SSLAgainErrors:
  636. pass
  637. if offset > 0:
  638. self._app_protocol_buffer_updated(offset)
  639. if not count:
  640. # close_notify
  641. self._call_eof_received()
  642. self._start_shutdown()
  643. def _do_read__copied(self):
  644. chunk = b'1'
  645. zero = True
  646. one = False
  647. try:
  648. while True:
  649. chunk = self._sslobj.read(self.max_size)
  650. if not chunk:
  651. break
  652. if zero:
  653. zero = False
  654. one = True
  655. first = chunk
  656. elif one:
  657. one = False
  658. data = [first, chunk]
  659. else:
  660. data.append(chunk)
  661. except SSLAgainErrors:
  662. pass
  663. if one:
  664. self._app_protocol.data_received(first)
  665. elif not zero:
  666. self._app_protocol.data_received(b''.join(data))
  667. if not chunk:
  668. # close_notify
  669. self._call_eof_received()
  670. self._start_shutdown()
  671. def _call_eof_received(self):
  672. try:
  673. if self._app_state == AppProtocolState.STATE_CON_MADE:
  674. self._app_state = AppProtocolState.STATE_EOF
  675. keep_open = self._app_protocol.eof_received()
  676. if keep_open:
  677. logger.warning('returning true from eof_received() '
  678. 'has no effect when using ssl')
  679. except (KeyboardInterrupt, SystemExit):
  680. raise
  681. except BaseException as ex:
  682. self._fatal_error(ex, 'Error calling eof_received()')
  683. # Flow control for writes from APP socket
  684. def _control_app_writing(self):
  685. size = self._get_write_buffer_size()
  686. if size >= self._outgoing_high_water and not self._app_writing_paused:
  687. self._app_writing_paused = True
  688. try:
  689. self._app_protocol.pause_writing()
  690. except (KeyboardInterrupt, SystemExit):
  691. raise
  692. except BaseException as exc:
  693. self._loop.call_exception_handler({
  694. 'message': 'protocol.pause_writing() failed',
  695. 'exception': exc,
  696. 'transport': self._app_transport,
  697. 'protocol': self,
  698. })
  699. elif size <= self._outgoing_low_water and self._app_writing_paused:
  700. self._app_writing_paused = False
  701. try:
  702. self._app_protocol.resume_writing()
  703. except (KeyboardInterrupt, SystemExit):
  704. raise
  705. except BaseException as exc:
  706. self._loop.call_exception_handler({
  707. 'message': 'protocol.resume_writing() failed',
  708. 'exception': exc,
  709. 'transport': self._app_transport,
  710. 'protocol': self,
  711. })
  712. def _get_write_buffer_size(self):
  713. return self._outgoing.pending + self._write_buffer_size
  714. def _set_write_buffer_limits(self, high=None, low=None):
  715. high, low = add_flowcontrol_defaults(
  716. high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE)
  717. self._outgoing_high_water = high
  718. self._outgoing_low_water = low
  719. # Flow control for reads to APP socket
  720. def _pause_reading(self):
  721. self._app_reading_paused = True
  722. def _resume_reading(self):
  723. if self._app_reading_paused:
  724. self._app_reading_paused = False
  725. def resume():
  726. if self._state == SSLProtocolState.WRAPPED:
  727. self._do_read()
  728. elif self._state == SSLProtocolState.FLUSHING:
  729. self._do_flush()
  730. elif self._state == SSLProtocolState.SHUTDOWN:
  731. self._do_shutdown()
  732. self._loop.call_soon(resume)
  733. # Flow control for reads from SSL socket
  734. def _control_ssl_reading(self):
  735. size = self._get_read_buffer_size()
  736. if size >= self._incoming_high_water and not self._ssl_reading_paused:
  737. self._ssl_reading_paused = True
  738. self._transport.pause_reading()
  739. elif size <= self._incoming_low_water and self._ssl_reading_paused:
  740. self._ssl_reading_paused = False
  741. self._transport.resume_reading()
  742. def _set_read_buffer_limits(self, high=None, low=None):
  743. high, low = add_flowcontrol_defaults(
  744. high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ)
  745. self._incoming_high_water = high
  746. self._incoming_low_water = low
  747. def _get_read_buffer_size(self):
  748. return self._incoming.pending
  749. # Flow control for writes to SSL socket
  750. def pause_writing(self):
  751. """Called when the low-level transport's buffer goes over
  752. the high-water mark.
  753. """
  754. assert not self._ssl_writing_paused
  755. self._ssl_writing_paused = True
  756. def resume_writing(self):
  757. """Called when the low-level transport's buffer drains below
  758. the low-water mark.
  759. """
  760. assert self._ssl_writing_paused
  761. self._ssl_writing_paused = False
  762. self._process_outgoing()
  763. def _fatal_error(self, exc, message='Fatal error on transport'):
  764. if self._transport:
  765. self._transport._force_close(exc)
  766. if isinstance(exc, OSError):
  767. if self._loop.get_debug():
  768. logger.debug("%r: %s", self, message, exc_info=True)
  769. elif not isinstance(exc, exceptions.CancelledError):
  770. self._loop.call_exception_handler({
  771. 'message': message,
  772. 'exception': exc,
  773. 'transport': self._transport,
  774. 'protocol': self,
  775. })