Przeglądaj źródła

cmd/protoc-gen-go: generate public imports by parsing the imported .pb.go

Rather than explicitly enumerating the set of symbols to import,
just parse the imported file and extract every exported symbol.

This is possibly a bit more code, but adapts much better to future
expansion.

Change-Id: I4429664f4c068a2a55949d46aefc19865b008a77
Reviewed-on: https://go-review.googlesource.com/c/155677
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
Damien Neil 7 lat temu
rodzic
commit
7bf3ce2145

+ 48 - 54
cmd/protoc-gen-go/internal_gengo/main.go

@@ -11,10 +11,15 @@ import (
 	"crypto/sha256"
 	"encoding/hex"
 	"fmt"
+	"go/ast"
+	"go/parser"
+	"go/token"
 	"math"
 	"sort"
 	"strconv"
 	"strings"
+	"unicode"
+	"unicode/utf8"
 
 	"github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/v2/internal/encoding/tag"
@@ -178,65 +183,54 @@ func genImport(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, imp
 	if !imp.IsPublic {
 		return
 	}
-	// TODO: An alternate approach to generating public imports might be
-	// to generate the imported file contents, parse it, and extract all
-	// exported identifiers from the AST to build a list of forwarding
-	// declarations.
-	//
-	// TODO: Consider whether this should generate recursive aliases. e.g.,
-	// if a.proto publicly imports b.proto publicly imports c.proto, should
-	// a.pb.go contain aliases for symbols defined in c.proto?
-	var enums []*protogen.Enum
-	enums = append(enums, impFile.Enums...)
-	walkMessages(impFile.Messages, func(message *protogen.Message) {
-		if message.Desc.IsMapEntry() {
+
+	// Generate public imports by generating the imported file, parsing it,
+	// and extracting every symbol that should receive a forwarding declaration.
+	impGen := gen.NewGeneratedFile("temp.go", impFile.GoImportPath)
+	impGen.Skip()
+	GenerateFile(gen, impFile, impGen)
+	b, err := impGen.Content()
+	if err != nil {
+		gen.Error(err)
+		return
+	}
+	fset := token.NewFileSet()
+	astFile, err := parser.ParseFile(fset, "", b, parser.ParseComments)
+	if err != nil {
+		gen.Error(err)
+		return
+	}
+	genForward := func(tok token.Token, name string) {
+		// Don't import unexported symbols.
+		r, _ := utf8.DecodeRuneInString(name)
+		if !unicode.IsUpper(r) {
 			return
 		}
-		enums = append(enums, message.Enums...)
-		for _, field := range message.Fields {
-			if !field.Desc.HasDefault() {
-				continue
-			}
-			defVar := protogen.GoIdent{
-				GoImportPath: message.GoIdent.GoImportPath,
-				GoName:       "Default_" + message.GoIdent.GoName + "_" + field.GoName,
-			}
-			decl := "const"
-			switch field.Desc.Kind() {
-			case protoreflect.BytesKind:
-				decl = "var"
-			case protoreflect.FloatKind, protoreflect.DoubleKind:
-				f := field.Desc.Default().Float()
-				if math.IsInf(f, -1) || math.IsInf(f, 1) || math.IsNaN(f) {
-					decl = "var"
-				}
-			}
-			g.P(decl, " ", defVar.GoName, " = ", defVar)
+		// Don't import the FileDescriptor.
+		if name == impFile.GoDescriptorIdent.GoName {
+			return
 		}
-		g.P("// ", message.GoIdent.GoName, " from public import ", imp.Path())
-		g.P("type ", message.GoIdent.GoName, " = ", message.GoIdent)
-		for _, oneof := range message.Oneofs {
-			for _, field := range oneof.Fields {
-				typ := fieldOneofType(field)
-				g.P("type ", typ.GoName, " = ", typ)
+		g.P(tok, " ", name, " = ", impFile.GoImportPath.Ident(name))
+	}
+	g.P("// Symbols defined in public import of ", imp.Path())
+	g.P()
+	for _, decl := range astFile.Decls {
+		switch decl := decl.(type) {
+		case *ast.GenDecl:
+			for _, spec := range decl.Specs {
+				switch spec := spec.(type) {
+				case *ast.TypeSpec:
+					genForward(decl.Tok, spec.Name.Name)
+				case *ast.ValueSpec:
+					for _, name := range spec.Names {
+						genForward(decl.Tok, name.Name)
+					}
+				case *ast.ImportSpec:
+				default:
+					panic(fmt.Sprintf("can't generate forward for spec type %T", spec))
+				}
 			}
 		}
-		g.P()
-	})
-	for _, enum := range enums {
-		g.P("// ", enum.GoIdent.GoName, " from public import ", imp.Path())
-		g.P("type ", enum.GoIdent.GoName, " = ", enum.GoIdent)
-		g.P("var ", enum.GoIdent.GoName, "_name = ", enum.GoIdent, "_name")
-		g.P("var ", enum.GoIdent.GoName, "_value = ", enum.GoIdent, "_value")
-		g.P()
-		for _, value := range enum.Values {
-			g.P("const ", value.GoIdent.GoName, " = ", enum.GoIdent.GoName, "(", value.GoIdent, ")")
-		}
-	}
-	for _, ext := range impFile.Extensions {
-		ident := extensionVar(impFile, ext)
-		g.P("var ", ident.GoName, " = ", ident)
-		g.P()
 	}
 	g.P()
 }

+ 19 - 22
cmd/protoc-gen-go/testdata/import_public/a.pb.go

@@ -17,44 +17,41 @@ import (
 // proto package needs to be updated.
 const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
 
-const Default_M_S = sub.Default_M_S
-
-var Default_M_B = sub.Default_M_B
-var Default_M_F = sub.Default_M_F
-
-// M from public import import_public/sub/a.proto
-type M = sub.M
-type M_OneofInt32 = sub.M_OneofInt32
-type M_OneofInt64 = sub.M_OneofInt64
+// Symbols defined in public import of import_public/sub/a.proto
 
-// M_Submessage from public import import_public/sub/a.proto
-type M_Submessage = sub.M_Submessage
-type M_Submessage_SubmessageOneofInt32 = sub.M_Submessage_SubmessageOneofInt32
-type M_Submessage_SubmessageOneofInt64 = sub.M_Submessage_SubmessageOneofInt64
-
-// E from public import import_public/sub/a.proto
 type E = sub.E
 
+const E_ZERO = sub.E_ZERO
+
 var E_name = sub.E_name
 var E_value = sub.E_value
 
-const E_ZERO = E(sub.E_ZERO)
-
-// M_Subenum from public import import_public/sub/a.proto
 type M_Subenum = sub.M_Subenum
 
+const M_M_ZERO = sub.M_M_ZERO
+
 var M_Subenum_name = sub.M_Subenum_name
 var M_Subenum_value = sub.M_Subenum_value
 
-const M_M_ZERO = M_Subenum(sub.M_M_ZERO)
-
-// M_Submessage_Submessage_Subenum from public import import_public/sub/a.proto
 type M_Submessage_Submessage_Subenum = sub.M_Submessage_Submessage_Subenum
 
+const M_Submessage_M_SUBMESSAGE_ZERO = sub.M_Submessage_M_SUBMESSAGE_ZERO
+
 var M_Submessage_Submessage_Subenum_name = sub.M_Submessage_Submessage_Subenum_name
 var M_Submessage_Submessage_Subenum_value = sub.M_Submessage_Submessage_Subenum_value
 
-const M_Submessage_M_SUBMESSAGE_ZERO = M_Submessage_Submessage_Subenum(sub.M_Submessage_M_SUBMESSAGE_ZERO)
+type M = sub.M
+
+const Default_M_S = sub.Default_M_S
+
+var Default_M_B = sub.Default_M_B
+var Default_M_F = sub.Default_M_F
+
+type M_OneofInt32 = sub.M_OneofInt32
+type M_OneofInt64 = sub.M_OneofInt64
+type M_Submessage = sub.M_Submessage
+type M_Submessage_SubmessageOneofInt32 = sub.M_Submessage_SubmessageOneofInt32
+type M_Submessage_SubmessageOneofInt64 = sub.M_Submessage_SubmessageOneofInt64
 
 var E_ExtensionField = sub.E_ExtensionField
 

+ 40 - 31
protogen/protogen.go

@@ -342,7 +342,10 @@ func (gen *Plugin) Response() *pluginpb.CodeGeneratorResponse {
 		return resp
 	}
 	for _, g := range gen.genFiles {
-		content, err := g.content()
+		if g.skip {
+			continue
+		}
+		content, err := g.Content()
 		if err != nil {
 			return &pluginpb.CodeGeneratorResponse{
 				Error: scalar.String(err.Error()),
@@ -761,34 +764,6 @@ func newEnumValue(gen *Plugin, f *File, message *Message, enum *Enum, desc proto
 	}
 }
 
-// A GeneratedFile is a generated file.
-type GeneratedFile struct {
-	gen              *Plugin
-	filename         string
-	goImportPath     GoImportPath
-	buf              bytes.Buffer
-	packageNames     map[GoImportPath]GoPackageName
-	usedPackageNames map[GoPackageName]bool
-	manualImports    map[GoImportPath]bool
-	annotations      map[string][]Location
-}
-
-// NewGeneratedFile creates a new generated file with the given filename
-// and import path.
-func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile {
-	g := &GeneratedFile{
-		gen:              gen,
-		filename:         filename,
-		goImportPath:     goImportPath,
-		packageNames:     make(map[GoImportPath]GoPackageName),
-		usedPackageNames: make(map[GoPackageName]bool),
-		manualImports:    make(map[GoImportPath]bool),
-		annotations:      make(map[string][]Location),
-	}
-	gen.genFiles = append(gen.genFiles, g)
-	return g
-}
-
 // A Service describes a service.
 type Service struct {
 	Desc protoreflect.ServiceDescriptor
@@ -851,6 +826,35 @@ func (method *Method) init(gen *Plugin) error {
 	return nil
 }
 
+// A GeneratedFile is a generated file.
+type GeneratedFile struct {
+	gen              *Plugin
+	skip             bool
+	filename         string
+	goImportPath     GoImportPath
+	buf              bytes.Buffer
+	packageNames     map[GoImportPath]GoPackageName
+	usedPackageNames map[GoPackageName]bool
+	manualImports    map[GoImportPath]bool
+	annotations      map[string][]Location
+}
+
+// NewGeneratedFile creates a new generated file with the given filename
+// and import path.
+func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile {
+	g := &GeneratedFile{
+		gen:              gen,
+		filename:         filename,
+		goImportPath:     goImportPath,
+		packageNames:     make(map[GoImportPath]GoPackageName),
+		usedPackageNames: make(map[GoPackageName]bool),
+		manualImports:    make(map[GoImportPath]bool),
+		annotations:      make(map[string][]Location),
+	}
+	gen.genFiles = append(gen.genFiles, g)
+	return g
+}
+
 // P prints a line to the generated output. It converts each parameter to a
 // string following the same rules as fmt.Print. It never inserts spaces
 // between parameters.
@@ -924,6 +928,11 @@ func (g *GeneratedFile) Write(p []byte) (n int, err error) {
 	return g.buf.Write(p)
 }
 
+// Skip removes the generated file from the plugin output.
+func (g *GeneratedFile) Skip() {
+	g.skip = true
+}
+
 // Annotate associates a symbol in a generated Go file with a location in a
 // source .proto file.
 //
@@ -934,8 +943,8 @@ func (g *GeneratedFile) Annotate(symbol string, loc Location) {
 	g.annotations[symbol] = append(g.annotations[symbol], loc)
 }
 
-// content returns the contents of the generated file.
-func (g *GeneratedFile) content() ([]byte, error) {
+// Content returns the contents of the generated file.
+func (g *GeneratedFile) Content() ([]byte, error) {
 	if !strings.HasSuffix(g.filename, ".go") {
 		return g.buf.Bytes(), nil
 	}

+ 4 - 4
protogen/protogen_test.go

@@ -297,9 +297,9 @@ var _ = bar1.X    // "golang.org/y/bar"
 var _ = baz.X     // "golang.org/x/baz"
 var _ = string1.X // "golang.org/z/string"
 `
-	got, err := g.content()
+	got, err := g.Content()
 	if err != nil {
-		t.Fatalf("g.content() = %v", err)
+		t.Fatalf("g.Content() = %v", err)
 	}
 	if want != string(got) {
 		t.Fatalf(`want:
@@ -333,9 +333,9 @@ import bar "prefix/golang.org/x/bar"
 
 var _ = bar.X
 `
-	got, err := g.content()
+	got, err := g.Content()
 	if err != nil {
-		t.Fatalf("g.content() = %v", err)
+		t.Fatalf("g.Content() = %v", err)
 	}
 	if want != string(got) {
 		t.Fatalf(`want: