Browse Source

KIKIMR-19791: Fix data race in ldap mock

molotkov-and 1 year ago
parent
commit
db7beb649a

+ 48 - 44
ydb/core/security/ticket_parser_ut.cpp

@@ -72,9 +72,9 @@ void InitLdapSettingsWithInvalidFilter(NKikimrProto::TLdapAuthentication* ldapSe
     ldapSettings->SetSearchFilter("&(uid=$username)()");
     ldapSettings->SetSearchFilter("&(uid=$username)()");
 }
 }
 
 
-void InitLdapSettingsWithUnavaliableHost(NKikimrProto::TLdapAuthentication* ldapSettings, ui16 ldapPort, TTempFileHandle& certificateFile) {
+void InitLdapSettingsWithUnavailableHost(NKikimrProto::TLdapAuthentication* ldapSettings, ui16 ldapPort, TTempFileHandle& certificateFile) {
     InitLdapSettings(ldapSettings, ldapPort, certificateFile);
     InitLdapSettings(ldapSettings, ldapPort, certificateFile);
-    ldapSettings->SetHost("unavaliablehost");
+    ldapSettings->SetHost("unavailablehost");
 }
 }
 
 
 void InitLdapSettingsWithCustomGroupAttribute(NKikimrProto::TLdapAuthentication* ldapSettings, ui16 ldapPort, TTempFileHandle& certificateFile) {
 void InitLdapSettingsWithCustomGroupAttribute(NKikimrProto::TLdapAuthentication* ldapSettings, ui16 ldapPort, TTempFileHandle& certificateFile) {
@@ -710,8 +710,8 @@ Y_UNIT_TEST_SUITE(TTicketParserTest) {
         ldapServer.Stop();
         ldapServer.Stop();
     }
     }
 
 
-    Y_UNIT_TEST(LdapServerIsUnavaliable) {
-        TLdapKikimrServer server(InitLdapSettingsWithUnavaliableHost);
+    Y_UNIT_TEST(LdapServerIsUnavailable) {
+        TLdapKikimrServer server(InitLdapSettingsWithUnavailableHost);
 
 
         LdapMock::TLdapMockResponses responses;
         LdapMock::TLdapMockResponses responses;
         LdapMock::TLdapSimpleServer ldapServer(server.GetLdapPort(), responses);
         LdapMock::TLdapSimpleServer ldapServer(server.GetLdapPort(), responses);
@@ -731,39 +731,14 @@ Y_UNIT_TEST_SUITE(TTicketParserTest) {
         TString login = "ldapuser";
         TString login = "ldapuser";
         TString password = "ldapUserPassword";
         TString password = "ldapUserPassword";
 
 
-        TLdapKikimrServer server(InitLdapSettings);
-        auto responses = TCorrectLdapResponse::GetResponses(login);
-        LdapMock::TLdapSimpleServer ldapServer(server.GetLdapPort(), responses);
-
-        auto loginResponse = GetLoginResponse(server, login, password);
-        TTestActorRuntime* runtime = server.GetRuntime();
-        TActorId sender = runtime->AllocateEdgeActor();
-        runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvAuthorizeTicket(loginResponse.Token)), 0);
-        TAutoPtr<IEventHandle> handle;
-        TEvTicketParser::TEvAuthorizeTicketResult* ticketParserResult = runtime->GrabEdgeEvent<TEvTicketParser::TEvAuthorizeTicketResult>(handle);
-
-        UNIT_ASSERT_C(ticketParserResult->Error.empty(), ticketParserResult->Error);
-        UNIT_ASSERT(ticketParserResult->Token != nullptr);
-        const TString ldapDomain = "@ldap";
-        UNIT_ASSERT_VALUES_EQUAL(ticketParserResult->Token->GetUserSID(), login + ldapDomain);
-        const auto& fetchedGroups = ticketParserResult->Token->GetGroupSIDs();
-        THashSet<TString> groups(fetchedGroups.begin(), fetchedGroups.end());
 
 
-        THashSet<TString> expectedGroups;
-        std::transform(TCorrectLdapResponse::Groups.begin(), TCorrectLdapResponse::Groups.end(), std::inserter(expectedGroups, expectedGroups.end()), [&ldapDomain](TString& group) {
-            return group.append(ldapDomain);
-        });
-        expectedGroups.insert("all-users@well-known");
-
-        UNIT_ASSERT_VALUES_EQUAL(fetchedGroups.size(), expectedGroups.size());
-        for (const auto& expectedGroup : expectedGroups) {
-            UNIT_ASSERT_C(groups.contains(expectedGroup), "Can not find " + expectedGroup);
-        }
+        auto responses = TCorrectLdapResponse::GetResponses(login);
+        LdapMock::TLdapMockResponses updatedResponses = responses;
 
 
         std::vector<TString> newLdapGroups {
         std::vector<TString> newLdapGroups {
             "ou=groups,dc=search,dc=yandex,dc=net",
             "ou=groups,dc=search,dc=yandex,dc=net",
             "cn=people,ou=groups,dc=search,dc=yandex,dc=net",
             "cn=people,ou=groups,dc=search,dc=yandex,dc=net",
-            "cn=desiners,ou=groups,dc=search,dc=yandex,dc=net"
+            "cn=designers,ou=groups,dc=search,dc=yandex,dc=net"
         };
         };
         std::vector<LdapMock::TSearchEntry> newFetchGroupsSearchResponseEntries {
         std::vector<LdapMock::TSearchEntry> newFetchGroupsSearchResponseEntries {
             {
             {
@@ -774,6 +749,7 @@ Y_UNIT_TEST_SUITE(TTicketParserTest) {
             }
             }
         };
         };
 
 
+        const TString ldapDomain = "@ldap";
         THashSet<TString> newExpectedGroups;
         THashSet<TString> newExpectedGroups;
         std::transform(newLdapGroups.begin(), newLdapGroups.end(), std::inserter(newExpectedGroups, newExpectedGroups.end()), [&ldapDomain](TString& group) {
         std::transform(newLdapGroups.begin(), newLdapGroups.end(), std::inserter(newExpectedGroups, newExpectedGroups.end()), [&ldapDomain](TString& group) {
             return group.append(ldapDomain);
             return group.append(ldapDomain);
@@ -785,9 +761,37 @@ Y_UNIT_TEST_SUITE(TTicketParserTest) {
             .ResponseDone = {.Status = LdapMock::EStatus::SUCCESS}
             .ResponseDone = {.Status = LdapMock::EStatus::SUCCESS}
         };
         };
 
 
-        auto& searchresponse = responses.SearchResponses.front();
-        searchresponse.second = newFetchGroupsSearchResponseInfo;
-        ldapServer.SetSearchReasponse(searchresponse);
+        auto& searchResponse = updatedResponses.SearchResponses.front();
+        searchResponse.second = newFetchGroupsSearchResponseInfo;
+
+        TLdapKikimrServer server(InitLdapSettings);
+        LdapMock::TLdapSimpleServer ldapServer(server.GetLdapPort(), {responses, updatedResponses});
+
+        auto loginResponse = GetLoginResponse(server, login, password);
+        TTestActorRuntime* runtime = server.GetRuntime();
+        TActorId sender = runtime->AllocateEdgeActor();
+        runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvAuthorizeTicket(loginResponse.Token)), 0);
+        TAutoPtr<IEventHandle> handle;
+        TEvTicketParser::TEvAuthorizeTicketResult* ticketParserResult = runtime->GrabEdgeEvent<TEvTicketParser::TEvAuthorizeTicketResult>(handle);
+
+        UNIT_ASSERT_C(ticketParserResult->Error.empty(), ticketParserResult->Error);
+        UNIT_ASSERT(ticketParserResult->Token != nullptr);
+        UNIT_ASSERT_VALUES_EQUAL(ticketParserResult->Token->GetUserSID(), login + ldapDomain);
+        const auto& fetchedGroups = ticketParserResult->Token->GetGroupSIDs();
+        THashSet<TString> groups(fetchedGroups.begin(), fetchedGroups.end());
+
+        THashSet<TString> expectedGroups;
+        std::transform(TCorrectLdapResponse::Groups.begin(), TCorrectLdapResponse::Groups.end(), std::inserter(expectedGroups, expectedGroups.end()), [&ldapDomain](TString& group) {
+            return group.append(ldapDomain);
+        });
+        expectedGroups.insert("all-users@well-known");
+
+        UNIT_ASSERT_VALUES_EQUAL(fetchedGroups.size(), expectedGroups.size());
+        for (const auto& expectedGroup : expectedGroups) {
+            UNIT_ASSERT_C(groups.contains(expectedGroup), "Can not find " + expectedGroup);
+        }
+
+        ldapServer.UpdateResponses();
         Sleep(TDuration::Seconds(10));
         Sleep(TDuration::Seconds(10));
 
 
         runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvAuthorizeTicket(loginResponse.Token)), 0);
         runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvAuthorizeTicket(loginResponse.Token)), 0);
@@ -812,8 +816,15 @@ Y_UNIT_TEST_SUITE(TTicketParserTest) {
 
 
         TLdapKikimrServer server(InitLdapSettings);
         TLdapKikimrServer server(InitLdapSettings);
         auto responses = TCorrectLdapResponse::GetResponses(login);
         auto responses = TCorrectLdapResponse::GetResponses(login);
-        LdapMock::TLdapSimpleServer ldapServer(server.GetLdapPort(), responses);
+        LdapMock::TLdapMockResponses updatedResponses = responses;
+        LdapMock::TSearchResponseInfo newFetchGroupsSearchResponseInfo {
+            .ResponseEntries = {}, // User has been removed. Return empty entries list
+            .ResponseDone = {.Status = LdapMock::EStatus::SUCCESS}
+        };
 
 
+        auto& searchResponse = updatedResponses.SearchResponses.front();
+        searchResponse.second = newFetchGroupsSearchResponseInfo;
+        LdapMock::TLdapSimpleServer ldapServer(server.GetLdapPort(), {responses, updatedResponses});
 
 
         auto loginResponse = GetLoginResponse(server, login, password);
         auto loginResponse = GetLoginResponse(server, login, password);
         TTestActorRuntime* runtime = server.GetRuntime();
         TTestActorRuntime* runtime = server.GetRuntime();
@@ -840,14 +851,7 @@ Y_UNIT_TEST_SUITE(TTicketParserTest) {
             UNIT_ASSERT_C(groups.contains(expectedGroup), "Can not find " + expectedGroup);
             UNIT_ASSERT_C(groups.contains(expectedGroup), "Can not find " + expectedGroup);
         }
         }
 
 
-        LdapMock::TSearchResponseInfo newFetchGroupsSearchResponseInfo {
-            .ResponseEntries = {}, // User has been removed. Return empty entries list
-            .ResponseDone = {.Status = LdapMock::EStatus::SUCCESS}
-        };
-
-        auto& searchresponse = responses.SearchResponses.front();
-        searchresponse.second = newFetchGroupsSearchResponseInfo;
-        ldapServer.SetSearchReasponse(searchresponse);
+        ldapServer.UpdateResponses();
         Sleep(TDuration::Seconds(10));
         Sleep(TDuration::Seconds(10));
 
 
         runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvAuthorizeTicket(loginResponse.Token)), 0);
         runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvAuthorizeTicket(loginResponse.Token)), 0);

+ 34 - 34
ydb/library/testlib/service_mocks/ldap_mock/ldap_message_processor.cpp

@@ -15,7 +15,7 @@ struct TResponseInfo {
     TString DiagnosticMsg = "";
     TString DiagnosticMsg = "";
 };
 };
 
 
-TString CreateResopnse(const TResponseInfo& responseInfo) {
+TString CreateResponse(const TResponseInfo& responseInfo) {
     TString result = EncodeEnum(static_cast<int>(responseInfo.Status));
     TString result = EncodeEnum(static_cast<int>(responseInfo.Status));
     result += EncodeString(responseInfo.MatchedDN);
     result += EncodeString(responseInfo.MatchedDN);
     result += EncodeString(responseInfo.DiagnosticMsg);
     result += EncodeString(responseInfo.DiagnosticMsg);
@@ -23,7 +23,7 @@ TString CreateResopnse(const TResponseInfo& responseInfo) {
 }
 }
 
 
 TString CreateExtendedResponse(const TResponseInfo& responseInfo, const TString& oid = "", const TString& oidValue = "") {
 TString CreateExtendedResponse(const TResponseInfo& responseInfo, const TString& oid = "", const TString& oidValue = "") {
-    TString result = CreateResopnse(responseInfo);
+    TString result = CreateResponse(responseInfo);
     result += EncodeString(oid);
     result += EncodeString(oid);
     result += EncodeString(oidValue);
     result += EncodeString(oidValue);
     return result;
     return result;
@@ -70,11 +70,11 @@ std::vector<TLdapRequestProcessor::TProtocolOpData> CreateSearchEntryResponses(c
     return result;
     return result;
 }
 }
 
 
-}
+} // namespace
 
 
 TLdapRequestProcessor::TLdapRequestProcessor(TAtomicSharedPtr<TLdapSocketWrapper> socket)
 TLdapRequestProcessor::TLdapRequestProcessor(TAtomicSharedPtr<TLdapSocketWrapper> socket)
-        : Socket(socket)
-    {}
+    : Socket(socket)
+{}
 
 
 void TLdapRequestProcessor::SslAccept() {
 void TLdapRequestProcessor::SslAccept() {
     Socket->SslAccept();
     Socket->SslAccept();
@@ -152,7 +152,7 @@ std::vector<TLdapRequestProcessor::TProtocolOpData> TLdapRequestProcessor::Proce
             return ProcessExtendedRequest();
             return ProcessExtendedRequest();
         }
         }
         default: {
         default: {
-            return {{.Type = EProtocolOp::UNKNOWN_OP, .Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR})}};
+            return {{.Type = EProtocolOp::UNKNOWN_OP, .Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR})}};
         }
         }
     }
     }
 }
 }
@@ -161,10 +161,10 @@ std::vector<TLdapRequestProcessor::TProtocolOpData> TLdapRequestProcessor::Proce
     TProtocolOpData responseOpData;
     TProtocolOpData responseOpData;
     responseOpData.Type  = EProtocolOp::EXTENDED_OP_RESPONSE;
     responseOpData.Type  = EProtocolOp::EXTENDED_OP_RESPONSE;
 
 
-    size_t lenght = GetLength();
+    size_t length = GetLength();
 
 
-    if (lenght == 0) {
-        responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+    if (length == 0) {
+        responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
         return {responseOpData};
         return {responseOpData};
     }
     }
 
 
@@ -188,22 +188,22 @@ std::vector<TLdapRequestProcessor::TProtocolOpData> TLdapRequestProcessor::Proce
     TProtocolOpData responseOpData;
     TProtocolOpData responseOpData;
     responseOpData.Type = EProtocolOp::BIND_OP_RESPONSE;
     responseOpData.Type = EProtocolOp::BIND_OP_RESPONSE;
 
 
-    size_t lenght = GetLength();
+    size_t length = GetLength();
 
 
-    if (lenght == 0) {
-        responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+    if (length == 0) {
+        responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
         return {responseOpData};
         return {responseOpData};
     }
     }
 
 
     int version = GetInt();
     int version = GetInt();
     if (version > 127) {
     if (version > 127) {
-        responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+        responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
         return {responseOpData};
         return {responseOpData};
     }
     }
 
 
     unsigned char elementType = GetByte();
     unsigned char elementType = GetByte();
     if (elementType != EElementType::STRING) {
     if (elementType != EElementType::STRING) {
-        responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+        responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
         return {responseOpData};
         return {responseOpData};
     }
     }
 
 
@@ -220,11 +220,11 @@ std::vector<TLdapRequestProcessor::TProtocolOpData> TLdapRequestProcessor::Proce
     });
     });
 
 
     if (it == responses.end()) {
     if (it == responses.end()) {
-        responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+        responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
         return {responseOpData};
         return {responseOpData};
     }
     }
 
 
-    responseOpData.Data = CreateResopnse({.Status = it->second.Status, .MatchedDN = it->second.MatchedDN, .DiagnosticMsg = it->second.DiagnosticMsg});
+    responseOpData.Data = CreateResponse({.Status = it->second.Status, .MatchedDN = it->second.MatchedDN, .DiagnosticMsg = it->second.DiagnosticMsg});
     return {responseOpData};
     return {responseOpData};
 }
 }
 
 
@@ -236,41 +236,41 @@ std::vector<TLdapRequestProcessor::TProtocolOpData> TLdapRequestProcessor::Proce
 
 
     size_t length = GetLength();
     size_t length = GetLength();
     if (length == 0) {
     if (length == 0) {
-        responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+        responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
         return {responseOpData};
         return {responseOpData};
     }
     }
 
 
-    // extract BaseDn
+    // Extract BaseDn
     unsigned char elementType = GetByte();
     unsigned char elementType = GetByte();
     if (elementType != EElementType::STRING) {
     if (elementType != EElementType::STRING) {
-        responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+        responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
         return {responseOpData};
         return {responseOpData};
     }
     }
 
 
     requestInfo.BaseDn = GetString();
     requestInfo.BaseDn = GetString();
 
 
-    // Extarct scope
+    // Extract scope
     elementType = GetByte();
     elementType = GetByte();
     if (elementType != EElementType::ENUMERATED) {
     if (elementType != EElementType::ENUMERATED) {
-        responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+        responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
         return {responseOpData};
         return {responseOpData};
     }
     }
     length = GetLength();
     length = GetLength();
     if (length == 0) {
     if (length == 0) {
-        responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+        responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
         return {responseOpData};
         return {responseOpData};
     }
     }
     requestInfo.Scope = GetByte();
     requestInfo.Scope = GetByte();
 
 
-    // Extract derefAlliases
+    // Extract derefAliases
     elementType = GetByte();
     elementType = GetByte();
     if (elementType != EElementType::ENUMERATED) {
     if (elementType != EElementType::ENUMERATED) {
-        responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+        responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
         return {responseOpData};
         return {responseOpData};
     }
     }
     length = GetLength();
     length = GetLength();
     if (length == 0) {
     if (length == 0) {
-        responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+        responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
         return {responseOpData};
         return {responseOpData};
     }
     }
     requestInfo.DerefAliases = GetByte();
     requestInfo.DerefAliases = GetByte();
@@ -279,29 +279,29 @@ std::vector<TLdapRequestProcessor::TProtocolOpData> TLdapRequestProcessor::Proce
     int sizeLimit = GetInt();
     int sizeLimit = GetInt();
     Y_UNUSED(sizeLimit);
     Y_UNUSED(sizeLimit);
 
 
-    // Extarct timeLimit
+    // Extract timeLimit
     int timeLimit = GetInt();
     int timeLimit = GetInt();
     Y_UNUSED(timeLimit);
     Y_UNUSED(timeLimit);
 
 
-    // Extact typesOnly
+    // Extract typesOnly
     elementType = GetByte();
     elementType = GetByte();
     if (elementType != EElementType::BOOL) {
     if (elementType != EElementType::BOOL) {
-        responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+        responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
         return {responseOpData};
         return {responseOpData};
     }
     }
     length = GetLength();
     length = GetLength();
     if (length == 0) {
     if (length == 0) {
-        responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+        responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
         return {responseOpData};
         return {responseOpData};
     }
     }
     requestInfo.TypesOnly = GetByte();
     requestInfo.TypesOnly = GetByte();
 
 
     requestInfo.Filter = ProcessFilter();
     requestInfo.Filter = ProcessFilter();
 
 
-    // Extarct Attributes
+    // Extract Attributes
     elementType = GetByte();
     elementType = GetByte();
     if (elementType != EElementType::SEQUENCE) {
     if (elementType != EElementType::SEQUENCE) {
-        responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+        responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
         return {responseOpData};
         return {responseOpData};
     }
     }
     length = GetLength();
     length = GetLength();
@@ -309,7 +309,7 @@ std::vector<TLdapRequestProcessor::TProtocolOpData> TLdapRequestProcessor::Proce
     while (ReadBytes < limit) {
     while (ReadBytes < limit) {
         elementType = GetByte();
         elementType = GetByte();
         if (elementType != EElementType::STRING) {
         if (elementType != EElementType::STRING) {
-            responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+            responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
             return {responseOpData};
             return {responseOpData};
         }
         }
         requestInfo.Attributes.push_back(GetString());
         requestInfo.Attributes.push_back(GetString());
@@ -321,14 +321,14 @@ std::vector<TLdapRequestProcessor::TProtocolOpData> TLdapRequestProcessor::Proce
     });
     });
 
 
     if (it == responses.end()) {
     if (it == responses.end()) {
-        responseOpData.Data = CreateResopnse({.Status = EStatus::PROTOCOL_ERROR});
+        responseOpData.Data = CreateResponse({.Status = EStatus::PROTOCOL_ERROR});
         return {responseOpData};
         return {responseOpData};
     }
     }
 
 
     std::vector<TLdapRequestProcessor::TProtocolOpData> res = CreateSearchEntryResponses(it->second.ResponseEntries);
     std::vector<TLdapRequestProcessor::TProtocolOpData> res = CreateSearchEntryResponses(it->second.ResponseEntries);
 
 
     const auto& responseDoneInfo = it->second.ResponseDone;
     const auto& responseDoneInfo = it->second.ResponseDone;
-    responseOpData.Data = CreateResopnse({.Status = responseDoneInfo.Status, .MatchedDN = responseDoneInfo.MatchedDN, .DiagnosticMsg = responseDoneInfo.DiagnosticMsg});
+    responseOpData.Data = CreateResponse({.Status = responseDoneInfo.Status, .MatchedDN = responseDoneInfo.MatchedDN, .DiagnosticMsg = responseDoneInfo.DiagnosticMsg});
     res.push_back(std::move(responseOpData));
     res.push_back(std::move(responseOpData));
     return res;
     return res;
 }
 }

+ 14 - 34
ydb/library/testlib/service_mocks/ldap_mock/ldap_simple_server.cpp

@@ -12,6 +12,10 @@
 namespace LdapMock {
 namespace LdapMock {
 
 
 TLdapSimpleServer::TLdapSimpleServer(ui16 port, const TLdapMockResponses& responses)
 TLdapSimpleServer::TLdapSimpleServer(ui16 port, const TLdapMockResponses& responses)
+    : TLdapSimpleServer(port, {responses, {}})
+{}
+
+TLdapSimpleServer::TLdapSimpleServer(ui16 port, const std::pair<TLdapMockResponses, TLdapMockResponses>& responses)
     : Port(port)
     : Port(port)
     , Responses(responses)
     , Responses(responses)
 {
 {
@@ -34,7 +38,7 @@ TLdapSimpleServer::TLdapSimpleServer(ui16 port, const TLdapMockResponses& respon
     ThreadPool->Start(1);
     ThreadPool->Start(1);
 
 
     auto receiveFinish = MakeAtomicShared<TInetStreamSocket>(socketPair[0]);
     auto receiveFinish = MakeAtomicShared<TInetStreamSocket>(socketPair[0]);
-    ListenerThread = ThreadPool->Run([listenSocket, receiveFinish, &responses = this->Responses] {
+    ListenerThread = ThreadPool->Run([listenSocket, receiveFinish, &useFirstSetResponses = this->UseFirstSetResponses, &responses = this->Responses] {
         TSocketPoller socketPoller;
         TSocketPoller socketPoller;
         socketPoller.WaitRead(*receiveFinish, nullptr);
         socketPoller.WaitRead(*receiveFinish, nullptr);
         socketPoller.WaitRead(*listenSocket, (void*)1);
         socketPoller.WaitRead(*listenSocket, (void*)1);
@@ -51,8 +55,8 @@ TLdapSimpleServer::TLdapSimpleServer(ui16 port, const TLdapMockResponses& respon
                     socket->OnAccept();
                     socket->OnAccept();
 
 
                     SystemThreadFactory()->Run(
                     SystemThreadFactory()->Run(
-                        [socket, &responses] {
-                            LdapRequestHandler(socket, responses);
+                        [socket, &useFirstSetResponses, &responses] {
+                            LdapRequestHandler(socket, useFirstSetResponses ? responses.first : responses.second);
                             socket->Close();
                             socket->Close();
                         });
                         });
                 }
                 }
@@ -61,8 +65,7 @@ TLdapSimpleServer::TLdapSimpleServer(ui16 port, const TLdapMockResponses& respon
     });
     });
 }
 }
 
 
-TLdapSimpleServer::~TLdapSimpleServer()
-{
+TLdapSimpleServer::~TLdapSimpleServer() {
     try {
     try {
         if (ThreadPool) {
         if (ThreadPool) {
             Stop();
             Stop();
@@ -71,8 +74,7 @@ TLdapSimpleServer::~TLdapSimpleServer()
     }
     }
 }
 }
 
 
-void TLdapSimpleServer::Stop()
-{
+void TLdapSimpleServer::Stop() {
     // Just send something to indicate shutdown.
     // Just send something to indicate shutdown.
     SendFinishSocket->Send("X", 1);
     SendFinishSocket->Send("X", 1);
     ListenerThread->Join();
     ListenerThread->Join();
@@ -80,38 +82,16 @@ void TLdapSimpleServer::Stop()
     ThreadPool.Destroy();
     ThreadPool.Destroy();
 }
 }
 
 
-int TLdapSimpleServer::GetPort() const
-{
+int TLdapSimpleServer::GetPort() const {
     return Port;
     return Port;
 }
 }
 
 
-TString TLdapSimpleServer::GetAddress() const
-{
+TString TLdapSimpleServer::GetAddress() const {
     return TStringBuilder() << "localhost:" << Port;
     return TStringBuilder() << "localhost:" << Port;
 }
 }
 
 
-void TLdapSimpleServer::SetBindResponse(const std::pair<TBindRequestInfo, TBindResponseInfo>& response) {
-    auto it = std::find_if(Responses.BindResponses.begin(), Responses.BindResponses.end(), [&response](const std::pair<TBindRequestInfo, TBindResponseInfo>& expectedResponse){
-        return response.first == expectedResponse.first;
-    });
-
-    if (it != Responses.BindResponses.end()) {
-        it->second = response.second;
-    } else {
-        Responses.BindResponses.push_back(response);
-    }
-}
-
-void TLdapSimpleServer::SetSearchReasponse(const std::pair<TSearchRequestInfo, TSearchResponseInfo>& response) {
-    auto it = std::find_if(Responses.SearchResponses.begin(), Responses.SearchResponses.end(), [&response](const std::pair<TSearchRequestInfo, TSearchResponseInfo>& expectedResponse){
-        return response.first == expectedResponse.first;
-    });
-
-    if (it != Responses.SearchResponses.end()) {
-        it->second = response.second;
-    } else {
-        Responses.SearchResponses.push_back(response);
-    }
+void TLdapSimpleServer::UpdateResponses() {
+    UseFirstSetResponses = !UseFirstSetResponses;
 }
 }
 
 
-}
+} // namespace LdapMock

+ 4 - 3
ydb/library/testlib/service_mocks/ldap_mock/ldap_simple_server.h

@@ -19,6 +19,7 @@ public:
     using TRequestHandler = std::function<void(TAtomicSharedPtr<TStreamSocket> socket)>;
     using TRequestHandler = std::function<void(TAtomicSharedPtr<TStreamSocket> socket)>;
 
 
 public:
 public:
+    TLdapSimpleServer(ui16 port, const std::pair<TLdapMockResponses, TLdapMockResponses>& responses);
     TLdapSimpleServer(ui16 port, const TLdapMockResponses& responses);
     TLdapSimpleServer(ui16 port, const TLdapMockResponses& responses);
     ~TLdapSimpleServer();
     ~TLdapSimpleServer();
 
 
@@ -27,8 +28,7 @@ public:
     int GetPort() const;
     int GetPort() const;
     TString GetAddress() const;
     TString GetAddress() const;
 
 
-    void SetBindResponse(const std::pair<TBindRequestInfo, TBindResponseInfo>& response);
-    void SetSearchReasponse(const std::pair<TSearchRequestInfo, TSearchResponseInfo>& response);
+    void UpdateResponses();
 
 
 private:
 private:
     const int Port;
     const int Port;
@@ -36,7 +36,8 @@ private:
     THolder<IThreadFactory::IThread> ListenerThread;
     THolder<IThreadFactory::IThread> ListenerThread;
     THolder<TInetStreamSocket> SendFinishSocket;
     THolder<TInetStreamSocket> SendFinishSocket;
 
 
-    TLdapMockResponses Responses;
+    std::pair<TLdapMockResponses, TLdapMockResponses> Responses;
+    std::atomic_bool UseFirstSetResponses = true;
 };
 };
 
 
 }
 }