Преглед на файлове

Big sync from internal version:
- include MessageSet support code
- support message_set_wire_format for extensions
- use append throughout encode.go

R=r
CC=golang-dev
http://codereview.appspot.com/3023041

David Symonds преди 15 години
родител
ревизия
4fee3b12e7

+ 1 - 0
.hgignore

@@ -15,5 +15,6 @@ syntax:glob
 core
 _obj
 _test
+_testmain.go
 compiler/protoc-gen-go
 compiler/testdata/extension_test

+ 22 - 2
compiler/generator/generator.go

@@ -630,8 +630,9 @@ func (g *Generator) generateHeader() {
 func (g *Generator) generateImports() {
 	// We almost always need a proto import.  Rather than computing when we
 	// do, which is tricky when there's a plugin, just import it and
-	// reference it later.
+	// reference it later. The same argument applies to the os package.
 	g.P("import " + g.ProtoPkg + " " + Quote(g.ImportPrefix+"goprotobuf.googlecode.com/hg/proto"))
+	g.P(`import "os"`)
 	for _, s := range g.file.Dependency {
 		// Need to find the descriptor for this file
 		for _, fd := range g.allFiles {
@@ -663,8 +664,9 @@ func (g *Generator) generateImports() {
 		p.GenerateImports(g.file)
 		g.P()
 	}
-	g.P("// Reference proto import to suppress error if it's not otherwise used.")
+	g.P("// Reference proto & os imports to suppress error if it's not otherwise used.")
 	g.P("var _ = ", g.ProtoPkg, ".GetString")
+	g.P("var _ os.Error")
 	g.P()
 }
 
@@ -898,6 +900,24 @@ func (g *Generator) generateMessage(message *Descriptor) {
 
 	// Extension support methods
 	if len(message.ExtensionRange) > 0 {
+		// message_set_wire_format only makes sense when extensions are defined.
+		if opts := message.Options; opts != nil && proto.GetBool(opts.MessageSetWireFormat) {
+			g.P()
+			g.P("func (this *", ccTypeName, ") Marshal() ([]byte, os.Error) {")
+			g.In()
+			g.P("return ", g.ProtoPkg, ".MarshalMessageSet(this.ExtensionMap())")
+			g.Out()
+			g.P("}")
+			g.P("func (this *", ccTypeName, ") Unmarshal(buf []byte) os.Error {")
+			g.In()
+			g.P("return ", g.ProtoPkg, ".UnmarshalMessageSet(buf, this.ExtensionMap())")
+			g.Out()
+			g.P("}")
+			g.P("// ensure ", ccTypeName, " satisfies proto.Marshaler and proto.Unmarshaler")
+			g.P("var _ ", g.ProtoPkg, ".Marshaler = (*", ccTypeName, ")(nil)")
+			g.P("var _ ", g.ProtoPkg, ".Unmarshaler = (*", ccTypeName, ")(nil)")
+		}
+
 		g.P()
 		g.P("var extRange_", ccTypeName, " = []", g.ProtoPkg, ".ExtensionRange{")
 		g.In()

+ 6 - 0
compiler/testdata/extension_base.proto

@@ -36,3 +36,9 @@ message BaseMessage {
   extensions 4 to 9;
   extensions 16 to max;
 }
+
+// Another message that may be extended, using message_set_wire_format.
+message OldStyleMessage {
+  option message_set_wire_format = true;
+  extensions 100 to max;
+}

+ 42 - 0
compiler/testdata/extension_test.go

@@ -34,6 +34,7 @@
 package main
 
 import (
+	"bytes"
 	"regexp"
 	"testing"
 
@@ -152,6 +153,47 @@ func TestTopLevelExtension(t *testing.T) {
 	}
 }
 
+func TestMessageSetWireFormat(t *testing.T) {
+	osm := new(base.OldStyleMessage)
+	osp := &user.OldStyleParcel{
+		Name:   proto.String("Dave"),
+		Height: proto.Int32(178),
+	}
+
+	err := proto.SetExtension(osm, user.E_OldStyleParcel_MessageSetExtension, osp)
+	if err != nil {
+		t.Fatal("Failed setting extension:", err)
+	}
+
+	buf, err := proto.Marshal(osm)
+	if err != nil {
+		t.Fatal("Failed encoding message:", err)
+	}
+
+	// Data generated from Python implementation.
+	expected := []byte{
+		11, 16, 209, 15, 26, 9, 10, 4, 68, 97, 118, 101, 16, 178, 1, 12,
+	}
+
+	if !bytes.Equal(expected, buf) {
+		t.Errorf("Encoding mismatch.\nwant %+v\n got %+v", expected, buf)
+	}
+
+	// Check that it is restored correctly.
+	osm = new(base.OldStyleMessage)
+	if err := proto.Unmarshal(buf, osm); err != nil {
+		t.Fatal("Failed decoding message:", err)
+	}
+	osp_out, err := proto.GetExtension(osm, user.E_OldStyleParcel_MessageSetExtension)
+	if err != nil {
+		t.Fatal("Failed getting extension:", err)
+	}
+	osp = osp_out.(*user.OldStyleParcel)
+	if *osp.Name != "Dave" || *osp.Height != 178 {
+		t.Errorf("Retrieved extension from decoded message is not correct: %+v", osp)
+	}
+}
+
 func main() {
 	// simpler than rigging up gotest
 	testing.Main(regexp.MatchString, []testing.InternalTest{

+ 10 - 0
compiler/testdata/extension_user.proto

@@ -71,3 +71,13 @@ message Announcement {
     optional Announcement loud_ext = 100;
   }
 }
+
+// Something that can be put in a message set.
+message OldStyleParcel {
+  extend extension_base.OldStyleMessage {
+    optional OldStyleParcel message_set_extension = 2001;
+  }
+
+  required string name = 1;
+  optional int32 height = 2;
+}

+ 3 - 1
compiler/testdata/test.pb.go.golden

@@ -4,10 +4,12 @@
 package my_test
 
 import proto "goprotobuf.googlecode.com/hg/proto"
+import "os"
 import imp "imp.pb"
 
-// Reference proto import to suppress error if it's not otherwise used.
+// Reference proto & os imports to suppress error if it's not otherwise used.
 var _ = proto.GetString
+var _ os.Error
 
 type HatType int32
 const (

+ 1 - 0
proto/Makefile

@@ -38,6 +38,7 @@ GOFILES=\
 	decode.go\
 	extensions.go\
 	lib.go\
+	message_set.go\
 	properties.go\
 	text.go\
 	text_parser.go\

+ 24 - 30
proto/encode.go

@@ -61,6 +61,8 @@ var ErrNil = os.NewError("marshal called with nil")
 // Those that take integer types all accept uint64 and are
 // therefore of type valueEncoder.
 
+const maxVarintBytes = 10 // maximum length of a varint
+
 // EncodeVarint returns the varint encoding of x.
 // This is the format for the
 // int32, int64, uint32, uint64, bool, and enum
@@ -68,7 +70,7 @@ var ErrNil = os.NewError("marshal called with nil")
 // Not used by the package itself, but helpful to clients
 // wishing to use the same encoding.
 func EncodeVarint(x uint64) []byte {
-	var buf [16]byte
+	var buf [maxVarintBytes]byte
 	var n int
 	for n = 0; x > 127; n++ {
 		buf[n] = 0x80 | uint8(x&0x7F)
@@ -79,31 +81,27 @@ func EncodeVarint(x uint64) []byte {
 	return buf[0:n]
 }
 
+var emptyBytes [maxVarintBytes]byte
+
 // EncodeVarint writes a varint-encoded integer to the Buffer.
 // This is the format for the
 // int32, int64, uint32, uint64, bool, and enum
 // protocol buffer types.
 func (p *Buffer) EncodeVarint(x uint64) os.Error {
 	l := len(p.buf)
-	c := cap(p.buf)
-	if l+10 > c {
-		c += c/2 + 10
-		obuf := make([]byte, c)
-		copy(obuf, p.buf)
-		p.buf = obuf
-	}
-	p.buf = p.buf[0:c]
-
-	for {
-		if x < 1<<7 {
-			break
-		}
+	if l+maxVarintBytes > cap(p.buf) { // not necessary except for performance
+		p.buf = append(p.buf, emptyBytes[:]...)
+	} else {
+		p.buf = p.buf[:l+maxVarintBytes]
+	}
+
+	for x >= 1<<7 {
 		p.buf[l] = uint8(x&0x7f | 0x80)
 		l++
 		x >>= 7
 	}
 	p.buf[l] = uint8(x)
-	p.buf = p.buf[0 : l+1]
+	p.buf = p.buf[:l+1]
 	return nil
 }
 
@@ -111,15 +109,13 @@ func (p *Buffer) EncodeVarint(x uint64) os.Error {
 // This is the format for the
 // fixed64, sfixed64, and double protocol buffer types.
 func (p *Buffer) EncodeFixed64(x uint64) os.Error {
+	const fixed64Bytes = 8
 	l := len(p.buf)
-	c := cap(p.buf)
-	if l+8 > c {
-		c += c/2 + 8
-		obuf := make([]byte, c)
-		copy(obuf, p.buf)
-		p.buf = obuf
+	if l+fixed64Bytes > cap(p.buf) { // not necessary except for performance
+		p.buf = append(p.buf, emptyBytes[:fixed64Bytes]...)
+	} else {
+		p.buf = p.buf[:l+fixed64Bytes]
 	}
-	p.buf = p.buf[0 : l+8]
 
 	p.buf[l] = uint8(x)
 	p.buf[l+1] = uint8(x >> 8)
@@ -136,15 +132,13 @@ func (p *Buffer) EncodeFixed64(x uint64) os.Error {
 // This is the format for the
 // fixed32, sfixed32, and float protocol buffer types.
 func (p *Buffer) EncodeFixed32(x uint64) os.Error {
+	const fixed32Bytes = 4
 	l := len(p.buf)
-	c := cap(p.buf)
-	if l+4 > c {
-		c += c/2 + 4
-		obuf := make([]byte, c)
-		copy(obuf, p.buf)
-		p.buf = obuf
-	}
-	p.buf = p.buf[0 : l+4]
+	if l+fixed32Bytes > cap(p.buf) { // not necessary except for performance
+		p.buf = append(p.buf, emptyBytes[:fixed32Bytes]...)
+	} else {
+		p.buf = p.buf[:l+fixed32Bytes]
+	}
 
 	p.buf[l] = uint8(x)
 	p.buf[l+1] = uint8(x >> 8)

+ 180 - 0
proto/message_set.go

@@ -0,0 +1,180 @@
+// Go support for Protocol Buffers - Google's data interchange format
+//
+// Copyright 2010 Google Inc.  All rights reserved.
+// http://code.google.com/p/goprotobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+//     * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+//     * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+//     * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+package proto
+
+/*
+ * Support for message sets.
+ */
+
+import (
+	"bytes"
+	"os"
+)
+
+// ErrNoMessageTypeId occurs when a protocol buffer does not have a message type ID.
+// A message type ID is required for storing a protocol buffer in a message set.
+var ErrNoMessageTypeId = os.NewError("proto does not have a message type ID")
+
+// The first two types (_MessageSet_Item and MessageSet)
+// model what the protocol compiler produces for the following protocol message:
+//   message MessageSet {
+//     repeated group Item = 1 {
+//       required int32 type_id = 2;
+//       required string message = 3;
+//     };
+//   }
+// That is the MessageSet wire format. We can't use a proto to generate these
+// because that would introduce a circular dependency between it and this package.
+//
+// When a proto1 proto has a field that looks like:
+//   optional message<MessageSet> info = 3;
+// the protocol compiler produces a field in the generated struct that looks like:
+//   Info *_proto_.MessageSet  "PB(bytes,3,opt,name=info)"
+// The package is automatically inserted so there is no need for that proto file to
+// import this package.
+
+type _MessageSet_Item struct {
+	TypeId  *int32 "PB(varint,2,req,name=type_id)"
+	Message []byte "PB(bytes,3,req,name=message)"
+}
+
+type MessageSet struct {
+	Item             []*_MessageSet_Item "PB(group,1,rep)"
+	XXX_unrecognized *bytes.Buffer
+	// TODO: caching?
+}
+
+// messageTypeIder is an interface satisfied by a protocol buffer type
+// that may be stored in a MessageSet.
+type messageTypeIder interface {
+	MessageTypeId() int32
+}
+
+func (ms *MessageSet) find(pb interface{}) *_MessageSet_Item {
+	mti, ok := pb.(messageTypeIder)
+	if !ok {
+		return nil
+	}
+	id := mti.MessageTypeId()
+	for _, item := range ms.Item {
+		if *item.TypeId == id {
+			return item
+		}
+	}
+	return nil
+}
+
+func (ms *MessageSet) Has(pb interface{}) bool {
+	if ms.find(pb) != nil {
+		return true
+	}
+	return false
+}
+
+func (ms *MessageSet) Unmarshal(pb interface{}) os.Error {
+	if item := ms.find(pb); item != nil {
+		return Unmarshal(item.Message, pb)
+	}
+	if _, ok := pb.(messageTypeIder); !ok {
+		return ErrNoMessageTypeId
+	}
+	return nil // TODO: return error instead?
+}
+
+func (ms *MessageSet) Marshal(pb interface{}) os.Error {
+	msg, err := Marshal(pb)
+	if err != nil {
+		return err
+	}
+	if item := ms.find(pb); item != nil {
+		// reuse existing item
+		item.Message = msg
+		return nil
+	}
+
+	mti, ok := pb.(messageTypeIder)
+	if !ok {
+		return ErrWrongType // TODO: custom error?
+	}
+
+	mtid := mti.MessageTypeId()
+	ms.Item = append(ms.Item, &_MessageSet_Item{
+		TypeId:  &mtid,
+		Message: msg,
+	})
+	return nil
+}
+
+// Support for the message_set_wire_format message option.
+
+func skipVarint(buf []byte) []byte {
+	i := 0
+	for ; buf[i]&0x80 != 0; i++ {
+	}
+	return buf[i+1:]
+}
+
+// MarshalMessageSet encodes the extension map represented by m in the message set wire format.
+// It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option.
+func MarshalMessageSet(m map[int32][]byte) ([]byte, os.Error) {
+	ms := &MessageSet{Item: make([]*_MessageSet_Item, len(m))}
+	i := 0
+	for k, v := range m {
+		// Remove the wire type and field number varint, as well as the length varint.
+		v = skipVarint(skipVarint(v))
+
+		ms.Item[i] = &_MessageSet_Item{
+			TypeId:  Int32(k),
+			Message: v,
+		}
+		i++
+	}
+	return Marshal(ms)
+}
+
+// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
+// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
+func UnmarshalMessageSet(buf []byte, m map[int32][]byte) os.Error {
+	ms := new(MessageSet)
+	if err := Unmarshal(buf, ms); err != nil {
+		return err
+	}
+	for _, item := range ms.Item {
+		// restore wire type and field number varint, plus length varint.
+		b := EncodeVarint(uint64(*item.TypeId)<<3 | WireBytes)
+		b = append(b, EncodeVarint(uint64(len(item.Message)))...)
+		b = append(b, item.Message...)
+
+		m[*item.TypeId] = b
+	}
+	return nil
+}

+ 3 - 1
proto/testdata/test.pb.go

@@ -4,9 +4,11 @@
 package test_proto
 
 import proto "goprotobuf.googlecode.com/hg/proto"
+import "os"
 
-// Reference proto import to suppress error if it's not otherwise used.
+// Reference proto & os imports to suppress error if it's not otherwise used.
 var _ = proto.GetString
+var _ os.Error
 
 type FOO int32
 const (

+ 1 - 1
proto/text.go

@@ -33,7 +33,7 @@ package proto
 
 // Functions for writing the Text protocol buffer format.
 // TODO:
-//	- groups.
+//	- Message sets, groups.
 
 import (
 	"bytes"

+ 1 - 1
proto/text_parser.go

@@ -33,7 +33,7 @@ package proto
 
 // Functions for parsing the Text protocol buffer format.
 // TODO:
-//     - groups.
+//     - message sets, groups.
 
 import (
 	"fmt"