123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276 |
- # coding=utf-8
- import os
- import errno
- import socket
- import random
- import logging
- import platform
- import threading
- import six
- UI16MAXVAL = (1 << 16) - 1
- logger = logging.getLogger(__name__)
- class PortManagerException(Exception):
- pass
- class PortManager(object):
- """
- See documentation here
- https://wiki.yandex-team.ru/yatool/test/#python-acquire-ports
- """
- def __init__(self, sync_dir=None):
- self._sync_dir = sync_dir or os.environ.get('PORT_SYNC_PATH')
- if self._sync_dir:
- _makedirs(self._sync_dir)
- self._valid_range = get_valid_port_range()
- self._valid_port_count = self._count_valid_ports()
- self._filelocks = {}
- self._lock = threading.Lock()
- def __enter__(self):
- return self
- def __exit__(self, type, value, traceback):
- self.release()
- def get_port(self, port=0):
- '''
- Gets free TCP port
- '''
- return self.get_tcp_port(port)
- def get_tcp_port(self, port=0):
- '''
- Gets free TCP port
- '''
- return self._get_port(port, socket.SOCK_STREAM)
- def get_udp_port(self, port=0):
- '''
- Gets free UDP port
- '''
- return self._get_port(port, socket.SOCK_DGRAM)
- def get_tcp_and_udp_port(self, port=0):
- '''
- Gets one free port for use in both TCP and UDP protocols
- '''
- if port and self._no_random_ports():
- return port
- retries = 20
- while retries > 0:
- retries -= 1
- result_port = self.get_tcp_port()
- if not self.is_port_free(result_port, socket.SOCK_DGRAM):
- self.release_port(result_port)
- # Don't try to _capture_port(), it's already captured in the get_tcp_port()
- return result_port
- raise Exception('Failed to find port')
- def release_port(self, port):
- with self._lock:
- self._release_port_no_lock(port)
- def _release_port_no_lock(self, port):
- filelock = self._filelocks.pop(port, None)
- if filelock:
- filelock.release()
- def release(self):
- with self._lock:
- while self._filelocks:
- _, filelock = self._filelocks.popitem()
- if filelock:
- filelock.release()
- def get_port_range(self, start_port, count, random_start=True):
- assert count > 0
- if start_port and self._no_random_ports():
- return start_port
- candidates = []
- def drop_candidates():
- for port in candidates:
- self._release_port_no_lock(port)
- candidates[:] = []
- with self._lock:
- for attempts in six.moves.range(128):
- for left, right in self._valid_range:
- if right - left < count:
- continue
- if random_start:
- start = random.randint(left, right - ((right - left) // 2))
- else:
- start = left
- for probe_port in six.moves.range(start, right):
- if self._capture_port_no_lock(probe_port, socket.SOCK_STREAM):
- candidates.append(probe_port)
- else:
- drop_candidates()
- if len(candidates) == count:
- return candidates[0]
- # Can't find required number of ports without gap in the current range
- drop_candidates()
- raise PortManagerException(
- "Failed to find valid port range (start_port: {} count: {}) (range: {} used: {})".format(
- start_port, count, self._valid_range, self._filelocks
- )
- )
- def _count_valid_ports(self):
- res = 0
- for left, right in self._valid_range:
- res += right - left
- assert res, ('There are no available valid ports', self._valid_range)
- return res
- def _get_port(self, port, sock_type):
- if port and self._no_random_ports():
- return port
- if len(self._filelocks) >= self._valid_port_count:
- raise PortManagerException("All valid ports are taken ({}): {}".format(self._valid_range, self._filelocks))
- salt = random.randint(0, UI16MAXVAL)
- for attempt in six.moves.range(self._valid_port_count):
- probe_port = (salt + attempt) % self._valid_port_count
- for left, right in self._valid_range:
- if probe_port >= (right - left):
- probe_port -= right - left
- else:
- probe_port += left
- break
- if not self._capture_port(probe_port, sock_type):
- continue
- return probe_port
- raise PortManagerException(
- "Failed to find valid port (range: {} used: {})".format(self._valid_range, self._filelocks)
- )
- def _capture_port(self, port, sock_type):
- with self._lock:
- return self._capture_port_no_lock(port, sock_type)
- def is_port_free(self, port, sock_type=socket.SOCK_STREAM):
- sock = socket.socket(socket.AF_INET6, sock_type)
- try:
- sock.bind(('::', port))
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- except socket.error as e:
- if e.errno == errno.EADDRINUSE:
- return False
- raise
- finally:
- sock.close()
- return True
- def _capture_port_no_lock(self, port, sock_type):
- if port in self._filelocks:
- return False
- filelock = None
- if self._sync_dir:
- # yatest.common should try to be hermetic and don't have peerdirs
- # otherwise, PYTEST_SCRIPT (aka USE_ARCADIA_PYTHON=no) won't work
- import library.python.filelock
- filelock = library.python.filelock.FileLock(os.path.join(self._sync_dir, str(port)))
- if not filelock.acquire(blocking=False):
- return False
- if self.is_port_free(port, sock_type):
- self._filelocks[port] = filelock
- return True
- else:
- filelock.release()
- return False
- if self.is_port_free(port, sock_type):
- self._filelocks[port] = filelock
- return True
- if filelock:
- filelock.release()
- return False
- def _no_random_ports(self):
- return os.environ.get("NO_RANDOM_PORTS")
- def get_valid_port_range():
- first_valid = 1025
- last_valid = UI16MAXVAL
- given_range = os.environ.get('VALID_PORT_RANGE')
- if given_range and ':' in given_range:
- return [list(int(x) for x in given_range.split(':', 2))]
- first_eph, last_eph = get_ephemeral_range()
- first_invalid = max(first_eph, first_valid)
- last_invalid = min(last_eph, last_valid)
- ranges = []
- if first_invalid > first_valid:
- ranges.append((first_valid, first_invalid - 1))
- if last_invalid < last_valid:
- ranges.append((last_invalid + 1, last_valid))
- return ranges
- def get_ephemeral_range():
- if platform.system() == 'Linux':
- filename = "/proc/sys/net/ipv4/ip_local_port_range"
- if os.path.exists(filename):
- with open(filename) as afile:
- data = afile.read(1024) # fix for musl
- port_range = tuple(map(int, data.strip().split()))
- if len(port_range) == 2:
- return port_range
- else:
- logger.warning("Bad ip_local_port_range format: '%s'. Going to use IANA suggestion", data)
- elif platform.system() == 'Darwin':
- first = _sysctlbyname_uint("net.inet.ip.portrange.first")
- last = _sysctlbyname_uint("net.inet.ip.portrange.last")
- if first and last:
- return first, last
- # IANA suggestion
- return (1 << 15) + (1 << 14), UI16MAXVAL
- def _sysctlbyname_uint(name):
- try:
- from ctypes import CDLL, c_uint, byref
- from ctypes.util import find_library
- except ImportError:
- return
- libc = CDLL(find_library("c"))
- size = c_uint(0)
- res = c_uint(0)
- libc.sysctlbyname(name, None, byref(size), None, 0)
- libc.sysctlbyname(name, byref(res), byref(size), None, 0)
- return res.value
- def _makedirs(path):
- try:
- os.makedirs(path)
- except OSError as e:
- if e.errno == errno.EEXIST:
- return
- raise
|