build_catboost.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from __future__ import print_function
  2. import sys
  3. import os
  4. import shutil
  5. import subprocess
  6. def get_value(val):
  7. dct = val.split('=', 1)
  8. if len(dct) > 1:
  9. return dct[1]
  10. return ''
  11. class BuildCbBase(object):
  12. def run(self, cbmodel, cbname, cb_cpp_path):
  13. data_prefix = "CB_External_"
  14. data = data_prefix + cbname
  15. datasize = data + "Size"
  16. cbtype = "const NCatboostCalcer::TCatboostCalcer"
  17. cbload = "(ReadModel({0}, {1}, EModelType::CatboostBinary))".format(data, datasize)
  18. cb_cpp_tmp_path = cb_cpp_path + ".tmp"
  19. cb_cpp_tmp = open(cb_cpp_tmp_path, 'w')
  20. cb_cpp_tmp.write("#include <kernel/catboost/catboost_calcer.h>\n")
  21. ro_data_path = os.path.dirname(cb_cpp_path) + "/" + data_prefix + cbname + ".rodata"
  22. cb_cpp_tmp.write("namespace{\n")
  23. cb_cpp_tmp.write(" extern \"C\" {\n")
  24. cb_cpp_tmp.write(" extern const unsigned char {1}{0}[];\n".format(cbname, data_prefix))
  25. cb_cpp_tmp.write(" extern const ui32 {1}{0}Size;\n".format(cbname, data_prefix))
  26. cb_cpp_tmp.write(" }\n")
  27. cb_cpp_tmp.write("}\n")
  28. archiverCall = subprocess.Popen(
  29. [self.archiver, "-q", "-p", "-o", ro_data_path, cbmodel], stdout=None, stderr=subprocess.PIPE
  30. )
  31. archiverCall.wait()
  32. cb_cpp_tmp.write("extern {0} {1};\n".format(cbtype, cbname))
  33. cb_cpp_tmp.write("{0} {1}{2};".format(cbtype, cbname, cbload))
  34. cb_cpp_tmp.close()
  35. shutil.move(cb_cpp_tmp_path, cb_cpp_path)
  36. class BuildCb(BuildCbBase):
  37. def run(self, argv):
  38. if len(argv) < 5:
  39. print("BuildCb.Run(<ARCADIA_ROOT> <archiver> <mninfo> <mnname> <cppOutput> [params...])", file=sys.stderr)
  40. sys.exit(1)
  41. self.SrcRoot = argv[0]
  42. self.archiver = argv[1]
  43. cbmodel = argv[2]
  44. cbname = argv[3]
  45. cb_cpp_path = argv[4]
  46. super(BuildCb, self).run(cbmodel, cbname, cb_cpp_path)
  47. def build_cb_f(argv):
  48. build_cb = BuildCb()
  49. build_cb.run(argv)
  50. if __name__ == '__main__':
  51. if len(sys.argv) < 2:
  52. print("Usage: build_cb.py <funcName> <args...>", file=sys.stderr)
  53. sys.exit(1)
  54. if sys.argv[2:]:
  55. globals()[sys.argv[1]](sys.argv[2:])
  56. else:
  57. globals()[sys.argv[1]]()