Browse Source

[dq] Refactor dq gateway session handling

YQL-16599
udovichenko-r 1 year ago
parent
commit
74c6728ab9

+ 10 - 7
ydb/library/yql/providers/dq/provider/exec/yql_dq_exectransformer.cpp

@@ -1429,14 +1429,17 @@ private:
     }
 
     IDqGateway::TDqProgressWriter MakeDqProgressWriter(const TPublicIds::TPtr& publicIds) const {
-        IDqGateway::TDqProgressWriter dqProgressWriter = [progressWriter = State->ProgressWriter, publicIds](const TString& stage) {
-            for (const auto& publicId : publicIds->AllPublicIds) {
-                auto p = TOperationProgress(TString(DqProviderName), publicId.first, TOperationProgress::EState::InProgress, stage);
-                if (publicId.second) {
-                    p.Counters.ConstructInPlace();
-                    p.Counters->Running = p.Counters->Total = publicId.second;
+        IDqGateway::TDqProgressWriter dqProgressWriter = [progressWriter = State->ProgressWriter, publicIds, current = std::make_shared<TString>()](const TString& stage) {
+            if (*current != stage) {
+                for (const auto& publicId : publicIds->AllPublicIds) {
+                    auto p = TOperationProgress(TString(DqProviderName), publicId.first, TOperationProgress::EState::InProgress, stage);
+                    if (publicId.second) {
+                        p.Counters.ConstructInPlace();
+                        p.Counters->Running = p.Counters->Total = publicId.second;
+                    }
+                    progressWriter(p);
                 }
-                progressWriter(p);
+                *current = stage;
             }
         };
         return dqProgressWriter;

+ 250 - 219
ydb/library/yql/providers/dq/provider/yql_dq_gateway.cpp

@@ -10,10 +10,13 @@
 #include <ydb/public/lib/yson_value/ydb_yson_value.h>
 
 #include <ydb/library/grpc/client/grpc_client_low.h>
+
 #include <library/cpp/yson/node/node_io.h>
 #include <library/cpp/threading/task_scheduler/task_scheduler.h>
 
-#include <util/system/thread.h>
+#include <util/system/mutex.h>
+#include <util/generic/hash.h>
+#include <util/string/builder.h>
 
 #include <utility>
 
@@ -21,36 +24,59 @@ namespace NYql {
 
 using namespace NThreading;
 
-class TDqGatewayImpl: public std::enable_shared_from_this<TDqGatewayImpl>
-{
+class TDqTaskScheduler : public TTaskScheduler {
+private:
+    struct TDelay: public TTaskScheduler::ITask {
+        TDelay(TPromise<void> p)
+            : Promise(std::move(p))
+        { }
+
+        TInstant Process() override {
+            Promise.SetValue();
+            return TInstant::Max();
+        }
+
+        TPromise<void> Promise;
+    };
+
+public:
+    TDqTaskScheduler()
+        : TTaskScheduler(1) // threads
+    {}
+
+    TFuture<void> Delay(TDuration duration) {
+        TPromise<void> promise = NewPromise();
+
+        auto future = promise.GetFuture();
+
+        if (!Add(MakeIntrusive<TDelay>(promise), TInstant::Now() + duration)) {
+            promise.SetException("cannot delay");
+        }
+
+        return future;
+    }
+};
+
+class TDqGatewaySession: public std::enable_shared_from_this<TDqGatewaySession> {
 public:
     using TResult = IDqGateway::TResult;
     using TDqProgressWriter = IDqGateway::TDqProgressWriter;
 
-    TDqGatewayImpl(const TString& host, int port, const TString& vanillaJobPath, const TString& vanillaJobMd5, TDuration timeout, TDuration requestTimeout)
-        : GrpcConf(TStringBuilder() << host << ":" << port, requestTimeout)
-        , GrpcClient(1)
-        , Service(GrpcClient.CreateGRpcServiceConnection<Yql::DqsProto::DqService>(GrpcConf))
-        , VanillaJobPath(vanillaJobPath)
-        , VanillaJobMd5(vanillaJobMd5)
-        , TaskScheduler(1)
-        , OpenSessionTimeout(timeout)
+    TDqGatewaySession(const TString& sessionId, TDqTaskScheduler& taskScheduler, NYdbGrpc::TServiceConnection<Yql::DqsProto::DqService>& service, TFuture<void>&& openSessionFuture)
+        : SessionId(sessionId)
+        , TaskScheduler(taskScheduler)
+        , Service(service)
+        , OpenSessionFuture(std::move(openSessionFuture))
     {
-        TaskScheduler.Start();
     }
 
-    TString GetVanillaJobPath() {
-        return VanillaJobPath;
-    }
-
-    TString GetVanillaJobMd5() {
-        return VanillaJobMd5;
+    const TString& GetSessionId() const {
+        return SessionId;
     }
 
     template<typename RespType>
-    void OnResponse(TPromise<TResult> promise, TString sessionId, NYdbGrpc::TGrpcStatus&& status, RespType&& resp, const THashMap<TString, TString>& modulesMapping, bool alwaysFallback = false)
-    {
-        YQL_LOG_CTX_ROOT_SESSION_SCOPE(sessionId);
+    void OnResponse(TPromise<TResult> promise, NYdbGrpc::TGrpcStatus&& status, RespType&& resp, const THashMap<TString, TString>& modulesMapping, bool alwaysFallback = false) {
+        YQL_LOG_CTX_ROOT_SESSION_SCOPE(SessionId);
         YQL_CLOG(TRACE, ProviderDq) << "TDqGateway::callback";
 
         TResult result;
@@ -89,7 +115,7 @@ public:
                     auto& message = *queue.front();
                     queue.pop_front();
                     message.Setmessage(NBacktrace::Symbolize(message.Getmessage(), modulesMapping));
-                    for (auto &subMsg : *message.Mutableissues()) {
+                    for (auto& subMsg : *message.Mutableissues()) {
                         queue.push_back(&subMsg);
                     }
                 }
@@ -153,56 +179,34 @@ public:
         promise.SetValue(result);
     }
 
-    TFuture<void> Delay(TDuration duration) {
-        TPromise<void> promise = NewPromise();
-
-        auto future = promise.GetFuture();
-
-        if (!TaskScheduler.Add(MakeIntrusive<TDelay>(promise), TInstant::Now() + duration)) {
-            promise.SetException("cannot delay");
-        }
-
-        return future;
-    }
-
     template <typename TResponse, typename TRequest, typename TStub>
     TFuture<TResult> WithRetry(
-        const TString& sessionId,
         const TRequest& queryPB,
         TStub stub,
         int retry,
         const TDqSettings::TPtr& settings,
-        const THashMap<TString, TString>& modulesMapping
+        const THashMap<TString, TString>& modulesMapping,
+        const TDqProgressWriter& progressWriter
     ) {
         auto backoff = TDuration::MilliSeconds(settings->RetryBackoffMs.Get().GetOrElse(1000));
         auto promise = NewPromise<TResult>();
         const auto fallbackPolicy = settings->FallbackPolicy.Get().GetOrElse(EFallbackPolicy::Default);
         const auto alwaysFallback = EFallbackPolicy::Always == fallbackPolicy;
         auto self = weak_from_this();
-        auto callback = [self, promise, sessionId, alwaysFallback, modulesMapping](NYdbGrpc::TGrpcStatus&& status, TResponse&& resp) mutable {
+        auto callback = [self, promise, sessionId = SessionId, alwaysFallback, modulesMapping](NYdbGrpc::TGrpcStatus&& status, TResponse&& resp) mutable {
             auto this_ = self.lock();
             if (!this_) {
-                YQL_CLOG(DEBUG, ProviderDq) << "Gateway was closed: " << sessionId;
-                promise.SetException("Gateway was closed");
+                YQL_CLOG(DEBUG, ProviderDq) << "Session was closed: " << sessionId;
+                promise.SetException("Session was closed");
                 return;
             }
 
-            this_->OnResponse(std::move(promise), std::move(sessionId), std::move(status), std::move(resp), modulesMapping, alwaysFallback);
+            this_->OnResponse(std::move(promise), std::move(status), std::move(resp), modulesMapping, alwaysFallback);
         };
 
-        Service->DoRequest<TRequest, TResponse>(queryPB, callback, stub);
+        Service.DoRequest<TRequest, TResponse>(queryPB, callback, stub);
 
-        {
-            TGuard<TMutex> lock(ProgressMutex);
-            auto i = RunningQueries.find(sessionId);
-            if (i != RunningQueries.end()) {
-                if (i->second.ProgressWriter) {
-                    ScheduleQueryStatusRequest(sessionId);
-                }
-            } else {
-                return MakeFuture(TResult());
-            }
-        }
+        ScheduleQueryStatusRequest(progressWriter);
 
         return promise.GetFuture().Apply([=](const TFuture<TResult>& result) {
             if (result.HasException()) {
@@ -215,31 +219,31 @@ public:
                 return result;
             }
 
-            return this_->Delay(backoff)
-                .Apply([=](const TFuture<void>& result) {
+            return this_->TaskScheduler.Delay(backoff)
+                .Apply([=, sessionId = this_->GetSessionId()](const TFuture<void>& result) {
                     auto this_ = self.lock();
                     try {
                         result.TryRethrow();
                         if (!this_) {
-                            YQL_CLOG(DEBUG, ProviderDq) << "Gateway was closed: " << sessionId;
-                            throw std::runtime_error("Gateway was closed");
+                            YQL_CLOG(DEBUG, ProviderDq) << "Session was closed: " << sessionId;
+                            throw std::runtime_error("Session was closed");
                         }
                     } catch (...) {
                         return MakeErrorFuture<TResult>(std::current_exception());
                     }
-                    return this_->WithRetry<TResponse>(sessionId, queryPB, stub, retry - 1, settings, modulesMapping);
+                    return this_->WithRetry<TResponse>(queryPB, stub, retry - 1, settings, modulesMapping, progressWriter);
                 });
         });
     }
 
     TFuture<TResult>
-    ExecutePlan(const TString& sessionId, NDqs::TPlan&& plan, const TVector<TString>& columns,
+    ExecutePlan(NDqs::TPlan&& plan, const TVector<TString>& columns,
                 const THashMap<TString, TString>& secureParams, const THashMap<TString, TString>& graphParams,
                 const TDqSettings::TPtr& settings,
                 const TDqProgressWriter& progressWriter, const THashMap<TString, TString>& modulesMapping,
                 bool discard)
     {
-        YQL_LOG_CTX_ROOT_SESSION_SCOPE(sessionId);
+        YQL_LOG_CTX_ROOT_SESSION_SCOPE(SessionId);
 
         Yql::DqsProto::ExecuteGraphRequest queryPB;
         for (const auto& task : plan.Tasks) {
@@ -253,7 +257,7 @@ public:
                 YQL_ENSURE(!file.GetObjectId().empty());
             }
         }
-        queryPB.SetSession(sessionId);
+        queryPB.SetSession(SessionId);
         queryPB.SetResultType(plan.ResultType);
         queryPB.SetSourceId(plan.SourceID.NodeId()-1);
         for (const auto& column : columns) {
@@ -279,179 +283,197 @@ public:
 
         int retry = settings->MaxRetries.Get().GetOrElse(5);
 
-        TFuture<void> sessionFuture;
-        {
-            TGuard<TMutex> lock(ProgressMutex);
-            auto it = RunningQueries.find(sessionId);
-            if (it == RunningQueries.end()) {
-                YQL_CLOG(DEBUG, ProviderDq) << "Session was closed: " << sessionId;
-                return MakeErrorFuture<TResult>(std::make_exception_ptr(std::runtime_error("Session was closed")));
-            }
-            it->second.ProgressWriter = progressWriter;
-            sessionFuture = it->second.OpenSessionFuture;
-        }
-
         YQL_CLOG(DEBUG, ProviderDq) << "Send query of size " << queryPB.ByteSizeLong();
 
         auto self = weak_from_this();
-        return sessionFuture.Apply([self, sessionId, queryPB, retry, settings, modulesMapping](const TFuture<void>& ) {
+        return OpenSessionFuture.Apply([self, sessionId = SessionId, queryPB, retry, settings, modulesMapping, progressWriter](const TFuture<void>& f) {
+            f.TryRethrow();
             auto this_ = self.lock();
             if (!this_) {
-                YQL_CLOG(DEBUG, ProviderDq) << "Gateway was closed: " << sessionId;
-                return MakeErrorFuture<TResult>(std::make_exception_ptr(std::runtime_error("Gateway was closed")));
+                YQL_CLOG(DEBUG, ProviderDq) << "Session was closed: " << sessionId;
+                return MakeErrorFuture<TResult>(std::make_exception_ptr(std::runtime_error("Session was closed")));
             }
 
             return this_->WithRetry<Yql::DqsProto::ExecuteGraphResponse>(
-                sessionId,
                 queryPB,
                 &Yql::DqsProto::DqService::Stub::AsyncExecuteGraph,
                 retry,
                 settings,
-                modulesMapping);
+                modulesMapping,
+                progressWriter);
         });
     }
 
-    TFuture<void> OpenSession(const TString& sessionId, const TString& username) {
-        YQL_LOG_CTX_ROOT_SESSION_SCOPE(sessionId);
-        YQL_CLOG(INFO, ProviderDq) << "OpenSession";
-        Yql::DqsProto::OpenSessionRequest request;
-        request.SetSession(sessionId);
-        request.SetUsername(username);
-
-        {
-            TGuard<TMutex> lock(ProgressMutex);
-            if (RunningQueries.find(sessionId) != RunningQueries.end()) {
-                return MakeFuture();
-            }
-        }
-
-        NYdbGrpc::TCallMeta meta;
-        meta.Timeout = OpenSessionTimeout;
+    TFuture<void> Close() {
+        Yql::DqsProto::CloseSessionRequest request;
+        request.SetSession(SessionId);
 
         auto promise = NewPromise<void>();
-        auto self = weak_from_this();
-        auto callback = [self, promise, sessionId](NYdbGrpc::TGrpcStatus&& status, Yql::DqsProto::OpenSessionResponse&& resp) mutable {
+        auto callback = [promise, sessionId = SessionId](NYdbGrpc::TGrpcStatus&& status, Yql::DqsProto::CloseSessionResponse&& resp) mutable {
             Y_UNUSED(resp);
             YQL_LOG_CTX_ROOT_SESSION_SCOPE(sessionId);
-            auto this_ = self.lock();
-            if (!this_) {
-                YQL_CLOG(DEBUG, ProviderDq) << "Gateway was closed: " << sessionId;
-                promise.SetException("Gateway was closed");
-                return;
-            }
             if (status.Ok()) {
-                YQL_CLOG(INFO, ProviderDq) << "OpenSession OK";
-                this_->SchedulePingSessionRequest(sessionId);
+                YQL_CLOG(DEBUG, ProviderDq) << "Async close session OK";
                 promise.SetValue();
             } else {
-                YQL_CLOG(ERROR, ProviderDq) << "OpenSession error: " << status.Msg;
-                promise.SetException(status.Msg);
+                YQL_CLOG(ERROR, ProviderDq) << "Async close session error: " << status.GRpcStatusCode << ", message: " << status.Msg;
+                promise.SetException(TStringBuilder() << "Async close session error: " << status.GRpcStatusCode << ", message: " << status.Msg);
             }
         };
 
-        Service->DoRequest<Yql::DqsProto::OpenSessionRequest, Yql::DqsProto::OpenSessionResponse>(
-            request, callback, &Yql::DqsProto::DqService::Stub::AsyncOpenSession, meta);
-
-        {
-            TGuard<TMutex> lock(ProgressMutex);
-            RunningQueries.emplace(sessionId, TSession {
-                    std::optional<TDqProgressWriter> {},
-                        "",
-                        promise.GetFuture()
-                        });
-        }
-
-        return MakeFuture();
-    }
-
-    TFuture<void> CloseSession(const TString& sessionId) {
-        Yql::DqsProto::CloseSessionRequest request;
-        request.SetSession(sessionId);
-
-        auto callback = [](NYdbGrpc::TGrpcStatus&& status, Yql::DqsProto::CloseSessionResponse&& resp) {
-            Y_UNUSED(resp);
-            Y_UNUSED(status);
-        };
-
-        {
-            TGuard<TMutex> lock(ProgressMutex);
-            RunningQueries.erase(sessionId);
-        }
-
-        Service->DoRequest<Yql::DqsProto::CloseSessionRequest, Yql::DqsProto::CloseSessionResponse>(
+        Service.DoRequest<Yql::DqsProto::CloseSessionRequest, Yql::DqsProto::CloseSessionResponse>(
             request, callback, &Yql::DqsProto::DqService::Stub::AsyncCloseSession);
-
-        return MakeFuture();
+        return promise.GetFuture();
     }
 
-    void OnRequestQueryStatus(const TString& sessionId, const TString& status, bool ok) {
-        TGuard<TMutex> lock(ProgressMutex);
-        TString stage;
-        TDqProgressWriter* dqProgressWriter = nullptr;
-        auto it = RunningQueries.find(sessionId);
-        if (it != RunningQueries.end() && ok) {
-            dqProgressWriter = it->second.ProgressWriter ? &*it->second.ProgressWriter:nullptr;
-            auto lastStatus = it->second.Status;
-            if (dqProgressWriter && lastStatus != status) {
-                stage = status;
-                it->second.Status = stage;
+    void OnRequestQueryStatus(const TDqProgressWriter& progressWriter, const TString& status, bool ok) {
+        if (ok) {
+            ScheduleQueryStatusRequest(progressWriter);
+            if (!status.empty()) {
+                progressWriter(status);
             }
-
-            ScheduleQueryStatusRequest(sessionId);
-        } else if (it != RunningQueries.end()) {
-            it->second.ProgressWriter = {};
-        }
-
-        if (!stage.empty() && dqProgressWriter) {
-            (*dqProgressWriter)(stage);
         }
     }
 
-    void RequestQueryStatus(const TString& sessionId) {
+    void RequestQueryStatus(const TDqProgressWriter& progressWriter) {
         Yql::DqsProto::QueryStatusRequest request;
-        request.SetSession(sessionId);
+        request.SetSession(SessionId);
         auto self = weak_from_this();
-        auto callback = [self, sessionId](NYdbGrpc::TGrpcStatus&& status, Yql::DqsProto::QueryStatusResponse&& resp) {
+        auto callback = [self, progressWriter](NYdbGrpc::TGrpcStatus&& status, Yql::DqsProto::QueryStatusResponse&& resp) {
             auto this_ = self.lock();
             if (!this_) {
                 return;
             }
 
-            this_->OnRequestQueryStatus(sessionId, resp.GetStatus(), status.Ok());
+            this_->OnRequestQueryStatus(progressWriter, resp.GetStatus(), status.Ok());
         };
 
-        Service->DoRequest<Yql::DqsProto::QueryStatusRequest, Yql::DqsProto::QueryStatusResponse>(
+        Service.DoRequest<Yql::DqsProto::QueryStatusRequest, Yql::DqsProto::QueryStatusResponse>(
             request, callback, &Yql::DqsProto::DqService::Stub::AsyncQueryStatus, {}, nullptr);
     }
 
-    void StartQueryStatusRequest(const TString& sessionId, bool ok) {
-        TGuard<TMutex> lock(ProgressMutex);
-        auto it = RunningQueries.find(sessionId);
-        if (it != RunningQueries.end() && ok) {
-            RequestQueryStatus(sessionId);
-        } else if (it != RunningQueries.end()) {
-            it->second.ProgressWriter = {};
+    void ScheduleQueryStatusRequest(const TDqProgressWriter& progressWriter) {
+        auto self = weak_from_this();
+        TaskScheduler.Delay(TDuration::MilliSeconds(1000)).Subscribe([self, progressWriter](const TFuture<void>& f) {
+            auto this_ = self.lock();
+            if (!this_) {
+                return;
+            }
+
+            if (!f.HasException()) {
+                this_->RequestQueryStatus(progressWriter);
+            }
+        });
+    }
+
+private:
+    const TString SessionId;
+    TDqTaskScheduler& TaskScheduler;
+    NYdbGrpc::TServiceConnection<Yql::DqsProto::DqService>& Service;
+
+    TMutex ProgressMutex;
+
+    std::optional<TDqProgressWriter> ProgressWriter;
+    TString Status;
+    TFuture<void> OpenSessionFuture;
+};
+
+class TDqGatewayImpl: public std::enable_shared_from_this<TDqGatewayImpl> {
+    using TResult = IDqGateway::TResult;
+    using TDqProgressWriter = IDqGateway::TDqProgressWriter;
+
+public:
+    TDqGatewayImpl(const TString& host, int port, TDuration timeout = TDuration::Minutes(60), TDuration requestTimeout = TDuration::Max())
+        : GrpcConf(TStringBuilder() << host << ":" << port, requestTimeout)
+        , GrpcClient(1)
+        , Service(GrpcClient.CreateGRpcServiceConnection<Yql::DqsProto::DqService>(GrpcConf))
+        , TaskScheduler()
+        , OpenSessionTimeout(timeout)
+    {
+        TaskScheduler.Start();
+    }
+
+    ~TDqGatewayImpl() {
+        Stop();
+    }
+
+    void Stop() {
+        decltype(Sessions) sessions;
+        with_lock (Mutex) {
+            sessions = std::move(Sessions);
+        }
+        for (auto& pair: sessions) {
+            try {
+                pair.second->Close().GetValueSync();
+            } catch (...) {
+                YQL_LOG_CTX_ROOT_SESSION_SCOPE(pair.first);
+                YQL_CLOG(ERROR, ProviderDq) << "Error closing session " << pair.first << ": " << CurrentExceptionMessage();
+            }
+        }
+        sessions.clear(); // Destroy session objects explicitly before stopping grpc
+        TaskScheduler.Stop();
+        try {
+            GrpcClient.Stop();
+        } catch (...) {
+            YQL_CLOG(ERROR, ProviderDq) << "Error while stopping GRPC client: " << CurrentExceptionMessage();
+        }
+    }
+
+    void DropSession(const TString& sessionId) {
+        with_lock (Mutex) {
+            Sessions.erase(sessionId);
         }
     }
 
-    void ScheduleQueryStatusRequest(const TString& sessionId) {
+    TFuture<void> OpenSession(const TString& sessionId, const TString& username) {
+        YQL_LOG_CTX_ROOT_SESSION_SCOPE(sessionId);
+        YQL_CLOG(INFO, ProviderDq) << "OpenSession";
+
+        auto promise = NewPromise<void>();
+        std::shared_ptr<TDqGatewaySession> session = std::make_shared<TDqGatewaySession>(sessionId, TaskScheduler, *Service, promise.GetFuture());
+        with_lock (Mutex) {
+            if (!Sessions.emplace(sessionId, session).second) {
+                return MakeErrorFuture<void>(std::make_exception_ptr(yexception() << "Duplicate session id: " << sessionId));
+            }
+        }
+
+        Yql::DqsProto::OpenSessionRequest request;
+        request.SetSession(sessionId);
+        request.SetUsername(username);
+
+        NYdbGrpc::TCallMeta meta;
+        meta.Timeout = OpenSessionTimeout;
+
         auto self = weak_from_this();
-        Delay(TDuration::MilliSeconds(1000)).Subscribe([self, sessionId](TFuture<void> fut) {
+        auto callback = [self, promise, sessionId](NYdbGrpc::TGrpcStatus&& status, Yql::DqsProto::OpenSessionResponse&& resp) mutable {
+            Y_UNUSED(resp);
+            YQL_LOG_CTX_ROOT_SESSION_SCOPE(sessionId);
             auto this_ = self.lock();
             if (!this_) {
+                YQL_CLOG(ERROR, ProviderDq) << "Session was closed: " << sessionId;
+                promise.SetException("Session was closed");
                 return;
             }
+            if (status.Ok()) {
+                YQL_CLOG(INFO, ProviderDq) << "OpenSession OK";
+                this_->SchedulePingSessionRequest(sessionId);
+                promise.SetValue();
+            } else {
+                YQL_CLOG(ERROR, ProviderDq) << "OpenSession error: " << status.Msg;
+                this_->DropSession(sessionId);
+                promise.SetException(status.Msg);
+            }
+        };
 
-            this_->StartQueryStatusRequest(sessionId, !fut.HasException());
-        });
+        Service->DoRequest<Yql::DqsProto::OpenSessionRequest, Yql::DqsProto::OpenSessionResponse>(
+            request, callback, &Yql::DqsProto::DqService::Stub::AsyncOpenSession, meta);
+
+       return MakeFuture();
     }
 
     void SchedulePingSessionRequest(const TString& sessionId) {
         auto self = weak_from_this();
-        auto callback = [self, sessionId](
-            NYdbGrpc::TGrpcStatus&& status,
-            Yql::DqsProto::PingSessionResponse&&) mutable
-        {
+        auto callback = [self, sessionId] (NYdbGrpc::TGrpcStatus&& status, Yql::DqsProto::PingSessionResponse&&) mutable {
             auto this_ = self.lock();
             if (!this_) {
                 return;
@@ -459,11 +481,12 @@ public:
 
             if (status.GRpcStatusCode == grpc::INVALID_ARGUMENT) {
                 YQL_CLOG(INFO, ProviderDq) << "Session closed " << sessionId;
+                this_->DropSession(sessionId);
             } else {
                 this_->SchedulePingSessionRequest(sessionId);
             }
         };
-        Delay(TDuration::Seconds(10)).Subscribe([self, callback, sessionId](const TFuture<void>&) {
+        TaskScheduler.Delay(TDuration::Seconds(10)).Subscribe([self, callback, sessionId](const TFuture<void>&) {
             auto this_ = self.lock();
             if (!this_) {
                 return;
@@ -479,21 +502,37 @@ public:
         });
     }
 
-    struct TDelay: public TTaskScheduler::ITask {
-        TDelay(TPromise<void> p)
-            : Promise(std::move(p))
-        { }
-
-        TInstant Process() override {
-            Promise.SetValue();
-            return TInstant::Max();
+    TFuture<void> CloseSessionAsync(const TString& sessionId) {
+        std::shared_ptr<TDqGatewaySession> session;
+        with_lock (Mutex) {
+            auto it = Sessions.find(sessionId);
+            if (it != Sessions.end()) {
+                session = it->second;
+                Sessions.erase(it);
+            }
         }
+        if (session) {
+            return session->Close();
+        }
+        return MakeFuture();
+    }
 
-        TPromise<void> Promise;
-    };
-
-    void Stop() {
-        GrpcClient.Stop();
+    TFuture<TResult> ExecutePlan(const TString& sessionId, NDqs::TPlan&& plan, const TVector<TString>& columns,
+        const THashMap<TString, TString>& secureParams, const THashMap<TString, TString>& graphParams,
+        const TDqSettings::TPtr& settings,
+        const TDqProgressWriter& progressWriter, const THashMap<TString, TString>& modulesMapping,
+        bool discard)
+    {
+        std::shared_ptr<TDqGatewaySession> session;
+        with_lock(Mutex) {
+            auto it = Sessions.find(sessionId);
+            if (it == Sessions.end()) {
+                YQL_CLOG(ERROR, ProviderDq) << "Session was closed: " << sessionId;
+                return MakeErrorFuture<TResult>(std::make_exception_ptr(std::runtime_error("Session was closed")));
+            }
+            session = it->second;
+        }
+        return session->ExecutePlan(std::move(plan), columns, secureParams, graphParams, settings, progressWriter, modulesMapping, discard);
     }
 
 private:
@@ -501,66 +540,58 @@ private:
     NYdbGrpc::TGRpcClientLow GrpcClient;
     std::unique_ptr<NYdbGrpc::TServiceConnection<Yql::DqsProto::DqService>> Service;
 
-    TMutex ProgressMutex;
-    TMutex Mutex;
-
-    struct TSession {
-        std::optional<TDqProgressWriter> ProgressWriter;
-        TString Status;
-        TFuture<void> OpenSessionFuture;
-    };
-    THashMap<TString, TSession> RunningQueries;
-    TString VanillaJobPath;
-    TString VanillaJobMd5;
-
-    TTaskScheduler TaskScheduler;
+    TDqTaskScheduler TaskScheduler;
     const TDuration OpenSessionTimeout;
+
+    TMutex Mutex;
+    THashMap<TString, std::shared_ptr<TDqGatewaySession>> Sessions;
 };
 
 class TDqGateway: public IDqGateway {
 public:
     TDqGateway(const TString& host, int port, const TString& vanillaJobPath, const TString& vanillaJobMd5, TDuration timeout = TDuration::Minutes(60), TDuration requestTimeout = TDuration::Max())
-        : Impl(std::make_shared<TDqGatewayImpl>(host, port, vanillaJobPath, vanillaJobMd5, timeout, requestTimeout))
-    { }
-
-    ~TDqGateway()
+        : Impl(std::make_shared<TDqGatewayImpl>(host, port, timeout, requestTimeout))
+        , VanillaJobPath(vanillaJobPath)
+        , VanillaJobMd5(vanillaJobMd5)
     {
-        Stop();
+    }
+
+    ~TDqGateway() {
     }
 
     void Stop() override {
         Impl->Stop();
     }
 
-    TFuture<void> OpenSession(const TString& sessionId, const TString& username) override
-    {
+    TFuture<void> OpenSession(const TString& sessionId, const TString& username) override {
         return Impl->OpenSession(sessionId, username);
     }
 
-    TFuture<void> CloseSessionAsync(const TString& sessionId) override
-    {
-        return Impl->CloseSession(sessionId);
+    TFuture<void> CloseSessionAsync(const TString& sessionId) override {
+        return Impl->CloseSessionAsync(sessionId);
     }
 
     TFuture<TResult> ExecutePlan(const TString& sessionId, NDqs::TPlan&& plan, const TVector<TString>& columns,
-                const THashMap<TString, TString>& secureParams, const THashMap<TString, TString>& graphParams,
-                const TDqSettings::TPtr& settings,
-                const TDqProgressWriter& progressWriter, const THashMap<TString, TString>& modulesMapping,
-                bool discard) override
+        const THashMap<TString, TString>& secureParams, const THashMap<TString, TString>& graphParams,
+        const TDqSettings::TPtr& settings,
+        const TDqProgressWriter& progressWriter, const THashMap<TString, TString>& modulesMapping,
+        bool discard) override
     {
         return Impl->ExecutePlan(sessionId, std::move(plan), columns, secureParams, graphParams, settings, progressWriter, modulesMapping, discard);
     }
 
     TString GetVanillaJobPath() override {
-        return Impl->GetVanillaJobPath();
+        return VanillaJobPath;
     }
 
     TString GetVanillaJobMd5() override {
-        return Impl->GetVanillaJobMd5();
+        return VanillaJobMd5;
     }
 
 private:
     std::shared_ptr<TDqGatewayImpl> Impl;
+    TString VanillaJobPath;
+    TString VanillaJobMd5;
 };
 
 TIntrusivePtr<IDqGateway> CreateDqGateway(const TString& host, int port) {