|
|
@@ -14,6 +14,11 @@ import (
|
|
|
"github.com/golang/protobuf/v2/protogen"
|
|
|
)
|
|
|
|
|
|
+const (
|
|
|
+ contextPackage = protogen.GoImportPath("context")
|
|
|
+ grpcPackage = protogen.GoImportPath("google.golang.org/grpc")
|
|
|
+)
|
|
|
+
|
|
|
// GenerateFile generates a _grpc.pb.go file containing gRPC service definitions.
|
|
|
func GenerateFile(gen *protogen.Plugin, file *protogen.File) {
|
|
|
if len(file.Services) == 0 {
|
|
|
@@ -36,13 +41,13 @@ func GenerateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.
|
|
|
|
|
|
// TODO: Remove this. We don't need to include these references any more.
|
|
|
g.P("// Reference imports to suppress errors if they are not otherwise used.")
|
|
|
- g.P("var _ ", ident("context.Context"))
|
|
|
- g.P("var _ ", ident("grpc.ClientConn"))
|
|
|
+ g.P("var _ ", contextPackage.Ident("Context"))
|
|
|
+ g.P("var _ ", grpcPackage.Ident("ClientConn"))
|
|
|
g.P()
|
|
|
|
|
|
g.P("// This is a compile-time assertion to ensure that this generated file")
|
|
|
g.P("// is compatible with the grpc package it is being compiled against.")
|
|
|
- g.P("const _ = ", ident("grpc.SupportPackageIsVersion4"))
|
|
|
+ g.P("const _ = ", grpcPackage.Ident("SupportPackageIsVersion4"))
|
|
|
g.P()
|
|
|
for _, service := range file.Services {
|
|
|
genService(gen, file, g, service)
|
|
|
@@ -73,7 +78,7 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated
|
|
|
|
|
|
// Client structure.
|
|
|
g.P("type ", unexport(clientName), " struct {")
|
|
|
- g.P("cc *", ident("grpc.ClientConn"))
|
|
|
+ g.P("cc *", grpcPackage.Ident("ClientConn"))
|
|
|
g.P("}")
|
|
|
g.P()
|
|
|
|
|
|
@@ -81,7 +86,7 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated
|
|
|
if service.Desc.Options().(*descpb.ServiceOptions).GetDeprecated() {
|
|
|
g.P(deprecationComment)
|
|
|
}
|
|
|
- g.P("func New", clientName, " (cc *", ident("grpc.ClientConn"), ") ", clientName, " {")
|
|
|
+ g.P("func New", clientName, " (cc *", grpcPackage.Ident("ClientConn"), ") ", clientName, " {")
|
|
|
g.P("return &", unexport(clientName), "{cc}")
|
|
|
g.P("}")
|
|
|
g.P()
|
|
|
@@ -122,7 +127,7 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated
|
|
|
g.P(deprecationComment)
|
|
|
}
|
|
|
serviceDescVar := "_" + service.GoName + "_serviceDesc"
|
|
|
- g.P("func Register", service.GoName, "Server(s *", ident("grpc.Server"), ", srv ", serverType, ") {")
|
|
|
+ g.P("func Register", service.GoName, "Server(s *", grpcPackage.Ident("Server"), ", srv ", serverType, ") {")
|
|
|
g.P("s.RegisterService(&", serviceDescVar, `, srv)`)
|
|
|
g.P("}")
|
|
|
g.P()
|
|
|
@@ -135,10 +140,10 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated
|
|
|
}
|
|
|
|
|
|
// Service descriptor.
|
|
|
- g.P("var ", serviceDescVar, " = ", ident("grpc.ServiceDesc"), " {")
|
|
|
+ g.P("var ", serviceDescVar, " = ", grpcPackage.Ident("ServiceDesc"), " {")
|
|
|
g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",")
|
|
|
g.P("HandlerType: (*", serverType, ")(nil),")
|
|
|
- g.P("Methods: []", ident("grpc.MethodDesc"), "{")
|
|
|
+ g.P("Methods: []", grpcPackage.Ident("MethodDesc"), "{")
|
|
|
for i, method := range service.Methods {
|
|
|
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
|
|
|
continue
|
|
|
@@ -149,7 +154,7 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated
|
|
|
g.P("},")
|
|
|
}
|
|
|
g.P("},")
|
|
|
- g.P("Streams: []", ident("grpc.StreamDesc"), "{")
|
|
|
+ g.P("Streams: []", grpcPackage.Ident("StreamDesc"), "{")
|
|
|
for i, method := range service.Methods {
|
|
|
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
|
|
|
continue
|
|
|
@@ -172,11 +177,11 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated
|
|
|
}
|
|
|
|
|
|
func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
|
|
|
- s := method.GoName + "(ctx " + g.QualifiedGoIdent(ident("context.Context"))
|
|
|
+ s := method.GoName + "(ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context"))
|
|
|
if !method.Desc.IsStreamingClient() {
|
|
|
s += ", in *" + g.QualifiedGoIdent(method.InputType.GoIdent)
|
|
|
}
|
|
|
- s += ", opts ..." + g.QualifiedGoIdent(ident("grpc.CallOption")) + ") ("
|
|
|
+ s += ", opts ..." + g.QualifiedGoIdent(grpcPackage.Ident("CallOption")) + ") ("
|
|
|
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
|
|
|
s += "*" + g.QualifiedGoIdent(method.OutputType.GoIdent)
|
|
|
} else {
|
|
|
@@ -231,12 +236,12 @@ func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
|
|
|
if genCloseAndRecv {
|
|
|
g.P("CloseAndRecv() (*", method.OutputType.GoIdent, ", error)")
|
|
|
}
|
|
|
- g.P(ident("grpc.ClientStream"))
|
|
|
+ g.P(grpcPackage.Ident("ClientStream"))
|
|
|
g.P("}")
|
|
|
g.P()
|
|
|
|
|
|
g.P("type ", streamType, " struct {")
|
|
|
- g.P(ident("grpc.ClientStream"))
|
|
|
+ g.P(grpcPackage.Ident("ClientStream"))
|
|
|
g.P("}")
|
|
|
g.P()
|
|
|
|
|
|
@@ -269,7 +274,7 @@ func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string
|
|
|
var reqArgs []string
|
|
|
ret := "error"
|
|
|
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
|
|
|
- reqArgs = append(reqArgs, g.QualifiedGoIdent(ident("context.Context")))
|
|
|
+ reqArgs = append(reqArgs, g.QualifiedGoIdent(contextPackage.Ident("Context")))
|
|
|
ret = "(*" + g.QualifiedGoIdent(method.OutputType.GoIdent) + ", error)"
|
|
|
}
|
|
|
if !method.Desc.IsStreamingClient() {
|
|
|
@@ -286,15 +291,15 @@ func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
|
|
|
hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName)
|
|
|
|
|
|
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
|
|
|
- g.P("func ", hname, "(srv interface{}, ctx ", ident("context.Context"), ", dec func(interface{}) error, interceptor ", ident("grpc.UnaryServerInterceptor"), ") (interface{}, error) {")
|
|
|
+ g.P("func ", hname, "(srv interface{}, ctx ", contextPackage.Ident("Context"), ", dec func(interface{}) error, interceptor ", grpcPackage.Ident("UnaryServerInterceptor"), ") (interface{}, error) {")
|
|
|
g.P("in := new(", method.InputType.GoIdent, ")")
|
|
|
g.P("if err := dec(in); err != nil { return nil, err }")
|
|
|
g.P("if interceptor == nil { return srv.(", service.GoName, "Server).", method.GoName, "(ctx, in) }")
|
|
|
- g.P("info := &", ident("grpc.UnaryServerInfo"), "{")
|
|
|
+ g.P("info := &", grpcPackage.Ident("UnaryServerInfo"), "{")
|
|
|
g.P("Server: srv,")
|
|
|
g.P("FullMethod: ", strconv.Quote(fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name())), ",")
|
|
|
g.P("}")
|
|
|
- g.P("handler := func(ctx ", ident("context.Context"), ", req interface{}) (interface{}, error) {")
|
|
|
+ g.P("handler := func(ctx ", contextPackage.Ident("Context"), ", req interface{}) (interface{}, error) {")
|
|
|
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(ctx, req.(*", method.InputType.GoIdent, "))")
|
|
|
g.P("}")
|
|
|
g.P("return interceptor(ctx, in, info, handler)")
|
|
|
@@ -303,7 +308,7 @@ func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
|
|
|
return hname
|
|
|
}
|
|
|
streamType := unexport(service.GoName) + method.GoName + "Server"
|
|
|
- g.P("func ", hname, "(srv interface{}, stream ", ident("grpc.ServerStream"), ") error {")
|
|
|
+ g.P("func ", hname, "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {")
|
|
|
if !method.Desc.IsStreamingClient() {
|
|
|
g.P("m := new(", method.InputType.GoIdent, ")")
|
|
|
g.P("if err := stream.RecvMsg(m); err != nil { return err }")
|
|
|
@@ -329,12 +334,12 @@ func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
|
|
|
if genRecv {
|
|
|
g.P("Recv() (*", method.InputType.GoIdent, ", error)")
|
|
|
}
|
|
|
- g.P(ident("grpc.ServerStream"))
|
|
|
+ g.P(grpcPackage.Ident("ServerStream"))
|
|
|
g.P("}")
|
|
|
g.P()
|
|
|
|
|
|
g.P("type ", streamType, " struct {")
|
|
|
- g.P(ident("grpc.ServerStream"))
|
|
|
+ g.P(grpcPackage.Ident("ServerStream"))
|
|
|
g.P("}")
|
|
|
g.P()
|
|
|
|
|
|
@@ -362,19 +367,6 @@ func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
|
|
|
return hname
|
|
|
}
|
|
|
|
|
|
-var packages = map[string]protogen.GoImportPath{
|
|
|
- "context": "golang.org/x/net/context",
|
|
|
- "grpc": "google.golang.org/grpc",
|
|
|
-}
|
|
|
-
|
|
|
-func ident(name string) protogen.GoIdent {
|
|
|
- idx := strings.LastIndex(name, ".")
|
|
|
- return protogen.GoIdent{
|
|
|
- GoImportPath: packages[name[:idx]],
|
|
|
- GoName: name[idx+1:],
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
const deprecationComment = "// Deprecated: Do not use."
|
|
|
|
|
|
func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] }
|