123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365 |
- #
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- #
- import inspect
- import logging
- import os
- import platform
- import ssl
- import sys
- import tempfile
- import threading
- import unittest
- import warnings
- from contextlib import contextmanager
- from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket, _match_has_ipaddress # FIX_FOR_YANDEX moved here
- from thrift.transport.TTransport import TTransportException # FIX_FOR_YANDEX moved here
- import yatest.common as yc
- SCRIPT_DIR = os.path.realpath(os.path.dirname(yc.source_path(__file__)))
- ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
- SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem')
- SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt')
- SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key')
- CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt')
- CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key')
- CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt')
- CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key')
- CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem')
- TEST_CIPHERS = 'DES-CBC3-SHA:ECDHE-RSA-AES128-GCM-SHA256'
- class ServerAcceptor(threading.Thread):
- def __init__(self, server, expect_failure=False):
- super(ServerAcceptor, self).__init__()
- self.daemon = True
- self._server = server
- self._listening = threading.Event()
- self._port = None
- self._port_bound = threading.Event()
- self._client = None
- self._client_accepted = threading.Event()
- self._expect_failure = expect_failure
- frame = inspect.stack(3)[2]
- self.name = frame[3]
- del frame
- def run(self):
- self._server.listen()
- self._listening.set()
- try:
- address = self._server.handle.getsockname()
- if len(address) > 1:
- # AF_INET addresses are 2-tuples (host, port) and AF_INET6 are
- # 4-tuples (host, port, ...), but in each case port is in the second slot.
- self._port = address[1]
- finally:
- self._port_bound.set()
- try:
- self._client = self._server.accept()
- if self._client:
- self._client.read(5) # hello
- self._client.write(b"there")
- except Exception:
- logging.exception('error on server side (%s):' % self.name)
- if not self._expect_failure:
- raise
- finally:
- self._client_accepted.set()
- def await_listening(self):
- self._listening.wait()
- @property
- def port(self):
- self._port_bound.wait()
- return self._port
- @property
- def client(self):
- self._client_accepted.wait()
- return self._client
- def close(self):
- if self._client:
- self._client.close()
- self._server.close()
- # Python 2.6 compat
- class AssertRaises(object):
- def __init__(self, expected):
- self._expected = expected
- def __enter__(self):
- pass
- def __exit__(self, exc_type, exc_value, traceback):
- if not exc_type or not issubclass(exc_type, self._expected):
- raise Exception('fail')
- return True
- @unittest.skip("failing SSL test to be fixed in subsequent pull request")
- class TSSLSocketTest(unittest.TestCase):
- def _server_socket(self, **kwargs):
- return TSSLServerSocket(port=0, **kwargs)
- @contextmanager
- def _connectable_client(self, server, expect_failure=False, path=None, **client_kwargs):
- acc = ServerAcceptor(server, expect_failure)
- try:
- acc.start()
- acc.await_listening()
- host, port = ('localhost', acc.port) if path is None else (None, None)
- client = TSSLSocket(host, port, unix_socket=path, **client_kwargs)
- yield acc, client
- finally:
- acc.close()
- def _assert_connection_failure(self, server, path=None, **client_args):
- logging.disable(logging.CRITICAL)
- try:
- with self._connectable_client(server, True, path=path, **client_args) as (acc, client):
- # We need to wait for a connection failure, but not too long. 20ms is a tunable
- # compromise between test speed and stability
- client.setTimeout(20)
- with self._assert_raises(TTransportException):
- client.open()
- client.write(b"hello")
- client.read(5) # b"there"
- finally:
- logging.disable(logging.NOTSET)
- def _assert_raises(self, exc):
- if sys.hexversion >= 0x020700F0:
- return self.assertRaises(exc)
- else:
- return AssertRaises(exc)
- def _assert_connection_success(self, server, path=None, **client_args):
- with self._connectable_client(server, path=path, **client_args) as (acc, client):
- try:
- self.assertFalse(client.isOpen())
- client.open()
- self.assertTrue(client.isOpen())
- client.write(b"hello")
- self.assertEqual(client.read(5), b"there")
- self.assertTrue(acc.client is not None)
- finally:
- client.close()
- # deprecated feature
- def test_deprecation(self):
- with warnings.catch_warnings(record=True) as w:
- warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
- TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT)
- self.assertEqual(len(w), 1)
- with warnings.catch_warnings(record=True) as w:
- warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
- # Deprecated signature
- # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
- TSSLSocket('localhost', 0, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS)
- self.assertEqual(len(w), 7)
- with warnings.catch_warnings(record=True) as w:
- warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
- # Deprecated signature
- # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
- TSSLServerSocket(None, 0, SERVER_PEM, None, TEST_CIPHERS)
- self.assertEqual(len(w), 3)
- # deprecated feature
- def test_set_cert_reqs_by_validate(self):
- with warnings.catch_warnings(record=True) as w:
- warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
- c1 = TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT)
- self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED)
- c1 = TSSLSocket('localhost', 0, validate=False)
- self.assertEqual(c1.cert_reqs, ssl.CERT_NONE)
- self.assertEqual(len(w), 2)
- # deprecated feature
- def test_set_validate_by_cert_reqs(self):
- with warnings.catch_warnings(record=True) as w:
- warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
- c1 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_NONE)
- self.assertFalse(c1.validate)
- c2 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
- self.assertTrue(c2.validate)
- c3 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT)
- self.assertTrue(c3.validate)
- self.assertEqual(len(w), 3)
- def test_unix_domain_socket(self):
- if platform.system() == 'Windows':
- print('skipping test_unix_domain_socket')
- return
- fd, path = tempfile.mkstemp()
- os.close(fd)
- os.unlink(path)
- try:
- server = self._server_socket(unix_socket=path, keyfile=SERVER_KEY, certfile=SERVER_CERT)
- self._assert_connection_success(server, path=path, cert_reqs=ssl.CERT_NONE)
- finally:
- os.unlink(path)
- def test_server_cert(self):
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
- self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
- # server cert not in ca_certs
- self._assert_connection_failure(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT)
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
- self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE)
- def test_set_server_cert(self):
- server = self._server_socket(keyfile=SERVER_KEY, certfile=CLIENT_CERT)
- with self._assert_raises(Exception):
- server.certfile = 'foo'
- with self._assert_raises(Exception):
- server.certfile = None
- server.certfile = SERVER_CERT
- self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
- def test_client_cert(self):
- if not _match_has_ipaddress:
- print('skipping test_client_cert')
- return
- server = self._server_socket(
- cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
- certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
- self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY)
- server = self._server_socket(
- cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
- certfile=SERVER_CERT, ca_certs=CLIENT_CA)
- self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP)
- server = self._server_socket(
- cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
- certfile=SERVER_CERT, ca_certs=CLIENT_CA)
- self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
- server = self._server_socket(
- cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY,
- certfile=SERVER_CERT, ca_certs=CLIENT_CA)
- self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
- def test_ciphers(self):
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
- self._assert_connection_success(server, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS)
- if not TSSLSocket._has_ciphers:
- # unittest.skip is not available for Python 2.6
- print('skipping test_ciphers')
- return
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
- self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL')
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
- self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL')
- def test_ssl2_and_ssl3_disabled(self):
- if not hasattr(ssl, 'PROTOCOL_SSLv3'):
- print('PROTOCOL_SSLv3 is not available')
- else:
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
- self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
- self._assert_connection_failure(server, ca_certs=SERVER_CERT)
- if not hasattr(ssl, 'PROTOCOL_SSLv2'):
- print('PROTOCOL_SSLv2 is not available')
- else:
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
- self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
- self._assert_connection_failure(server, ca_certs=SERVER_CERT)
- def test_newer_tls(self):
- if not TSSLSocket._has_ssl_context:
- # unittest.skip is not available for Python 2.6
- print('skipping test_newer_tls')
- return
- if not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
- print('PROTOCOL_TLSv1_2 is not available')
- else:
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
- self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
- if not hasattr(ssl, 'PROTOCOL_TLSv1_1'):
- print('PROTOCOL_TLSv1_1 is not available')
- else:
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
- self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
- if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
- print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available')
- else:
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
- self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
- def test_ssl_context(self):
- if not TSSLSocket._has_ssl_context:
- # unittest.skip is not available for Python 2.6
- print('skipping test_ssl_context')
- return
- server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
- server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
- server_context.load_verify_locations(CLIENT_CA)
- server_context.verify_mode = ssl.CERT_REQUIRED
- server = self._server_socket(ssl_context=server_context)
- client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
- client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
- client_context.load_verify_locations(SERVER_CERT)
- client_context.verify_mode = ssl.CERT_REQUIRED
- self._assert_connection_success(server, ssl_context=client_context)
- # Add a dummy test because starting from python 3.12, if all tests in a test
- # file are skipped that's considered an error.
- class DummyTest(unittest.TestCase):
- def test_dummy(self):
- self.assertEqual(0, 0)
- if __name__ == '__main__':
- logging.basicConfig(level=logging.WARN)
- from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket, _match_has_ipaddress
- from thrift.transport.TTransport import TTransportException
- unittest.main()
|