Procházet zdrojové kódy

protogen: add GoImportPath.Ident helper

The GoImportPath.Ident helper creates a GoIdent using the receiver
as the GoImportPath in the GoIdent. This helper helps with the construction
of qualified identifiers.

Example usage:
	const protoPackage = protogen.GoImportPath("github.com/golang/protobuf/proto")
	protoPackage.Ident("ExtensionRange") // produces "proto.ExtensionRange"

The advantage of this helper is that usage of it looks similar to how
the identifier will eventually be rendered.

This is significantly more readable than the current approach:
	protogen.GoIdent{
		GoImportPath: protoPackage,
		GoName: "ExtensionRange",
	}

Change-Id: If7ecd7e60fad12bc491eee0dcb05f8fdebc9c94e
Reviewed-on: https://go-review.googlesource.com/c/150058
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai před 7 roky
rodič
revize
c1c17aa013

+ 26 - 66
cmd/protoc-gen-go/internal_gengo/main.go

@@ -29,7 +29,11 @@ import (
 // a constant, proto.ProtoPackageIsVersionN (where N is generatedCodeVersion).
 const generatedCodeVersion = 2
 
-const protoPackage = "github.com/golang/protobuf/proto"
+const (
+	fmtPackage   = protogen.GoImportPath("fmt")
+	mathPackage  = protogen.GoImportPath("math")
+	protoPackage = protogen.GoImportPath("github.com/golang/protobuf/proto")
+)
 
 type fileInfo struct {
 	*protogen.File
@@ -88,19 +92,17 @@ func GenerateFile(gen *protogen.Plugin, file *protogen.File, g *protogen.Generat
 	//
 	// TODO: Eventually remove this.
 	g.P("// Reference imports to suppress errors if they are not otherwise used.")
-	g.P("var _ = ", protogen.GoIdent{GoImportPath: protoPackage, GoName: "Marshal"})
-	g.P("var _ = ", protogen.GoIdent{GoImportPath: "fmt", GoName: "Errorf"})
-	g.P("var _ = ", protogen.GoIdent{GoImportPath: "math", GoName: "Inf"})
+	g.P("var _ = ", protoPackage.Ident("Marshal"))
+	g.P("var _ = ", fmtPackage.Ident("Errorf"))
+	g.P("var _ = ", mathPackage.Ident("Inf"))
 	g.P()
 
 	g.P("// This is a compile-time assertion to ensure that this generated file")
 	g.P("// is compatible with the proto package it is being compiled against.")
 	g.P("// A compilation error at this line likely means your copy of the")
 	g.P("// proto package needs to be updated.")
-	g.P("const _ = ", protogen.GoIdent{
-		GoImportPath: protoPackage,
-		GoName:       fmt.Sprintf("ProtoPackageIsVersion%d", generatedCodeVersion),
-	}, "// please upgrade the proto package")
+	g.P("const _ = ", protoPackage.Ident(fmt.Sprintf("ProtoPackageIsVersion%d", generatedCodeVersion)),
+		"// please upgrade the proto package")
 	g.P()
 
 	for i, imps := 0, f.Desc.Imports(); i < imps.Len(); i++ {
@@ -284,13 +286,13 @@ func genEnum(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, enum
 		g.P()
 	}
 	g.P("func (x ", enum.GoIdent, ") String() string {")
-	g.P("return ", protogen.GoIdent{GoImportPath: protoPackage, GoName: "EnumName"}, "(", enum.GoIdent, "_name, int32(x))")
+	g.P("return ", protoPackage.Ident("EnumName"), "(", enum.GoIdent, "_name, int32(x))")
 	g.P("}")
 	g.P()
 
 	if enum.Desc.Syntax() != protoreflect.Proto3 {
 		g.P("func (x *", enum.GoIdent, ") UnmarshalJSON(data []byte) error {")
-		g.P("value, err := ", protogen.GoIdent{GoImportPath: protoPackage, GoName: "UnmarshalJSONEnum"}, "(", enum.GoIdent, `_value, data, "`, enum.GoIdent, `")`)
+		g.P("value, err := ", protoPackage.Ident("UnmarshalJSONEnum"), "(", enum.GoIdent, `_value, data, "`, enum.GoIdent, `")`)
 		g.P("if err != nil {")
 		g.P("return err")
 		g.P("}")
@@ -389,10 +391,7 @@ func genMessage(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, me
 			tags = append(tags, `protobuf_messageset:"1"`)
 		}
 		tags = append(tags, `json:"-"`)
-		g.P(protogen.GoIdent{
-			GoImportPath: protoPackage,
-			GoName:       "XXX_InternalExtensions",
-		}, " `", strings.Join(tags, " "), "`")
+		g.P(protoPackage.Ident("XXX_InternalExtensions"), " `", strings.Join(tags, " "), "`")
 	}
 	// TODO XXX_InternalExtensions
 	g.P("XXX_unrecognized []byte `json:\"-\"`")
@@ -403,10 +402,7 @@ func genMessage(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, me
 	// Reset
 	g.P("func (m *", message.GoIdent, ") Reset() { *m = ", message.GoIdent, "{} }")
 	// String
-	g.P("func (m *", message.GoIdent, ") String() string { return ", protogen.GoIdent{
-		GoImportPath: protoPackage,
-		GoName:       "CompactTextString",
-	}, "(m) }")
+	g.P("func (m *", message.GoIdent, ") String() string { return ", protoPackage.Ident("CompactTextString"), "(m) }")
 	// ProtoMessage
 	g.P("func (*", message.GoIdent, ") ProtoMessage() {}")
 	// Descriptor
@@ -423,24 +419,15 @@ func genMessage(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, me
 	if extranges := message.Desc.ExtensionRanges(); extranges.Len() > 0 {
 		if message.Desc.Options().(*descpb.MessageOptions).GetMessageSetWireFormat() {
 			g.P("func (m *", message.GoIdent, ") MarshalJSON() ([]byte, error) {")
-			g.P("return ", protogen.GoIdent{
-				GoImportPath: protoPackage,
-				GoName:       "MarshalMessageSetJSON",
-			}, "(&m.XXX_InternalExtensions)")
+			g.P("return ", protoPackage.Ident("MarshalMessageSetJSON"), "(&m.XXX_InternalExtensions)")
 			g.P("}")
 			g.P("func (m *", message.GoIdent, ") UnmarshalJSON(buf []byte) error {")
-			g.P("return ", protogen.GoIdent{
-				GoImportPath: protoPackage,
-				GoName:       "UnmarshalMessageSetJSON",
-			}, "(buf, &m.XXX_InternalExtensions)")
+			g.P("return ", protoPackage.Ident("UnmarshalMessageSetJSON"), "(buf, &m.XXX_InternalExtensions)")
 			g.P("}")
 			g.P()
 		}
 
-		protoExtRange := protogen.GoIdent{
-			GoImportPath: protoPackage,
-			GoName:       "ExtensionRange",
-		}
+		protoExtRange := protoPackage.Ident("ExtensionRange")
 		extRangeVar := "extRange_" + message.GoIdent.GoName
 		g.P("var ", extRangeVar, " = []", protoExtRange, " {")
 		for i := 0; i < extranges.Len(); i++ {
@@ -486,10 +473,7 @@ func genMessage(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, me
 	g.P(messageInfoVar, ".DiscardUnknown(m)")
 	g.P("}")
 	g.P()
-	g.P("var ", messageInfoVar, " ", protogen.GoIdent{
-		GoImportPath: protoPackage,
-		GoName:       "InternalMessageInfo",
-	})
+	g.P("var ", messageInfoVar, " ", protoPackage.Ident("InternalMessageInfo"))
 	g.P()
 
 	// Constants and vars holding the default values of fields.
@@ -519,10 +503,7 @@ func genMessage(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, me
 			// funcCall returns a call to a function in the math package,
 			// possibly converting the result to float32.
 			funcCall := func(fn, param string) string {
-				s := g.QualifiedGoIdent(protogen.GoIdent{
-					GoImportPath: "math",
-					GoName:       fn,
-				}) + param
+				s := g.QualifiedGoIdent(mathPackage.Ident(fn)) + param
 				if goType != "float64" {
 					s = goType + "(" + s + ")"
 				}
@@ -707,10 +688,7 @@ func genExtension(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo,
 		name = n
 	}
 
-	g.P("var ", extensionVar(f.File, extension), " = &", protogen.GoIdent{
-		GoImportPath: protoPackage,
-		GoName:       "ExtensionDesc",
-	}, "{")
+	g.P("var ", extensionVar(f.File, extension), " = &", protoPackage.Ident("ExtensionDesc"), "{")
 	g.P("ExtendedType: (*", extension.ExtendedType.GoIdent, ")(nil),")
 	goType, pointer := fieldGoType(g, extension)
 	if pointer {
@@ -755,10 +733,7 @@ func extensionVar(f *protogen.File, extension *protogen.Extension) protogen.GoId
 		name += extension.ParentMessage.GoIdent.GoName + "_"
 	}
 	name += extension.GoName
-	return protogen.GoIdent{
-		GoImportPath: f.GoImportPath,
-		GoName:       name,
-	}
+	return f.GoImportPath.Ident(name)
 }
 
 // genInitFunction generates an init function that registers the types in the
@@ -771,10 +746,7 @@ func genInitFunction(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInf
 	g.P("func init() {")
 	for _, enum := range f.allEnums {
 		name := enum.GoIdent.GoName
-		g.P(protogen.GoIdent{
-			GoImportPath: protoPackage,
-			GoName:       "RegisterEnum",
-		}, fmt.Sprintf("(%q, %s_name, %s_value)", enumRegistryName(enum), name, name))
+		g.P(protoPackage.Ident("RegisterEnum"), fmt.Sprintf("(%q, %s_name, %s_value)", enumRegistryName(enum), name, name))
 	}
 	for _, message := range f.allMessages {
 		if message.Desc.IsMapEntry() {
@@ -786,10 +758,7 @@ func genInitFunction(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInf
 		}
 
 		name := message.GoIdent.GoName
-		g.P(protogen.GoIdent{
-			GoImportPath: protoPackage,
-			GoName:       "RegisterType",
-		}, fmt.Sprintf("((*%s)(nil), %q)", name, message.Desc.FullName()))
+		g.P(protoPackage.Ident("RegisterType"), fmt.Sprintf("((*%s)(nil), %q)", name, message.Desc.FullName()))
 
 		// Types of map fields, sorted by the name of the field message type.
 		var mapFields []*protogen.Field
@@ -806,10 +775,7 @@ func genInitFunction(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInf
 		for _, field := range mapFields {
 			typeName := string(field.MessageType.Desc.FullName())
 			goType, _ := fieldGoType(g, field)
-			g.P(protogen.GoIdent{
-				GoImportPath: protoPackage,
-				GoName:       "RegisterMapType",
-			}, fmt.Sprintf("((%v)(nil), %q)", goType, typeName))
+			g.P(protoPackage.Ident("RegisterMapType"), fmt.Sprintf("((%v)(nil), %q)", goType, typeName))
 		}
 	}
 	for _, extension := range f.Extensions {
@@ -820,19 +786,13 @@ func genInitFunction(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInf
 }
 
 func genRegisterExtension(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, extension *protogen.Extension) {
-	g.P(protogen.GoIdent{
-		GoImportPath: protoPackage,
-		GoName:       "RegisterExtension",
-	}, "(", extensionVar(f.File, extension), ")")
+	g.P(protoPackage.Ident("RegisterExtension"), "(", extensionVar(f.File, extension), ")")
 	if name, ok := isExtensionMessageSetElement(extension); ok {
 		goType, pointer := fieldGoType(g, extension)
 		if pointer {
 			goType = "*" + goType
 		}
-		g.P(protogen.GoIdent{
-			GoImportPath: protoPackage,
-			GoName:       "RegisterMessageSetType",
-		}, "((", goType, ")(nil), ", extension.Desc.Number(), ",", strconv.Quote(string(name)), ")")
+		g.P(protoPackage.Ident("RegisterMessageSetType"), "((", goType, ")(nil), ", extension.Desc.Number(), ",", strconv.Quote(string(name)), ")")
 	}
 }
 

+ 13 - 46
cmd/protoc-gen-go/internal_gengo/oneof.go

@@ -81,14 +81,8 @@ func oneofInterfaceName(oneof *protogen.Oneof) string {
 
 // genOneofFuncs generates the XXX_OneofFuncs method for a message.
 func genOneofFuncs(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, message *protogen.Message) {
-	protoMessage := g.QualifiedGoIdent(protogen.GoIdent{
-		GoImportPath: protoPackage,
-		GoName:       "Message",
-	})
-	protoBuffer := g.QualifiedGoIdent(protogen.GoIdent{
-		GoImportPath: protoPackage,
-		GoName:       "Buffer",
-	})
+	protoMessage := g.QualifiedGoIdent(protoPackage.Ident("Message"))
+	protoBuffer := g.QualifiedGoIdent(protoPackage.Ident("Buffer"))
 	encFunc := "_" + message.GoIdent.GoName + "_OneofMarshaler"
 	decFunc := "_" + message.GoIdent.GoName + "_OneofUnmarshaler"
 	sizeFunc := "_" + message.GoIdent.GoName + "_OneofSizer"
@@ -120,10 +114,7 @@ func genOneofFuncs(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo,
 		}
 		g.P("case nil:")
 		g.P("default:")
-		g.P("return ", protogen.GoIdent{
-			GoImportPath: "fmt",
-			GoName:       "Errorf",
-		}, `("`, message.GoIdent.GoName, ".", oneofFieldName(oneof), ` has unexpected type %T", x)`)
+		g.P("return ", fmtPackage.Ident("Errorf"), `("`, message.GoIdent.GoName, ".", oneofFieldName(oneof), ` has unexpected type %T", x)`)
 		g.P("}")
 	}
 	g.P("return nil")
@@ -156,10 +147,7 @@ func genOneofFuncs(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo,
 		}
 		g.P("case nil:")
 		g.P("default:")
-		g.P("panic(", protogen.GoIdent{
-			GoImportPath: "fmt",
-			GoName:       "Sprintf",
-		}, `("proto: unexpected type %T in oneof", x))`)
+		g.P("panic(", fmtPackage.Ident("Sprintf"), `("proto: unexpected type %T in oneof", x))`)
 		g.P("}")
 	}
 	g.P("return n")
@@ -171,10 +159,7 @@ func genOneofFuncs(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo,
 func genOneofFieldMarshal(g *protogen.GeneratedFile, field *protogen.Field) {
 	g.P("case *", fieldOneofType(field), ":")
 	encodeTag := func(wireType string) {
-		g.P("b.EncodeVarint(", field.Desc.Number(), "<<3|", protogen.GoIdent{
-			GoImportPath: protoPackage,
-			GoName:       wireType,
-		}, ")")
+		g.P("b.EncodeVarint(", field.Desc.Number(), "<<3|", protoPackage.Ident(wireType), ")")
 	}
 	switch field.Desc.Kind() {
 	case protoreflect.BoolKind:
@@ -196,19 +181,13 @@ func genOneofFieldMarshal(g *protogen.GeneratedFile, field *protogen.Field) {
 		g.P("b.EncodeFixed32(uint64(x.", field.GoName, "))")
 	case protoreflect.FloatKind:
 		encodeTag("WireFixed32")
-		g.P("b.EncodeFixed32(uint64(", protogen.GoIdent{
-			GoImportPath: "math",
-			GoName:       "Float32bits",
-		}, "(x.", field.GoName, ")))")
+		g.P("b.EncodeFixed32(uint64(", mathPackage.Ident("Float32bits"), "(x.", field.GoName, ")))")
 	case protoreflect.Sfixed64Kind, protoreflect.Fixed64Kind:
 		encodeTag("WireFixed64")
 		g.P("b.EncodeFixed64(uint64(x.", field.GoName, "))")
 	case protoreflect.DoubleKind:
 		encodeTag("WireFixed64")
-		g.P("b.EncodeFixed64(", protogen.GoIdent{
-			GoImportPath: "math",
-			GoName:       "Float64bits",
-		}, "(x.", field.GoName, "))")
+		g.P("b.EncodeFixed64(", mathPackage.Ident("Float64bits"), "(x.", field.GoName, "))")
 	case protoreflect.StringKind:
 		encodeTag("WireBytes")
 		g.P("b.EncodeStringBytes(x.", field.GoName, ")")
@@ -234,14 +213,8 @@ func genOneofFieldUnmarshal(g *protogen.GeneratedFile, field *protogen.Field) {
 	oneof := field.OneofType
 	g.P("case ", field.Desc.Number(), ": // ", oneof.Desc.Name(), ".", field.Desc.Name())
 	checkTag := func(wireType string) {
-		g.P("if wire != ", protogen.GoIdent{
-			GoImportPath: protoPackage,
-			GoName:       wireType,
-		}, " {")
-		g.P("return true, ", protogen.GoIdent{
-			GoImportPath: protoPackage,
-			GoName:       "ErrInternalBadWireType",
-		})
+		g.P("if wire != ", protoPackage.Ident(wireType), " {")
+		g.P("return true, ", protoPackage.Ident("ErrInternalBadWireType"))
 		g.P("}")
 	}
 	switch field.Desc.Kind() {
@@ -280,10 +253,7 @@ func genOneofFieldUnmarshal(g *protogen.GeneratedFile, field *protogen.Field) {
 	case protoreflect.FloatKind:
 		checkTag("WireFixed32")
 		g.P("x, err := b.DecodeFixed32()")
-		g.P("m.", oneofFieldName(oneof), " = &", fieldOneofType(field), "{", protogen.GoIdent{
-			GoImportPath: "math",
-			GoName:       "Float32frombits",
-		}, "(uint32(x))}")
+		g.P("m.", oneofFieldName(oneof), " = &", fieldOneofType(field), "{", mathPackage.Ident("Float32frombits"), "(uint32(x))}")
 	case protoreflect.Sfixed64Kind:
 		checkTag("WireFixed64")
 		g.P("x, err := b.DecodeFixed64()")
@@ -295,10 +265,7 @@ func genOneofFieldUnmarshal(g *protogen.GeneratedFile, field *protogen.Field) {
 	case protoreflect.DoubleKind:
 		checkTag("WireFixed64")
 		g.P("x, err := b.DecodeFixed64()")
-		g.P("m.", oneofFieldName(oneof), " = &", fieldOneofType(field), "{", protogen.GoIdent{
-			GoImportPath: "math",
-			GoName:       "Float64frombits",
-		}, "(x)}")
+		g.P("m.", oneofFieldName(oneof), " = &", fieldOneofType(field), "{", mathPackage.Ident("Float64frombits"), "(x)}")
 	case protoreflect.StringKind:
 		checkTag("WireBytes")
 		g.P("x, err := b.DecodeStringBytes()")
@@ -323,8 +290,8 @@ func genOneofFieldUnmarshal(g *protogen.GeneratedFile, field *protogen.Field) {
 
 // genOneofFieldSizer  generates the sizer case for a oneof subfield.
 func genOneofFieldSizer(g *protogen.GeneratedFile, field *protogen.Field) {
-	sizeProto := protogen.GoIdent{GoImportPath: protoPackage, GoName: "Size"}
-	sizeVarint := protogen.GoIdent{GoImportPath: protoPackage, GoName: "SizeVarint"}
+	sizeProto := protoPackage.Ident("Size")
+	sizeVarint := protoPackage.Ident("SizeVarint")
 	g.P("case *", fieldOneofType(field), ":")
 	if field.Desc.Kind() == protoreflect.MessageKind {
 		g.P("s := ", sizeProto, "(x.", field.GoName, ")")

+ 5 - 0
protogen/names.go

@@ -33,6 +33,11 @@ type GoImportPath string
 
 func (p GoImportPath) String() string { return strconv.Quote(string(p)) }
 
+// Ident returns a GoIdent with s as the GoName and p as the GoImportPath.
+func (p GoImportPath) Ident(s string) GoIdent {
+	return GoIdent{GoName: s, GoImportPath: p}
+}
+
 // A GoPackageName is the name of a Go package. e.g., "protobuf".
 type GoPackageName string
 

+ 2 - 5
protogen/protogen.go

@@ -748,11 +748,8 @@ func newEnumValue(gen *Plugin, f *File, message *Message, enum *Enum, desc proto
 	}
 	name := parentIdent.GoName + "_" + string(desc.Name())
 	return &EnumValue{
-		Desc: desc,
-		GoIdent: GoIdent{
-			GoName:       name,
-			GoImportPath: f.GoImportPath,
-		},
+		Desc:     desc,
+		GoIdent:  f.GoImportPath.Ident(name),
 		Location: enum.Location.appendPath(enumValueField, int32(desc.Index())),
 	}
 }