main.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. /*
  2. *
  3. * Copyright 2016 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. // client starts an interop client to do stress test and a metrics server to report qps.
  19. package main
  20. import (
  21. "context"
  22. "flag"
  23. "fmt"
  24. "math/rand"
  25. "net"
  26. "strconv"
  27. "strings"
  28. "sync"
  29. "time"
  30. "google.golang.org/grpc"
  31. "google.golang.org/grpc/codes"
  32. "google.golang.org/grpc/credentials"
  33. "google.golang.org/grpc/credentials/insecure"
  34. "google.golang.org/grpc/grpclog"
  35. "google.golang.org/grpc/interop"
  36. "google.golang.org/grpc/status"
  37. "google.golang.org/grpc/testdata"
  38. testgrpc "google.golang.org/grpc/interop/grpc_testing"
  39. metricspb "google.golang.org/grpc/stress/grpc_testing"
  40. )
  41. var (
  42. serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
  43. testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights")
  44. testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds")
  45. numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
  46. numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
  47. metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
  48. useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
  49. testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
  50. tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
  51. caFile = flag.String("ca_file", "", "The file containing the CA root cert file")
  52. logger = grpclog.Component("stress")
  53. )
  54. // testCaseWithWeight contains the test case type and its weight.
  55. type testCaseWithWeight struct {
  56. name string
  57. weight int
  58. }
  59. // parseTestCases converts test case string to a list of struct testCaseWithWeight.
  60. func parseTestCases(testCaseString string) []testCaseWithWeight {
  61. testCaseStrings := strings.Split(testCaseString, ",")
  62. testCases := make([]testCaseWithWeight, len(testCaseStrings))
  63. for i, str := range testCaseStrings {
  64. testCase := strings.Split(str, ":")
  65. if len(testCase) != 2 {
  66. panic(fmt.Sprintf("invalid test case with weight: %s", str))
  67. }
  68. // Check if test case is supported.
  69. switch testCase[0] {
  70. case
  71. "empty_unary",
  72. "large_unary",
  73. "client_streaming",
  74. "server_streaming",
  75. "ping_pong",
  76. "empty_stream",
  77. "timeout_on_sleeping_server",
  78. "cancel_after_begin",
  79. "cancel_after_first_response",
  80. "status_code_and_message",
  81. "custom_metadata":
  82. default:
  83. panic(fmt.Sprintf("unknown test type: %s", testCase[0]))
  84. }
  85. testCases[i].name = testCase[0]
  86. w, err := strconv.Atoi(testCase[1])
  87. if err != nil {
  88. panic(fmt.Sprintf("%v", err))
  89. }
  90. testCases[i].weight = w
  91. }
  92. return testCases
  93. }
  94. // weightedRandomTestSelector defines a weighted random selector for test case types.
  95. type weightedRandomTestSelector struct {
  96. tests []testCaseWithWeight
  97. totalWeight int
  98. }
  99. // newWeightedRandomTestSelector constructs a weightedRandomTestSelector with the given list of testCaseWithWeight.
  100. func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTestSelector {
  101. var totalWeight int
  102. for _, t := range tests {
  103. totalWeight += t.weight
  104. }
  105. rand.Seed(time.Now().UnixNano())
  106. return &weightedRandomTestSelector{tests, totalWeight}
  107. }
  108. func (selector weightedRandomTestSelector) getNextTest() string {
  109. random := rand.Intn(selector.totalWeight)
  110. var weightSofar int
  111. for _, test := range selector.tests {
  112. weightSofar += test.weight
  113. if random < weightSofar {
  114. return test.name
  115. }
  116. }
  117. panic("no test case selected by weightedRandomTestSelector")
  118. }
  119. // gauge stores the qps of one interop client (one stub).
  120. type gauge struct {
  121. mutex sync.RWMutex
  122. val int64
  123. }
  124. func (g *gauge) set(v int64) {
  125. g.mutex.Lock()
  126. defer g.mutex.Unlock()
  127. g.val = v
  128. }
  129. func (g *gauge) get() int64 {
  130. g.mutex.RLock()
  131. defer g.mutex.RUnlock()
  132. return g.val
  133. }
  134. // server implements metrics server functions.
  135. type server struct {
  136. metricspb.UnimplementedMetricsServiceServer
  137. mutex sync.RWMutex
  138. // gauges is a map from /stress_test/server_<n>/channel_<n>/stub_<n>/qps to its qps gauge.
  139. gauges map[string]*gauge
  140. }
  141. // newMetricsServer returns a new metrics server.
  142. func newMetricsServer() *server {
  143. return &server{gauges: make(map[string]*gauge)}
  144. }
  145. // GetAllGauges returns all gauges.
  146. func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.MetricsService_GetAllGaugesServer) error {
  147. s.mutex.RLock()
  148. defer s.mutex.RUnlock()
  149. for name, gauge := range s.gauges {
  150. if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil {
  151. return err
  152. }
  153. }
  154. return nil
  155. }
  156. // GetGauge returns the gauge for the given name.
  157. func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*metricspb.GaugeResponse, error) {
  158. s.mutex.RLock()
  159. defer s.mutex.RUnlock()
  160. if g, ok := s.gauges[in.Name]; ok {
  161. return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil
  162. }
  163. return nil, status.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name)
  164. }
  165. // createGauge creates a gauge using the given name in metrics server.
  166. func (s *server) createGauge(name string) *gauge {
  167. s.mutex.Lock()
  168. defer s.mutex.Unlock()
  169. if _, ok := s.gauges[name]; ok {
  170. // gauge already exists.
  171. panic(fmt.Sprintf("gauge %s already exists", name))
  172. }
  173. var g gauge
  174. s.gauges[name] = &g
  175. return &g
  176. }
  177. func startServer(server *server, port int) {
  178. lis, err := net.Listen("tcp", ":"+strconv.Itoa(port))
  179. if err != nil {
  180. logger.Fatalf("failed to listen: %v", err)
  181. }
  182. s := grpc.NewServer()
  183. metricspb.RegisterMetricsServiceServer(s, server)
  184. s.Serve(lis)
  185. }
  186. // performRPCs uses weightedRandomTestSelector to select test case and runs the tests.
  187. func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) {
  188. client := testgrpc.NewTestServiceClient(conn)
  189. var numCalls int64
  190. startTime := time.Now()
  191. for {
  192. test := selector.getNextTest()
  193. switch test {
  194. case "empty_unary":
  195. interop.DoEmptyUnaryCall(client, grpc.WaitForReady(true))
  196. case "large_unary":
  197. interop.DoLargeUnaryCall(client, grpc.WaitForReady(true))
  198. case "client_streaming":
  199. interop.DoClientStreaming(client, grpc.WaitForReady(true))
  200. case "server_streaming":
  201. interop.DoServerStreaming(client, grpc.WaitForReady(true))
  202. case "ping_pong":
  203. interop.DoPingPong(client, grpc.WaitForReady(true))
  204. case "empty_stream":
  205. interop.DoEmptyStream(client, grpc.WaitForReady(true))
  206. case "timeout_on_sleeping_server":
  207. interop.DoTimeoutOnSleepingServer(client, grpc.WaitForReady(true))
  208. case "cancel_after_begin":
  209. interop.DoCancelAfterBegin(client, grpc.WaitForReady(true))
  210. case "cancel_after_first_response":
  211. interop.DoCancelAfterFirstResponse(client, grpc.WaitForReady(true))
  212. case "status_code_and_message":
  213. interop.DoStatusCodeAndMessage(client, grpc.WaitForReady(true))
  214. case "custom_metadata":
  215. interop.DoCustomMetadata(client, grpc.WaitForReady(true))
  216. }
  217. numCalls++
  218. gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds()))
  219. select {
  220. case <-stop:
  221. return
  222. default:
  223. }
  224. }
  225. }
  226. func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
  227. logger.Infof("server_addresses: %s", *serverAddresses)
  228. logger.Infof("test_cases: %s", *testCases)
  229. logger.Infof("test_duration_secs: %d", *testDurationSecs)
  230. logger.Infof("num_channels_per_server: %d", *numChannelsPerServer)
  231. logger.Infof("num_stubs_per_channel: %d", *numStubsPerChannel)
  232. logger.Infof("metrics_port: %d", *metricsPort)
  233. logger.Infof("use_tls: %t", *useTLS)
  234. logger.Infof("use_test_ca: %t", *testCA)
  235. logger.Infof("server_host_override: %s", *tlsServerName)
  236. logger.Infoln("addresses:")
  237. for i, addr := range addresses {
  238. logger.Infof("%d. %s\n", i+1, addr)
  239. }
  240. logger.Infoln("tests:")
  241. for i, test := range tests {
  242. logger.Infof("%d. %v\n", i+1, test)
  243. }
  244. }
  245. func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) {
  246. var opts []grpc.DialOption
  247. if useTLS {
  248. var sn string
  249. if tlsServerName != "" {
  250. sn = tlsServerName
  251. }
  252. var creds credentials.TransportCredentials
  253. if testCA {
  254. var err error
  255. if *caFile == "" {
  256. *caFile = testdata.Path("x509/server_ca_cert.pem")
  257. }
  258. creds, err = credentials.NewClientTLSFromFile(*caFile, sn)
  259. if err != nil {
  260. logger.Fatalf("Failed to create TLS credentials: %v", err)
  261. }
  262. } else {
  263. creds = credentials.NewClientTLSFromCert(nil, sn)
  264. }
  265. opts = append(opts, grpc.WithTransportCredentials(creds))
  266. } else {
  267. opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
  268. }
  269. return grpc.Dial(address, opts...)
  270. }
  271. func main() {
  272. flag.Parse()
  273. addresses := strings.Split(*serverAddresses, ",")
  274. tests := parseTestCases(*testCases)
  275. logParameterInfo(addresses, tests)
  276. testSelector := newWeightedRandomTestSelector(tests)
  277. metricsServer := newMetricsServer()
  278. var wg sync.WaitGroup
  279. wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel)
  280. stop := make(chan bool)
  281. for serverIndex, address := range addresses {
  282. for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ {
  283. conn, err := newConn(address, *useTLS, *testCA, *tlsServerName)
  284. if err != nil {
  285. logger.Fatalf("Fail to dial: %v", err)
  286. }
  287. defer conn.Close()
  288. for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ {
  289. name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1)
  290. go func() {
  291. defer wg.Done()
  292. g := metricsServer.createGauge(name)
  293. performRPCs(g, conn, testSelector, stop)
  294. }()
  295. }
  296. }
  297. }
  298. go startServer(metricsServer, *metricsPort)
  299. if *testDurationSecs > 0 {
  300. time.Sleep(time.Duration(*testDurationSecs) * time.Second)
  301. close(stop)
  302. }
  303. wg.Wait()
  304. logger.Infof(" ===== ALL DONE ===== ")
  305. }