yql_simple_arrow_resolver.cpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #include "yql_simple_arrow_resolver.h"
  2. #include <yql/essentials/minikql/arrow/mkql_functions.h>
  3. #include <yql/essentials/minikql/mkql_program_builder.h>
  4. #include <yql/essentials/minikql/mkql_type_builder.h>
  5. #include <yql/essentials/minikql/mkql_function_registry.h>
  6. #include <yql/essentials/providers/common/mkql/yql_type_mkql.h>
  7. #include <util/stream/null.h>
  8. namespace NYql {
  9. using namespace NKikimr::NMiniKQL;
  10. class TSimpleArrowResolver: public IArrowResolver {
  11. public:
  12. TSimpleArrowResolver(const IFunctionRegistry& functionRegistry)
  13. : FunctionRegistry_(functionRegistry)
  14. {}
  15. private:
  16. EStatus LoadFunctionMetadata(const TPosition& pos, TStringBuf name, const TVector<const TTypeAnnotationNode*>& argTypes,
  17. const TTypeAnnotationNode* returnType, TExprContext& ctx) const override
  18. {
  19. try {
  20. TScopedAlloc alloc(__LOCATION__);
  21. TTypeEnvironment env(alloc);
  22. TTypeBuilder typeBuilder(env);
  23. TNullOutput null;
  24. TVector<TType*> mkqlInputTypes;
  25. NCommon::TMemoizedTypesMap typeMemoization;
  26. for (const auto& type : argTypes) {
  27. auto mkqlType = NCommon::BuildType(*type, typeBuilder, typeMemoization, null);
  28. YQL_ENSURE(mkqlType, "Failed to convert type " << *type << " to MKQL type");
  29. mkqlInputTypes.emplace_back(mkqlType);
  30. }
  31. TType* mkqlOutputType = NCommon::BuildType(*returnType, typeBuilder, typeMemoization, null);
  32. YQL_ENSURE(mkqlOutputType, "Failed to convert type " << *returnType << " to MKQL type");
  33. bool found = FindArrowFunction(name, mkqlInputTypes, mkqlOutputType, *FunctionRegistry_.GetBuiltins());
  34. return found ? EStatus::OK : EStatus::NOT_FOUND;
  35. } catch (const std::exception& e) {
  36. ctx.AddError(TIssue(pos, e.what()));
  37. return EStatus::ERROR;
  38. }
  39. }
  40. EStatus HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, TExprContext& ctx) const override {
  41. try {
  42. TScopedAlloc alloc(__LOCATION__);
  43. TTypeEnvironment env(alloc);
  44. TTypeBuilder typeBuilder(env);
  45. NCommon::TMemoizedTypesMap typeMemoization;
  46. TNullOutput null;
  47. auto mkqlFromType = NCommon::BuildType(*from, typeBuilder, typeMemoization, null);
  48. YQL_ENSURE(mkqlFromType, "Failed to convert type " << *from << " to MKQL type");
  49. auto mkqlToType = NCommon::BuildType(*to, typeBuilder, typeMemoization, null);
  50. YQL_ENSURE(mkqlToType, "Failed to convert type " << *to << " to MKQL type");
  51. return HasArrowCast(mkqlFromType, mkqlToType) ? EStatus::OK : EStatus::NOT_FOUND;
  52. } catch (const std::exception& e) {
  53. ctx.AddError(TIssue(pos, e.what()));
  54. return EStatus::ERROR;
  55. }
  56. }
  57. EStatus AreTypesSupported(const TPosition& pos, const TVector<const TTypeAnnotationNode*>& types, TExprContext& ctx,
  58. const TUnsupportedTypeCallback& onUnsupported = {}) const override
  59. {
  60. try {
  61. TScopedAlloc alloc(__LOCATION__);
  62. TTypeEnvironment env(alloc);
  63. TTypeBuilder typeBuilder(env);
  64. NCommon::TMemoizedTypesMap typeMemoization;
  65. TNullOutput null;
  66. bool allOk = true;
  67. TArrowConvertFailedCallback cb;
  68. if (onUnsupported) {
  69. cb = [&](TType* failed) {
  70. if (failed->IsData()) {
  71. auto slot = static_cast<TDataType*>(failed)->GetDataSlot();
  72. YQL_ENSURE(slot);
  73. onUnsupported(*slot);
  74. } else {
  75. onUnsupported(NYql::NCommon::ConvertMiniKQLTypeKind(failed));
  76. }
  77. };
  78. }
  79. for (const auto& type : types) {
  80. YQL_ENSURE(type);
  81. auto mkqlType = NCommon::BuildType(*type, typeBuilder, typeMemoization, null);
  82. YQL_ENSURE(mkqlType);
  83. std::shared_ptr<arrow::DataType> arrowType;
  84. if (!ConvertArrowType(mkqlType, arrowType, cb)) {
  85. allOk = false;
  86. if (!cb) {
  87. break;
  88. }
  89. }
  90. }
  91. return allOk ? EStatus::OK : EStatus::NOT_FOUND;
  92. } catch (const std::exception& e) {
  93. ctx.AddError(TIssue(pos, e.what()));
  94. return EStatus::ERROR;
  95. }
  96. }
  97. private:
  98. const IFunctionRegistry& FunctionRegistry_;
  99. };
  100. IArrowResolver::TPtr MakeSimpleArrowResolver(const IFunctionRegistry& functionRegistry) {
  101. return new TSimpleArrowResolver(functionRegistry);
  102. }
  103. } // namespace NYql