py_yql_module.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. #include "py_yql_module.h"
  2. #include "py_void.h"
  3. #include "py_iterator.h"
  4. #include "py_list.h"
  5. #include "py_dict.h"
  6. #include "py_stream.h"
  7. #include "py_utils.h"
  8. #include "py_callable.h"
  9. #include <library/cpp/resource/resource.h>
  10. #include <yql/essentials/udfs/common/python/python_udf/python_udf.h>
  11. namespace NPython {
  12. static PyMethodDef ModuleMethods[] = {
  13. { nullptr, nullptr, 0, nullptr } /* sentinel */
  14. };
  15. #define MODULE_NAME "yql"
  16. #if PY_MAJOR_VERSION >= 3
  17. #define MODULE_NAME_TYPING "yql.typing"
  18. #endif
  19. #define MODULE_INITIALIZED_ATTRIBUTE "_initialized"
  20. PyDoc_STRVAR(ModuleDoc,
  21. "This module provides YQL specific types for Python.");
  22. #if PY_MAJOR_VERSION >= 3
  23. PyDoc_STRVAR(ModuleDocTyping,
  24. "This module provides annotations for YQL types for Python.");
  25. #endif
  26. PyDoc_STRVAR(StopIterationException_doc,
  27. "Can be throwed to yield stream iteration.");
  28. #define PREPARE_TYPE(Name, Type) \
  29. do { \
  30. if (PyType_Ready(Type) < 0) { \
  31. throw yexception() << "Can't prepare type: " << (Name); \
  32. } \
  33. } while (0)
  34. #define REGISTER_TYPE(Name, Type) \
  35. do { \
  36. PREPARE_TYPE(Name, Type); \
  37. Py_INCREF(Type); \
  38. if (PyModule_AddObject(module, (Name), (PyObject*) Type) < 0) { \
  39. throw yexception() << "Can't add type: " << (Name); \
  40. } \
  41. } while (0)
  42. #define REGISTER_OBJECT(Name, Object) \
  43. do { \
  44. if (PyDict_SetItemString(dict, (Name), (PyObject *) (Object)) < 0) \
  45. throw yexception() << "Can't register object: " << (Name); \
  46. } while (0)
  47. #define REGISTER_EXCEPTION(Name, Object, Doc) \
  48. do { \
  49. if (!Object) { \
  50. Object = PyErr_NewExceptionWithDoc((char*) MODULE_NAME "." Name, Doc, nullptr, nullptr); \
  51. if (!Object) { \
  52. throw yexception() << "Can't register exception: " << (Name); \
  53. } \
  54. REGISTER_OBJECT(Name, Object); \
  55. } \
  56. } while (0)
  57. #if PY_MAJOR_VERSION >= 3
  58. static PyModuleDef ModuleDefinition = {
  59. PyModuleDef_HEAD_INIT,
  60. INIT_MEMBER(m_name, MODULE_NAME),
  61. INIT_MEMBER(m_doc, ModuleDoc),
  62. INIT_MEMBER(m_size, -1),
  63. INIT_MEMBER(m_methods, ModuleMethods),
  64. INIT_MEMBER(m_slots, nullptr),
  65. INIT_MEMBER(m_traverse, nullptr),
  66. INIT_MEMBER(m_clear, nullptr),
  67. INIT_MEMBER(m_free, nullptr),
  68. };
  69. static PyModuleDef ModuleDefinitionTyping = {
  70. PyModuleDef_HEAD_INIT,
  71. INIT_MEMBER(m_name, MODULE_NAME_TYPING),
  72. INIT_MEMBER(m_doc, ModuleDocTyping),
  73. INIT_MEMBER(m_size, -1),
  74. INIT_MEMBER(m_methods, nullptr),
  75. INIT_MEMBER(m_slots, nullptr),
  76. INIT_MEMBER(m_traverse, nullptr),
  77. INIT_MEMBER(m_clear, nullptr),
  78. INIT_MEMBER(m_free, nullptr),
  79. };
  80. PyMODINIT_FUNC PyInit_YQL(void)
  81. {
  82. auto mod = PyModule_Create(&ModuleDefinition);
  83. PyModule_AddObject(mod, "__path__", Py_BuildValue("()"));
  84. return mod;
  85. }
  86. void go_throw();
  87. PyMODINIT_FUNC PyInit_YQLTyping(void)
  88. {
  89. return PyModule_Create(&ModuleDefinitionTyping);
  90. }
  91. #else
  92. PyMODINIT_FUNC PyInit_YQL(void)
  93. {
  94. Py_InitModule3(MODULE_NAME, ModuleMethods, ModuleDoc);
  95. }
  96. #endif
  97. void PrepareYqlModule() {
  98. PyImport_AppendInittab(MODULE_NAME, &PyInit_YQL);
  99. #if PY_MAJOR_VERSION >= 3
  100. PyImport_AppendInittab(MODULE_NAME_TYPING, &PyInit_YQLTyping);
  101. #endif
  102. }
  103. #if PY_MAJOR_VERSION >= 3
  104. void RegisterRuntimeModule(const char* name, PyObject* module) {
  105. if (!module || !PyModule_Check(module)) {
  106. throw yexception() << "Invalid object for module " << name;
  107. }
  108. // borrowed reference
  109. PyObject* modules = PyImport_GetModuleDict();
  110. if (!modules || !PyDict_CheckExact(modules)) {
  111. throw yexception() << "Can't get sys.modules dictionary";
  112. }
  113. if (PyDict_SetItemString(modules, name, module) < 0) {
  114. throw yexception() << "Can't register module " << name;
  115. }
  116. }
  117. #endif
  118. void InitYqlModule(NYql::NUdf::EPythonFlavor pythonFlavor, bool standalone) {
  119. TPyObjectPtr m = PyImport_ImportModule(MODULE_NAME);
  120. if (!standalone && !m) {
  121. PyErr_Clear();
  122. #if PY_MAJOR_VERSION >= 3
  123. m = PyInit_YQL();
  124. RegisterRuntimeModule(MODULE_NAME, m.Get());
  125. #else
  126. PyInit_YQL();
  127. #endif
  128. m = PyImport_ImportModule(MODULE_NAME);
  129. }
  130. PyObject* module = m.Get();
  131. if (!module) {
  132. throw yexception() << "Can't get YQL module.";
  133. }
  134. TPyObjectPtr initialized = PyObject_GetAttrString(module, MODULE_INITIALIZED_ATTRIBUTE);
  135. if (!initialized) {
  136. PyErr_Clear();
  137. } else if (initialized.Get() == Py_True) {
  138. return;
  139. }
  140. PyObject* dict = PyModule_GetDict(module);
  141. REGISTER_TYPE("TVoid", &PyVoidType);
  142. REGISTER_OBJECT("Void", &PyVoidObject);
  143. PREPARE_TYPE("TIterator", &PyIteratorType);
  144. PREPARE_TYPE("TPairIterator", &PyPairIteratorType);
  145. PREPARE_TYPE("TDict", &PyLazyDictType);
  146. PREPARE_TYPE("TSet", &PyLazySetType);
  147. PREPARE_TYPE("TLazyListIterator", &PyLazyListIteratorType);
  148. PREPARE_TYPE("TLazyList", &PyLazyListType);
  149. PREPARE_TYPE("TThinListIterator", &PyThinListIteratorType);
  150. PREPARE_TYPE("TThinList", &PyThinListType);
  151. PREPARE_TYPE("TStream", &PyStreamType);
  152. PREPARE_TYPE("TCallable", &PyCallableType);
  153. REGISTER_EXCEPTION("TYieldIteration", PyYieldIterationException, StopIterationException_doc);
  154. #if PY_MAJOR_VERSION >= 3
  155. if (pythonFlavor == NYql::NUdf::EPythonFlavor::Arcadia) {
  156. if (!standalone) {
  157. TPyObjectPtr typingModule = PyImport_ImportModule(MODULE_NAME_TYPING);
  158. if (!typingModule) {
  159. PyErr_Clear();
  160. typingModule = PyInit_YQLTyping();
  161. RegisterRuntimeModule(MODULE_NAME_TYPING, typingModule.Get());
  162. }
  163. }
  164. const auto typing = NResource::Find(TStringBuf("typing.py"));
  165. const auto rc = PyRun_SimpleStringFlags(typing.c_str(), nullptr);
  166. if (rc < 0) {
  167. // Not sure if PyErr_Print() works after PyRun_SimpleStringFlags,
  168. // but just in case...
  169. PyErr_Print();
  170. ythrow yexception() << "Can't parse YQL type annotations module";
  171. }
  172. auto processError = [&] (PyObject* obj, TStringBuf message) {
  173. if (obj) {
  174. return;
  175. }
  176. PyObject *ptype, *pvalue, *ptraceback;
  177. PyErr_Fetch(&ptype, &pvalue, &ptraceback);
  178. if (pvalue) {
  179. auto pstr = PyObject_Str(pvalue);
  180. if (pstr) {
  181. if (auto err_msg = PyUnicode_AsUTF8(pstr)) {
  182. Cerr << err_msg << Endl;
  183. }
  184. }
  185. PyErr_Restore(ptype, pvalue, ptraceback);
  186. }
  187. ythrow yexception() << "Can't setup YQL type annotations module: " << message;
  188. };
  189. auto main = PyImport_ImportModule("__main__");
  190. processError(main, "PyImport_ImportModule");
  191. auto function = PyObject_GetAttrString(main, "main");
  192. processError(function, "PyObject_GetAttrString");
  193. auto args = PyTuple_New(0);
  194. processError(args, "PyTuple_New");
  195. auto result = PyObject_CallObject(function, args);
  196. processError(result, "PyObject_CallObject");
  197. Py_DECREF(result);
  198. Py_DECREF(args);
  199. Py_DECREF(function);
  200. Py_DECREF(main);
  201. }
  202. #endif
  203. REGISTER_OBJECT(MODULE_INITIALIZED_ATTRIBUTE, Py_True);
  204. }
  205. void TermYqlModule() {
  206. PyYieldIterationException = nullptr;
  207. }
  208. } // namspace NPython