module.h 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. #pragma once
  2. #define PY_SSIZE_T_CLEAN
  3. #include <Python.h>
  4. #include "ptr.h"
  5. #include "cast.h"
  6. #include "exceptions.h"
  7. #include <util/generic/function.h>
  8. namespace NPyBind {
  9. #if PY_MAJOR_VERSION >= 3
  10. namespace NPrivate {
  11. using TFinalizationCallBack = std::function<void()>;
  12. void AddFinalizationCallBack(TFinalizationCallBack);
  13. class TAtExitRegistrar: private TNonCopyable {
  14. TAtExitRegistrar(TPyObjectPtr module);
  15. public:
  16. static void Instantiate(TPyObjectPtr module) {
  17. static TAtExitRegistrar registrar(module);
  18. Y_UNUSED(registrar);
  19. }
  20. };
  21. class TPyBindModuleRegistrar: private TNonCopyable {
  22. TPyBindModuleRegistrar();
  23. TPyObjectPtr Module;
  24. public:
  25. static void Instantiate() {
  26. static TPyBindModuleRegistrar registrar;
  27. Y_UNUSED(registrar);
  28. }
  29. };
  30. } //NPrivate
  31. #endif
  32. class TModuleHolder {
  33. private:
  34. TModuleHolder(const TModuleHolder&);
  35. TModuleHolder& operator=(const TModuleHolder&);
  36. TModuleHolder();
  37. private:
  38. typedef PyCFunction TModuleMethod;
  39. #if PY_MAJOR_VERSION >= 3
  40. typedef PyObject* (*TModuleInitFunc)();
  41. #else
  42. typedef void (*TModuleInitFunc)();
  43. #endif
  44. struct TMethodDef {
  45. TString Name;
  46. TModuleMethod Method;
  47. TString Description;
  48. int Flags;
  49. TMethodDef(const TString& name, TModuleMethod method, const TString& descr, int flags)
  50. : Name(name)
  51. , Method(method)
  52. , Description(descr)
  53. , Flags(flags)
  54. {
  55. }
  56. operator PyMethodDef() const {
  57. PyMethodDef cur = {Name.c_str(), Method, Flags, Description.c_str()};
  58. return cur;
  59. }
  60. };
  61. typedef TSimpleSharedPtr<TVector<TMethodDef>> TMethodDefVecPtr;
  62. typedef TSimpleSharedPtr<TVector<PyMethodDef>> TPyMethodDefVecPtr;
  63. TVector<TMethodDefVecPtr> Methods;
  64. TVector<TPyMethodDefVecPtr> Defs;
  65. #if PY_MAJOR_VERSION >= 3
  66. //because the md_name will leak otherwise
  67. class TPyModuleDefWithName {
  68. PyModuleDef Def;
  69. TString Name;
  70. public:
  71. explicit TPyModuleDefWithName(TString name, TPyMethodDefVecPtr moduleDefs)
  72. : Name(std::move(name))
  73. {
  74. Def = PyModuleDef{
  75. PyModuleDef_HEAD_INIT,
  76. Name.c_str(),
  77. nullptr,
  78. -1,
  79. moduleDefs->data(),
  80. nullptr, nullptr, nullptr, nullptr
  81. };
  82. }
  83. PyModuleDef* GetDefPtr() {
  84. return &Def;
  85. }
  86. };
  87. TVector<TSimpleSharedPtr<TPyModuleDefWithName>> ModuleDefs;
  88. #endif
  89. template <TModuleMethod method>
  90. static PyObject* MethodWrapper(PyObject* obj, PyObject* args) {
  91. try {
  92. PyObject* res = method(obj, args);
  93. if (!res && !PyErr_Occurred())
  94. ythrow yexception() << "\nModule method exited with NULL, but didn't set Error.\n Options:\n -- Return correct value or None;\n -- Set python exception;\n -- Throw c++ exception.";
  95. return res;
  96. } catch (const std::exception& ex) {
  97. PyErr_SetString(TExceptionsHolder::Instance().ToPyException(ex).Get(), ex.what());
  98. } catch (...) {
  99. PyErr_SetString(PyExc_RuntimeError, "Unknown error occurred while trying to call module method");
  100. }
  101. return nullptr;
  102. }
  103. public:
  104. static TModuleHolder& Instance() {
  105. static TModuleHolder Holder;
  106. return Holder;
  107. }
  108. void ImportModule(TPyObjectPtr module, const char* const name, TModuleInitFunc initFunc) {
  109. PyImport_AppendInittab(const_cast<char*>(name), initFunc);
  110. TPyObjectPtr importedModule(PyImport_ImportModule(name), true);
  111. PyModule_AddObject(module.Get(), name, importedModule.Get());
  112. }
  113. template <TModuleMethod method>
  114. void AddModuleMethod(const TString& name, const TString& descr = "") {
  115. Methods.back()->push_back(TMethodDef(name, MethodWrapper<method>, descr, METH_VARARGS));
  116. }
  117. TPyObjectPtr InitModule(const TString& name) {
  118. Defs.push_back(new TVector<PyMethodDef>(Methods.back()->begin(), Methods.back()->end()));
  119. PyMethodDef blank = {nullptr, nullptr, 0, nullptr};
  120. Defs.back()->push_back(blank);
  121. #if PY_MAJOR_VERSION >= 3
  122. ModuleDefs.push_back(MakeSimpleShared<TPyModuleDefWithName>(name, Defs.back()));
  123. TPyObjectPtr res(PyModule_Create(ModuleDefs.back()->GetDefPtr()));
  124. NPrivate::TAtExitRegistrar::Instantiate(res);
  125. NPrivate::TPyBindModuleRegistrar::Instantiate();
  126. #else
  127. TPyObjectPtr res(Py_InitModule(name.c_str(), &(Defs.back()->at(0))));
  128. #endif
  129. Methods.push_back(new TVector<TMethodDef>);
  130. return res;
  131. }
  132. };
  133. template <typename TMethodSignature, TMethodSignature method>
  134. class TModuleMethodCaller {
  135. private:
  136. template <typename TResult, typename... Args>
  137. struct TCaller {
  138. static PyObject* Call(PyObject* args) {
  139. return BuildPyObject(Apply(method, GetArguments<Args...>(args)));
  140. }
  141. };
  142. template <typename TResult, typename... Args>
  143. static PyObject* InternalCall(TResult (*)(Args...), PyObject* args) {
  144. return BuildPyObject(Apply(method, GetArguments<Args...>(args)));
  145. }
  146. public:
  147. static PyObject* Call(PyObject*, PyObject* args) {
  148. if (args && (!PyTuple_Check(args) || PyTuple_Size(args) != TFunctionArgs<TMethodSignature>::Length)) {
  149. ythrow yexception() << "Method takes " << (size_t)(TFunctionArgs<TMethodSignature>::Length) << " arguments, " << PyTuple_Size(args) << " provided";
  150. }
  151. return InternalCall(method, args);
  152. }
  153. };
  154. }