grpc.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. // Copyright 2018 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package gengogrpc contains the gRPC code generator.
  5. package gengogrpc
  6. import (
  7. "fmt"
  8. "strconv"
  9. "strings"
  10. "google.golang.org/protobuf/compiler/protogen"
  11. "google.golang.org/protobuf/types/descriptorpb"
  12. )
  13. const (
  14. contextPackage = protogen.GoImportPath("context")
  15. grpcPackage = protogen.GoImportPath("google.golang.org/grpc")
  16. codesPackage = protogen.GoImportPath("google.golang.org/grpc/codes")
  17. statusPackage = protogen.GoImportPath("google.golang.org/grpc/status")
  18. )
  19. // GenerateFile generates a _grpc.pb.go file containing gRPC service definitions.
  20. func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile {
  21. if len(file.Services) == 0 {
  22. return nil
  23. }
  24. filename := file.GeneratedFilenamePrefix + "_grpc.pb.go"
  25. g := gen.NewGeneratedFile(filename, file.GoImportPath)
  26. g.P("// Code generated by protoc-gen-go-grpc. DO NOT EDIT.")
  27. g.P()
  28. g.P("package ", file.GoPackageName)
  29. g.P()
  30. GenerateFileContent(gen, file, g)
  31. return g
  32. }
  33. // GenerateFileContent generates the gRPC service definitions, excluding the package statement.
  34. func GenerateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile) {
  35. if len(file.Services) == 0 {
  36. return
  37. }
  38. // TODO: Remove this. We don't need to include these references any more.
  39. g.P("// Reference imports to suppress errors if they are not otherwise used.")
  40. g.P("var _ ", contextPackage.Ident("Context"))
  41. g.P("var _ ", grpcPackage.Ident("ClientConnInterface"))
  42. g.P()
  43. g.P("// This is a compile-time assertion to ensure that this generated file")
  44. g.P("// is compatible with the grpc package it is being compiled against.")
  45. g.P("const _ = ", grpcPackage.Ident("SupportPackageIsVersion6"))
  46. g.P()
  47. for _, service := range file.Services {
  48. genService(gen, file, g, service)
  49. }
  50. }
  51. func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service) {
  52. clientName := service.GoName + "Client"
  53. g.P("// ", clientName, " is the client API for ", service.GoName, " service.")
  54. g.P("//")
  55. g.P("// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.")
  56. // Client interface.
  57. if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
  58. g.P("//")
  59. g.P(deprecationComment)
  60. }
  61. g.Annotate(clientName, service.Location)
  62. g.P("type ", clientName, " interface {")
  63. for _, method := range service.Methods {
  64. g.Annotate(clientName+"."+method.GoName, method.Location)
  65. if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
  66. g.P(deprecationComment)
  67. }
  68. g.P(method.Comments.Leading,
  69. clientSignature(g, method))
  70. }
  71. g.P("}")
  72. g.P()
  73. // Client structure.
  74. g.P("type ", unexport(clientName), " struct {")
  75. g.P("cc ", grpcPackage.Ident("ClientConnInterface"))
  76. g.P("}")
  77. g.P()
  78. // NewClient factory.
  79. if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
  80. g.P(deprecationComment)
  81. }
  82. g.P("func New", clientName, " (cc ", grpcPackage.Ident("ClientConnInterface"), ") ", clientName, " {")
  83. g.P("return &", unexport(clientName), "{cc}")
  84. g.P("}")
  85. g.P()
  86. var methodIndex, streamIndex int
  87. // Client method implementations.
  88. for _, method := range service.Methods {
  89. if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
  90. // Unary RPC method
  91. genClientMethod(gen, file, g, method, methodIndex)
  92. methodIndex++
  93. } else {
  94. // Streaming RPC method
  95. genClientMethod(gen, file, g, method, streamIndex)
  96. streamIndex++
  97. }
  98. }
  99. // Server interface.
  100. serverType := service.GoName + "Server"
  101. g.P("// ", serverType, " is the server API for ", service.GoName, " service.")
  102. if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
  103. g.P("//")
  104. g.P(deprecationComment)
  105. }
  106. g.Annotate(serverType, service.Location)
  107. g.P("type ", serverType, " interface {")
  108. for _, method := range service.Methods {
  109. g.Annotate(serverType+"."+method.GoName, method.Location)
  110. if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
  111. g.P(deprecationComment)
  112. }
  113. g.P(method.Comments.Leading,
  114. serverSignature(g, method))
  115. }
  116. g.P("}")
  117. g.P()
  118. // Server Unimplemented struct for forward compatibility.
  119. g.P("// Unimplemented", serverType, " can be embedded to have forward compatible implementations.")
  120. g.P("type Unimplemented", serverType, " struct {")
  121. g.P("}")
  122. g.P()
  123. for _, method := range service.Methods {
  124. nilArg := ""
  125. if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
  126. nilArg = "nil,"
  127. }
  128. g.P("func (*Unimplemented", serverType, ") ", serverSignature(g, method), "{")
  129. g.P("return ", nilArg, statusPackage.Ident("Errorf"), "(", codesPackage.Ident("Unimplemented"), `, "method `, method.GoName, ` not implemented")`)
  130. g.P("}")
  131. }
  132. g.P()
  133. // Server registration.
  134. if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
  135. g.P(deprecationComment)
  136. }
  137. serviceDescVar := "_" + service.GoName + "_serviceDesc"
  138. g.P("func Register", service.GoName, "Server(s *", grpcPackage.Ident("Server"), ", srv ", serverType, ") {")
  139. g.P("s.RegisterService(&", serviceDescVar, `, srv)`)
  140. g.P("}")
  141. g.P()
  142. // Server handler implementations.
  143. var handlerNames []string
  144. for _, method := range service.Methods {
  145. hname := genServerMethod(gen, file, g, method)
  146. handlerNames = append(handlerNames, hname)
  147. }
  148. // Service descriptor.
  149. g.P("var ", serviceDescVar, " = ", grpcPackage.Ident("ServiceDesc"), " {")
  150. g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",")
  151. g.P("HandlerType: (*", serverType, ")(nil),")
  152. g.P("Methods: []", grpcPackage.Ident("MethodDesc"), "{")
  153. for i, method := range service.Methods {
  154. if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
  155. continue
  156. }
  157. g.P("{")
  158. g.P("MethodName: ", strconv.Quote(string(method.Desc.Name())), ",")
  159. g.P("Handler: ", handlerNames[i], ",")
  160. g.P("},")
  161. }
  162. g.P("},")
  163. g.P("Streams: []", grpcPackage.Ident("StreamDesc"), "{")
  164. for i, method := range service.Methods {
  165. if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
  166. continue
  167. }
  168. g.P("{")
  169. g.P("StreamName: ", strconv.Quote(string(method.Desc.Name())), ",")
  170. g.P("Handler: ", handlerNames[i], ",")
  171. if method.Desc.IsStreamingServer() {
  172. g.P("ServerStreams: true,")
  173. }
  174. if method.Desc.IsStreamingClient() {
  175. g.P("ClientStreams: true,")
  176. }
  177. g.P("},")
  178. }
  179. g.P("},")
  180. g.P("Metadata: \"", file.Desc.Path(), "\",")
  181. g.P("}")
  182. g.P()
  183. }
  184. func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
  185. s := method.GoName + "(ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context"))
  186. if !method.Desc.IsStreamingClient() {
  187. s += ", in *" + g.QualifiedGoIdent(method.Input.GoIdent)
  188. }
  189. s += ", opts ..." + g.QualifiedGoIdent(grpcPackage.Ident("CallOption")) + ") ("
  190. if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
  191. s += "*" + g.QualifiedGoIdent(method.Output.GoIdent)
  192. } else {
  193. s += method.Parent.GoName + "_" + method.GoName + "Client"
  194. }
  195. s += ", error)"
  196. return s
  197. }
  198. func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, index int) {
  199. service := method.Parent
  200. sname := fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name())
  201. if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
  202. g.P(deprecationComment)
  203. }
  204. g.P("func (c *", unexport(service.GoName), "Client) ", clientSignature(g, method), "{")
  205. if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
  206. g.P("out := new(", method.Output.GoIdent, ")")
  207. g.P(`err := c.cc.Invoke(ctx, "`, sname, `", in, out, opts...)`)
  208. g.P("if err != nil { return nil, err }")
  209. g.P("return out, nil")
  210. g.P("}")
  211. g.P()
  212. return
  213. }
  214. streamType := unexport(service.GoName) + method.GoName + "Client"
  215. serviceDescVar := "_" + service.GoName + "_serviceDesc"
  216. g.P("stream, err := c.cc.NewStream(ctx, &", serviceDescVar, ".Streams[", index, `], "`, sname, `", opts...)`)
  217. g.P("if err != nil { return nil, err }")
  218. g.P("x := &", streamType, "{stream}")
  219. if !method.Desc.IsStreamingClient() {
  220. g.P("if err := x.ClientStream.SendMsg(in); err != nil { return nil, err }")
  221. g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
  222. }
  223. g.P("return x, nil")
  224. g.P("}")
  225. g.P()
  226. genSend := method.Desc.IsStreamingClient()
  227. genRecv := method.Desc.IsStreamingServer()
  228. genCloseAndRecv := !method.Desc.IsStreamingServer()
  229. // Stream auxiliary types and methods.
  230. g.P("type ", service.GoName, "_", method.GoName, "Client interface {")
  231. if genSend {
  232. g.P("Send(*", method.Input.GoIdent, ") error")
  233. }
  234. if genRecv {
  235. g.P("Recv() (*", method.Output.GoIdent, ", error)")
  236. }
  237. if genCloseAndRecv {
  238. g.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)")
  239. }
  240. g.P(grpcPackage.Ident("ClientStream"))
  241. g.P("}")
  242. g.P()
  243. g.P("type ", streamType, " struct {")
  244. g.P(grpcPackage.Ident("ClientStream"))
  245. g.P("}")
  246. g.P()
  247. if genSend {
  248. g.P("func (x *", streamType, ") Send(m *", method.Input.GoIdent, ") error {")
  249. g.P("return x.ClientStream.SendMsg(m)")
  250. g.P("}")
  251. g.P()
  252. }
  253. if genRecv {
  254. g.P("func (x *", streamType, ") Recv() (*", method.Output.GoIdent, ", error) {")
  255. g.P("m := new(", method.Output.GoIdent, ")")
  256. g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
  257. g.P("return m, nil")
  258. g.P("}")
  259. g.P()
  260. }
  261. if genCloseAndRecv {
  262. g.P("func (x *", streamType, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
  263. g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
  264. g.P("m := new(", method.Output.GoIdent, ")")
  265. g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
  266. g.P("return m, nil")
  267. g.P("}")
  268. g.P()
  269. }
  270. }
  271. func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
  272. var reqArgs []string
  273. ret := "error"
  274. if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
  275. reqArgs = append(reqArgs, g.QualifiedGoIdent(contextPackage.Ident("Context")))
  276. ret = "(*" + g.QualifiedGoIdent(method.Output.GoIdent) + ", error)"
  277. }
  278. if !method.Desc.IsStreamingClient() {
  279. reqArgs = append(reqArgs, "*"+g.QualifiedGoIdent(method.Input.GoIdent))
  280. }
  281. if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
  282. reqArgs = append(reqArgs, method.Parent.GoName+"_"+method.GoName+"Server")
  283. }
  284. return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
  285. }
  286. func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method) string {
  287. service := method.Parent
  288. hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName)
  289. if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
  290. g.P("func ", hname, "(srv interface{}, ctx ", contextPackage.Ident("Context"), ", dec func(interface{}) error, interceptor ", grpcPackage.Ident("UnaryServerInterceptor"), ") (interface{}, error) {")
  291. g.P("in := new(", method.Input.GoIdent, ")")
  292. g.P("if err := dec(in); err != nil { return nil, err }")
  293. g.P("if interceptor == nil { return srv.(", service.GoName, "Server).", method.GoName, "(ctx, in) }")
  294. g.P("info := &", grpcPackage.Ident("UnaryServerInfo"), "{")
  295. g.P("Server: srv,")
  296. g.P("FullMethod: ", strconv.Quote(fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.GoName)), ",")
  297. g.P("}")
  298. g.P("handler := func(ctx ", contextPackage.Ident("Context"), ", req interface{}) (interface{}, error) {")
  299. g.P("return srv.(", service.GoName, "Server).", method.GoName, "(ctx, req.(*", method.Input.GoIdent, "))")
  300. g.P("}")
  301. g.P("return interceptor(ctx, in, info, handler)")
  302. g.P("}")
  303. g.P()
  304. return hname
  305. }
  306. streamType := unexport(service.GoName) + method.GoName + "Server"
  307. g.P("func ", hname, "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {")
  308. if !method.Desc.IsStreamingClient() {
  309. g.P("m := new(", method.Input.GoIdent, ")")
  310. g.P("if err := stream.RecvMsg(m); err != nil { return err }")
  311. g.P("return srv.(", service.GoName, "Server).", method.GoName, "(m, &", streamType, "{stream})")
  312. } else {
  313. g.P("return srv.(", service.GoName, "Server).", method.GoName, "(&", streamType, "{stream})")
  314. }
  315. g.P("}")
  316. g.P()
  317. genSend := method.Desc.IsStreamingServer()
  318. genSendAndClose := !method.Desc.IsStreamingServer()
  319. genRecv := method.Desc.IsStreamingClient()
  320. // Stream auxiliary types and methods.
  321. g.P("type ", service.GoName, "_", method.GoName, "Server interface {")
  322. if genSend {
  323. g.P("Send(*", method.Output.GoIdent, ") error")
  324. }
  325. if genSendAndClose {
  326. g.P("SendAndClose(*", method.Output.GoIdent, ") error")
  327. }
  328. if genRecv {
  329. g.P("Recv() (*", method.Input.GoIdent, ", error)")
  330. }
  331. g.P(grpcPackage.Ident("ServerStream"))
  332. g.P("}")
  333. g.P()
  334. g.P("type ", streamType, " struct {")
  335. g.P(grpcPackage.Ident("ServerStream"))
  336. g.P("}")
  337. g.P()
  338. if genSend {
  339. g.P("func (x *", streamType, ") Send(m *", method.Output.GoIdent, ") error {")
  340. g.P("return x.ServerStream.SendMsg(m)")
  341. g.P("}")
  342. g.P()
  343. }
  344. if genSendAndClose {
  345. g.P("func (x *", streamType, ") SendAndClose(m *", method.Output.GoIdent, ") error {")
  346. g.P("return x.ServerStream.SendMsg(m)")
  347. g.P("}")
  348. g.P()
  349. }
  350. if genRecv {
  351. g.P("func (x *", streamType, ") Recv() (*", method.Input.GoIdent, ", error) {")
  352. g.P("m := new(", method.Input.GoIdent, ")")
  353. g.P("if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err }")
  354. g.P("return m, nil")
  355. g.P("}")
  356. g.P()
  357. }
  358. return hname
  359. }
  360. const deprecationComment = "// Deprecated: Do not use."
  361. func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] }