test_cli.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. """
  2. Unit tests for CLI entry points.
  3. """
  4. from __future__ import print_function
  5. import functools
  6. import io
  7. import os
  8. import sys
  9. import typing
  10. import unittest
  11. from contextlib import contextmanager, redirect_stdout, redirect_stderr
  12. import rsa
  13. import rsa.cli
  14. import rsa.util
  15. @contextmanager
  16. def captured_output() -> typing.Generator:
  17. """Captures output to stdout and stderr"""
  18. # According to mypy, we're not supposed to change buf_out.buffer.
  19. # However, this is just a test, and it works, hence the 'type: ignore'.
  20. buf_out = io.StringIO()
  21. buf_out.buffer = io.BytesIO() # type: ignore
  22. buf_err = io.StringIO()
  23. buf_err.buffer = io.BytesIO() # type: ignore
  24. with redirect_stdout(buf_out), redirect_stderr(buf_err):
  25. yield buf_out, buf_err
  26. def get_bytes_out(buf) -> bytes:
  27. return buf.buffer.getvalue()
  28. @contextmanager
  29. def cli_args(*new_argv):
  30. """Updates sys.argv[1:] for a single test."""
  31. old_args = sys.argv[:]
  32. sys.argv[1:] = [str(arg) for arg in new_argv]
  33. try:
  34. yield
  35. finally:
  36. sys.argv[1:] = old_args
  37. def remove_if_exists(fname):
  38. """Removes a file if it exists."""
  39. if os.path.exists(fname):
  40. os.unlink(fname)
  41. def cleanup_files(*filenames):
  42. """Makes sure the files don't exist when the test runs, and deletes them afterward."""
  43. def remove():
  44. for fname in filenames:
  45. remove_if_exists(fname)
  46. def decorator(func):
  47. @functools.wraps(func)
  48. def wrapper(*args, **kwargs):
  49. remove()
  50. try:
  51. return func(*args, **kwargs)
  52. finally:
  53. remove()
  54. return wrapper
  55. return decorator
  56. class AbstractCliTest(unittest.TestCase):
  57. @classmethod
  58. def setUpClass(cls):
  59. # Ensure there is a key to use
  60. cls.pub_key, cls.priv_key = rsa.newkeys(512)
  61. cls.pub_fname = "%s.pub" % cls.__name__
  62. cls.priv_fname = "%s.key" % cls.__name__
  63. with open(cls.pub_fname, "wb") as outfile:
  64. outfile.write(cls.pub_key.save_pkcs1())
  65. with open(cls.priv_fname, "wb") as outfile:
  66. outfile.write(cls.priv_key.save_pkcs1())
  67. @classmethod
  68. def tearDownClass(cls):
  69. if hasattr(cls, "pub_fname"):
  70. remove_if_exists(cls.pub_fname)
  71. if hasattr(cls, "priv_fname"):
  72. remove_if_exists(cls.priv_fname)
  73. def assertExits(self, status_code, func, *args, **kwargs):
  74. try:
  75. func(*args, **kwargs)
  76. except SystemExit as ex:
  77. if status_code == ex.code:
  78. return
  79. self.fail(
  80. "SystemExit() raised by %r, but exited with code %r, expected %r"
  81. % (func, ex.code, status_code)
  82. )
  83. else:
  84. self.fail("SystemExit() not raised by %r" % func)
  85. class KeygenTest(AbstractCliTest):
  86. def test_keygen_no_args(self):
  87. with captured_output(), cli_args():
  88. self.assertExits(1, rsa.cli.keygen)
  89. def test_keygen_priv_stdout(self):
  90. with captured_output() as (out, err):
  91. with cli_args(128):
  92. rsa.cli.keygen()
  93. lines = get_bytes_out(out).splitlines()
  94. self.assertEqual(b"-----BEGIN RSA PRIVATE KEY-----", lines[0])
  95. self.assertEqual(b"-----END RSA PRIVATE KEY-----", lines[-1])
  96. # The key size should be shown on stderr
  97. self.assertTrue("128-bit key" in err.getvalue())
  98. @cleanup_files("test_cli_privkey_out.pem")
  99. def test_keygen_priv_out_pem(self):
  100. with captured_output() as (out, err):
  101. with cli_args("--out=test_cli_privkey_out.pem", "--form=PEM", 128):
  102. rsa.cli.keygen()
  103. # The key size should be shown on stderr
  104. self.assertTrue("128-bit key" in err.getvalue())
  105. # The output file should be shown on stderr
  106. self.assertTrue("test_cli_privkey_out.pem" in err.getvalue())
  107. # If we can load the file as PEM, it's good enough.
  108. with open("test_cli_privkey_out.pem", "rb") as pemfile:
  109. rsa.PrivateKey.load_pkcs1(pemfile.read())
  110. @cleanup_files("test_cli_privkey_out.der")
  111. def test_keygen_priv_out_der(self):
  112. with captured_output() as (out, err):
  113. with cli_args("--out=test_cli_privkey_out.der", "--form=DER", 128):
  114. rsa.cli.keygen()
  115. # The key size should be shown on stderr
  116. self.assertTrue("128-bit key" in err.getvalue())
  117. # The output file should be shown on stderr
  118. self.assertTrue("test_cli_privkey_out.der" in err.getvalue())
  119. # If we can load the file as der, it's good enough.
  120. with open("test_cli_privkey_out.der", "rb") as derfile:
  121. rsa.PrivateKey.load_pkcs1(derfile.read(), format="DER")
  122. @cleanup_files("test_cli_privkey_out.pem", "test_cli_pubkey_out.pem")
  123. def test_keygen_pub_out_pem(self):
  124. with captured_output() as (out, err):
  125. with cli_args(
  126. "--out=test_cli_privkey_out.pem",
  127. "--pubout=test_cli_pubkey_out.pem",
  128. "--form=PEM",
  129. 256,
  130. ):
  131. rsa.cli.keygen()
  132. # The key size should be shown on stderr
  133. self.assertTrue("256-bit key" in err.getvalue())
  134. # The output files should be shown on stderr
  135. self.assertTrue("test_cli_privkey_out.pem" in err.getvalue())
  136. self.assertTrue("test_cli_pubkey_out.pem" in err.getvalue())
  137. # If we can load the file as PEM, it's good enough.
  138. with open("test_cli_pubkey_out.pem", "rb") as pemfile:
  139. rsa.PublicKey.load_pkcs1(pemfile.read())
  140. class EncryptDecryptTest(AbstractCliTest):
  141. def test_empty_decrypt(self):
  142. with captured_output(), cli_args():
  143. self.assertExits(1, rsa.cli.decrypt)
  144. def test_empty_encrypt(self):
  145. with captured_output(), cli_args():
  146. self.assertExits(1, rsa.cli.encrypt)
  147. @cleanup_files("encrypted.txt", "cleartext.txt")
  148. def test_encrypt_decrypt(self):
  149. with open("cleartext.txt", "wb") as outfile:
  150. outfile.write(b"Hello cleartext RSA users!")
  151. with cli_args("-i", "cleartext.txt", "--out=encrypted.txt", self.pub_fname):
  152. with captured_output():
  153. rsa.cli.encrypt()
  154. with cli_args("-i", "encrypted.txt", self.priv_fname):
  155. with captured_output() as (out, err):
  156. rsa.cli.decrypt()
  157. # We should have the original cleartext on stdout now.
  158. output = get_bytes_out(out)
  159. self.assertEqual(b"Hello cleartext RSA users!", output)
  160. @cleanup_files("encrypted.txt", "cleartext.txt")
  161. def test_encrypt_decrypt_unhappy(self):
  162. with open("cleartext.txt", "wb") as outfile:
  163. outfile.write(b"Hello cleartext RSA users!")
  164. with cli_args("-i", "cleartext.txt", "--out=encrypted.txt", self.pub_fname):
  165. with captured_output():
  166. rsa.cli.encrypt()
  167. # Change a few bytes in the encrypted stream.
  168. with open("encrypted.txt", "r+b") as encfile:
  169. encfile.seek(40)
  170. encfile.write(b"hahaha")
  171. with cli_args("-i", "encrypted.txt", self.priv_fname):
  172. with captured_output() as (out, err):
  173. self.assertRaises(rsa.DecryptionError, rsa.cli.decrypt)
  174. class SignVerifyTest(AbstractCliTest):
  175. def test_empty_verify(self):
  176. with captured_output(), cli_args():
  177. self.assertExits(1, rsa.cli.verify)
  178. def test_empty_sign(self):
  179. with captured_output(), cli_args():
  180. self.assertExits(1, rsa.cli.sign)
  181. @cleanup_files("signature.txt", "cleartext.txt")
  182. def test_sign_verify(self):
  183. with open("cleartext.txt", "wb") as outfile:
  184. outfile.write(b"Hello RSA users!")
  185. with cli_args("-i", "cleartext.txt", "--out=signature.txt", self.priv_fname, "SHA-256"):
  186. with captured_output():
  187. rsa.cli.sign()
  188. with cli_args("-i", "cleartext.txt", self.pub_fname, "signature.txt"):
  189. with captured_output() as (out, err):
  190. rsa.cli.verify()
  191. self.assertFalse(b"Verification OK" in get_bytes_out(out))
  192. @cleanup_files("signature.txt", "cleartext.txt")
  193. def test_sign_verify_unhappy(self):
  194. with open("cleartext.txt", "wb") as outfile:
  195. outfile.write(b"Hello RSA users!")
  196. with cli_args("-i", "cleartext.txt", "--out=signature.txt", self.priv_fname, "SHA-256"):
  197. with captured_output():
  198. rsa.cli.sign()
  199. # Change a few bytes in the cleartext file.
  200. with open("cleartext.txt", "r+b") as encfile:
  201. encfile.seek(6)
  202. encfile.write(b"DSA")
  203. with cli_args("-i", "cleartext.txt", self.pub_fname, "signature.txt"):
  204. with captured_output() as (out, err):
  205. self.assertExits("Verification failed.", rsa.cli.verify)
  206. class PrivatePublicTest(AbstractCliTest):
  207. """Test CLI command to convert a private to a public key."""
  208. @cleanup_files("test_private_to_public.pem")
  209. def test_private_to_public(self):
  210. with cli_args("-i", self.priv_fname, "-o", "test_private_to_public.pem"):
  211. with captured_output():
  212. rsa.util.private_to_public()
  213. # Check that the key is indeed valid.
  214. with open("test_private_to_public.pem", "rb") as pemfile:
  215. key = rsa.PublicKey.load_pkcs1(pemfile.read())
  216. self.assertEqual(self.priv_key.n, key.n)
  217. self.assertEqual(self.priv_key.e, key.e)