gen_mx_table.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import sys
  2. tmpl = """
  3. #include "yabs_mx_calc_table.h"
  4. #include <kernel/matrixnet/mn_sse.h>
  5. #include <library/cpp/archive/yarchive.h>
  6. #include <util/memory/blob.h>
  7. #include <util/generic/hash.h>
  8. #include <util/generic/ptr.h>
  9. #include <util/generic/singleton.h>
  10. using namespace NMatrixnet;
  11. extern "C" {
  12. extern const unsigned char MxFormulas[];
  13. extern const ui32 MxFormulasSize;
  14. }
  15. namespace {
  16. struct TFml: public TBlob, public TMnSseInfo {
  17. inline TFml(const TBlob& b)
  18. : TBlob(b)
  19. , TMnSseInfo(Data(), Size())
  20. {
  21. }
  22. };
  23. struct TFormulas: public THashMap<size_t, TAutoPtr<TFml>> {
  24. inline TFormulas() {
  25. TBlob b = TBlob::NoCopy(MxFormulas, MxFormulasSize);
  26. TArchiveReader ar(b);
  27. %s
  28. }
  29. inline const TMnSseInfo& at(size_t n) const noexcept {
  30. return *find(n)->second;
  31. }
  32. };
  33. %s
  34. static func_descr_t yabs_funcs[] = {
  35. %s
  36. };
  37. }
  38. yabs_mx_calc_table_t yabs_mx_calc_table = {YABS_MX_CALC_VERSION, 10000, 0, yabs_funcs};
  39. """
  40. if __name__ == '__main__':
  41. init = []
  42. body = []
  43. defs = {}
  44. for i in sys.argv[1:]:
  45. name = i.replace('.', '_')
  46. num = long(name.split('_')[1])
  47. init.append('(*this)[%s] = new TFml(ar.ObjectBlobByKey("%s"));' % (num, '/' + i))
  48. f1 = 'static void yabs_%s(size_t count, const float** args, double* res) {Singleton<TFormulas>()->at(%s).DoCalcRelevs(args, res, count);}' % (name, num)
  49. f2 = 'static size_t yabs_%s_factor_count() {return Singleton<TFormulas>()->at(%s).MaxFactorIndex() + 1;}' % (name, num)
  50. body.append(f1)
  51. body.append(f2)
  52. d1 = 'yabs_%s' % name
  53. d2 = 'yabs_%s_factor_count' % name
  54. defs[num] = '{%s, %s}' % (d1, d2)
  55. print tmpl % ('\n'.join(init), '\n\n'.join(body), ',\n'.join((defs.get(i, '{nullptr, nullptr}') for i in range(0, 10000))))