utils.cpp 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. #include "pg_compat.h"
  2. extern "C" {
  3. #include "postgres.h"
  4. #include "fmgr.h"
  5. #include "utils/array.h"
  6. #include "utils/elog.h"
  7. #include "pgstat.h"
  8. #include "catalog/pg_namespace_d.h"
  9. }
  10. #undef Max
  11. constexpr auto PG_ERROR = ERROR;
  12. #undef ERROR
  13. #include "utils.h"
  14. #include <util/system/compiler.h>
  15. #include <yql/essentials/parser/pg_catalog/catalog.h>
  16. #include <util/system/dynlib.h>
  17. #define ERROR PG_ERROR
  18. namespace NYql {
  19. void PrepareVariadicArraySlow(FunctionCallInfoBaseData& callInfo, const NPg::TProcDesc& desc) {
  20. const auto& elemDesc = NPg::LookupType(desc.VariadicType);
  21. Datum varArgs[FUNC_MAX_ARGS];
  22. bool varArgsNulls[FUNC_MAX_ARGS];
  23. const ui32 nelems = callInfo.nargs - desc.ArgTypes.size();
  24. Y_ENSURE(nelems >= 1);
  25. for (ui32 i = desc.ArgTypes.size(); i < callInfo.nargs; ++i) {
  26. varArgs[i - desc.ArgTypes.size()] = callInfo.args[i].value;
  27. varArgsNulls[i - desc.ArgTypes.size()] = callInfo.args[i].isnull;
  28. }
  29. callInfo.nargs = desc.ArgTypes.size() + 1;
  30. int dims[MAXDIM];
  31. int lbs[MAXDIM];
  32. dims[0] = nelems;
  33. lbs[0] = 1;
  34. auto array = construct_md_array(varArgs, varArgsNulls, 1, dims, lbs,
  35. desc.VariadicType, elemDesc.TypeLen, elemDesc.PassByValue, elemDesc.TypeAlign);
  36. auto& argDatum = callInfo.args[callInfo.nargs - 1];
  37. argDatum.value = PointerGetDatum(array);
  38. argDatum.isnull = false;
  39. }
  40. void FreeVariadicArray(FunctionCallInfoBaseData& callInfo, ui32 originalArgs) {
  41. pfree(DatumGetPointer(callInfo.args[callInfo.nargs - 1].value));
  42. callInfo.nargs = originalArgs;
  43. }
  44. static __thread ui32 PgExtIndexMax = 0;
  45. class TExtensionsRegistry::TImpl {
  46. public:
  47. struct TExtension {
  48. TExtension(const TString& name, const TString& libraryPath) {
  49. Lib.Open(libraryPath.c_str(), RTLD_GLOBAL | RTLD_NOW);
  50. Lib.SetUnloadable(false);
  51. InitFunc = (TInitFunc)Lib.Sym("YqlPgThreadInit");
  52. CleanupFunc = (TInitFunc)Lib.Sym("YqlPgThreadCleanup");
  53. }
  54. using TInitFunc = void(*)(void);
  55. using TCleanupFunc = void(*)(void);
  56. TDynamicLibrary Lib;
  57. TInitFunc InitFunc = nullptr;
  58. TInitFunc CleanupFunc = nullptr;
  59. };
  60. void InitThread() {
  61. while (PgExtIndexMax < Extensions.size()) {
  62. Extensions[PgExtIndexMax]->InitFunc();
  63. ++PgExtIndexMax;
  64. }
  65. }
  66. void CleanupThread() {
  67. for (ui32 i = 0; i < PgExtIndexMax; ++i) {
  68. Extensions[i]->CleanupFunc();
  69. }
  70. }
  71. void Load(ui32 extensionIndex, const TString& name, const TString& path) {
  72. Cerr << "Loading PG extension " << name << " [" << extensionIndex << "]: " << path << "\n";
  73. Y_ENSURE(extensionIndex == Extensions.size() + 1);
  74. Extensions.emplace_back(std::make_unique<TExtension>(name, path));
  75. Extensions.back()->InitFunc();
  76. ++PgExtIndexMax;
  77. }
  78. PGFunction GetFuncAddr(ui32 extensionIndex, const TString& funcName) {
  79. Y_ENSURE(extensionIndex > 0 && extensionIndex <= Extensions.size());
  80. return (PGFunction)Extensions[extensionIndex - 1]->Lib.Sym(funcName.c_str());
  81. }
  82. private:
  83. TVector<std::unique_ptr<TExtension>> Extensions;
  84. };
  85. extern "C" ui64 TouchReadTableApi();
  86. TExtensionsRegistry::TExtensionsRegistry()
  87. : Impl_(std::make_unique<TImpl>())
  88. {
  89. Y_UNUSED(TouchReadTableApi());
  90. }
  91. TExtensionsRegistry& TExtensionsRegistry::Instance() {
  92. return *Singleton<TExtensionsRegistry>();
  93. }
  94. void TExtensionsRegistry::InitThread() {
  95. Impl_->InitThread();
  96. }
  97. void TExtensionsRegistry::CleanupThread() {
  98. Impl_->CleanupThread();
  99. }
  100. void TExtensionsRegistry::Load(ui32 extensionIndex, const TString& name, const TString& path) {
  101. Impl_->Load(extensionIndex, name, path);
  102. }
  103. PGFunction TExtensionsRegistry::GetFuncAddr(ui32 extensionIndex, const TString& funcName) {
  104. return Impl_->GetFuncAddr(extensionIndex, funcName);
  105. }
  106. bool GetPgFuncAddr(ui32 procOid, FmgrInfo& finfo) {
  107. Zero(finfo);
  108. const auto& desc = NPg::LookupProc(procOid);
  109. finfo.fn_strict = desc.IsStrict;
  110. finfo.fn_retset = desc.ReturnSet;
  111. finfo.fn_nargs = desc.ArgTypes.size() + (desc.VariadicArgType ? 1 : 0);
  112. finfo.fn_mcxt = CurrentMemoryContext;
  113. finfo.fn_oid = procOid;
  114. finfo.fn_stats = TRACK_FUNC_ALL;
  115. if (desc.Kind != NPg::EProcKind::Function) {
  116. return false;
  117. }
  118. if (desc.ExtensionIndex == 0) {
  119. if (desc.Lang != NPg::LangInternal) {
  120. return false;
  121. }
  122. FmgrInfo tmp;
  123. Zero(tmp);
  124. fmgr_info(procOid, &tmp);
  125. Y_ENSURE(tmp.fn_addr);
  126. Y_ENSURE(finfo.fn_strict == tmp.fn_strict);
  127. Y_ENSURE(finfo.fn_retset == tmp.fn_retset);
  128. Y_ENSURE(finfo.fn_nargs == tmp.fn_nargs);
  129. finfo.fn_addr = tmp.fn_addr;
  130. return true;
  131. }
  132. const auto& extension = NPg::LookupExtension(desc.ExtensionIndex);
  133. if (extension.TypesOnly) {
  134. return false;
  135. }
  136. if (desc.Lang != NPg::LangC) {
  137. return false;
  138. }
  139. finfo.fn_addr = TExtensionsRegistry::Instance().GetFuncAddr(desc.ExtensionIndex, desc.Src);
  140. return true;
  141. }
  142. extern "C" Oid get_extension_oid(const char *extname, bool missing_ok)
  143. {
  144. Oid result = InvalidOid;
  145. try {
  146. result = NPg::LookupExtensionByName(extname);
  147. } catch (const yexception&) {
  148. }
  149. if (!OidIsValid(result) && !missing_ok)
  150. ereport(ERROR,
  151. (errcode(ERRCODE_UNDEFINED_OBJECT),
  152. errmsg("extension \"%s\" does not exist",
  153. extname)));
  154. return result;
  155. }
  156. extern "C" char *get_extension_name(Oid ext_oid) {
  157. try {
  158. return pstrdup(NPg::LookupExtension(ext_oid).Name.c_str());
  159. } catch (const yexception&) {
  160. return nullptr;
  161. }
  162. }
  163. extern "C" Oid get_extension_schema(Oid) {
  164. return PG_CATALOG_NAMESPACE;
  165. }
  166. }