123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514 |
- #pragma once
- #include <library/cpp/pybind/method.h>
- #include <library/cpp/pybind/typedesc.h>
- #include <library/cpp/pybind/module.h>
- #include <util/generic/hash.h>
- #include <util/generic/hash_set.h>
- #include <util/generic/string.h>
- namespace NPyBind {
- #define DEFINE_CONVERTERS_IMPL(TClass) \
- PyObject* BuildPyObject(typename TClass::TBase&& base) { \
- return TClass::BuildPyObject(std::move(base)); \
- } \
- PyObject* BuildPyObject(const typename TClass::TBase& base) { \
- return TClass::BuildPyObject(base); \
- }
- #define DEFINE_CONVERTERS(function) DEFINE_CONVERTERS_IMPL(TFunctionResult<decltype(function)>)
- #define DEFINE_TRANSFORMERS_IMPL(TClass) \
- template <> \
- bool ::NPyBind::FromPyObject<typename TClass::TBase*>(PyObject * obj, typename TClass::TBase * &res) { \
- res = TClass::CastToObject(obj); \
- return res != nullptr; \
- } \
- template <> \
- bool ::NPyBind::FromPyObject<typename TClass::TBase const*>(PyObject * obj, typename TClass::TBase const*& res) { \
- res = TClass::CastToObject(obj); \
- return res != nullptr; \
- }
- #define DEFINE_TRANSFORMERS(function) DEFINE_TRANSFORMERS_IMPL(TFunctionResult<decltype(function)>)
- namespace Detail {
- struct IGetContextBase {
- virtual ~IGetContextBase() = default;
- };
- } //Detail
- struct TPyModuleDefinition {
- static void InitModule(const TString& name);
- static TPyModuleDefinition& GetModule();
- TString Name;
- NPyBind::TPyObjectPtr M;
- THashMap<TString, PyTypeObject*> ClassName2Type;
- THashMap<TString, Detail::IGetContextBase*> Class2ContextGetter;
- };
- namespace Detail {
- // Manages modules lifecycle
- // IMPORTANT!!! Don't use it in PyBind v1 environment, it will lead to inconsistent state of v1 module
- // UnnamedModule-> new unnamed module stub, this stub become current module. In this case you can add functions to it
- // InitModuleWithName -> convert unnamed module into named one, now you can switch to it in switch, this module remains current
- // SwitchToModule switches to the particular module in registry, this module becomes current.
- class TPyModuleRegistry {
- private:
- TPyModuleRegistry();
- TPyModuleRegistry(const TPyModuleRegistry&) = delete;
- TPyModuleRegistry& operator=(TPyModuleRegistry&) = delete;
- public:
- static TPyModuleRegistry& Get() {
- static TPyModuleRegistry registry;
- return registry;
- }
- TPyModuleDefinition& GetCurrentModule() {
- if (!CurrentModule) {
- GetUnnamedModule();
- }
- return *CurrentModule;
- }
- TPyModuleDefinition& GetUnnamedModule() {
- if (!UnnamedModule) {
- UnnamedModule = TPyModuleDefinition();
- CurrentModule = const_cast<TPyModuleDefinition*>(UnnamedModule.Get());
- }
- return *UnnamedModule;
- }
- TPyModuleDefinition& InitModuleWithName(const TString& name) {
- if (!UnnamedModule) {
- GetUnnamedModule();
- }
- Name2Def[name] = *UnnamedModule;
- UnnamedModule.Clear();
- CurrentModule = &Name2Def[name];
- return *CurrentModule;
- }
- TPyModuleDefinition& SwitchToModuleByName(const TString& name) {
- Y_ENSURE(Name2Def.contains(name));
- Y_ENSURE(UnnamedModule.Empty());
- CurrentModule = &Name2Def[name];
- return *CurrentModule;
- }
- private:
- TPyModuleDefinition* CurrentModule = nullptr;
- TMaybe<TPyModuleDefinition> UnnamedModule;//
- THashMap<TString, TPyModuleDefinition> Name2Def;
- };
- }//Detail
- inline void TPyModuleDefinition::InitModule(const TString& name) {
- Detail::TPyModuleRegistry::Get().GetUnnamedModule() = TPyModuleDefinition{name, TModuleHolder::Instance().InitModule(name), {}, {}};
- Detail::TPyModuleRegistry::Get().InitModuleWithName(name);
- }
- inline TPyModuleDefinition& TPyModuleDefinition::GetModule() {
- return Detail::TPyModuleRegistry::Get().GetCurrentModule();
- }
- namespace Detail {
- template <class TPythonType>
- struct TNameCtx {
- TString ClassShortName;
- static TNameCtx& GetNameCtx() {
- static TNameCtx result;
- return result;
- }
- };
- template <class TBase>
- struct TContextImpl {
- PyTypeObject* ParentType = nullptr;
- TString ClassShortName;
- TString ClassFullName;
- TString ClassDescription;
- TVector<std::pair<TString, typename TPythonTypeAttributes<TBase>::TCallerPtr>> ListCallers;
- TVector<std::pair<TString, typename TPythonTypeAttributes<TBase>::TGetterPtr>> ListGetters;
- TVector<std::pair<TString, typename TPythonTypeAttributes<TBase>::TSetterPtr>> ListSetters;
- };
- template <class TObject>
- struct IGetContext: public IGetContextBase {
- virtual ~IGetContext() = default;
- virtual const TContextImpl<TObject>& GetContext() const = 0;
- };
- template <typename THolderClass, typename TBaseClass, bool ShouldEnable, typename=std::enable_if_t<!ShouldEnable || !std::is_default_constructible_v<TBaseClass>>>
- THolderClass* DoInitPureObject(const TVector<TString>&) {
- ythrow yexception() << "Can't create this object in pure mode from python";
- }
- template <typename THolderClass, typename TBaseClass, bool ShouldEnable, typename=std::enable_if_t<ShouldEnable && std::is_default_constructible_v<TBaseClass>>, typename=void>
- THolderClass* DoInitPureObject(const TVector<TString>&) {
- return new THolderClass(MakeHolder<TBaseClass>());
- }
- template <typename T>
- PyTypeObject* GetParentType(const TPyModuleDefinition& m) {
- auto shortName = Detail::TNameCtx<T>::GetNameCtx().ClassShortName;
- auto it = m.ClassName2Type.find(shortName);
- return (it == m.ClassName2Type.end()) ? nullptr : it->second;
- }
- template <>
- PyTypeObject* GetParentType<void>(const TPyModuleDefinition&);
- template <bool InitEnabled>
- void UpdateClassNamesInModule(TPyModuleDefinition& M, const TString& name, PyTypeObject* pythonType);
- template <bool InitEnabled>
- void UpdateGetContextInModule(TPyModuleDefinition& M, const TString& name, IGetContextBase* base);
- }
- template <class TParentPyClass_=void>
- struct TPyParentClassTraits {
- using TParentPyClass = TParentPyClass_;
- };
- template <bool InitEnabled_, class TParentPyClass_=void>
- struct TPyClassConfigTraits: public TPyParentClassTraits<TParentPyClass_> {
- constexpr static bool InitEnabled = InitEnabled_;
- constexpr static bool RawInit = false;
- };
- template <class TParentPyClass_=void>
- struct TPyClassRawInitConfigTraits: public TPyParentClassTraits<TParentPyClass_> {
- constexpr static bool InitEnabled = true;
- constexpr static bool RawInit = true;
- };
- template <typename TBaseClass, typename TPyClassConfigTraits, typename... ConstructorArgs>
- class TPyClass {
- public:
- using TBase = TBaseClass;
- private:
- using TThisClass = TPyClass<TBaseClass, TPyClassConfigTraits, ConstructorArgs...>;
- using TContext = Detail::TContextImpl<TBase>;
- struct THolder {
- ::THolder<TBase> Holder;
- THolder(::THolder<TBase>&& right)
- : Holder(std::move(right))
- {
- }
- THolder(TBase&& right)
- : Holder(MakeHolder<TBase>(std::move(right)))
- {
- }
- };
- class TSelectedTraits: public NPyBind::TPythonType<THolder, TBase, TSelectedTraits> {
- private:
- using TParent = NPyBind::TPythonType<THolder, TBase, TSelectedTraits>;
- friend TParent;
- public:
- TSelectedTraits()
- : TParent(TThisClass::GetContext().ClassFullName.data(), TThisClass::GetContext().ClassDescription.data(), TThisClass::GetContext().ParentType)
- {
- for (const auto& caller : TThisClass::GetContext().ListCallers) {
- TParent::AddCaller(caller.first, caller.second);
- }
- for (const auto& getter : TThisClass::GetContext().ListGetters) {
- TParent::AddGetter(getter.first, getter.second);
- }
- for (const auto& setter : TThisClass::GetContext().ListSetters) {
- TParent::AddSetter(setter.first, setter.second);
- }
- }
- static TBase* GetObject(const THolder& holder) {
- return holder.Holder.Get();
- }
- static THolder* DoInitObject(PyObject* args, PyObject* kwargs) {
- if constexpr (TPyClassConfigTraits::InitEnabled) {
- if constexpr (TPyClassConfigTraits::RawInit) {
- static_assert(sizeof...(ConstructorArgs) == 0, "Do not pass construction args if use RawInit.");
- return new THolder(::MakeHolder<TBase>(args, kwargs));
- } else {
- if (args && (!PyTuple_Check(args) || PyTuple_Size(args) != sizeof...(ConstructorArgs))) {
- ythrow yexception() << "Method takes " << sizeof...(ConstructorArgs) << " arguments, " << PyTuple_Size(args) << " provided";
- }
- ::THolder<TBaseClass> basePtr{Apply([](auto&&... unpackedArgs) {return new TBase(std::forward<decltype(unpackedArgs)>(unpackedArgs)...); }, GetArguments<ConstructorArgs...>(args))};
- return new THolder(std::move(basePtr));
- }
- } else {
- ythrow yexception() << "Can't create this object from python";
- }
- }
- static THolder* DoInitPureObject(const TVector<TString>& properties) {
- return Detail::DoInitPureObject<THolder, TBase, TPyClassConfigTraits::InitEnabled>(properties);
- }
- static TBase* CastToObject(PyObject* obj) {
- return TParent::CastToObject(obj);
- }
- static PyTypeObject* GetType() {
- return TParent::GetPyTypePtr();
- }
- };
- class TContextHolder: public Detail::IGetContext<TBaseClass> {
- public:
- static TContextHolder& GetContextHolder() {
- static TContextHolder holder;
- return holder;
- }
- TContext& GetContext() {
- return Context;
- }
- const TContext& GetContext() const override {
- return Context;
- }
- private:
- TContext Context;
- };
- template <class TDerivedClass, class TSuperClass>
- class TCallerWrapper: public TBaseMethodCaller<TDerivedClass> {
- public:
- explicit TCallerWrapper(TSimpleSharedPtr<const TBaseMethodCaller<TSuperClass>> baseCaller)
- : BaseCaller(baseCaller) {
- Y_ENSURE(BaseCaller);
- }
- bool CallMethod(PyObject* owner, TDerivedClass* self, PyObject* args, PyObject* kwargs, PyObject*& res) const override {
- return BaseCaller->CallMethod(owner, static_cast<TSuperClass*>(self), args, kwargs, res);
- }
- private:
- TSimpleSharedPtr<const TBaseMethodCaller<TSuperClass>> BaseCaller;
- };
- template <class TDerivedClass, class TSuperClass>
- class TSetterWrapper: public TBaseAttrSetter<TDerivedClass> {
- public:
- explicit TSetterWrapper(TSimpleSharedPtr<TBaseAttrSetter<TSuperClass>> baseSetter)
- : BaseSetter(baseSetter) {
- Y_ENSURE(BaseSetter);
- }
- bool SetAttr(PyObject* owner, TDerivedClass& self, const TString& attr, PyObject* val) override {
- return BaseSetter->SetAttr(owner, static_cast<TSuperClass&>(self), attr, val);
- }
- private:
- TSimpleSharedPtr<TBaseAttrSetter<TSuperClass>> BaseSetter;
- };
- template <class TDerivedClass, class TSuperClass>
- class TGetterWrapper: public TBaseAttrGetter<TDerivedClass> {
- public:
- explicit TGetterWrapper(TSimpleSharedPtr<const TBaseAttrGetter<TSuperClass>> baseGetter)
- : BaseGetter(baseGetter) {
- Y_ENSURE(BaseGetter);
- }
- bool GetAttr(PyObject* owner, const TDerivedClass& self, const TString& attr, PyObject*& res) const override {
- return BaseGetter->GetAttr(owner, static_cast<const TSuperClass&>(self), attr, res);
- }
- private:
- TSimpleSharedPtr<const TBaseAttrGetter<TSuperClass>> BaseGetter;
- };
- template <class TSuperClass, typename=std::enable_if_t<!std::is_same_v<TSuperClass, void>>>
- void ReloadAttrsFromBase() {
- auto shortName = Detail::TNameCtx<TSuperClass>::GetNameCtx().ClassShortName;
- if (!M.Class2ContextGetter.count(shortName)) {
- return;
- }
- auto callerBasePtr = M.Class2ContextGetter[shortName];
- if (auto getContextPtr = dynamic_cast<const Detail::IGetContext<TSuperClass>*>(callerBasePtr)) {
- auto& ctx = getContextPtr->GetContext();
- auto getUniqueNames = [](const auto& collection) {
- THashSet<TString> uniqueNames;
- for (const auto& elem : collection) {
- uniqueNames.insert(elem.first);
- }
- return uniqueNames;
- };
- auto uniqueCallerNames = getUniqueNames(GetContext().ListCallers);
- using TConcreteCallerWrapper = TCallerWrapper<TBaseClass, TSuperClass>;
- for (const auto& caller : ctx.ListCallers) {
- if (uniqueCallerNames.contains(caller.first)) {
- continue;
- }
- GetContext().ListCallers.push_back(std::make_pair(caller.first, MakeSimpleShared<TConcreteCallerWrapper>(caller.second)));
- }
- auto uniqueGettersNames = getUniqueNames(GetContext().ListGetters);
- using TConcreteGetterWrapper = TGetterWrapper<TBaseClass, TSuperClass>;
- for (const auto& getter : ctx.ListGetters) {
- if (uniqueGettersNames.contains(getter.first)) {
- continue;
- }
- GetContext().ListGetters.push_back(std::make_pair(getter.first, MakeSimpleShared<TConcreteGetterWrapper>(getter.second)));
- }
- auto uniqueSetterNames = getUniqueNames(GetContext().ListSetters);
- using TConcreteSetterWrapper = TSetterWrapper<TBaseClass, TSuperClass>;
- for (auto& setter : ctx.ListSetters) {
- if (uniqueSetterNames.contains(setter.first)) {
- continue;
- }
- GetContext().ListSetters.push_back(std::make_pair(setter.first, MakeSimpleShared<TConcreteSetterWrapper>(setter.second)));
- }
- }
- }
- template <class TSuperClass, typename=std::enable_if_t<std::is_same_v<TSuperClass, void>>, typename=void>
- void ReloadAttrsFromBase() {
- }
- void CompleteImpl() {
- ReloadAttrsFromBase<typename TPyClassConfigTraits::TParentPyClass>();
- TSelectedTraits::Instance().Register(M.M, GetContext().ClassShortName);
- }
- static TContext& GetContext() {
- return TContextHolder::GetContextHolder().GetContext();
- }
- friend struct Detail::TContextImpl<TBase>;//instead of context
- friend struct THolder;
- friend class TSelectedTraits;
- using TCallerFunc = std::function<bool(PyObject*, TBaseClass*, PyObject*, PyObject*, PyObject*&)>;
- class TFuncCallerWrapper: public TBaseMethodCaller<TBaseClass> {
- public:
- explicit TFuncCallerWrapper(TCallerFunc func)
- : Func(func) {
- Y_ENSURE(func);
- }
- bool CallMethod(PyObject* owner, TBaseClass* self, PyObject* args, PyObject* kwargs, PyObject*& res) const override {
- return Func(owner, self, args, kwargs, res);
- }
- private:
- mutable TCallerFunc Func;
- };
- public:
- TPyClass(const TString& name, const TString& descr = "")
- : M(TPyModuleDefinition::GetModule())
- {
- Detail::UpdateClassNamesInModule<TPyClassConfigTraits::InitEnabled>(M, name, TSelectedTraits::GetType());
- Detail::UpdateGetContextInModule<TPyClassConfigTraits::InitEnabled>(M, name, &TContextHolder::GetContextHolder());
- GetContext().ClassFullName = TString::Join(M.Name, ".", name);
- GetContext().ClassShortName = name;
- GetContext().ClassDescription = descr;
- GetContext().ParentType = Detail::GetParentType<typename TPyClassConfigTraits::TParentPyClass>(M);
- Detail::TNameCtx<TBaseClass>::GetNameCtx().ClassShortName = name;
- }
- template <typename TMemberFuction, typename = std::enable_if_t<std::is_member_function_pointer_v<TMemberFuction>>, typename=std::enable_if_t<!TIsPointerToConstMemberFunction<TMemberFuction>::value>>
- TThisClass& Def(const TString& name, TMemberFuction t) {
- GetContext().ListCallers.push_back(std::make_pair(name, CreateMethodCaller<TBase>(t)));
- return *this;
- }
- template <typename TMemberFuction, typename = std::enable_if_t<std::is_member_function_pointer_v<TMemberFuction>>, typename=std::enable_if_t<TIsPointerToConstMemberFunction<TMemberFuction>::value>, typename=void>
- TThisClass& Def(const TString& name, TMemberFuction t) {
- GetContext().ListCallers.push_back(std::make_pair(name, CreateConstMethodCaller<TBase>(t)));
- return *this;
- }
- template <typename TMemberObject, typename = std::enable_if_t<std::is_member_object_pointer_v<TMemberObject>>>
- TThisClass& Def(const TString& name, TMemberObject t) {
- GetContext().ListGetters.push_back(std::make_pair(name, CreateAttrGetter<TBase>(t)));
- GetContext().ListSetters.push_back(std::make_pair(name, CreateAttrSetter<TBase>(t)));
- return *this;
- }
- template <typename TResultType, typename... Args>
- TThisClass& DefByFunc(const TString& name, std::function<TResultType(TBaseClass&, Args...)> func) {
- GetContext().ListCallers.push_back(std::make_pair(name, CreateFunctorCaller<TBase, TResultType, Args...>(func)));
- return *this;
- }
- TThisClass& DefByFunc(const TString& name, TCallerFunc origFunc) {
- GetContext().ListCallers.push_back(std::make_pair(name, MakeSimpleShared<TFuncCallerWrapper>(origFunc)));
- return *this;
- }
- template <typename TMemberObject>
- TThisClass& DefReadonly(const TString& name, TMemberObject t, std::enable_if_t<std::is_member_object_pointer<TMemberObject>::value>* = nullptr) {
- GetContext().ListGetters.push_back(std::make_pair(name, CreateAttrGetter<TBase>(t)));
- return *this;
- }
- template <typename TMethodGetter, typename TMethodSetter, typename=std::enable_if_t<std::is_member_function_pointer_v<TMethodGetter> && std::is_member_function_pointer_v<TMethodSetter>>>
- TThisClass& AsProperty(const TString& name, TMethodGetter getter, TMethodSetter setter) {
- GetContext().ListGetters.push_back(std::make_pair(name, CreateMethodAttrGetter<TBase>(getter)));
- GetContext().ListSetters.push_back(std::make_pair(name, CreateMethodAttrSetter<TBase>(setter)));
- return *this;
- }
- template <typename TMethodGetter, typename TMethodSetter, typename=std::enable_if_t<!std::is_member_function_pointer_v<TMethodGetter> && !std::is_member_function_pointer_v<TMethodSetter>>>
- TThisClass& AsPropertyByFunc(const TString& name, TMethodGetter getter, TMethodSetter setter) {
- GetContext().ListGetters.push_back(std::make_pair(name, CreateFunctorAttrGetter<TBase>(getter)));
- GetContext().ListSetters.push_back(std::make_pair(name, CreateFunctorAttrSetter<TBase>(setter)));
- return *this;
- }
- template <typename TMethodGetter, typename=std::enable_if_t<std::is_member_function_pointer_v<TMethodGetter>>>
- TThisClass& AsProperty(const TString& name, TMethodGetter getter) {
- GetContext().ListGetters.push_back(std::make_pair(name, CreateMethodAttrGetter<TBase>(getter)));
- return *this;
- }
- template <typename TMethodGetter>
- TThisClass& AsPropertyByFunc(const TString& name, TMethodGetter getter) {
- GetContext().ListGetters.push_back(std::make_pair(name, CreateFunctorAttrGetter<TBase>(getter)));
- return *this;
- }
- TThisClass& Complete() {
- if (!Completed) {
- CompleteImpl();
- Completed = true;
- }
- return *this;
- }
- public:
- static PyObject* BuildPyObject(TBase&& base) {
- return NPyBind::BuildPyObject(TSelectedTraits::Instance().CreatePyObject(new THolder(std::move(base))));
- }
- static PyObject* BuildPyObject(const TBase& base) {
- return NPyBind::BuildPyObject(TSelectedTraits::Instance().CreatePyObject(new THolder(TBase(base)))); // WARN - copy
- }
- static TBase* CastToObject(PyObject* obj) {
- return TSelectedTraits::CastToObject(obj);
- }
- private:
- TPyModuleDefinition& M;
- bool Completed = false;
- };
- template <typename TFunctionSignature, TFunctionSignature function>
- void DefImpl(const TString& name, const TString& descr = "") {
- NPyBind::TModuleHolder::Instance().AddModuleMethod<TModuleMethodCaller<TFunctionSignature, function>::Call>(name, descr);
- }
- #define DefFunc(NAME, FUNC) NPyBind::DefImpl<decltype(FUNC), FUNC>(NAME)
- #define DefFuncDescr(NAME, FUNC, DESCR) NPyBind::DefImpl<decltype(FUNC), FUNC>(NAME, DESCR)
- };
|