Browse Source

pg cancel request

xenoxeno 1 year ago
parent
commit
5e44523f6b

+ 5 - 0
.mapping.json

@@ -4989,6 +4989,11 @@
   "ydb/core/pgproxy/CMakeLists.linux-x86_64.txt":"",
   "ydb/core/pgproxy/CMakeLists.txt":"",
   "ydb/core/pgproxy/CMakeLists.windows-x86_64.txt":"",
+  "ydb/core/pgproxy/protos/CMakeLists.darwin-x86_64.txt":"",
+  "ydb/core/pgproxy/protos/CMakeLists.linux-aarch64.txt":"",
+  "ydb/core/pgproxy/protos/CMakeLists.linux-x86_64.txt":"",
+  "ydb/core/pgproxy/protos/CMakeLists.txt":"",
+  "ydb/core/pgproxy/protos/CMakeLists.windows-x86_64.txt":"",
   "ydb/core/pgproxy/ut/CMakeLists.darwin-x86_64.txt":"",
   "ydb/core/pgproxy/ut/CMakeLists.linux-aarch64.txt":"",
   "ydb/core/pgproxy/ut/CMakeLists.linux-x86_64.txt":"",

+ 1 - 0
ydb/apps/pgwire/pg_ydb_proxy.cpp

@@ -93,6 +93,7 @@ public:
         TActorId actorId = Register(actor);
         PgToYdbConnection[ev->Sender] = actorId;
         BLOG_D("Created ydb connection " << actorId);
+        Send(ev->Sender, new NPG::TEvPGEvents::TEvFinishHandshake(), 0, ev->Cookie);
     }
 
     void Handle(NPG::TEvPGEvents::TEvConnectionClosed::TPtr& ev) {

+ 60 - 32
ydb/core/local_pgwire/local_pgwire.cpp

@@ -1,4 +1,6 @@
 #include "log_impl.h"
+#include "local_pgwire.h"
+#include "local_pgwire_util.h"
 #include <ydb/core/pgproxy/pg_proxy_events.h>
 #include <ydb/core/grpc_services/local_rpc/local_rpc.h>
 #include <ydb/public/api/grpc/ydb_auth_v1.grpc.pb.h>
@@ -10,7 +12,7 @@ namespace NLocalPgWire {
 using namespace NActors;
 using namespace NKikimr;
 
-extern IActor* CreateConnection(std::unordered_map<TString, TString> params);
+extern IActor* CreateConnection(std::unordered_map<TString, TString> params, NPG::TEvPGEvents::TEvConnectionOpened::TPtr&& event, const TConnectionState& connection);
 
 class TPgYdbProxy : public TActor<TPgYdbProxy> {
     using TBase = TActor<TPgYdbProxy>;
@@ -45,9 +47,15 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
         };
     };
 
-    std::unordered_map<TActorId, TActorId> PgToYdbConnection;
+    struct TConnectionState {
+        TActorId YdbConnection;
+        uint32_t ConnectionNum;
+    };
+
+    std::unordered_map<TActorId, TConnectionState> ConnectionState;
     std::unordered_map<TActorId, TSecurityState> SecurityState;
     std::unordered_map<TString, TTokenState> TokenState;
+    uint32_t ConnectionNum = 0;
 
 public:
     TPgYdbProxy()
@@ -97,7 +105,7 @@ public:
 
     void Handle(NPG::TEvPGEvents::TEvAuth::TPtr& ev) {
         std::unordered_map<TString, TString> clientParams = ev->Get()->InitialMessage->GetClientParams();
-        BLOG_D("TEvAuth " << ev->Get()->InitialMessage->Dump());
+        BLOG_D("TEvAuth " << ev->Get()->InitialMessage->Dump() << " cookie " << ev->Cookie);
         Ydb::Auth::LoginRequest request;
         request.set_user(clientParams["user"]);
         if (ev->Get()->PasswordMessage) {
@@ -132,7 +140,7 @@ public:
     }
 
     void Handle(NPG::TEvPGEvents::TEvConnectionOpened::TPtr& ev) {
-        BLOG_D("TEvConnectionOpened " << ev->Sender);
+        BLOG_D("TEvConnectionOpened " << ev->Sender << " cookie " << ev->Cookie);
         auto params = ev->Get()->Message->GetClientParams();
         auto itSecurityState = SecurityState.find(ev->Sender);
         if (itSecurityState != SecurityState.end()) {
@@ -143,64 +151,83 @@ public:
                 params["ydb-serialized-token"] = itSecurityState->second.SerializedToken;
             }
         }
-        IActor* actor = CreateConnection(std::move(params));
+        auto& connectionState = ConnectionState[ev->Sender];
+        connectionState.ConnectionNum = ++ConnectionNum;
+        IActor* actor = CreateConnection(std::move(params), std::move(ev), {.ConnectionNum = connectionState.ConnectionNum});
         TActorId actorId = Register(actor);
-        PgToYdbConnection[ev->Sender] = actorId;
-        BLOG_D("Created ydb connection " << actorId);
+        connectionState.YdbConnection = actorId;
+        BLOG_D("Created ydb connection " << actorId << " num " << connectionState.ConnectionNum);
     }
 
     void Handle(NPG::TEvPGEvents::TEvConnectionClosed::TPtr& ev) {
-        BLOG_D("TEvConnectionClosed " << ev->Sender);
-        auto itConnection = PgToYdbConnection.find(ev->Sender);
-        if (itConnection != PgToYdbConnection.end()) {
-            Send(itConnection->second, new TEvents::TEvPoisonPill());
-            BLOG_D("Destroyed ydb connection " << itConnection->second);
+        BLOG_D("TEvConnectionClosed " << ev->Sender << " cookie " << ev->Cookie);
+        auto itConnection = ConnectionState.find(ev->Sender);
+        if (itConnection != ConnectionState.end()) {
+            Send(itConnection->second.YdbConnection, new TEvents::TEvPoisonPill());
+            BLOG_D("Destroyed ydb connection " << itConnection->second.YdbConnection << " num " << itConnection->second.ConnectionNum);
         }
         SecurityState.erase(ev->Sender);
+        ConnectionState.erase(itConnection);
         // TODO: cleanup TokenState too
     }
 
     void Handle(NPG::TEvPGEvents::TEvQuery::TPtr& ev) {
-        auto itConnection = PgToYdbConnection.find(ev->Sender);
-        if (itConnection != PgToYdbConnection.end()) {
-            Forward(ev, itConnection->second);
+        auto itConnection = ConnectionState.find(ev->Sender);
+        if (itConnection != ConnectionState.end()) {
+            Forward(ev, itConnection->second.YdbConnection);
         }
     }
 
     void Handle(NPG::TEvPGEvents::TEvParse::TPtr& ev) {
-        auto itConnection = PgToYdbConnection.find(ev->Sender);
-        if (itConnection != PgToYdbConnection.end()) {
-            Forward(ev, itConnection->second);
+        auto itConnection = ConnectionState.find(ev->Sender);
+        if (itConnection != ConnectionState.end()) {
+            Forward(ev, itConnection->second.YdbConnection);
         }
     }
 
     void Handle(NPG::TEvPGEvents::TEvBind::TPtr& ev) {
-        auto itConnection = PgToYdbConnection.find(ev->Sender);
-        if (itConnection != PgToYdbConnection.end()) {
-            Forward(ev, itConnection->second);
+        auto itConnection = ConnectionState.find(ev->Sender);
+        if (itConnection != ConnectionState.end()) {
+            Forward(ev, itConnection->second.YdbConnection);
         }
     }
 
     void Handle(NPG::TEvPGEvents::TEvDescribe::TPtr& ev) {
-        auto itConnection = PgToYdbConnection.find(ev->Sender);
-        if (itConnection != PgToYdbConnection.end()) {
-            Forward(ev, itConnection->second);
+        auto itConnection = ConnectionState.find(ev->Sender);
+        if (itConnection != ConnectionState.end()) {
+            Forward(ev, itConnection->second.YdbConnection);
         }
     }
 
     void Handle(NPG::TEvPGEvents::TEvExecute::TPtr& ev) {
-        auto itConnection = PgToYdbConnection.find(ev->Sender);
-        if (itConnection != PgToYdbConnection.end()) {
-            Forward(ev, itConnection->second);
-            BLOG_D("Forwarded to ydb connection " << itConnection->second);
+        auto itConnection = ConnectionState.find(ev->Sender);
+        if (itConnection != ConnectionState.end()) {
+            Forward(ev, itConnection->second.YdbConnection);
         }
     }
 
     void Handle(NPG::TEvPGEvents::TEvClose::TPtr& ev) {
-        auto itConnection = PgToYdbConnection.find(ev->Sender);
-        if (itConnection != PgToYdbConnection.end()) {
-            Forward(ev, itConnection->second);
-            BLOG_D("Forwarded to ydb connection " << itConnection->second);
+        auto itConnection = ConnectionState.find(ev->Sender);
+        if (itConnection != ConnectionState.end()) {
+            Forward(ev, itConnection->second.YdbConnection);
+        }
+    }
+
+    void Handle(NPG::TEvPGEvents::TEvCancelRequest::TPtr& ev) {
+        uint32_t nodeId = ev->Get()->Record.GetProcessId();
+        if (nodeId == SelfId().NodeId()) {
+            uint32_t connectionNum = ev->Get()->Record.GetSecretKey();
+            for (const auto& [pgConnectionId, connectionState] : ConnectionState) {
+                if (connectionState.ConnectionNum == connectionNum) {
+                    BLOG_D("Cancelling ConnectionNum " << connectionNum);
+                    Forward(ev, connectionState.YdbConnection);
+                    return;
+                }
+            }
+            BLOG_W("Cancelling ConnectionNum " << connectionNum << " - connection not found");
+        } else {
+            BLOG_D("Forwarding TEvCancelRequest to Node " << nodeId);
+            Forward(ev, CreateLocalPgWireProxyId(nodeId));
         }
     }
 
@@ -215,6 +242,7 @@ public:
             hFunc(NPG::TEvPGEvents::TEvDescribe, Handle);
             hFunc(NPG::TEvPGEvents::TEvExecute, Handle);
             hFunc(NPG::TEvPGEvents::TEvClose, Handle);
+            hFunc(NPG::TEvPGEvents::TEvCancelRequest, Handle);
             hFunc(TEvPrivate::TEvTokenReady, Handle);
             hFunc(TEvTicketParser::TEvAuthorizeTicketResult, Handle);
         }

+ 1 - 1
ydb/core/local_pgwire/local_pgwire.h

@@ -2,7 +2,7 @@
 
 namespace NLocalPgWire {
 
-inline NActors::TActorId CreateLocalPgWireProxyId() { return NActors::TActorId(0, "localpgwire"); }
+inline NActors::TActorId CreateLocalPgWireProxyId(uint32_t nodeId = 0) { return NActors::TActorId(nodeId, "localpgwire"); }
 NActors::IActor* CreateLocalPgWireProxy();
 
 }

+ 67 - 9
ydb/core/local_pgwire/local_pgwire_connection.cpp

@@ -22,22 +22,62 @@ namespace NLocalPgWire {
 using namespace NActors;
 using namespace NKikimr;
 
-class TPgYdbConnection : public TActor<TPgYdbConnection> {
-    using TBase = TActor<TPgYdbConnection>;
+class TPgYdbConnection : public TActorBootstrapped<TPgYdbConnection> {
+    using TBase = TActorBootstrapped<TPgYdbConnection>;
 
     std::unordered_map<TString, TString> ConnectionParams;
+    NPG::TEvPGEvents::TEvConnectionOpened::TPtr ConnectionEvent;
     std::unordered_map<TString, TParsedStatement> ParsedStatements;
     std::unordered_map<TString, TPortal> Portals;
     TConnectionState Connection;
     std::deque<TAutoPtr<IEventHandle>> Events;
     ui32 Inflight = 0;
+    std::unordered_set<TActorId> CurrentRunningQueries;
 
 public:
-    TPgYdbConnection(std::unordered_map<TString, TString> params)
-        : TActor<TPgYdbConnection>(&TPgYdbConnection::StateSchedule)
-        , ConnectionParams(std::move(params))
+    TPgYdbConnection(std::unordered_map<TString, TString> params, NPG::TEvPGEvents::TEvConnectionOpened::TPtr&& event, const TConnectionState& connection)
+        : ConnectionParams(std::move(params))
+        , ConnectionEvent(std::move(event))
+        , Connection(connection)
     {}
 
+    void Bootstrap() {
+        TString database;
+        if (ConnectionParams.count("database")) {
+            database = ConnectionParams["database"];
+        }
+        auto ev = MakeHolder<NKqp::TEvKqp::TEvCreateSessionRequest>();
+        NKikimrKqp::TCreateSessionRequest& request = *ev->Record.MutableRequest();
+        request.SetDatabase(database);
+        BLOG_D("Sent CreateSessionRequest to kqpProxy " << ev->Record.ShortDebugString());
+        Send(NKqp::MakeKqpProxyID(SelfId().NodeId()), ev.Release());
+        TBase::Become(&TPgYdbConnection::StateCreateSession);
+    }
+
+    void Handle(NKqp::TEvKqp::TEvCreateSessionResponse::TPtr& ev) {
+        const auto& record(ev->Get()->Record);
+        BLOG_D("Received TEvCreateSessionResponse " << record.ShortDebugString());
+        if (record.GetYdbStatus() == Ydb::StatusIds::SUCCESS) {
+            BLOG_D("Session id is " << record.GetResponse().GetSessionId());
+            Connection.SessionId = record.GetResponse().GetSessionId();
+
+            auto response = MakeHolder<NPG::TEvPGEvents::TEvFinishHandshake>();
+            response->BackendData.Pid = SelfId().NodeId();
+            response->BackendData.Key = Connection.ConnectionNum;
+            Send(ConnectionEvent->Sender, response.Release(), 0, ev->Cookie);
+            TBase::Become(&TPgYdbConnection::StateSchedule);
+            ConnectionEvent.Destroy(); // don't need it anymore
+        } else {
+            BLOG_W("Failed to create session: " << record.ShortDebugString());
+            auto response = MakeHolder<NPG::TEvPGEvents::TEvFinishHandshake>();
+            // TODO: report actuall error
+            response->ErrorFields.push_back({'E', "ERROR"});
+            response->ErrorFields.push_back({'M', record.GetError()});
+            Send(ConnectionEvent->Sender, response.Release(), 0, ev->Cookie);
+            return PassAway();
+        }
+    }
+
     void ProcessEventsQueue() {
         while (!Events.empty() && Inflight == 0) {
             StateWork(Events.front());
@@ -58,6 +98,7 @@ public:
         ++Inflight;
         TActorId actorId = RegisterWithSameMailbox(CreatePgwireKqpProxyQuery(SelfId(), ConnectionParams, Connection, std::move(ev)));
         BLOG_D("Created pgwireKqpProxyQuery: " << actorId);
+        CurrentRunningQueries.insert(actorId);
     }
 
     void Handle(NPG::TEvPGEvents::TEvQuery::TPtr& ev) {
@@ -85,7 +126,7 @@ public:
         ++Inflight;
         TActorId actorId = RegisterWithSameMailbox(CreatePgwireKqpProxyParse(SelfId(), ConnectionParams, Connection, std::move(ev)));
         BLOG_D("Created pgwireKqpProxyParse: " << actorId);
-        return;
+        CurrentRunningQueries.insert(actorId);
     }
 
     void Handle(NPG::TEvPGEvents::TEvBind::TPtr& ev) {
@@ -183,6 +224,7 @@ public:
         ++Inflight;
         TActorId actorId = RegisterWithSameMailbox(CreatePgwireKqpProxyExecute(SelfId(), ConnectionParams, Connection, std::move(ev), it->second));
         BLOG_D("Created pgwireKqpProxyExecute: " << actorId);
+        CurrentRunningQueries.insert(actorId);
     }
 
     void Handle(TEvEvents::TEvUpdateStatement::TPtr& ev) {
@@ -216,23 +258,39 @@ public:
             BLOG_D("Session id is " << connection.SessionId);
             Connection.SessionId = connection.SessionId;
         }
+        CurrentRunningQueries.erase(ev->Sender);
         ProcessEventsQueue();
     }
 
+    void Handle(NPG::TEvPGEvents::TEvCancelRequest::TPtr&) {
+        BLOG_D("Received TEvCancelRequest");
+        for (const TActorId& actor : CurrentRunningQueries) {
+            Send(actor, new TEvEvents::TEvCancelRequest());
+        }
+    }
+
     void PassAway() override {
         if (Connection.SessionId) {
-            BLOG_D("Closing session " << Connection.SessionId);
             auto ev = MakeHolder<NKqp::TEvKqp::TEvCloseSessionRequest>();
             ev->Record.MutableRequest()->SetSessionId(Connection.SessionId);
+            BLOG_D("Closing session " << Connection.SessionId << ", sent event to kqpProxy " << ev->Record.ShortDebugString());
             Send(NKqp::MakeKqpProxyID(SelfId().NodeId()), ev.Release());
         }
         TBase::PassAway();
     }
 
+    STATEFN(StateCreateSession) {
+        switch (ev->GetTypeRewrite()) {
+            hFunc(NKqp::TEvKqp::TEvCreateSessionResponse, Handle);
+            cFunc(TEvents::TEvPoisonPill::EventType, PassAway);
+        }
+    }
+
     STATEFN(StateSchedule) {
         switch (ev->GetTypeRewrite()) {
             hFunc(TEvEvents::TEvProxyCompleted, Handle);
             hFunc(TEvEvents::TEvUpdateStatement, Handle);
+            hFunc(NPG::TEvPGEvents::TEvCancelRequest, Handle);
             cFunc(TEvents::TEvPoisonPill::EventType, PassAway);
             default: {
                 if (Inflight == 0) {
@@ -259,8 +317,8 @@ public:
 };
 
 
-NActors::IActor* CreateConnection(std::unordered_map<TString, TString> params) {
-    return new TPgYdbConnection(std::move(params));
+NActors::IActor* CreateConnection(std::unordered_map<TString, TString> params, NPG::TEvPGEvents::TEvConnectionOpened::TPtr&& event, const TConnectionState& connection) {
+    return new TPgYdbConnection(std::move(params), std::move(event), connection);
 }
 
 }

+ 6 - 0
ydb/core/local_pgwire/local_pgwire_util.h

@@ -27,6 +27,7 @@ struct TTransactionState {
 struct TConnectionState {
     TString SessionId;
     TTransactionState Transaction;
+    uint32_t ConnectionNum = 0;
 };
 
 struct TParsedStatement {
@@ -54,6 +55,7 @@ struct TEvEvents {
         EvProxyCompleted = EventSpaceBegin(NActors::TEvents::ES_PRIVATE),
         EvUpdateStatement,
         EvSingleQuery,
+        EvCancelRequest,
         EvEnd
     };
 
@@ -92,6 +94,10 @@ struct TEvEvents {
             return std::make_unique<NPG::TEvPGEvents::TEvQueryResponse>();
         }
     };
+
+    struct TEvCancelRequest : NActors::TEventLocal<TEvCancelRequest, EvCancelRequest> {
+        TEvCancelRequest() = default;
+    };
 };
 
 TString ColumnPrimitiveValueToString(NYdb::TValueParser& valueParser);

+ 17 - 1
ydb/core/local_pgwire/pgwire_kqp_proxy.cpp

@@ -241,7 +241,7 @@ protected:
     void ReplyWithResponseAndPassAway() {
         Response_->TransactionStatus = Connection_.Transaction.Status;
         TBase::Send(Owner_, new TEvEvents::TEvProxyCompleted(Connection_));
-        BLOG_D("Finally replying to " << EventRequest_->Sender);
+        BLOG_D("Finally replying to " << EventRequest_->Sender << " cookie " << EventRequest_->Cookie);
         TBase::Send(EventRequest_->Sender, Response_.release(), 0, EventRequest_->Cookie);
         TBase::PassAway();
     }
@@ -297,7 +297,20 @@ protected:
         return ReplyWithResponseAndPassAway();
     }
 
+    void Handle(TEvEvents::TEvCancelRequest::TPtr&) {
+        auto ev = MakeHolder<NKqp::TEvKqp::TEvCancelQueryRequest>();
+        if (Connection_.SessionId) {
+            ev->Record.MutableRequest()->SetSessionId(Connection_.SessionId);
+        }
+        BLOG_D("Sent CancelQueryRequest to kqpProxy " << ev->Record.ShortDebugString());
+        TBase::Send(NKqp::MakeKqpProxyID(TBase::SelfId().NodeId()), ev.Release());
 
+        Response_->ErrorFields.push_back({'S', "ERROR"});
+        Response_->ErrorFields.push_back({'V', "ERROR"});
+        Response_->ErrorFields.push_back({'C', "57014"});
+        Response_->ErrorFields.push_back({'M', "Cancelling statement due to user request"});
+        return ReplyWithResponseAndPassAway();
+    }
 };
 
 class TPgwireKqpProxyQuery : public TPgwireKqpProxy<TPgwireKqpProxyQuery, TEvEvents::TEvSingleQuery> {
@@ -330,6 +343,7 @@ public:
         switch (ev->GetTypeRewrite()) {
             hFunc(NKqp::TEvKqp::TEvQueryResponse, TBase::Handle);
             hFunc(NKqp::TEvKqpExecuter::TEvStreamData, TBase::Handle);
+            hFunc(TEvEvents::TEvCancelRequest, Handle);
         }
     }
 };
@@ -416,6 +430,7 @@ public:
     STATEFN(StateWork) {
         switch (ev->GetTypeRewrite()) {
             hFunc(NKqp::TEvKqp::TEvQueryResponse, Handle);
+            hFunc(TEvEvents::TEvCancelRequest, TBase::Handle);
         }
     }
 };
@@ -460,6 +475,7 @@ public:
         switch (ev->GetTypeRewrite()) {
             hFunc(NKqp::TEvKqp::TEvQueryResponse, TBase::Handle);
             hFunc(NKqp::TEvKqpExecuter::TEvStreamData, TBase::Handle);
+            hFunc(TEvEvents::TEvCancelRequest, Handle);
         }
     }
 };

+ 2 - 0
ydb/core/pgproxy/CMakeLists.darwin-x86_64.txt

@@ -6,6 +6,7 @@
 # original buildsystem will not be accepted.
 
 
+add_subdirectory(protos)
 add_subdirectory(ut)
 
 add_library(ydb-core-pgproxy)
@@ -16,6 +17,7 @@ target_link_libraries(ydb-core-pgproxy PUBLIC
   cpp-actors-protos
   cpp-string_utils-base64
   ydb-core-base
+  core-pgproxy-protos
   ydb-core-protos
   ydb-core-raw_socket
 )

+ 2 - 0
ydb/core/pgproxy/CMakeLists.linux-aarch64.txt

@@ -6,6 +6,7 @@
 # original buildsystem will not be accepted.
 
 
+add_subdirectory(protos)
 add_subdirectory(ut)
 
 add_library(ydb-core-pgproxy)
@@ -17,6 +18,7 @@ target_link_libraries(ydb-core-pgproxy PUBLIC
   cpp-actors-protos
   cpp-string_utils-base64
   ydb-core-base
+  core-pgproxy-protos
   ydb-core-protos
   ydb-core-raw_socket
 )

+ 2 - 0
ydb/core/pgproxy/CMakeLists.linux-x86_64.txt

@@ -6,6 +6,7 @@
 # original buildsystem will not be accepted.
 
 
+add_subdirectory(protos)
 add_subdirectory(ut)
 
 add_library(ydb-core-pgproxy)
@@ -17,6 +18,7 @@ target_link_libraries(ydb-core-pgproxy PUBLIC
   cpp-actors-protos
   cpp-string_utils-base64
   ydb-core-base
+  core-pgproxy-protos
   ydb-core-protos
   ydb-core-raw_socket
 )

Some files were not shown because too many files changed in this diff