Browse Source

Перенести проверку TTL до парсинга заголовков

conterouz 1 year ago
parent
commit
6a7def716e

+ 5 - 0
library/cpp/http/server/http.cpp

@@ -744,6 +744,11 @@ void TClientRequest::Process(void* ThreadSpecificResource) {
             HttpConn_->Output()->EnableCompression(HttpServ()->Options().CompressionEnabled);
         }
 
+        if (!BeforeParseRequestOk(ThreadSpecificResource)) {
+            ReleaseConnection();
+            return;
+        }
+
         if (ParsedHeaders.empty()) {
             RequestString = Input().FirstLine();
 

+ 4 - 0
library/cpp/http/server/http.h

@@ -137,6 +137,10 @@ private:
      * 'true' otherwise ('this' will be deleted)
      */
     virtual bool Reply(void* ThreadSpecificResource);
+    virtual bool BeforeParseRequestOk(void* ThreadSpecificResource) {
+        Y_UNUSED(ThreadSpecificResource);
+        return true;
+    }
     void Process(void* ThreadSpecificResource) override;
 
 public:

+ 66 - 0
library/cpp/http/server/http_ut.cpp

@@ -62,6 +62,15 @@ Y_UNIT_TEST_SUITE(THttpServerTest) {
             {
             }
 
+            bool BeforeParseRequestOk(void*) override {
+                if (Server->Ttl && (TInstant::Now() - CreateTime > TDuration::MilliSeconds(Server->Ttl))) {
+                    Output().Write("HTTP/1.0 503 Created\nX-Server: sleeping server\n\nTTL Exceed");
+                    return false;
+                } else {
+                    return true;
+                }
+            }
+
             bool DoReply(const TReplyParams& params) override {
                 ++Server->Replies;
                 with_lock (Server->Lock) {
@@ -71,11 +80,17 @@ Y_UNIT_TEST_SUITE(THttpServerTest) {
                 return true;
             }
 
+            using TClientRequest::Output;
+
         private:
             TSleepingServer* Server = nullptr;
+            TInstant CreateTime = TInstant::Now();
         };
 
     public:
+        TSleepingServer(size_t ttl = 0)
+        : Ttl(ttl) {}
+
         TClientRequest* CreateClient() override {
             return new TReplier(this);
         }
@@ -88,6 +103,7 @@ Y_UNIT_TEST_SUITE(THttpServerTest) {
 
         std::atomic<size_t> Replies;
         std::atomic<size_t> MaxConns;
+        size_t Ttl;
     };
 
     static const TString CrLf = "\r\n";
@@ -953,4 +969,54 @@ Y_UNIT_TEST_SUITE(THttpServerTest) {
         }
 
     }
+
+    inline TString ToString(const THashSet<TString>& hs) {
+        TString res = "";
+        for (auto s : hs) {
+            if (res) {
+                res.append(",");
+            }
+            res.append("\"").append(s).append("\"");
+        }
+        return res;
+    }
+
+    Y_UNIT_TEST(TestTTLExceed) {
+        // Checks that one of request returns "TTL Exceed"
+        // First request waits for server.Lock.Release() for one threaded TSleepingServer
+        // So second request in queue should fail with TTL Exceed, because fist one lock thread pool for (ttl + 1) ms
+        TPortManager portManager;
+        const ui16 port = portManager.GetPort();
+        TString res = TestData(25);
+        const size_t ttl = 10;
+        TSleepingServer server{ttl};
+        THttpServer::TOptions options(port);
+        options.nThreads = 1;
+        options.MaxConnections = 2;
+        THttpServer srv(&server, options);
+
+        UNIT_ASSERT(srv.Start());
+        UNIT_ASSERT(server.Lock.TryAcquire());
+
+        THashSet<TString> results;
+        TMutex resultLock;
+        auto func = [port, &resultLock, &results]() {
+            try {
+                TTestRequest r(port);
+                TString result = r.Execute();
+                with_lock(resultLock) {
+                    results.insert(result);
+                }
+            } catch (...) {
+            }
+        };
+
+        auto t1 = SystemThreadFactory()->Run(func);
+        auto t2 = SystemThreadFactory()->Run(func);
+        Sleep(TDuration::MilliSeconds(ttl + 1));
+        server.Lock.Release();
+        t1->Join();
+        t2->Join();
+        UNIT_ASSERT_EQUAL_C(results, (THashSet<TString>({"Zoooo", "TTL Exceed"})), "Results is {" + ToString(results) + "}");
+    }
 }