Ver Fonte

internal/impl: support legacy unknown fields

Add wrapper data structures to get legacy XXX_unrecognized fields to support
the new protoreflect.UnknownFields interface. This is a challenge since the
field is a []byte, which does not give us much flexibility to work with
in terms of choice of data structures.

This implementation is relatively naive where every operation is O(n) since
it needs to strip through the entire []byte each time. The Range operation
operates slightly differently from ranging over Go maps since it presents a
stale version of RawFields should a mutation occur while ranging.
This distinction is unlikely to affect anyone in practice.

Change-Id: Ib3247cb827f9a0dd6c2192cd59830dca5eef8257
Reviewed-on: https://go-review.googlesource.com/c/144697
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai há 7 anos atrás
pai
commit
e2afdc27e7

+ 123 - 1
internal/impl/legacy_test.go

@@ -5,9 +5,12 @@
 package impl
 
 import (
+	"bytes"
+	"math"
 	"reflect"
 	"testing"
 
+	"github.com/golang/protobuf/v2/internal/encoding/pack"
 	"github.com/golang/protobuf/v2/internal/pragma"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
 	ptype "github.com/golang/protobuf/v2/reflect/prototype"
@@ -25,7 +28,7 @@ func mustLoadFileDesc(b []byte) pref.FileDescriptor {
 var fileDescLP2 = mustLoadFileDesc(LP2FileDescriptor)
 var fileDescLP3 = mustLoadFileDesc(LP3FileDescriptor)
 
-func TestLegacy(t *testing.T) {
+func TestLegacyDescriptor(t *testing.T) {
 	tests := []struct {
 		got  pref.Descriptor
 		want pref.Descriptor
@@ -133,3 +136,122 @@ func TestLegacy(t *testing.T) {
 		})
 	}
 }
+
+func TestLegacyUnknown(t *testing.T) {
+	rawOf := func(toks ...pack.Token) pref.RawFields {
+		return pref.RawFields(pack.Message(toks).Marshal())
+	}
+	raw1a := rawOf(pack.Tag{1, pack.VarintType}, pack.Svarint(-4321))                // 08c143
+	raw1b := rawOf(pack.Tag{1, pack.Fixed32Type}, pack.Uint32(0xdeadbeef))           // 0defbeadde
+	raw1c := rawOf(pack.Tag{1, pack.Fixed64Type}, pack.Float64(math.Pi))             // 09182d4454fb210940
+	raw2a := rawOf(pack.Tag{2, pack.BytesType}, pack.String("hello, world!"))        // 120d68656c6c6f2c20776f726c6421
+	raw2b := rawOf(pack.Tag{2, pack.VarintType}, pack.Uvarint(1234))                 // 10d209
+	raw3a := rawOf(pack.Tag{3, pack.StartGroupType}, pack.Tag{3, pack.EndGroupType}) // 1b1c
+	raw3b := rawOf(pack.Tag{3, pack.BytesType}, pack.Bytes("\xde\xad\xbe\xef"))      // 1a04deadbeef
+
+	joinRaw := func(bs ...pref.RawFields) (out []byte) {
+		for _, b := range bs {
+			out = append(out, b...)
+		}
+		return out
+	}
+
+	var fs legacyUnknownBytes
+	if got, want := fs.Len(), 0; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := []byte(fs), joinRaw(); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+
+	fs.Set(1, raw1a)
+	fs.Set(1, append(fs.Get(1), raw1b...))
+	fs.Set(1, append(fs.Get(1), raw1c...))
+	if got, want := fs.Len(), 1; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := []byte(fs), joinRaw(raw1a, raw1b, raw1c); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+
+	fs.Set(2, raw2a)
+	if got, want := fs.Len(), 2; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := []byte(fs), joinRaw(raw1a, raw1b, raw1c, raw2a); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+
+	if got, want := fs.Get(1), joinRaw(raw1a, raw1b, raw1c); !bytes.Equal(got, want) {
+		t.Errorf("Get(%d) = %x, want %x", 1, got, want)
+	}
+	if got, want := fs.Get(2), joinRaw(raw2a); !bytes.Equal(got, want) {
+		t.Errorf("Get(%d) = %x, want %x", 2, got, want)
+	}
+	if got, want := fs.Get(3), joinRaw(); !bytes.Equal(got, want) {
+		t.Errorf("Get(%d) = %x, want %x", 3, got, want)
+	}
+
+	fs.Set(1, nil) // remove field 1
+	if got, want := fs.Len(), 1; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := []byte(fs), joinRaw(raw2a); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+
+	// Simulate manual appending of raw field data.
+	fs = append(fs, joinRaw(raw3a, raw1a, raw1b, raw3b, raw2b, raw1c)...)
+	if got, want := fs.Len(), 3; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+
+	// Verify range iteration order.
+	var i int
+	want := []struct {
+		num pref.FieldNumber
+		raw pref.RawFields
+	}{
+		{3, joinRaw(raw3a, raw3b)},
+		{2, joinRaw(raw2a, raw2b)},
+		{1, joinRaw(raw1a, raw1b, raw1c)},
+	}
+	fs.Range(func(num pref.FieldNumber, raw pref.RawFields) bool {
+		if i < len(want) {
+			if num != want[i].num || !bytes.Equal(raw, want[i].raw) {
+				t.Errorf("Range(%d) = (%d, %x), want (%d, %x)", i, num, raw, want[i].num, want[i].raw)
+			}
+		} else {
+			t.Errorf("unexpected Range iteration: %d", i)
+		}
+		i++
+		return true
+	})
+
+	fs.Set(2, fs.Get(2)) // moves field 2 to the end
+	if got, want := fs.Len(), 3; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := []byte(fs), joinRaw(raw3a, raw1a, raw1b, raw3b, raw1c, raw2a, raw2b); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+	fs.Set(1, nil) // remove field 1
+	if got, want := fs.Len(), 2; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := []byte(fs), joinRaw(raw3a, raw3b, raw2a, raw2b); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+
+	// Remove all fields.
+	fs.Range(func(n pref.FieldNumber, b pref.RawFields) bool {
+		fs.Set(n, nil)
+		return true
+	})
+	if got, want := fs.Len(), 0; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := []byte(fs), joinRaw(); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+}

+ 142 - 0
internal/impl/legacy_unknown.go

@@ -0,0 +1,142 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package impl
+
+import (
+	"reflect"
+
+	protoV1 "github.com/golang/protobuf/proto"
+	"github.com/golang/protobuf/v2/internal/encoding/wire"
+	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+var (
+	extTypeA = reflect.TypeOf(map[int32]protoV1.Extension(nil))
+	extTypeB = reflect.TypeOf(protoV1.XXX_InternalExtensions{})
+)
+
+func generateLegacyUnknownFieldFuncs(t reflect.Type, md pref.MessageDescriptor) func(p *messageDataType) pref.UnknownFields {
+	fu, ok := t.FieldByName("XXX_unrecognized")
+	if !ok || fu.Type != bytesType {
+		return nil
+	}
+	fx1, _ := t.FieldByName("XXX_extensions")
+	fx2, _ := t.FieldByName("XXX_InternalExtensions")
+	if fx1.Type == extTypeA || fx2.Type == extTypeB {
+		// TODO: In proto v1, the unknown fields are split between both
+		// XXX_unrecognized and XXX_InternalExtensions. If the message supports
+		// extensions, then we will need to create a wrapper data structure
+		// that presents unknown fields in both lists as a single ordered list.
+		panic("not implemented")
+	}
+	fieldOffset := offsetOf(fu)
+	return func(p *messageDataType) pref.UnknownFields {
+		rv := p.p.apply(fieldOffset).asType(bytesType)
+		return (*legacyUnknownBytes)(rv.Interface().(*[]byte))
+	}
+}
+
+// legacyUnknownBytes is a wrapper around XXX_unrecognized that implements
+// the protoreflect.UnknownFields interface. This is challenging since we are
+// limited to a []byte, so we do not have much flexibility in the choice
+// of data structure that would have been ideal.
+type legacyUnknownBytes []byte
+
+func (fs *legacyUnknownBytes) Len() int {
+	// Runtime complexity: O(n)
+	b := *fs
+	m := map[pref.FieldNumber]bool{}
+	for len(b) > 0 {
+		num, _, n := wire.ConsumeField(b)
+		m[num] = true
+		b = b[n:]
+	}
+	return len(m)
+}
+
+func (fs *legacyUnknownBytes) Get(num pref.FieldNumber) (raw pref.RawFields) {
+	// Runtime complexity: O(n)
+	b := *fs
+	for len(b) > 0 {
+		num2, _, n := wire.ConsumeField(b)
+		if num == num2 {
+			raw = append(raw, b[:n]...)
+		}
+		b = b[n:]
+	}
+	return raw
+}
+
+func (fs *legacyUnknownBytes) Set(num pref.FieldNumber, raw pref.RawFields) {
+	num2, _, _ := wire.ConsumeTag(raw)
+	if len(raw) > 0 && (!raw.IsValid() || num != num2) {
+		panic("invalid raw fields")
+	}
+
+	// Remove all current fields of num.
+	// Runtime complexity: O(n)
+	b := *fs
+	out := (*fs)[:0]
+	for len(b) > 0 {
+		num2, _, n := wire.ConsumeField(b)
+		if num != num2 {
+			out = append(out, b[:n]...)
+		}
+		b = b[n:]
+	}
+	*fs = out
+
+	// Append new fields of num.
+	*fs = append(*fs, raw...)
+}
+
+func (fs *legacyUnknownBytes) Range(f func(pref.FieldNumber, pref.RawFields) bool) {
+	type entry struct {
+		num pref.FieldNumber
+		raw pref.RawFields
+	}
+	var xs []entry
+
+	// Collect up a list of all the raw fields.
+	// We preserve the order such that the latest encountered fields
+	// are presented at the end.
+	//
+	// Runtime complexity: O(n)
+	b := *fs
+	m := map[pref.FieldNumber]int{}
+	for len(b) > 0 {
+		num, _, n := wire.ConsumeField(b)
+
+		// Ensure the most recently updated entry is always at the end of xs.
+		x := entry{num: num}
+		if i, ok := m[num]; ok {
+			j := len(xs) - 1
+			xs[i], xs[j] = xs[j], xs[i] // swap current entry with last entry
+			m[xs[i].num] = i            // update index of swapped entry
+			x = xs[j]                   // retrieve the last entry
+			xs = xs[:j]                 // truncate off the last entry
+		}
+		m[num] = len(xs)
+		x.raw = append(x.raw, b[:n]...)
+		xs = append(xs, x)
+
+		b = b[n:]
+	}
+
+	// Iterate over all the raw fields.
+	// This ranges over a snapshot of the current state such that mutations
+	// while ranging are not observable.
+	//
+	// Runtime complexity: O(n)
+	for _, x := range xs {
+		if !f(x.num, x.raw) {
+			return
+		}
+	}
+}
+
+func (fs *legacyUnknownBytes) IsSupported() bool {
+	return true
+}

+ 4 - 1
internal/impl/message.go

@@ -163,7 +163,10 @@ fieldLoop:
 }
 
 func (mi *MessageType) generateUnknownFieldFuncs(t reflect.Type, md pref.MessageDescriptor) {
-	// TODO
+	if f := generateLegacyUnknownFieldFuncs(t, md); f != nil {
+		mi.unknownFields = f
+		return
+	}
 	mi.unknownFields = func(*messageDataType) pref.UnknownFields {
 		return emptyUnknownFields{}
 	}

+ 7 - 4
reflect/protoreflect/value.go

@@ -159,17 +159,20 @@ type UnknownFields interface {
 // and also the wire data itself.
 //
 // Once stored, the content of a RawFields must be treated as immutable.
-// (e.g., raw[:len(raw)] is immutable, but raw[len(raw):cap(raw)] is mutable).
-// Thus, appending to RawFields (with valid wire data) is permitted.
+// The capacity of RawFields may be treated as mutable only for the use-case of
+// appending additional data to store back into UnknownFields.
 type RawFields []byte
 
 // IsValid reports whether RawFields is syntactically correct wire format.
+// All fields must belong to the same field number.
 func (b RawFields) IsValid() bool {
+	var want FieldNumber
 	for len(b) > 0 {
-		_, _, n := wire.ConsumeField(b)
-		if n < 0 {
+		got, _, n := wire.ConsumeField(b)
+		if n < 0 || (want > 0 && got != want) {
 			return false
 		}
+		want = got
 		b = b[n:]
 	}
 	return true