Browse Source

goprotobuf: Optimize tag lookups and stype handling.

R=r, adg
CC=golang-dev
http://codereview.appspot.com/6453158
David Symonds 13 years ago
parent
commit
6a6f82cf8c
5 changed files with 65 additions and 32 deletions
  1. 1 0
      CONTRIBUTORS
  2. 9 13
      proto/decode.go
  3. 6 11
      proto/encode.go
  4. 5 1
      proto/lib.go
  5. 44 7
      proto/properties.go

+ 1 - 0
CONTRIBUTORS

@@ -8,6 +8,7 @@
 
 Dave Cheney <dave@cheney.net>
 David Symonds <dsymonds@golang.org>
+Jonathan Hseu <jhseu@google.com>
 Ken Thompson <ken@golang.org>
 Mikkel Krautz <mikkel@krautz.dk> <krautz@gmail.com>
 Nigel Tao <nigeltao@golang.org>

+ 9 - 13
proto/decode.go

@@ -333,7 +333,7 @@ func (p *Buffer) Unmarshal(pb Message) error {
 		return err
 	}
 
-	err = p.unmarshalType(typ, GetProperties(typ.Elem()), false, base)
+	err = p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), false, base)
 
 	if collectStats {
 		stats.Decode++
@@ -343,8 +343,7 @@ func (p *Buffer) Unmarshal(pb Message) error {
 }
 
 // unmarshalType does the work of unmarshaling a structure.
-func (o *Buffer) unmarshalType(t reflect.Type, prop *StructProperties, is_group bool, base uintptr) error {
-	st := t.Elem()
+func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group bool, base uintptr) error {
 	required, reqFields := prop.reqCount, uint64(0)
 
 	var err error
@@ -366,7 +365,7 @@ func (o *Buffer) unmarshalType(t reflect.Type, prop *StructProperties, is_group
 		if tag <= 0 {
 			return fmt.Errorf("proto: illegal tag %d", tag)
 		}
-		fieldnum, ok := prop.tags[tag]
+		fieldnum, ok := prop.tags.get(tag)
 		if !ok {
 			// Maybe it's an extension?
 			iv := reflect.NewAt(st, unsafe.Pointer(base)).Interface()
@@ -384,7 +383,7 @@ func (o *Buffer) unmarshalType(t reflect.Type, prop *StructProperties, is_group
 		p := prop.Prop[fieldnum]
 
 		if p.dec == nil {
-			fmt.Fprintf(os.Stderr, "proto: no protobuf decoder for %s.%s\n", t, st.Field(fieldnum).Name)
+			fmt.Fprintf(os.Stderr, "proto: no protobuf decoder for %s.%s\n", st, st.Field(fieldnum).Name)
 			continue
 		}
 		dec := p.dec
@@ -650,8 +649,7 @@ func (o *Buffer) dec_slice_slice_byte(p *Properties, base uintptr) error {
 // Decode a group.
 func (o *Buffer) dec_struct_group(p *Properties, base uintptr) error {
 	ptr := (**struct{})(unsafe.Pointer(base + p.offset))
-	typ := p.stype.Elem()
-	bas := reflect.New(typ).Pointer()
+	bas := reflect.New(p.stype).Pointer()
 	structv := unsafe.Pointer(bas)
 	*ptr = (*struct{})(structv)
 
@@ -668,14 +666,13 @@ func (o *Buffer) dec_struct_message(p *Properties, base uintptr) (err error) {
 	}
 
 	ptr := (**struct{})(unsafe.Pointer(base + p.offset))
-	typ := p.stype.Elem()
-	bas := reflect.New(typ).Pointer()
+	bas := reflect.New(p.stype).Pointer()
 	structp := unsafe.Pointer(bas)
 	*ptr = (*struct{})(structp)
 
 	// If the object can unmarshal itself, let it.
 	if p.isMarshaler {
-		iv := reflect.NewAt(p.stype.Elem(), structp).Interface()
+		iv := reflect.NewAt(p.stype, structp).Interface()
 		return iv.(Unmarshaler).Unmarshal(raw)
 	}
 
@@ -707,8 +704,7 @@ func (o *Buffer) dec_slice_struct(p *Properties, is_group bool, base uintptr) er
 	v := (*[]*struct{})(unsafe.Pointer(base + p.offset))
 	y := *v
 
-	typ := p.stype.Elem()
-	bas := reflect.New(typ).Pointer()
+	bas := reflect.New(p.stype).Pointer()
 	structp := unsafe.Pointer(bas)
 	y = append(y, (*struct{})(structp))
 	*v = y
@@ -725,7 +721,7 @@ func (o *Buffer) dec_slice_struct(p *Properties, is_group bool, base uintptr) er
 
 	// If the object can unmarshal itself, let it.
 	if p.isUnmarshaler {
-		iv := reflect.NewAt(typ, structp).Interface()
+		iv := reflect.NewAt(p.stype, structp).Interface()
 		return iv.(Unmarshaler).Unmarshal(raw)
 	}
 

+ 6 - 11
proto/encode.go

@@ -276,11 +276,9 @@ func (o *Buffer) enc_struct_message(p *Properties, base uintptr) error {
 		return ErrNil
 	}
 
-	typ := p.stype.Elem()
-
 	// Can the object marshal itself?
 	if p.isMarshaler {
-		m := reflect.NewAt(typ, structp).Interface().(Marshaler)
+		m := reflect.NewAt(p.stype, structp).Interface().(Marshaler)
 		data, err := m.Marshal()
 		if err != nil {
 			return err
@@ -295,7 +293,7 @@ func (o *Buffer) enc_struct_message(p *Properties, base uintptr) error {
 	obuf := o.buf
 	o.buf = o.bufalloc()
 
-	err := o.enc_struct(typ, p.sprop, uintptr(structp))
+	err := o.enc_struct(p.stype, p.sprop, uintptr(structp))
 
 	nbuf := o.buf
 	o.buf = obuf
@@ -318,8 +316,7 @@ func (o *Buffer) enc_struct_group(p *Properties, base uintptr) error {
 
 	o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup))
 	b := uintptr(unsafe.Pointer(v))
-	typ := p.stype.Elem()
-	err := o.enc_struct(typ, p.sprop, b)
+	err := o.enc_struct(p.stype, p.sprop, b)
 	if err != nil {
 		return err
 	}
@@ -472,7 +469,6 @@ func (o *Buffer) enc_slice_string(p *Properties, base uintptr) error {
 func (o *Buffer) enc_slice_struct_message(p *Properties, base uintptr) error {
 	s := *(*[]unsafe.Pointer)(unsafe.Pointer(base + p.offset))
 	l := len(s)
-	typ := p.stype.Elem()
 
 	for i := 0; i < l; i++ {
 		structp := s[i]
@@ -482,7 +478,7 @@ func (o *Buffer) enc_slice_struct_message(p *Properties, base uintptr) error {
 
 		// Can the object marshal itself?
 		if p.isMarshaler {
-			m := reflect.NewAt(typ, structp).Interface().(Marshaler)
+			m := reflect.NewAt(p.stype, structp).Interface().(Marshaler)
 			data, err := m.Marshal()
 			if err != nil {
 				return err
@@ -495,7 +491,7 @@ func (o *Buffer) enc_slice_struct_message(p *Properties, base uintptr) error {
 		obuf := o.buf
 		o.buf = o.bufalloc()
 
-		err := o.enc_struct(typ, p.sprop, uintptr(structp))
+		err := o.enc_struct(p.stype, p.sprop, uintptr(structp))
 
 		nbuf := o.buf
 		o.buf = obuf
@@ -518,7 +514,6 @@ func (o *Buffer) enc_slice_struct_message(p *Properties, base uintptr) error {
 func (o *Buffer) enc_slice_struct_group(p *Properties, base uintptr) error {
 	s := *(*[]*struct{})(unsafe.Pointer(base + p.offset))
 	l := len(s)
-	typ := p.stype.Elem()
 
 	for i := 0; i < l; i++ {
 		v := s[i]
@@ -529,7 +524,7 @@ func (o *Buffer) enc_slice_struct_group(p *Properties, base uintptr) error {
 		o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup))
 
 		b := uintptr(unsafe.Pointer(v))
-		err := o.enc_struct(typ, p.sprop, b)
+		err := o.enc_struct(p.stype, p.sprop, b)
 
 		if err != nil {
 			if err == ErrNil {

+ 5 - 1
proto/lib.go

@@ -690,7 +690,11 @@ type scalarField struct {
 func buildDefaultMessage(t reflect.Type) (dm defaultMessage) {
 	sprop := GetProperties(t)
 	for _, prop := range sprop.Prop {
-		fi := sprop.tags[prop.Tag]
+		fi, ok := sprop.tags.get(prop.Tag)
+		if !ok {
+			// XXX_unrecognized
+			continue
+		}
 		ft := t.Field(fi).Type
 
 		// nested messages

+ 44 - 7
proto/properties.go

@@ -75,11 +75,49 @@ type decoder func(p *Buffer, prop *Properties, base uintptr) error
 // A valueDecoder decodes a single integer in a particular encoding.
 type valueDecoder func(o *Buffer) (x uint64, err error)
 
+// tagMap is an optimization over map[int]int for typical protocol buffer
+// use-cases. Encoded protocol buffers are often in tag order with small tag
+// numbers.
+type tagMap struct {
+	fastTags []int
+	slowTags map[int]int
+}
+
+// tagMapFastLimit is the upper bound on the tag number that will be stored in
+// the tagMap slice rather than its map.
+const tagMapFastLimit = 1024
+
+func (p *tagMap) get(t int) (int, bool) {
+	if t > 0 && t < tagMapFastLimit {
+		if t >= len(p.fastTags) {
+			return 0, false
+		}
+		fi := p.fastTags[t]
+		return fi, fi >= 0
+	}
+	fi, ok := p.slowTags[t]
+	return fi, ok
+}
+
+func (p *tagMap) put(t int, fi int) {
+	if t > 0 && t < tagMapFastLimit {
+		for len(p.fastTags) < t+1 {
+			p.fastTags = append(p.fastTags, -1)
+		}
+		p.fastTags[t] = fi
+		return
+	}
+	if p.slowTags == nil {
+		p.slowTags = make(map[int]int)
+	}
+	p.slowTags[t] = fi
+}
+
 // StructProperties represents properties for all the fields of a struct.
 type StructProperties struct {
 	Prop      []*Properties  // properties for each field
 	reqCount  int            // required count
-	tags      map[int]int    // map from proto tag to struct field number
+	tags      tagMap         // map from proto tag to struct field number
 	origNames map[string]int // map from original name to struct field number
 	order     []int          // list of struct field numbers in tag order
 
@@ -266,7 +304,7 @@ func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) {
 			p.enc = (*Buffer).enc_string
 			p.dec = (*Buffer).dec_string
 		case reflect.Struct:
-			p.stype = t1
+			p.stype = t1.Elem()
 			p.isMarshaler = isMarshaler(t1)
 			p.isUnmarshaler = isUnmarshaler(t1)
 			if p.Wire == "bytes" {
@@ -351,7 +389,7 @@ func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) {
 				fmt.Fprintf(os.Stderr, "proto: no ptr oenc for %T -> %T -> %T\n", t1, t2, t3)
 				break
 			case reflect.Struct:
-				p.stype = t2
+				p.stype = t2.Elem()
 				p.isMarshaler = isMarshaler(t2)
 				p.isUnmarshaler = isUnmarshaler(t2)
 				p.enc = (*Buffer).enc_slice_struct_group
@@ -389,9 +427,9 @@ func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) {
 
 	if p.stype != nil {
 		if lockGetProp {
-			p.sprop = GetProperties(p.stype.Elem())
+			p.sprop = GetProperties(p.stype)
 		} else {
-			p.sprop = getPropertiesLocked(p.stype.Elem())
+			p.sprop = getPropertiesLocked(p.stype)
 		}
 	}
 }
@@ -504,7 +542,6 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
 	// build required counts
 	// build tags
 	reqCount := 0
-	prop.tags = make(map[int]int)
 	prop.origNames = make(map[string]int)
 	for i, p := range prop.Prop {
 		if strings.HasPrefix(p.Name, "XXX_") {
@@ -515,7 +552,7 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
 		if p.Required {
 			reqCount++
 		}
-		prop.tags[p.Tag] = i
+		prop.tags.put(p.Tag, i)
 		prop.origNames[p.OrigName] = i
 	}
 	prop.reqCount = reqCount