__main__.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import argparse
  2. import errno
  3. import logging
  4. import os
  5. import tarfile
  6. import tempfile
  7. import time
  8. import glob
  9. import platform
  10. import yatest.common
  11. from library.python.testing.recipe import declare_recipe, set_env
  12. from library.python.testing.recipe.ports import get_port_range, release_port_range
  13. # copy-paste from yql/library/postgresql/recipe
  14. POSTGRES_TGZ = 'postgres.tgz'
  15. def wait(cmd, process):
  16. res = process.communicate()
  17. code = process.poll()
  18. if code != 0:
  19. raise RuntimeError("cmd failed: %s, code: %s, stdout/stderr: %s / %s", (cmd, code, res[0], res[1]))
  20. logging.debug(res[1])
  21. class PG:
  22. def __init__(self, install_dir):
  23. self.bin_dir = os.path.join(install_dir, "bin")
  24. lib_dir = os.path.join(install_dir, "lib")
  25. self.env = os.environ.copy()
  26. if 'LD_LIBRARY_PATH' in self.env:
  27. self.env['LD_LIBRARY_PATH'] += ':' + lib_dir
  28. else:
  29. self.env['LD_LIBRARY_PATH'] = lib_dir
  30. self.env['LANG'] = 'en_US.UTF-8'
  31. self.env['LC_MESSAGES'] = 'en_US.UTF-8'
  32. def run(self, dbdir, port, user, dbname, max_conn):
  33. self.port = port
  34. self.user = user
  35. self.dbname = dbname
  36. self.dbdir = dbdir
  37. self.max_conn = max_conn
  38. self._init_db()
  39. self._run()
  40. self._create_db()
  41. def migrate(self, migrations_dir):
  42. logging.info("migrations dir: " + migrations_dir)
  43. files = sorted(glob.glob(migrations_dir + "/*.sql"))
  44. for fname in files:
  45. self._execute_sql_file(fname)
  46. def get_pid(self):
  47. return self.handle.process.pid
  48. def _init_db(self):
  49. logging.info("Executing initdb...")
  50. initdb_path = os.path.join(self.bin_dir, 'initdb')
  51. cmd = [initdb_path, '-A trust', '-D', self.dbdir , '-U', self.user]
  52. yatest.common.execute(cmd, env=self.env)
  53. logging.info("Database initiated")
  54. def _run(self):
  55. logging.info("Running postgres...")
  56. postgres_path = os.path.join(self.bin_dir, 'postgres')
  57. cmd = [postgres_path, '-D', self.dbdir, '--port={}'.format(self.port), '-N{}'.format(self.max_conn), '-c', 'timezone=UTC']
  58. self.handle = yatest.common.execute(cmd, env=self.env, wait=False)
  59. logging.info("Postgres is running on port {} with pid {}".format(self.port, self.get_pid()))
  60. def _create_db(self):
  61. logging.info("Creating the database...")
  62. createdb_path = os.path.join(self.bin_dir, 'createdb')
  63. locale = "en_US.utf8"
  64. if platform.system() == "Darwin":
  65. locale = "en_US.UTF-8"
  66. # Try 10 times, since postgres might not be initialized and accepting connections yet
  67. ex = None
  68. for attempt in range(10):
  69. cmd = [createdb_path,
  70. '-p', str(self.port),
  71. '-U', self.user, self.dbname,
  72. '-E', 'UTF8', f'--locale={locale}', '--template=template0']
  73. ex = yatest.common.execute(cmd, check_exit_code=False, env=self.env)
  74. if ex.returncode == 0:
  75. logging.info("Database created")
  76. return
  77. logging.warn("Database creation failed (attempt {}): {}".format(attempt, ex.stderr))
  78. time.sleep(1)
  79. raise Exception(f"cannot create Database, {ex.stderr}")
  80. def _execute_sql_file(self, file_name):
  81. logging.info('Executing {}...'.format(file_name))
  82. psql_path = os.path.join(self.bin_dir, 'psql')
  83. cmd = [psql_path, '-p', str(self.port), '-U', self.user, '-q', '-A', '-d', self.dbname, '-f', file_name]
  84. yatest.common.execute(cmd, env=self.env)
  85. logging.info('Success')
  86. def start(argv):
  87. logging.info("Start postgresql")
  88. def parse_argv(argv):
  89. parser = argparse.ArgumentParser()
  90. parser.add_argument('--archive', metavar='<file>', type=str,
  91. help='postgres.tgz path',
  92. default=POSTGRES_TGZ)
  93. parser.add_argument('-s', '--schema-migrations-dir', type=str, help='directory with DB migrations to run for each schema')
  94. parser.add_argument('-m', '--migrations-dir', type=str, help='directory with DB migrations run')
  95. parser.add_argument('-n', '--max_connections', type=int, default=20, help='set max number of connections to DB')
  96. return parser.parse_args(argv)
  97. def parse_dir(arg):
  98. return [yatest.common.source_path(d.strip()) for d in arg.split(',') if d.strip()]
  99. args = parse_argv(argv)
  100. postgres_dir = tempfile.mkdtemp(prefix='postgres_')
  101. logging.info("Postgresql dir:" + postgres_dir)
  102. working_dir = os.path.join(postgres_dir, 'wd')
  103. os.mkdir(working_dir)
  104. logging.info("Working dir:" + working_dir)
  105. tgz = tarfile.open(args.archive)
  106. tgz.extractall(path=postgres_dir)
  107. port = get_port_range()
  108. dbname = 'test-db'
  109. username = 'test'
  110. pg = PG(os.path.join(postgres_dir, 'postgres'))
  111. pg.run(dbdir=working_dir, port=port, dbname=dbname, user=username, max_conn=args.max_connections)
  112. if args.schema_migrations_dir:
  113. set_env("PG_MIGRATIONS_DIR", ','.join(parse_dir(args.schema_migrations_dir)))
  114. if args.migrations_dir:
  115. for dir in parse_dir(args.migrations_dir):
  116. pg.migrate(dir)
  117. with open("postgres_recipe.pid", "w") as pid_file:
  118. pid_file.write(str(pg.get_pid()))
  119. set_env("POSTGRES_RECIPE_HOST", 'localhost')
  120. set_env("POSTGRES_RECIPE_PORT", str(port))
  121. set_env("POSTGRES_RECIPE_DBNAME", dbname)
  122. set_env("POSTGRES_RECIPE_USER", username)
  123. set_env("POSTGRES_RECIPE_MAX_CONNECTIONS", str(args.max_connections))
  124. set_env("POSTGRES_RECIPE_BIN_DIR", pg.bin_dir)
  125. def is_running(pid):
  126. try:
  127. os.kill(pid, 0)
  128. except OSError as err:
  129. if err.errno == errno.ESRCH:
  130. return False
  131. return True
  132. def stop(argv):
  133. if len(yatest.common.get_param("hang_with_pg", "")):
  134. while True:
  135. continue
  136. release_port_range()
  137. logging.info("Stop postgresql")
  138. with open("postgres_recipe.pid", "r") as f:
  139. pid = int(f.read())
  140. os.kill(pid, 15)
  141. _SHUTDOWN_TIMEOUT = 10
  142. seconds = _SHUTDOWN_TIMEOUT
  143. while is_running(pid) and seconds > 0:
  144. time.sleep(1)
  145. seconds -= 1
  146. if is_running(pid):
  147. logging.error('postgres is still running after %d seconds' % seconds)
  148. os.kill(pid, 9)
  149. if is_running(pid):
  150. logging.error("postgres failed to shutdown after kill 9!")
  151. if __name__ == "__main__":
  152. declare_recipe(start, stop)