test_sslsocket.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. #
  19. import inspect
  20. import logging
  21. import os
  22. import platform
  23. import ssl
  24. import sys
  25. import tempfile
  26. import threading
  27. import unittest
  28. import warnings
  29. from contextlib import contextmanager
  30. from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket, _match_has_ipaddress # FIX_FOR_YANDEX moved here
  31. from thrift.transport.TTransport import TTransportException # FIX_FOR_YANDEX moved here
  32. import yatest.common as yc
  33. SCRIPT_DIR = os.path.realpath(os.path.dirname(yc.source_path(__file__)))
  34. ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
  35. SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem')
  36. SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt')
  37. SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key')
  38. CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt')
  39. CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key')
  40. CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt')
  41. CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key')
  42. CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem')
  43. TEST_CIPHERS = 'DES-CBC3-SHA:ECDHE-RSA-AES128-GCM-SHA256'
  44. class ServerAcceptor(threading.Thread):
  45. def __init__(self, server, expect_failure=False):
  46. super(ServerAcceptor, self).__init__()
  47. self.daemon = True
  48. self._server = server
  49. self._listening = threading.Event()
  50. self._port = None
  51. self._port_bound = threading.Event()
  52. self._client = None
  53. self._client_accepted = threading.Event()
  54. self._expect_failure = expect_failure
  55. frame = inspect.stack(3)[2]
  56. self.name = frame[3]
  57. del frame
  58. def run(self):
  59. self._server.listen()
  60. self._listening.set()
  61. try:
  62. address = self._server.handle.getsockname()
  63. if len(address) > 1:
  64. # AF_INET addresses are 2-tuples (host, port) and AF_INET6 are
  65. # 4-tuples (host, port, ...), but in each case port is in the second slot.
  66. self._port = address[1]
  67. finally:
  68. self._port_bound.set()
  69. try:
  70. self._client = self._server.accept()
  71. if self._client:
  72. self._client.read(5) # hello
  73. self._client.write(b"there")
  74. except Exception:
  75. logging.exception('error on server side (%s):' % self.name)
  76. if not self._expect_failure:
  77. raise
  78. finally:
  79. self._client_accepted.set()
  80. def await_listening(self):
  81. self._listening.wait()
  82. @property
  83. def port(self):
  84. self._port_bound.wait()
  85. return self._port
  86. @property
  87. def client(self):
  88. self._client_accepted.wait()
  89. return self._client
  90. def close(self):
  91. if self._client:
  92. self._client.close()
  93. self._server.close()
  94. # Python 2.6 compat
  95. class AssertRaises(object):
  96. def __init__(self, expected):
  97. self._expected = expected
  98. def __enter__(self):
  99. pass
  100. def __exit__(self, exc_type, exc_value, traceback):
  101. if not exc_type or not issubclass(exc_type, self._expected):
  102. raise Exception('fail')
  103. return True
  104. @unittest.skip("failing SSL test to be fixed in subsequent pull request")
  105. class TSSLSocketTest(unittest.TestCase):
  106. def _server_socket(self, **kwargs):
  107. return TSSLServerSocket(port=0, **kwargs)
  108. @contextmanager
  109. def _connectable_client(self, server, expect_failure=False, path=None, **client_kwargs):
  110. acc = ServerAcceptor(server, expect_failure)
  111. try:
  112. acc.start()
  113. acc.await_listening()
  114. host, port = ('localhost', acc.port) if path is None else (None, None)
  115. client = TSSLSocket(host, port, unix_socket=path, **client_kwargs)
  116. yield acc, client
  117. finally:
  118. acc.close()
  119. def _assert_connection_failure(self, server, path=None, **client_args):
  120. logging.disable(logging.CRITICAL)
  121. try:
  122. with self._connectable_client(server, True, path=path, **client_args) as (acc, client):
  123. # We need to wait for a connection failure, but not too long. 20ms is a tunable
  124. # compromise between test speed and stability
  125. client.setTimeout(20)
  126. with self._assert_raises(TTransportException):
  127. client.open()
  128. client.write(b"hello")
  129. client.read(5) # b"there"
  130. finally:
  131. logging.disable(logging.NOTSET)
  132. def _assert_raises(self, exc):
  133. if sys.hexversion >= 0x020700F0:
  134. return self.assertRaises(exc)
  135. else:
  136. return AssertRaises(exc)
  137. def _assert_connection_success(self, server, path=None, **client_args):
  138. with self._connectable_client(server, path=path, **client_args) as (acc, client):
  139. try:
  140. self.assertFalse(client.isOpen())
  141. client.open()
  142. self.assertTrue(client.isOpen())
  143. client.write(b"hello")
  144. self.assertEqual(client.read(5), b"there")
  145. self.assertTrue(acc.client is not None)
  146. finally:
  147. client.close()
  148. # deprecated feature
  149. def test_deprecation(self):
  150. with warnings.catch_warnings(record=True) as w:
  151. warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
  152. TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT)
  153. self.assertEqual(len(w), 1)
  154. with warnings.catch_warnings(record=True) as w:
  155. warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
  156. # Deprecated signature
  157. # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
  158. TSSLSocket('localhost', 0, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS)
  159. self.assertEqual(len(w), 7)
  160. with warnings.catch_warnings(record=True) as w:
  161. warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
  162. # Deprecated signature
  163. # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
  164. TSSLServerSocket(None, 0, SERVER_PEM, None, TEST_CIPHERS)
  165. self.assertEqual(len(w), 3)
  166. # deprecated feature
  167. def test_set_cert_reqs_by_validate(self):
  168. with warnings.catch_warnings(record=True) as w:
  169. warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
  170. c1 = TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT)
  171. self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED)
  172. c1 = TSSLSocket('localhost', 0, validate=False)
  173. self.assertEqual(c1.cert_reqs, ssl.CERT_NONE)
  174. self.assertEqual(len(w), 2)
  175. # deprecated feature
  176. def test_set_validate_by_cert_reqs(self):
  177. with warnings.catch_warnings(record=True) as w:
  178. warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
  179. c1 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_NONE)
  180. self.assertFalse(c1.validate)
  181. c2 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
  182. self.assertTrue(c2.validate)
  183. c3 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT)
  184. self.assertTrue(c3.validate)
  185. self.assertEqual(len(w), 3)
  186. def test_unix_domain_socket(self):
  187. if platform.system() == 'Windows':
  188. print('skipping test_unix_domain_socket')
  189. return
  190. fd, path = tempfile.mkstemp()
  191. os.close(fd)
  192. os.unlink(path)
  193. try:
  194. server = self._server_socket(unix_socket=path, keyfile=SERVER_KEY, certfile=SERVER_CERT)
  195. self._assert_connection_success(server, path=path, cert_reqs=ssl.CERT_NONE)
  196. finally:
  197. os.unlink(path)
  198. def test_server_cert(self):
  199. server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
  200. self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
  201. server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
  202. # server cert not in ca_certs
  203. self._assert_connection_failure(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT)
  204. server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
  205. self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE)
  206. def test_set_server_cert(self):
  207. server = self._server_socket(keyfile=SERVER_KEY, certfile=CLIENT_CERT)
  208. with self._assert_raises(Exception):
  209. server.certfile = 'foo'
  210. with self._assert_raises(Exception):
  211. server.certfile = None
  212. server.certfile = SERVER_CERT
  213. self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
  214. def test_client_cert(self):
  215. if not _match_has_ipaddress:
  216. print('skipping test_client_cert')
  217. return
  218. server = self._server_socket(
  219. cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
  220. certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
  221. self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY)
  222. server = self._server_socket(
  223. cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
  224. certfile=SERVER_CERT, ca_certs=CLIENT_CA)
  225. self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP)
  226. server = self._server_socket(
  227. cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
  228. certfile=SERVER_CERT, ca_certs=CLIENT_CA)
  229. self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
  230. server = self._server_socket(
  231. cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY,
  232. certfile=SERVER_CERT, ca_certs=CLIENT_CA)
  233. self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
  234. def test_ciphers(self):
  235. server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
  236. self._assert_connection_success(server, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS)
  237. if not TSSLSocket._has_ciphers:
  238. # unittest.skip is not available for Python 2.6
  239. print('skipping test_ciphers')
  240. return
  241. server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
  242. self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL')
  243. server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
  244. self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL')
  245. def test_ssl2_and_ssl3_disabled(self):
  246. if not hasattr(ssl, 'PROTOCOL_SSLv3'):
  247. print('PROTOCOL_SSLv3 is not available')
  248. else:
  249. server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
  250. self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
  251. server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
  252. self._assert_connection_failure(server, ca_certs=SERVER_CERT)
  253. if not hasattr(ssl, 'PROTOCOL_SSLv2'):
  254. print('PROTOCOL_SSLv2 is not available')
  255. else:
  256. server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
  257. self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
  258. server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
  259. self._assert_connection_failure(server, ca_certs=SERVER_CERT)
  260. def test_newer_tls(self):
  261. if not TSSLSocket._has_ssl_context:
  262. # unittest.skip is not available for Python 2.6
  263. print('skipping test_newer_tls')
  264. return
  265. if not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
  266. print('PROTOCOL_TLSv1_2 is not available')
  267. else:
  268. server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
  269. self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
  270. if not hasattr(ssl, 'PROTOCOL_TLSv1_1'):
  271. print('PROTOCOL_TLSv1_1 is not available')
  272. else:
  273. server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
  274. self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
  275. if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
  276. print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available')
  277. else:
  278. server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
  279. self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
  280. def test_ssl_context(self):
  281. if not TSSLSocket._has_ssl_context:
  282. # unittest.skip is not available for Python 2.6
  283. print('skipping test_ssl_context')
  284. return
  285. server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
  286. server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
  287. server_context.load_verify_locations(CLIENT_CA)
  288. server_context.verify_mode = ssl.CERT_REQUIRED
  289. server = self._server_socket(ssl_context=server_context)
  290. client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
  291. client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
  292. client_context.load_verify_locations(SERVER_CERT)
  293. client_context.verify_mode = ssl.CERT_REQUIRED
  294. self._assert_connection_success(server, ssl_context=client_context)
  295. # Add a dummy test because starting from python 3.12, if all tests in a test
  296. # file are skipped that's considered an error.
  297. class DummyTest(unittest.TestCase):
  298. def test_dummy(self):
  299. self.assertEqual(0, 0)
  300. if __name__ == '__main__':
  301. logging.basicConfig(level=logging.WARN)
  302. from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket, _match_has_ipaddress
  303. from thrift.transport.TTransport import TTransportException
  304. unittest.main()