network.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. # coding=utf-8
  2. import os
  3. import errno
  4. import socket
  5. import random
  6. import logging
  7. import platform
  8. import threading
  9. import six
  10. UI16MAXVAL = (1 << 16) - 1
  11. logger = logging.getLogger(__name__)
  12. class PortManagerException(Exception):
  13. pass
  14. class PortManager(object):
  15. """
  16. See documentation here
  17. https://wiki.yandex-team.ru/yatool/test/#python-acquire-ports
  18. """
  19. def __init__(self, sync_dir=None):
  20. self._sync_dir = sync_dir or os.environ.get('PORT_SYNC_PATH')
  21. if self._sync_dir:
  22. _makedirs(self._sync_dir)
  23. self._valid_range = get_valid_port_range()
  24. self._valid_port_count = self._count_valid_ports()
  25. self._filelocks = {}
  26. self._lock = threading.Lock()
  27. def __enter__(self):
  28. return self
  29. def __exit__(self, type, value, traceback):
  30. self.release()
  31. def get_port(self, port=0):
  32. '''
  33. Gets free TCP port
  34. '''
  35. return self.get_tcp_port(port)
  36. def get_tcp_port(self, port=0):
  37. '''
  38. Gets free TCP port
  39. '''
  40. return self._get_port(port, socket.SOCK_STREAM)
  41. def get_udp_port(self, port=0):
  42. '''
  43. Gets free UDP port
  44. '''
  45. return self._get_port(port, socket.SOCK_DGRAM)
  46. def get_tcp_and_udp_port(self, port=0):
  47. '''
  48. Gets one free port for use in both TCP and UDP protocols
  49. '''
  50. if port and self._no_random_ports():
  51. return port
  52. retries = 20
  53. while retries > 0:
  54. retries -= 1
  55. result_port = self.get_tcp_port()
  56. if not self.is_port_free(result_port, socket.SOCK_DGRAM):
  57. self.release_port(result_port)
  58. # Don't try to _capture_port(), it's already captured in the get_tcp_port()
  59. return result_port
  60. raise Exception('Failed to find port')
  61. def release_port(self, port):
  62. with self._lock:
  63. self._release_port_no_lock(port)
  64. def _release_port_no_lock(self, port):
  65. filelock = self._filelocks.pop(port, None)
  66. if filelock:
  67. filelock.release()
  68. def release(self):
  69. with self._lock:
  70. while self._filelocks:
  71. _, filelock = self._filelocks.popitem()
  72. if filelock:
  73. filelock.release()
  74. def get_port_range(self, start_port, count, random_start=True):
  75. assert count > 0
  76. if start_port and self._no_random_ports():
  77. return start_port
  78. candidates = []
  79. def drop_candidates():
  80. for port in candidates:
  81. self._release_port_no_lock(port)
  82. candidates[:] = []
  83. with self._lock:
  84. for attempts in six.moves.range(128):
  85. for left, right in self._valid_range:
  86. if right - left < count:
  87. continue
  88. if random_start:
  89. start = random.randint(left, right - ((right - left) // 2))
  90. else:
  91. start = left
  92. for probe_port in six.moves.range(start, right):
  93. if self._capture_port_no_lock(probe_port, socket.SOCK_STREAM):
  94. candidates.append(probe_port)
  95. else:
  96. drop_candidates()
  97. if len(candidates) == count:
  98. return candidates[0]
  99. # Can't find required number of ports without gap in the current range
  100. drop_candidates()
  101. raise PortManagerException(
  102. "Failed to find valid port range (start_port: {} count: {}) (range: {} used: {})".format(
  103. start_port, count, self._valid_range, self._filelocks
  104. )
  105. )
  106. def _count_valid_ports(self):
  107. res = 0
  108. for left, right in self._valid_range:
  109. res += right - left
  110. assert res, ('There are no available valid ports', self._valid_range)
  111. return res
  112. def _get_port(self, port, sock_type):
  113. if port and self._no_random_ports():
  114. return port
  115. if len(self._filelocks) >= self._valid_port_count:
  116. raise PortManagerException("All valid ports are taken ({}): {}".format(self._valid_range, self._filelocks))
  117. salt = random.randint(0, UI16MAXVAL)
  118. for attempt in six.moves.range(self._valid_port_count):
  119. probe_port = (salt + attempt) % self._valid_port_count
  120. for left, right in self._valid_range:
  121. if probe_port >= (right - left):
  122. probe_port -= right - left
  123. else:
  124. probe_port += left
  125. break
  126. if not self._capture_port(probe_port, sock_type):
  127. continue
  128. return probe_port
  129. raise PortManagerException(
  130. "Failed to find valid port (range: {} used: {})".format(self._valid_range, self._filelocks)
  131. )
  132. def _capture_port(self, port, sock_type):
  133. with self._lock:
  134. return self._capture_port_no_lock(port, sock_type)
  135. def is_port_free(self, port, sock_type=socket.SOCK_STREAM):
  136. sock = socket.socket(socket.AF_INET6, sock_type)
  137. try:
  138. sock.bind(('::', port))
  139. sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  140. except socket.error as e:
  141. if e.errno == errno.EADDRINUSE:
  142. return False
  143. raise
  144. finally:
  145. sock.close()
  146. return True
  147. def _capture_port_no_lock(self, port, sock_type):
  148. if port in self._filelocks:
  149. return False
  150. filelock = None
  151. if self._sync_dir:
  152. # yatest.common should try to be hermetic and don't have peerdirs
  153. # otherwise, PYTEST_SCRIPT (aka USE_ARCADIA_PYTHON=no) won't work
  154. import library.python.filelock
  155. filelock = library.python.filelock.FileLock(os.path.join(self._sync_dir, str(port)))
  156. if not filelock.acquire(blocking=False):
  157. return False
  158. if self.is_port_free(port, sock_type):
  159. self._filelocks[port] = filelock
  160. return True
  161. else:
  162. filelock.release()
  163. return False
  164. if self.is_port_free(port, sock_type):
  165. self._filelocks[port] = filelock
  166. return True
  167. if filelock:
  168. filelock.release()
  169. return False
  170. def _no_random_ports(self):
  171. return os.environ.get("NO_RANDOM_PORTS")
  172. def get_valid_port_range():
  173. first_valid = 1025
  174. last_valid = UI16MAXVAL
  175. given_range = os.environ.get('VALID_PORT_RANGE')
  176. if given_range and ':' in given_range:
  177. return [list(int(x) for x in given_range.split(':', 2))]
  178. first_eph, last_eph = get_ephemeral_range()
  179. first_invalid = max(first_eph, first_valid)
  180. last_invalid = min(last_eph, last_valid)
  181. ranges = []
  182. if first_invalid > first_valid:
  183. ranges.append((first_valid, first_invalid - 1))
  184. if last_invalid < last_valid:
  185. ranges.append((last_invalid + 1, last_valid))
  186. return ranges
  187. def get_ephemeral_range():
  188. if platform.system() == 'Linux':
  189. filename = "/proc/sys/net/ipv4/ip_local_port_range"
  190. if os.path.exists(filename):
  191. with open(filename) as afile:
  192. data = afile.read(1024) # fix for musl
  193. port_range = tuple(map(int, data.strip().split()))
  194. if len(port_range) == 2:
  195. return port_range
  196. else:
  197. logger.warning("Bad ip_local_port_range format: '%s'. Going to use IANA suggestion", data)
  198. elif platform.system() == 'Darwin':
  199. first = _sysctlbyname_uint("net.inet.ip.portrange.first")
  200. last = _sysctlbyname_uint("net.inet.ip.portrange.last")
  201. if first and last:
  202. return first, last
  203. # IANA suggestion
  204. return (1 << 15) + (1 << 14), UI16MAXVAL
  205. def _sysctlbyname_uint(name):
  206. try:
  207. from ctypes import CDLL, c_uint, byref
  208. from ctypes.util import find_library
  209. except ImportError:
  210. return
  211. libc = CDLL(find_library("c"))
  212. size = c_uint(0)
  213. res = c_uint(0)
  214. libc.sysctlbyname(name, None, byref(size), None, 0)
  215. libc.sysctlbyname(name, byref(res), byref(size), None, 0)
  216. return res.value
  217. def _makedirs(path):
  218. try:
  219. os.makedirs(path)
  220. except OSError as e:
  221. if e.errno == errno.EEXIST:
  222. return
  223. raise