py_callable.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. #include "py_callable.h"
  2. #include "py_cast.h"
  3. #include "py_errors.h"
  4. #include "py_gil.h"
  5. #include "py_stream.h"
  6. #include "py_utils.h"
  7. #include <yql/essentials/public/udf/udf_value.h>
  8. #include <yql/essentials/public/udf/udf_value_builder.h>
  9. #include <yql/essentials/public/udf/udf_type_inspection.h>
  10. #include <yql/essentials/public/udf/udf_terminator.h>
  11. #include <library/cpp/containers/stack_vector/stack_vec.h>
  12. #include <util/string/builder.h>
  13. using namespace NKikimr;
  14. namespace NPython {
  15. namespace {
  16. //////////////////////////////////////////////////////////////////////////////
  17. // TPyCallableObject
  18. //////////////////////////////////////////////////////////////////////////////
  19. struct TPyCallableObject
  20. {
  21. PyObject_HEAD;
  22. TPyCastContext::TPtr CastCtx;
  23. const NUdf::TType* Type;
  24. TPyCleanupListItem<NUdf::IBoxedValuePtr> Value;
  25. NUdf::TCallableTypeInspector Inspector;
  26. TPyCallableObject(const TPyCastContext::TPtr& castCtx, const NUdf::TType* type)
  27. : CastCtx(castCtx)
  28. , Type(type)
  29. , Inspector(*castCtx->PyCtx->TypeInfoHelper, type)
  30. {}
  31. };
  32. inline TPyCallableObject* CastToCallable(PyObject* o)
  33. {
  34. return reinterpret_cast<TPyCallableObject*>(o);
  35. }
  36. void CallableDealloc(PyObject* self)
  37. {
  38. delete CastToCallable(self);
  39. }
  40. PyObject* CallableRepr(PyObject*)
  41. {
  42. // TODO: print callable signature
  43. return PyRepr("<yql.TCallable>").Release();
  44. }
  45. PyObject* CallableCall(PyObject *self, PyObject *args, PyObject *kwargs)
  46. {
  47. Y_UNUSED(kwargs);
  48. PY_TRY {
  49. TPyCallableObject* callable = CastToCallable(self);
  50. auto callableType = callable->Type;
  51. auto valueBuilder = callable->CastCtx->ValueBuilder;
  52. const auto& inspector = callable->Inspector;
  53. TSmallVec<NUdf::TUnboxedValue> cArgs;
  54. cArgs.resize(inspector.GetArgsCount());
  55. FromPyArgs(callable->CastCtx, callableType, args, cArgs.data(), inspector);
  56. NUdf::TUnboxedValue result;
  57. {
  58. TPyGilUnlocker unlock;
  59. result = NUdf::TBoxedValueAccessor::Run(*callable->Value.Get(), valueBuilder, cArgs.data());
  60. }
  61. return ToPyObject(callable->CastCtx, inspector.GetReturnType(), result).Release();
  62. } PY_CATCH(nullptr)
  63. }
  64. }
  65. PyTypeObject PyCallableType = {
  66. PyVarObject_HEAD_INIT(&PyType_Type, 0)
  67. INIT_MEMBER(tp_name , "yql.TCallable"),
  68. INIT_MEMBER(tp_basicsize , sizeof(TPyCallableObject)),
  69. INIT_MEMBER(tp_itemsize , 0),
  70. INIT_MEMBER(tp_dealloc , CallableDealloc),
  71. #if PY_VERSION_HEX < 0x030800b4
  72. INIT_MEMBER(tp_print , nullptr),
  73. #else
  74. INIT_MEMBER(tp_vectorcall_offset, 0),
  75. #endif
  76. INIT_MEMBER(tp_getattr , nullptr),
  77. INIT_MEMBER(tp_setattr , nullptr),
  78. #if PY_MAJOR_VERSION >= 3
  79. INIT_MEMBER(tp_as_async , nullptr),
  80. #else
  81. INIT_MEMBER(tp_compare , nullptr),
  82. #endif
  83. INIT_MEMBER(tp_repr , CallableRepr),
  84. INIT_MEMBER(tp_as_number , nullptr),
  85. INIT_MEMBER(tp_as_sequence , nullptr),
  86. INIT_MEMBER(tp_as_mapping , nullptr),
  87. INIT_MEMBER(tp_hash , nullptr),
  88. INIT_MEMBER(tp_call , CallableCall),
  89. INIT_MEMBER(tp_str , nullptr),
  90. INIT_MEMBER(tp_getattro , nullptr),
  91. INIT_MEMBER(tp_setattro , nullptr),
  92. INIT_MEMBER(tp_as_buffer , nullptr),
  93. INIT_MEMBER(tp_flags , 0),
  94. INIT_MEMBER(tp_doc , "yql.TCallable object"),
  95. INIT_MEMBER(tp_traverse , nullptr),
  96. INIT_MEMBER(tp_clear , nullptr),
  97. INIT_MEMBER(tp_richcompare , nullptr),
  98. INIT_MEMBER(tp_weaklistoffset , 0),
  99. INIT_MEMBER(tp_iter , nullptr),
  100. INIT_MEMBER(tp_iternext , nullptr),
  101. INIT_MEMBER(tp_methods , nullptr),
  102. INIT_MEMBER(tp_members , nullptr),
  103. INIT_MEMBER(tp_getset , nullptr),
  104. INIT_MEMBER(tp_base , nullptr),
  105. INIT_MEMBER(tp_dict , nullptr),
  106. INIT_MEMBER(tp_descr_get , nullptr),
  107. INIT_MEMBER(tp_descr_set , nullptr),
  108. INIT_MEMBER(tp_dictoffset , 0),
  109. INIT_MEMBER(tp_init , nullptr),
  110. INIT_MEMBER(tp_alloc , nullptr),
  111. INIT_MEMBER(tp_new , nullptr),
  112. INIT_MEMBER(tp_free , nullptr),
  113. INIT_MEMBER(tp_is_gc , nullptr),
  114. INIT_MEMBER(tp_bases , nullptr),
  115. INIT_MEMBER(tp_mro , nullptr),
  116. INIT_MEMBER(tp_cache , nullptr),
  117. INIT_MEMBER(tp_subclasses , nullptr),
  118. INIT_MEMBER(tp_weaklist , nullptr),
  119. INIT_MEMBER(tp_del , nullptr),
  120. INIT_MEMBER(tp_version_tag , 0),
  121. #if PY_MAJOR_VERSION >= 3
  122. INIT_MEMBER(tp_finalize , nullptr),
  123. #endif
  124. #if PY_VERSION_HEX >= 0x030800b1
  125. INIT_MEMBER(tp_vectorcall , nullptr),
  126. #endif
  127. #if PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x03090000
  128. INIT_MEMBER(tp_print , nullptr),
  129. #endif
  130. };
  131. //////////////////////////////////////////////////////////////////////////////
  132. // TPyCallable
  133. //////////////////////////////////////////////////////////////////////////////
  134. class TPyCallable: public NUdf::TBoxedValue
  135. {
  136. public:
  137. TPyCallable(
  138. PyObject* function,
  139. const NUdf::TType* functionType,
  140. const TPyCastContext::TPtr& castCtx)
  141. : Function_(function, TPyObjectPtr::ADD_REF)
  142. , FunctionType_(functionType)
  143. , CastCtx_(castCtx)
  144. , Inspector_(*castCtx->PyCtx->TypeInfoHelper, functionType)
  145. {
  146. // keep ownership of function closure if any
  147. if (PyFunction_Check(function)) {
  148. PyObject* closure = PyFunction_GetClosure(function);
  149. if (closure) {
  150. Closure_ = TPyObjectPtr(closure, TPyObjectPtr::ADD_REF);
  151. }
  152. }
  153. }
  154. ~TPyCallable() {
  155. TPyGilLocker lock;
  156. Closure_.Reset();
  157. Function_.Reset();
  158. CastCtx_.Reset();
  159. }
  160. private:
  161. NUdf::TUnboxedValue Run(
  162. const NUdf::IValueBuilder*,
  163. const NUdf::TUnboxedValuePod* args) const final
  164. {
  165. TPyGilLocker lock;
  166. try {
  167. TPyObjectPtr pyArgs = ToPyArgs(CastCtx_, FunctionType_, args, Inspector_);
  168. TPyObjectPtr resultObj =
  169. PyObject_CallObject(Function_.Get(), pyArgs.Get());
  170. if (!resultObj) {
  171. UdfTerminate((TStringBuilder() << CastCtx_->PyCtx->Pos << "Failed to execute:\n" << GetLastErrorAsString()).data());
  172. }
  173. auto returnType = Inspector_.GetReturnType();
  174. if (CastCtx_->PyCtx->TypeInfoHelper->GetTypeKind(returnType) == NUdf::ETypeKind::Stream) {
  175. return FromPyStream(CastCtx_, returnType, resultObj, Function_, Closure_, pyArgs);
  176. }
  177. return FromPyObject(CastCtx_, returnType, resultObj.Get());
  178. } catch (const yexception& e) {
  179. UdfTerminate((TStringBuilder() << CastCtx_->PyCtx->Pos << "Failed to cast arguments or result\n" << e.what()).data());
  180. }
  181. }
  182. TPyObjectPtr Function_;
  183. TPyObjectPtr Closure_;
  184. const NUdf::TType* FunctionType_;
  185. TPyCastContext::TPtr CastCtx_;
  186. NUdf::TCallableTypeInspector Inspector_;
  187. };
  188. TPyObjectPtr ToPyCallable(
  189. const TPyCastContext::TPtr& castCtx,
  190. const NUdf::TType* type,
  191. const NUdf::TUnboxedValuePod& value)
  192. {
  193. TPyCallableObject* callable = new TPyCallableObject(castCtx, type);
  194. PyObject_INIT(callable, &PyCallableType);
  195. callable->Value.Set(castCtx->PyCtx, value.AsBoxed());
  196. return reinterpret_cast<PyObject*>(callable);
  197. }
  198. NUdf::TUnboxedValue FromPyCallable(
  199. const TPyCastContext::TPtr& castCtx,
  200. const NUdf::TType* type,
  201. PyObject* value)
  202. {
  203. return NUdf::TUnboxedValuePod(new TPyCallable(value, type, castCtx));
  204. }
  205. TMaybe<TPyObjectPtr> GetOptionalAttribute(PyObject* value, const char* attrName) {
  206. if (TPyObjectPtr attr = PyObject_GetAttrString(value, attrName)) {
  207. return attr;
  208. } else {
  209. if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
  210. PyErr_Clear();
  211. return Nothing();
  212. } else {
  213. throw yexception() << "Cannot get attribute '" << attrName << "', error: " << GetLastErrorAsString();
  214. }
  215. }
  216. }
  217. struct TPySecureParam
  218. {
  219. PyObject_HEAD;
  220. TPyCastContext::TPtr CastCtx;
  221. TPySecureParam(const TPyCastContext::TPtr& castCtx) : CastCtx(castCtx) {}
  222. };
  223. inline TPySecureParam* CastToSecureParam(PyObject* o)
  224. {
  225. return reinterpret_cast<TPySecureParam*>(o);
  226. }
  227. void SecureParamDealloc(PyObject* self)
  228. {
  229. delete CastToSecureParam(self);
  230. }
  231. PyObject* SecureParamRepr(PyObject*)
  232. {
  233. return PyRepr("<yql.TSecureParam>").Release();
  234. }
  235. PyObject* SecureParamCall(PyObject* self, PyObject* args, PyObject* kwargs)
  236. {
  237. Y_UNUSED(kwargs);
  238. struct PyBufDeleter {
  239. void operator() (Py_buffer* view) { PyBuffer_Release(view); }
  240. };
  241. Py_buffer input;
  242. if (!PyArg_ParseTuple(args, "s*", &input)) {
  243. return nullptr;
  244. }
  245. std::unique_ptr<Py_buffer, PyBufDeleter> bufPtr(&input);
  246. auto valueBuilder = CastToSecureParam(self)->CastCtx->ValueBuilder;
  247. NUdf::TStringRef key(static_cast<const char*>(input.buf), input.len);
  248. PY_TRY {
  249. if (!valueBuilder->GetSecureParam(key, key)) {
  250. throw yexception() << "Cannot get secure parameter for key: " << key;
  251. }
  252. return PyRepr(TStringBuf(key.Data(), key.Size())).Release();
  253. } PY_CATCH(nullptr)
  254. }
  255. static PyTypeObject PySecureParamType = {
  256. PyVarObject_HEAD_INIT(&PyType_Type, 0)
  257. INIT_MEMBER(tp_name , "yql.TSecureParam"),
  258. INIT_MEMBER(tp_basicsize , sizeof(TPySecureParam)),
  259. INIT_MEMBER(tp_itemsize , 0),
  260. INIT_MEMBER(tp_dealloc , SecureParamDealloc),
  261. #if PY_VERSION_HEX < 0x030800b4
  262. INIT_MEMBER(tp_print , nullptr),
  263. #else
  264. INIT_MEMBER(tp_vectorcall_offset, 0),
  265. #endif
  266. INIT_MEMBER(tp_getattr , nullptr),
  267. INIT_MEMBER(tp_setattr , nullptr),
  268. #if PY_MAJOR_VERSION >= 3
  269. INIT_MEMBER(tp_as_async , nullptr),
  270. #else
  271. INIT_MEMBER(tp_compare , nullptr),
  272. #endif
  273. INIT_MEMBER(tp_repr , SecureParamRepr),
  274. INIT_MEMBER(tp_as_number , nullptr),
  275. INIT_MEMBER(tp_as_sequence , nullptr),
  276. INIT_MEMBER(tp_as_mapping , nullptr),
  277. INIT_MEMBER(tp_hash , nullptr),
  278. INIT_MEMBER(tp_call , SecureParamCall),
  279. INIT_MEMBER(tp_str , nullptr),
  280. INIT_MEMBER(tp_getattro , nullptr),
  281. INIT_MEMBER(tp_setattro , nullptr),
  282. INIT_MEMBER(tp_as_buffer , nullptr),
  283. INIT_MEMBER(tp_flags , 0),
  284. INIT_MEMBER(tp_doc , "yql.TSecureParam object"),
  285. INIT_MEMBER(tp_traverse , nullptr),
  286. INIT_MEMBER(tp_clear , nullptr),
  287. INIT_MEMBER(tp_richcompare , nullptr),
  288. INIT_MEMBER(tp_weaklistoffset , 0),
  289. INIT_MEMBER(tp_iter , nullptr),
  290. INIT_MEMBER(tp_iternext , nullptr),
  291. INIT_MEMBER(tp_methods , nullptr),
  292. INIT_MEMBER(tp_members , nullptr),
  293. INIT_MEMBER(tp_getset , nullptr),
  294. INIT_MEMBER(tp_base , nullptr),
  295. INIT_MEMBER(tp_dict , nullptr),
  296. INIT_MEMBER(tp_descr_get , nullptr),
  297. INIT_MEMBER(tp_descr_set , nullptr),
  298. INIT_MEMBER(tp_dictoffset , 0),
  299. INIT_MEMBER(tp_init , nullptr),
  300. INIT_MEMBER(tp_alloc , nullptr),
  301. INIT_MEMBER(tp_new , nullptr),
  302. INIT_MEMBER(tp_free , nullptr),
  303. INIT_MEMBER(tp_is_gc , nullptr),
  304. INIT_MEMBER(tp_bases , nullptr),
  305. INIT_MEMBER(tp_mro , nullptr),
  306. INIT_MEMBER(tp_cache , nullptr),
  307. INIT_MEMBER(tp_subclasses , nullptr),
  308. INIT_MEMBER(tp_weaklist , nullptr),
  309. INIT_MEMBER(tp_del , nullptr),
  310. INIT_MEMBER(tp_version_tag , 0),
  311. #if PY_MAJOR_VERSION >= 3
  312. INIT_MEMBER(tp_finalize , nullptr),
  313. #endif
  314. #if PY_VERSION_HEX >= 0x030800b1
  315. INIT_MEMBER(tp_vectorcall , nullptr),
  316. #endif
  317. #if PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x03090000
  318. INIT_MEMBER(tp_print , nullptr),
  319. #endif
  320. };
  321. TPyObjectPtr ToPySecureParam(const TPyCastContext::TPtr& castCtx)
  322. {
  323. TPySecureParam* ret = new TPySecureParam(castCtx);
  324. PyObject_INIT(ret, &PySecureParamType);
  325. return reinterpret_cast<PyObject*>(ret);
  326. }
  327. void SetupCallableSettings(const TPyCastContext::TPtr& castCtx, PyObject* value) {
  328. if (const auto lazyInput = GetOptionalAttribute(value, "_yql_lazy_input")) try {
  329. castCtx->LazyInputObjects = PyCast<bool>(lazyInput->Get());
  330. } catch (const yexception& e) {
  331. throw yexception() << "Cannot parse attribute '_yql_lazy_input', error: " << e.what();
  332. }
  333. if (const auto convertYson = GetOptionalAttribute(value, "_yql_convert_yson")) try {
  334. Py_ssize_t itemsCount = PyTuple_GET_SIZE(convertYson->Get());
  335. if (itemsCount != 2) {
  336. throw yexception() << "Expected tuple of 2 callables";
  337. }
  338. castCtx->YsonConverterIn.ResetAddRef(PyTuple_GET_ITEM(convertYson->Get(), 0));
  339. castCtx->YsonConverterOut.ResetAddRef(PyTuple_GET_ITEM(convertYson->Get(), 1));
  340. if (!PyCallable_Check(castCtx->YsonConverterIn.Get()) || !PyCallable_Check(castCtx->YsonConverterOut.Get())) {
  341. throw yexception() << "Expected tuple of 2 callables";
  342. }
  343. } catch (const yexception& e) {
  344. throw yexception() << "Cannot parse attribute '_yql_convert_yson', error: " << e.what();
  345. }
  346. if (const auto bytesDecodeMode = GetOptionalAttribute(value, "_yql_bytes_decode_mode")) try {
  347. PyObject* bytesValue = nullptr;
  348. if (PyBytes_Check(bytesDecodeMode->Get())) {
  349. bytesValue = PyObject_Bytes(bytesDecodeMode->Get());
  350. } else if (PyUnicode_Check(bytesDecodeMode->Get())) {
  351. bytesValue = PyUnicode_AsUTF8String(bytesDecodeMode->Get());
  352. } else {
  353. throw yexception() << "Expected bytes or unicode";
  354. }
  355. if (!bytesValue) {
  356. PyErr_Clear();
  357. throw yexception() << "Failed to convert to bytes";
  358. }
  359. TStringBuf view(PyBytes_AS_STRING(bytesValue));
  360. if (view == "never") {
  361. castCtx->BytesDecodeMode = EBytesDecodeMode::Never;
  362. } else if (view == "strict") {
  363. castCtx->BytesDecodeMode = EBytesDecodeMode::Strict;
  364. } else {
  365. Py_DECREF(bytesValue);
  366. throw yexception() << "Expected values 'never' or 'strict'";
  367. }
  368. Py_DECREF(bytesValue);
  369. } catch (const yexception& e) {
  370. throw yexception() << "Cannot parse attribute '_yql_bytes_decode_mode', error: " << e.what();
  371. }
  372. if (PyObject_SetAttrString(value, "_yql_secure_param", ToPySecureParam(castCtx).Get()) != 0) {
  373. throw yexception() << "Cannot set attribute '_yql_secure_param'";
  374. }
  375. }
  376. } // namespace NPython