factory.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Copyright (c) Twisted Matrix Laboratories.
  2. # See LICENSE for details.
  3. """
  4. A Factory for SSH servers.
  5. See also L{twisted.conch.openssh_compat.factory} for OpenSSH compatibility.
  6. Maintainer: Paul Swartz
  7. """
  8. import random
  9. from itertools import chain
  10. from typing import Dict, List, Optional, Tuple
  11. from twisted.conch import error
  12. from twisted.conch.ssh import _kex, connection, transport, userauth
  13. from twisted.internet import protocol
  14. from twisted.logger import Logger
  15. class SSHFactory(protocol.Factory):
  16. """
  17. A Factory for SSH servers.
  18. """
  19. primes: Optional[Dict[int, List[Tuple[int, int]]]]
  20. _log = Logger()
  21. protocol = transport.SSHServerTransport
  22. services = {
  23. b"ssh-userauth": userauth.SSHUserAuthServer,
  24. b"ssh-connection": connection.SSHConnection,
  25. }
  26. def startFactory(self) -> None:
  27. """
  28. Check for public and private keys.
  29. """
  30. if not hasattr(self, "publicKeys"):
  31. self.publicKeys = self.getPublicKeys()
  32. if not hasattr(self, "privateKeys"):
  33. self.privateKeys = self.getPrivateKeys()
  34. if not self.publicKeys or not self.privateKeys:
  35. raise error.ConchError("no host keys, failing")
  36. if not hasattr(self, "primes"):
  37. self.primes = self.getPrimes()
  38. def buildProtocol(self, addr):
  39. """
  40. Create an instance of the server side of the SSH protocol.
  41. @type addr: L{twisted.internet.interfaces.IAddress} provider
  42. @param addr: The address at which the server will listen.
  43. @rtype: L{twisted.conch.ssh.transport.SSHServerTransport}
  44. @return: The built transport.
  45. """
  46. t = protocol.Factory.buildProtocol(self, addr)
  47. t.supportedPublicKeys = list(
  48. chain.from_iterable(
  49. key.supportedSignatureAlgorithms() for key in self.privateKeys.values()
  50. )
  51. )
  52. if not self.primes:
  53. self._log.info(
  54. "disabling non-fixed-group key exchange algorithms "
  55. "because we cannot find moduli file"
  56. )
  57. t.supportedKeyExchanges = [
  58. kexAlgorithm
  59. for kexAlgorithm in t.supportedKeyExchanges
  60. if _kex.isFixedGroup(kexAlgorithm) or _kex.isEllipticCurve(kexAlgorithm)
  61. ]
  62. return t
  63. def getPublicKeys(self):
  64. """
  65. Called when the factory is started to get the public portions of the
  66. servers host keys. Returns a dictionary mapping SSH key types to
  67. public key strings.
  68. @rtype: L{dict}
  69. """
  70. raise NotImplementedError("getPublicKeys unimplemented")
  71. def getPrivateKeys(self):
  72. """
  73. Called when the factory is started to get the private portions of the
  74. servers host keys. Returns a dictionary mapping SSH key types to
  75. L{twisted.conch.ssh.keys.Key} objects.
  76. @rtype: L{dict}
  77. """
  78. raise NotImplementedError("getPrivateKeys unimplemented")
  79. def getPrimes(self) -> Optional[Dict[int, List[Tuple[int, int]]]]:
  80. """
  81. Called when the factory is started to get Diffie-Hellman generators and
  82. primes to use. Returns a dictionary mapping number of bits to lists of
  83. tuple of (generator, prime).
  84. """
  85. def getDHPrime(self, bits: int) -> Tuple[int, int]:
  86. """
  87. Return a tuple of (g, p) for a Diffe-Hellman process, with p being as
  88. close to C{bits} bits as possible.
  89. """
  90. def keyfunc(i: int) -> int:
  91. return abs(i - bits)
  92. assert self.primes is not None, "Factory should have been started by now."
  93. primesKeys = sorted(self.primes.keys(), key=keyfunc)
  94. realBits = primesKeys[0]
  95. return random.choice(self.primes[realBits])
  96. def getService(self, transport, service):
  97. """
  98. Return a class to use as a service for the given transport.
  99. @type transport: L{transport.SSHServerTransport}
  100. @type service: L{bytes}
  101. @rtype: subclass of L{service.SSHService}
  102. """
  103. if service == b"ssh-userauth" or hasattr(transport, "avatar"):
  104. return self.services[service]