py_test_engine.h 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. #pragma once
  2. #include "py_cast.h"
  3. #include "py_yql_module.h"
  4. #include "py_utils.h"
  5. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  6. #include <yql/essentials/minikql/mkql_type_builder.h>
  7. #include <yql/essentials/minikql/computation/mkql_value_builder.h>
  8. #include <yql/essentials/udfs/common/python/python_udf/python_udf.h>
  9. #include <library/cpp/testing/unittest/registar.h>
  10. #define PYTHON_TEST_TAG "Python2Test"
  11. using namespace NKikimr;
  12. using namespace NMiniKQL;
  13. namespace NPython {
  14. //////////////////////////////////////////////////////////////////////////////
  15. // TPyInitializer
  16. //////////////////////////////////////////////////////////////////////////////
  17. struct TPyInitializer {
  18. TPyInitializer() {
  19. PrepareYqlModule();
  20. Py_Initialize();
  21. InitYqlModule(NYql::NUdf::EPythonFlavor::Arcadia);
  22. }
  23. ~TPyInitializer() {
  24. TermYqlModule();
  25. Py_Finalize();
  26. }
  27. };
  28. //////////////////////////////////////////////////////////////////////////////
  29. // TPythonTestEngine
  30. //////////////////////////////////////////////////////////////////////////////
  31. class TPythonTestEngine {
  32. public:
  33. TPythonTestEngine()
  34. : MemInfo_("Memory")
  35. , Alloc_(__LOCATION__)
  36. , Env_(Alloc_)
  37. , TypeInfoHelper_(new TTypeInfoHelper)
  38. , FunctionInfoBuilder_(Env_, TypeInfoHelper_, "", nullptr, {})
  39. {
  40. HolderFactory_ = MakeHolder<THolderFactory>(
  41. Alloc_.Ref(),
  42. MemInfo_,
  43. nullptr);
  44. ValueBuilder_ = MakeHolder<TDefaultValueBuilder>(*HolderFactory_, NUdf::EValidatePolicy::Exception);
  45. BindTerminator_ = MakeHolder<TBindTerminator>(ValueBuilder_.Get());
  46. Singleton<TPyInitializer>();
  47. CastCtx_ = MakeIntrusive<TPyCastContext>(&GetValueBuilder(),
  48. MakeIntrusive<TPyContext>(TypeInfoHelper_.Get(), NUdf::TStringRef::Of(PYTHON_TEST_TAG), NUdf::TSourcePosition())
  49. );
  50. }
  51. ~TPythonTestEngine() {
  52. PyCleanup();
  53. }
  54. NUdf::IFunctionTypeInfoBuilder& GetTypeBuilder() {
  55. return FunctionInfoBuilder_;
  56. }
  57. const NUdf::IValueBuilder& GetValueBuilder() const {
  58. return *ValueBuilder_;
  59. }
  60. template <typename TChecker>
  61. void ToMiniKQL(NUdf::TType* udfType, const TStringBuf& script, TChecker&& checker) {
  62. TPyObjectPtr result = RunPythonFunction(script);
  63. UNIT_ASSERT_C(!!result, script);
  64. TType* type = static_cast<TType*>(udfType);
  65. auto value = FromPyObject(CastCtx_, type, result.Get());
  66. checker(value);
  67. }
  68. template <typename TExpectedType, typename TChecker>
  69. void ToMiniKQL(const TStringBuf& script, TChecker&& checker) {
  70. auto type = GetTypeBuilder().SimpleType<TExpectedType>();
  71. ToMiniKQL<TChecker>(type, script, std::move(checker));
  72. }
  73. template <typename TChecker>
  74. void ToMiniKQLWithArg(
  75. NUdf::TType* udfType, PyObject* argValue,
  76. const TStringBuf& script, TChecker&& checker)
  77. {
  78. TPyObjectPtr args = Py_BuildValue("(O)", argValue);
  79. auto result = RunPythonFunction(script, args.Get());
  80. if (!result || PyErr_Occurred()) {
  81. PyErr_Print();
  82. UNIT_FAIL("function execution error");
  83. }
  84. TType* type = static_cast<TType*>(udfType);
  85. auto value = FromPyObject(CastCtx_, type, result.Get());
  86. checker(value);
  87. }
  88. template <typename TExpectedType, typename TChecker>
  89. void ToMiniKQLWithArg(
  90. PyObject* argValue,
  91. const TStringBuf& script, TChecker&& checker)
  92. {
  93. auto type = GetTypeBuilder().SimpleType<TExpectedType>();
  94. ToMiniKQLWithArg<TChecker>(type, argValue, script, std::move(checker));
  95. }
  96. template <typename TMiniKQLValueBuilder>
  97. TPyObjectPtr ToPython(
  98. NUdf::TType* udfType,
  99. TMiniKQLValueBuilder&& builder,
  100. const TStringBuf& script)
  101. {
  102. try {
  103. TType* type = static_cast<TType*>(udfType);
  104. NUdf::TUnboxedValue value = builder(type, GetValueBuilder());
  105. TPyObjectPtr pyValue = ToPyObject(CastCtx_, type, value);
  106. if (!pyValue || PyErr_Occurred()) {
  107. PyErr_Print();
  108. UNIT_FAIL("object execution error");
  109. }
  110. TPyObjectPtr args = Py_BuildValue("(O)", pyValue.Get());
  111. auto result = RunPythonFunction(script, args.Get());
  112. if (!result || PyErr_Occurred()) {
  113. PyErr_Print();
  114. UNIT_FAIL("function execution error");
  115. }
  116. return result;
  117. } catch (const yexception& e) {
  118. Cerr << e << Endl;
  119. UNIT_FAIL("cast error");
  120. }
  121. Py_RETURN_NONE;
  122. }
  123. template <typename TExpectedType, typename TMiniKQLValueBuilder>
  124. TPyObjectPtr ToPython(TMiniKQLValueBuilder&& builder, const TStringBuf& script) {
  125. auto type = GetTypeBuilder().SimpleType<TExpectedType>();
  126. return ToPython<TMiniKQLValueBuilder>(type, std::move(builder), script);
  127. }
  128. NUdf::TUnboxedValue FromPython(NUdf::TType* udfType, const TStringBuf& script) {
  129. auto result = RunPythonFunction(script);
  130. if (!result || PyErr_Occurred()) {
  131. PyErr_Print();
  132. UNIT_FAIL("function execution error");
  133. }
  134. TType* type = static_cast<TType*>(udfType);
  135. return FromPyObject(CastCtx_, type, result.Get());
  136. }
  137. template <typename TExpectedType>
  138. NUdf::TUnboxedValue FromPython(const TStringBuf& script) {
  139. auto type = GetTypeBuilder().SimpleType<TExpectedType>();
  140. return FromPython(type, script);
  141. }
  142. template <typename TArgumentType, typename TReturnType = TArgumentType, typename TMiniKQLValueBuilder>
  143. NUdf::TUnboxedValue ToPythonAndBack(TMiniKQLValueBuilder&& builder, const TStringBuf& script) {
  144. const auto aType = GetTypeBuilder().SimpleType<TArgumentType>();
  145. const auto result = ToPython<TMiniKQLValueBuilder>(aType, std::move(builder), script);
  146. if (!result || PyErr_Occurred()) {
  147. PyErr_Print();
  148. UNIT_FAIL("function execution error");
  149. }
  150. const auto rType = static_cast<TType*>(GetTypeBuilder().SimpleType<TReturnType>());
  151. return FromPyObject(CastCtx_, rType, result.Get());
  152. }
  153. template <typename TArgumentType, typename TReturnType = TArgumentType, typename TMiniKQLValueBuilder, typename TChecker>
  154. void ToPythonAndBack(TMiniKQLValueBuilder&& builder, const TStringBuf& script, TChecker&& checker) {
  155. const auto result = ToPythonAndBack<TArgumentType, TReturnType, TMiniKQLValueBuilder>(std::move(builder), script);
  156. checker(result);
  157. }
  158. private:
  159. TPyObjectPtr RunPythonFunction(
  160. const TStringBuf& script, PyObject* args = nullptr)
  161. {
  162. TString filename(TStringBuf("embedded:test.py"));
  163. TPyObjectPtr code(Py_CompileString(script.data(), filename.data(), Py_file_input));
  164. if (!code) {
  165. PyErr_Print();
  166. UNIT_FAIL("can't compile python script");
  167. }
  168. TString moduleName(TStringBuf("py_cast_ut"));
  169. TPyObjectPtr module(PyImport_ExecCodeModule(moduleName.begin(), code.Get()));
  170. if (!module) {
  171. PyErr_Print();
  172. UNIT_FAIL("can't create python module");
  173. }
  174. TPyObjectPtr function(PyObject_GetAttrString(module.Get(), "Test"));
  175. if (!function) {
  176. PyErr_Print();
  177. UNIT_FAIL("function 'Test' is not found in module");
  178. }
  179. return PyObject_CallObject(function.Get(), args);
  180. }
  181. private:
  182. TMemoryUsageInfo MemInfo_;
  183. TScopedAlloc Alloc_;
  184. TTypeEnvironment Env_;
  185. const NUdf::ITypeInfoHelper::TPtr TypeInfoHelper_;
  186. TFunctionTypeInfoBuilder FunctionInfoBuilder_;
  187. THolder<THolderFactory> HolderFactory_;
  188. THolder<TDefaultValueBuilder> ValueBuilder_;
  189. THolder<TBindTerminator> BindTerminator_;
  190. TPyCastContext::TPtr CastCtx_;
  191. };
  192. } // namespace NPython