Browse Source

Fix incoming request profiling: increment counters even for non-accepted requests
commit_hash:1f97cd416e3407266f3cde961747a2d7386f3b60

babenko 4 months ago
parent
commit
30f2a784f7
3 changed files with 135 additions and 118 deletions
  1. 1 1
      yt/yt/core/rpc/config.h
  2. 107 97
      yt/yt/core/rpc/service_detail.cpp
  3. 27 20
      yt/yt/core/rpc/service_detail.h

+ 1 - 1
yt/yt/core/rpc/config.h

@@ -132,7 +132,7 @@ public:
     std::optional<bool> EnableErrorCodeCounter;
     std::optional<ERequestTracingMode> TracingMode;
     TTimeHistogramConfigPtr TimeHistogram;
-    THashMap<TString, TMethodConfigPtr> Methods;
+    THashMap<std::string, TMethodConfigPtr> Methods;
     std::optional<int> AuthenticationQueueSizeLimit;
     std::optional<TDuration> PendingPayloadsTimeout;
     std::optional<bool> Pooled;

+ 107 - 97
yt/yt/core/rpc/service_detail.cpp

@@ -50,6 +50,8 @@ static const auto InfiniteRequestThrottlerConfig = New<TThroughputThrottlerConfi
 static const auto DefaultLoggingSuppressionFailedRequestThrottlerConfig = TThroughputThrottlerConfig::Create(1'000);
 
 constexpr int MaxUserAgentLength = 200;
+constexpr TStringBuf UnknownUserAgent = "unknown";
+
 constexpr auto ServiceLivenessCheckPeriod = TDuration::MilliSeconds(100);
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -264,15 +266,15 @@ auto TServiceBase::TMethodDescriptor::SetHandleMethodError(bool value) const ->
 
 ////////////////////////////////////////////////////////////////////////////////
 
-TServiceBase::TErrorCodeCounter::TErrorCodeCounter(NProfiling::TProfiler profiler)
+TServiceBase::TErrorCodeCounters::TErrorCodeCounters(NProfiling::TProfiler profiler)
     : Profiler_(std::move(profiler))
 { }
 
-void TServiceBase::TErrorCodeCounter::Increment(TErrorCode code)
+NProfiling::TCounter* TServiceBase::TErrorCodeCounters::GetCounter(TErrorCode code)
 {
-    CodeToCounter_.FindOrInsert(code, [&] {
+    return CodeToCounter_.FindOrInsert(code, [&] {
         return Profiler_.WithTag("code", ToString(code)).Counter("/code_count");
-    }).first->Increment();
+    }).first;
 }
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -289,7 +291,7 @@ TServiceBase::TMethodPerformanceCounters::TMethodPerformanceCounters(
     , RequestMessageAttachmentSizeCounter(profiler.Counter("/request_message_attachment_bytes"))
     , ResponseMessageBodySizeCounter(profiler.Counter("/response_message_body_bytes"))
     , ResponseMessageAttachmentSizeCounter(profiler.Counter("/response_message_attachment_bytes"))
-    , ErrorCodeCounter(profiler)
+    , ErrorCodeCounters(profiler)
 {
     if (timeHistogramConfig && timeHistogramConfig->CustomBounds) {
         const auto& customBounds = *timeHistogramConfig->CustomBounds;
@@ -338,6 +340,19 @@ TRequestQueue* TServiceBase::TRuntimeMethodInfo::GetDefaultRequestQueue()
 
 ////////////////////////////////////////////////////////////////////////////////
 
+TServiceBase::TPerformanceCounters::TPerformanceCounters(const NProfiling::TProfiler& profiler)
+    : Profiler_(profiler.WithHot().WithSparse())
+{ }
+
+NProfiling::TCounter* TServiceBase::TPerformanceCounters::GetRequestsPerUserAgentCounter(TStringBuf userAgent)
+{
+    return RequestsPerUserAgent_.FindOrInsert(userAgent, [&] {
+        return Profiler_.WithRequiredTag("user_agent", TString(userAgent)).Counter("/user_agent");
+    }).first;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
 class TServiceBase::TServiceContext
     : public TServiceContextBase
 {
@@ -360,11 +375,9 @@ public:
         , TraceContext_(std::move(acceptedRequest.TraceContext))
         , RequestQueue_(acceptedRequest.RequestQueue)
         , ThrottledError_(std::move(acceptedRequest.ThrottledError))
-        , MethodPerformanceCounters_(Service_->GetMethodPerformanceCounters(
-            RuntimeInfo_,
-            {GetAuthenticationIdentity().UserTag, RequestQueue_}))
+        , MethodPerformanceCounters_(acceptedRequest.MethodPerformanceCounters)
         , PerformanceCounters_(Service_->GetPerformanceCounters())
-        , ArriveInstant_(NProfiling::GetInstant())
+        , ArriveInstant_(acceptedRequest.ArriveInstant)
     {
         YT_ASSERT(RequestMessage_);
         YT_ASSERT(ReplyBus_);
@@ -743,24 +756,6 @@ private:
 
     void Initialize()
     {
-        constexpr TStringBuf UnknownUserAgent = "unknown";
-        auto userAgent = RequestHeader_->has_user_agent()
-            ? TStringBuf(RequestHeader_->user_agent())
-            : UnknownUserAgent;
-        PerformanceCounters_->IncrementRequestsPerUserAgent(userAgent.SubString(0, MaxUserAgentLength));
-
-        MethodPerformanceCounters_->RequestCounter.Increment();
-        MethodPerformanceCounters_->RequestMessageBodySizeCounter.Increment(
-            GetMessageBodySize(RequestMessage_));
-        MethodPerformanceCounters_->RequestMessageAttachmentSizeCounter.Increment(
-            GetTotalMessageAttachmentSize(RequestMessage_));
-
-        if (RequestHeader_->has_start_time()) {
-            auto retryStart = FromProto<TInstant>(RequestHeader_->start_time());
-            auto now = NProfiling::GetInstant();
-            MethodPerformanceCounters_->RemoteWaitTimeCounter.Record(now - retryStart);
-        }
-
         // COMPAT(danilalexeev): legacy RPC codecs
         RequestCodec_ = RequestHeader_->has_request_codec()
             ? CheckedEnumCast<NCompression::ECodec>(RequestHeader_->request_codec())
@@ -1024,7 +1019,8 @@ private:
         MethodPerformanceCounters_->TotalTimeCounter.Record(*TotalTime_);
         if (!Error_.IsOK()) {
             if (Service_->EnableErrorCodeCounter_.load()) {
-                MethodPerformanceCounters_->ErrorCodeCounter.Increment(Error_.GetNonTrivialCode());
+                const auto* counter = MethodPerformanceCounters_->ErrorCodeCounters.GetCounter(Error_.GetNonTrivialCode());
+                counter->Increment();
             } else {
                 MethodPerformanceCounters_->FailedRequestCounter.Increment();
             }
@@ -1307,7 +1303,7 @@ private:
 
 ////////////////////////////////////////////////////////////////////////////////
 
-TRequestQueue::TRequestQueue(const std::string& name, NProfiling::TProfiler profiler)
+TRequestQueue::TRequestQueue(const std::string& name, const NProfiling::TProfiler& profiler)
     : Name_(name)
     , BytesThrottler_{CreateReconfigurableThroughputThrottler(InfiniteRequestThrottlerConfig,
         NLogging::TLogger(),
@@ -1610,23 +1606,20 @@ void TRequestQueue::SubscribeToThrottlers()
 
 ////////////////////////////////////////////////////////////////////////////////
 
-struct TServiceBase::TRuntimeMethodInfo::TPerformanceCountersKeyEquals
+bool TServiceBase::TRuntimeMethodInfo::TPerformanceCountersKeyEquals::operator()(
+    const TNonowningPerformanceCountersKey& lhs,
+    const TNonowningPerformanceCountersKey& rhs) const
 {
-    bool operator()(
-        const TNonowningPerformanceCountersKey& lhs,
-        const TNonowningPerformanceCountersKey& rhs) const
-    {
-        return lhs == rhs;
-    }
+    return lhs == rhs;
+}
 
-    bool operator()(
-        const TOwningPerformanceCountersKey& lhs,
-        const TNonowningPerformanceCountersKey& rhs) const
-    {
-        const auto& [lhsUserTag, lhsRequestQueue] = lhs;
-        return TNonowningPerformanceCountersKey{lhsUserTag, lhsRequestQueue} == rhs;
-    }
-};
+bool TServiceBase::TRuntimeMethodInfo::TPerformanceCountersKeyEquals::operator()(
+    const TOwningPerformanceCountersKey& lhs,
+    const TNonowningPerformanceCountersKey& rhs) const
+{
+    const auto& [lhsUserTag, lhsRequestQueue] = lhs;
+    return TNonowningPerformanceCountersKey{lhsUserTag, lhsRequestQueue} == rhs;
+}
 
 ////////////////////////////////////////////////////////////////////////////////
 
@@ -1675,8 +1668,15 @@ void TServiceBase::HandleRequest(
 {
     SetActive();
 
-    auto method = FromProto<TString>(header->method());
+    auto arriveInstant = NProfiling::GetInstant();
+
+    const auto& method = header->method();
     auto requestId = FromProto<TRequestId>(header->request_id());
+    auto userAgent = header->has_user_agent()
+        ? TStringBuf(header->user_agent()).SubString(0, MaxUserAgentLength)
+        : UnknownUserAgent;
+    const auto& user = header->has_user() ? header->user() : RootUserName;
+    const auto& userTag = header->has_user_tag() ? header->user_tag() : user;
 
     auto replyError = [&] (TError error) {
         ReplyError(std::move(error), *header, replyBus);
@@ -1689,11 +1689,6 @@ void TServiceBase::HandleRequest(
         return;
     }
 
-    if (auto error = DoCheckRequestCompatibility(*header); !error.IsOK()) {
-        replyError(std::move(error));
-        return;
-    }
-
     auto* runtimeInfo = FindMethodInfo(method);
     if (!runtimeInfo) {
         replyError(TError(
@@ -1702,8 +1697,29 @@ void TServiceBase::HandleRequest(
         return;
     }
 
+    auto* requestQueue = GetRequestQueue(runtimeInfo, *header);
+
+    const auto* requestsPerUserAgentCounter = PerformanceCounters_->GetRequestsPerUserAgentCounter(userAgent);
+    requestsPerUserAgentCounter->Increment();
+
+    auto* methodPerformanceCounters = GetMethodPerformanceCounters(runtimeInfo, {userTag, requestQueue});
+    methodPerformanceCounters->RequestCounter.Increment();
+    methodPerformanceCounters->RequestMessageBodySizeCounter.Increment(GetMessageBodySize(message));
+    methodPerformanceCounters->RequestMessageAttachmentSizeCounter.Increment(GetTotalMessageAttachmentSize(message));
+
+    if (header->has_start_time()) {
+        auto retryStart = FromProto<TInstant>(header->start_time());
+        methodPerformanceCounters->RemoteWaitTimeCounter.Record(arriveInstant - retryStart);
+    }
+
+    if (auto error = DoCheckRequestCompatibility(*header); !error.IsOK()) {
+        replyError(std::move(error));
+        return;
+    }
+
     auto memoryGuard = TMemoryUsageTrackerGuard::Acquire(MemoryUsageTracker_, TypicalRequestSize);
     message = TrackMemory(MemoryUsageTracker_, std::move(message));
+
     if (MemoryUsageTracker_ && MemoryUsageTracker_->IsExceeded()) {
         return replyError(TError(
             NRpc::EErrorCode::MemoryPressure,
@@ -1714,14 +1730,12 @@ void TServiceBase::HandleRequest(
     auto traceContext = tracingMode == ERequestTracingMode::Disable
         ? NTracing::TTraceContextPtr()
         : GetOrCreateHandlerTraceContext(*header, tracingMode == ERequestTracingMode::Force);
+
     if (traceContext && traceContext->IsRecorded()) {
         traceContext->AddTag(EndpointAnnotation, replyBus->GetEndpointDescription());
     }
 
-    auto* requestQueue = GetRequestQueue(runtimeInfo, *header);
-    RegisterRequestQueue(runtimeInfo, requestQueue);
-
-    auto maybeThrottled = GetThrottledError(*header);
+    auto throttledError = GetThrottledError(*header);
 
     if (requestQueue->IsQueueSizeLimitExceeded()) {
         runtimeInfo->RequestQueueSizeLimitErrorCounter.Increment();
@@ -1730,7 +1744,7 @@ void TServiceBase::HandleRequest(
             "Request queue size limit exceeded")
             << TErrorAttribute("limit", runtimeInfo->QueueSizeLimit.load())
             << TErrorAttribute("queue", requestQueue->GetName())
-            << maybeThrottled);
+            << throttledError);
         return;
     }
 
@@ -1741,7 +1755,7 @@ void TServiceBase::HandleRequest(
             "Request queue bytes size limit exceeded")
             << TErrorAttribute("limit", runtimeInfo->QueueByteSizeLimit.load())
             << TErrorAttribute("queue", requestQueue->GetName())
-            << maybeThrottled);
+            << throttledError);
         return;
     }
 
@@ -1749,6 +1763,7 @@ void TServiceBase::HandleRequest(
 
     // NOTE: Do not use replyError() after this line.
     TAcceptedRequest acceptedRequest{
+        .ArriveInstant = arriveInstant,
         .RequestId = requestId,
         .ReplyBus = std::move(replyBus),
         .RuntimeInfo = std::move(runtimeInfo),
@@ -1756,7 +1771,8 @@ void TServiceBase::HandleRequest(
         .Header = std::move(header),
         .Message = std::move(message),
         .RequestQueue = requestQueue,
-        .ThrottledError = maybeThrottled,
+        .MethodPerformanceCounters = methodPerformanceCounters,
+        .ThrottledError = throttledError,
         .MemoryGuard = std::move(memoryGuard),
         .MemoryUsageTracker = MemoryUsageTracker_,
     };
@@ -1909,55 +1925,49 @@ TRequestQueue* TServiceBase::GetRequestQueue(
     const NRpc::NProto::TRequestHeader& requestHeader)
 {
     TRequestQueue* requestQueue = nullptr;
-    if (auto& provider = runtimeInfo->Descriptor.RequestQueueProvider) {
+    if (const auto& provider = runtimeInfo->Descriptor.RequestQueueProvider) {
         requestQueue = provider->GetQueue(requestHeader);
     }
     if (!requestQueue) {
         requestQueue = runtimeInfo->DefaultRequestQueue.Get();
     }
-    return requestQueue;
-}
 
-void TServiceBase::RegisterRequestQueue(
-    TRuntimeMethodInfo* runtimeInfo,
-    TRequestQueue* requestQueue)
-{
-    if (!requestQueue->Register(this, runtimeInfo)) {
-        return;
-    }
+    if (requestQueue->Register(this, runtimeInfo)) {
+        const auto& method = runtimeInfo->Descriptor.Method;
+        YT_LOG_DEBUG("Request queue registered (Method: %v, Queue: %v)",
+            method,
+            requestQueue->GetName());
 
-    const auto& method = runtimeInfo->Descriptor.Method;
-    YT_LOG_DEBUG("Request queue registered (Method: %v, Queue: %v)",
-        method,
-        requestQueue->GetName());
+        auto profiler = runtimeInfo->Profiler.WithSparse();
+        if (runtimeInfo->Descriptor.RequestQueueProvider) {
+            profiler = profiler.WithTag("queue", requestQueue->GetName());
+        }
+        profiler.AddFuncGauge("/request_queue_size", MakeStrong(this), [=] {
+            return requestQueue->GetQueueSize();
+        });
+        profiler.AddFuncGauge("/request_queue_byte_size", MakeStrong(this), [=] {
+            return requestQueue->GetQueueByteSize();
+        });
+        profiler.AddFuncGauge("/concurrency", MakeStrong(this), [=] {
+            return requestQueue->GetConcurrency();
+        });
+        profiler.AddFuncGauge("/concurrency_byte", MakeStrong(this), [=] {
+            return requestQueue->GetConcurrencyByte();
+        });
 
-    auto profiler = runtimeInfo->Profiler.WithSparse();
-    if (runtimeInfo->Descriptor.RequestQueueProvider) {
-        profiler = profiler.WithTag("queue", requestQueue->GetName());
-    }
-    profiler.AddFuncGauge("/request_queue_size", MakeStrong(this), [=] {
-        return requestQueue->GetQueueSize();
-    });
-    profiler.AddFuncGauge("/request_queue_byte_size", MakeStrong(this), [=] {
-        return requestQueue->GetQueueByteSize();
-    });
-    profiler.AddFuncGauge("/concurrency", MakeStrong(this), [=] {
-        return requestQueue->GetConcurrency();
-    });
-    profiler.AddFuncGauge("/concurrency_byte", MakeStrong(this), [=] {
-        return requestQueue->GetConcurrencyByte();
-    });
+        TMethodConfigPtr methodConfig;
+        if (auto config = Config_.Acquire()) {
+            methodConfig = GetOrDefault(config->Methods, method);
+        }
+        ConfigureRequestQueue(runtimeInfo, requestQueue, methodConfig);
 
-    TMethodConfigPtr methodConfig;
-    if (auto config = Config_.Acquire()) {
-        methodConfig = GetOrDefault(config->Methods, method);
+        {
+            auto guard = Guard(runtimeInfo->RequestQueuesLock);
+            runtimeInfo->RequestQueues.push_back(requestQueue);
+        }
     }
-    ConfigureRequestQueue(runtimeInfo, requestQueue, methodConfig);
 
-    {
-        auto guard = Guard(runtimeInfo->RequestQueuesLock);
-        runtimeInfo->RequestQueues.push_back(requestQueue);
-    }
+    return requestQueue;
 }
 
 void TServiceBase::ConfigureRequestQueue(
@@ -2684,13 +2694,13 @@ TFuture<void> TServiceBase::Stop()
     return StopResult_.ToFuture();
 }
 
-TServiceBase::TRuntimeMethodInfo* TServiceBase::FindMethodInfo(const TString& method)
+TServiceBase::TRuntimeMethodInfo* TServiceBase::FindMethodInfo(const std::string& method)
 {
     auto it = MethodMap_.find(method);
     return it == MethodMap_.end() ? nullptr : it->second.Get();
 }
 
-TServiceBase::TRuntimeMethodInfo* TServiceBase::GetMethodInfoOrThrow(const TString& method)
+TServiceBase::TRuntimeMethodInfo* TServiceBase::GetMethodInfoOrThrow(const std::string& method)
 {
     auto* runtimeInfo = FindMethodInfo(method);
     if (!runtimeInfo) {

+ 27 - 20
yt/yt/core/rpc/service_detail.h

@@ -657,11 +657,12 @@ protected:
         TMethodDescriptor SetHandleMethodError(bool value) const;
     };
 
-    struct TErrorCodeCounter
+    class TErrorCodeCounters
     {
-        explicit TErrorCodeCounter(NProfiling::TProfiler profiler);
+    public:
+        explicit TErrorCodeCounters(NProfiling::TProfiler profiler);
 
-        void Increment(TErrorCode code);
+        NProfiling::TCounter* GetCounter(TErrorCode code);
 
     private:
         const NProfiling::TProfiler Profiler_;
@@ -717,7 +718,7 @@ protected:
         NProfiling::TCounter ResponseMessageAttachmentSizeCounter;
 
         //! Counts the number of errors, per error code.
-        TErrorCodeCounter ErrorCodeCounter;
+        TErrorCodeCounters ErrorCodeCounters;
     };
 
     using TMethodPerformanceCountersPtr = TIntrusivePtr<TMethodPerformanceCounters>;
@@ -763,13 +764,24 @@ protected:
         using TNonowningPerformanceCountersKey = std::tuple<TStringBuf, TRequestQueue*>;
         using TOwningPerformanceCountersKey = std::tuple<TString, TRequestQueue*>;
         using TPerformanceCountersKeyHash = THash<TNonowningPerformanceCountersKey>;
-        struct TPerformanceCountersKeyEquals;
+
+        struct TPerformanceCountersKeyEquals
+        {
+            bool operator()(
+                const TNonowningPerformanceCountersKey& lhs,
+                const TNonowningPerformanceCountersKey& rhs) const;
+            bool operator()(
+                const TOwningPerformanceCountersKey& lhs,
+                const TNonowningPerformanceCountersKey& rhs) const;
+        };
+
         using TPerformanceCountersMap = NConcurrency::TSyncMap<
             TOwningPerformanceCountersKey,
             TMethodPerformanceCountersPtr,
             TPerformanceCountersKeyHash,
             TPerformanceCountersKeyEquals
         >;
+
         TPerformanceCountersMap PerformanceCountersMap;
         TMethodPerformanceCountersPtr BasePerformanceCounters;
         TMethodPerformanceCountersPtr RootPerformanceCounters;
@@ -789,16 +801,9 @@ protected:
         : public TRefCounted
     {
     public:
-        explicit TPerformanceCounters(const NProfiling::TProfiler& profiler)
-            : Profiler_(profiler.WithHot().WithSparse())
-        { }
+        explicit TPerformanceCounters(const NProfiling::TProfiler& profiler);
 
-        void IncrementRequestsPerUserAgent(TStringBuf userAgent)
-        {
-            RequestsPerUserAgent_.FindOrInsert(userAgent, [&] {
-                return Profiler_.WithRequiredTag("user_agent", TString(userAgent)).Counter("/user_agent");
-            }).first->Increment();
-        }
+        NProfiling::TCounter* GetRequestsPerUserAgentCounter(TStringBuf userAgent);
 
     private:
         const NProfiling::TProfiler Profiler_;
@@ -846,10 +851,10 @@ protected:
 
     //! Returns a (non-owning!) pointer to TRuntimeMethodInfo for a given method's name
     //! or |nullptr| if no such method is registered.
-    TRuntimeMethodInfo* FindMethodInfo(const TString& method);
+    TRuntimeMethodInfo* FindMethodInfo(const std::string& method);
 
     //! Similar to #FindMethodInfo but throws if no method is found.
-    TRuntimeMethodInfo* GetMethodInfoOrThrow(const TString& method);
+    TRuntimeMethodInfo* GetMethodInfoOrThrow(const std::string& method);
 
     //! Returns the default invoker passed during construction.
     const IInvokerPtr& GetDefaultInvoker() const;
@@ -991,6 +996,7 @@ private:
 
     struct TAcceptedRequest
     {
+        TInstant ArriveInstant;
         TRequestId RequestId;
         NYT::NBus::IBusPtr ReplyBus;
         TRuntimeMethodInfo* RuntimeInfo;
@@ -998,6 +1004,7 @@ private:
         std::unique_ptr<NRpc::NProto::TRequestHeader> Header;
         TSharedRefArray Message;
         TRequestQueue* RequestQueue;
+        TMethodPerformanceCounters* MethodPerformanceCounters;
         std::optional<TError> ThrottledError;
         TMemoryUsageTrackerGuard MemoryGuard;
         IMemoryUsageTrackerPtr MemoryUsageTracker;
@@ -1022,9 +1029,6 @@ private:
     TRequestQueue* GetRequestQueue(
         TRuntimeMethodInfo* runtimeInfo,
         const NRpc::NProto::TRequestHeader& requestHeader);
-    void RegisterRequestQueue(
-        TRuntimeMethodInfo* runtimeInfo,
-        TRequestQueue* requestQueue);
     void ConfigureRequestQueue(
         TRuntimeMethodInfo* runtimeInfo,
         TRequestQueue* requestQueue,
@@ -1071,6 +1075,7 @@ private:
     static TString GetDiscoverRequestPayload(const TCtxDiscoverPtr& context);
 
     void OnServiceLivenessCheck();
+
 };
 
 DEFINE_REFCOUNTED_TYPE(TServiceBase)
@@ -1081,7 +1086,9 @@ class TRequestQueue
     : public TRefCounted
 {
 public:
-    TRequestQueue(const std::string& name, NProfiling::TProfiler profiler);
+    TRequestQueue(
+        const std::string& name,
+        const NProfiling::TProfiler& profiler);
 
     bool Register(TServiceBase* service, TServiceBase::TRuntimeMethodInfo* runtimeInfo);
     void Configure(const TMethodConfigPtr& config);