123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340 |
- /*
- *
- * Copyright 2016 gRPC authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- */
- // client starts an interop client to do stress test and a metrics server to report qps.
- package main
- import (
- "context"
- "flag"
- "fmt"
- "math/rand"
- "net"
- "strconv"
- "strings"
- "sync"
- "time"
- "google.golang.org/grpc"
- "google.golang.org/grpc/codes"
- "google.golang.org/grpc/credentials"
- "google.golang.org/grpc/credentials/insecure"
- "google.golang.org/grpc/grpclog"
- "google.golang.org/grpc/interop"
- "google.golang.org/grpc/status"
- "google.golang.org/grpc/testdata"
- testgrpc "google.golang.org/grpc/interop/grpc_testing"
- metricspb "google.golang.org/grpc/stress/grpc_testing"
- )
- var (
- serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
- testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights")
- testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds")
- numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
- numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
- metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
- useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
- testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
- tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
- caFile = flag.String("ca_file", "", "The file containing the CA root cert file")
- logger = grpclog.Component("stress")
- )
- // testCaseWithWeight contains the test case type and its weight.
- type testCaseWithWeight struct {
- name string
- weight int
- }
- // parseTestCases converts test case string to a list of struct testCaseWithWeight.
- func parseTestCases(testCaseString string) []testCaseWithWeight {
- testCaseStrings := strings.Split(testCaseString, ",")
- testCases := make([]testCaseWithWeight, len(testCaseStrings))
- for i, str := range testCaseStrings {
- testCase := strings.Split(str, ":")
- if len(testCase) != 2 {
- panic(fmt.Sprintf("invalid test case with weight: %s", str))
- }
- // Check if test case is supported.
- switch testCase[0] {
- case
- "empty_unary",
- "large_unary",
- "client_streaming",
- "server_streaming",
- "ping_pong",
- "empty_stream",
- "timeout_on_sleeping_server",
- "cancel_after_begin",
- "cancel_after_first_response",
- "status_code_and_message",
- "custom_metadata":
- default:
- panic(fmt.Sprintf("unknown test type: %s", testCase[0]))
- }
- testCases[i].name = testCase[0]
- w, err := strconv.Atoi(testCase[1])
- if err != nil {
- panic(fmt.Sprintf("%v", err))
- }
- testCases[i].weight = w
- }
- return testCases
- }
- // weightedRandomTestSelector defines a weighted random selector for test case types.
- type weightedRandomTestSelector struct {
- tests []testCaseWithWeight
- totalWeight int
- }
- // newWeightedRandomTestSelector constructs a weightedRandomTestSelector with the given list of testCaseWithWeight.
- func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTestSelector {
- var totalWeight int
- for _, t := range tests {
- totalWeight += t.weight
- }
- rand.Seed(time.Now().UnixNano())
- return &weightedRandomTestSelector{tests, totalWeight}
- }
- func (selector weightedRandomTestSelector) getNextTest() string {
- random := rand.Intn(selector.totalWeight)
- var weightSofar int
- for _, test := range selector.tests {
- weightSofar += test.weight
- if random < weightSofar {
- return test.name
- }
- }
- panic("no test case selected by weightedRandomTestSelector")
- }
- // gauge stores the qps of one interop client (one stub).
- type gauge struct {
- mutex sync.RWMutex
- val int64
- }
- func (g *gauge) set(v int64) {
- g.mutex.Lock()
- defer g.mutex.Unlock()
- g.val = v
- }
- func (g *gauge) get() int64 {
- g.mutex.RLock()
- defer g.mutex.RUnlock()
- return g.val
- }
- // server implements metrics server functions.
- type server struct {
- metricspb.UnimplementedMetricsServiceServer
- mutex sync.RWMutex
- // gauges is a map from /stress_test/server_<n>/channel_<n>/stub_<n>/qps to its qps gauge.
- gauges map[string]*gauge
- }
- // newMetricsServer returns a new metrics server.
- func newMetricsServer() *server {
- return &server{gauges: make(map[string]*gauge)}
- }
- // GetAllGauges returns all gauges.
- func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.MetricsService_GetAllGaugesServer) error {
- s.mutex.RLock()
- defer s.mutex.RUnlock()
- for name, gauge := range s.gauges {
- if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil {
- return err
- }
- }
- return nil
- }
- // GetGauge returns the gauge for the given name.
- func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*metricspb.GaugeResponse, error) {
- s.mutex.RLock()
- defer s.mutex.RUnlock()
- if g, ok := s.gauges[in.Name]; ok {
- return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil
- }
- return nil, status.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name)
- }
- // createGauge creates a gauge using the given name in metrics server.
- func (s *server) createGauge(name string) *gauge {
- s.mutex.Lock()
- defer s.mutex.Unlock()
- if _, ok := s.gauges[name]; ok {
- // gauge already exists.
- panic(fmt.Sprintf("gauge %s already exists", name))
- }
- var g gauge
- s.gauges[name] = &g
- return &g
- }
- func startServer(server *server, port int) {
- lis, err := net.Listen("tcp", ":"+strconv.Itoa(port))
- if err != nil {
- logger.Fatalf("failed to listen: %v", err)
- }
- s := grpc.NewServer()
- metricspb.RegisterMetricsServiceServer(s, server)
- s.Serve(lis)
- }
- // performRPCs uses weightedRandomTestSelector to select test case and runs the tests.
- func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) {
- client := testgrpc.NewTestServiceClient(conn)
- var numCalls int64
- startTime := time.Now()
- for {
- test := selector.getNextTest()
- switch test {
- case "empty_unary":
- interop.DoEmptyUnaryCall(client, grpc.WaitForReady(true))
- case "large_unary":
- interop.DoLargeUnaryCall(client, grpc.WaitForReady(true))
- case "client_streaming":
- interop.DoClientStreaming(client, grpc.WaitForReady(true))
- case "server_streaming":
- interop.DoServerStreaming(client, grpc.WaitForReady(true))
- case "ping_pong":
- interop.DoPingPong(client, grpc.WaitForReady(true))
- case "empty_stream":
- interop.DoEmptyStream(client, grpc.WaitForReady(true))
- case "timeout_on_sleeping_server":
- interop.DoTimeoutOnSleepingServer(client, grpc.WaitForReady(true))
- case "cancel_after_begin":
- interop.DoCancelAfterBegin(client, grpc.WaitForReady(true))
- case "cancel_after_first_response":
- interop.DoCancelAfterFirstResponse(client, grpc.WaitForReady(true))
- case "status_code_and_message":
- interop.DoStatusCodeAndMessage(client, grpc.WaitForReady(true))
- case "custom_metadata":
- interop.DoCustomMetadata(client, grpc.WaitForReady(true))
- }
- numCalls++
- gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds()))
- select {
- case <-stop:
- return
- default:
- }
- }
- }
- func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
- logger.Infof("server_addresses: %s", *serverAddresses)
- logger.Infof("test_cases: %s", *testCases)
- logger.Infof("test_duration_secs: %d", *testDurationSecs)
- logger.Infof("num_channels_per_server: %d", *numChannelsPerServer)
- logger.Infof("num_stubs_per_channel: %d", *numStubsPerChannel)
- logger.Infof("metrics_port: %d", *metricsPort)
- logger.Infof("use_tls: %t", *useTLS)
- logger.Infof("use_test_ca: %t", *testCA)
- logger.Infof("server_host_override: %s", *tlsServerName)
- logger.Infoln("addresses:")
- for i, addr := range addresses {
- logger.Infof("%d. %s\n", i+1, addr)
- }
- logger.Infoln("tests:")
- for i, test := range tests {
- logger.Infof("%d. %v\n", i+1, test)
- }
- }
- func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) {
- var opts []grpc.DialOption
- if useTLS {
- var sn string
- if tlsServerName != "" {
- sn = tlsServerName
- }
- var creds credentials.TransportCredentials
- if testCA {
- var err error
- if *caFile == "" {
- *caFile = testdata.Path("x509/server_ca_cert.pem")
- }
- creds, err = credentials.NewClientTLSFromFile(*caFile, sn)
- if err != nil {
- logger.Fatalf("Failed to create TLS credentials: %v", err)
- }
- } else {
- creds = credentials.NewClientTLSFromCert(nil, sn)
- }
- opts = append(opts, grpc.WithTransportCredentials(creds))
- } else {
- opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
- }
- return grpc.Dial(address, opts...)
- }
- func main() {
- flag.Parse()
- addresses := strings.Split(*serverAddresses, ",")
- tests := parseTestCases(*testCases)
- logParameterInfo(addresses, tests)
- testSelector := newWeightedRandomTestSelector(tests)
- metricsServer := newMetricsServer()
- var wg sync.WaitGroup
- wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel)
- stop := make(chan bool)
- for serverIndex, address := range addresses {
- for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ {
- conn, err := newConn(address, *useTLS, *testCA, *tlsServerName)
- if err != nil {
- logger.Fatalf("Fail to dial: %v", err)
- }
- defer conn.Close()
- for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ {
- name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1)
- go func() {
- defer wg.Done()
- g := metricsServer.createGauge(name)
- performRPCs(g, conn, testSelector, stop)
- }()
- }
- }
- }
- go startServer(metricsServer, *metricsPort)
- if *testDurationSecs > 0 {
- time.Sleep(time.Duration(*testDurationSecs) * time.Second)
- close(stop)
- }
- wg.Wait()
- logger.Infof(" ===== ALL DONE ===== ")
- }
|