Browse Source

jsonpb: add support for custom resolution of Any messages to/from JSON (#410)

Introduce type AnyResolver with func Resolve that takes in Any's type_url value and returns corresponding proto.Message.  Add AnyResolver field to both Marshaler and Unmarshaler for use when Any is encountered.
Joshua Humphries 8 years ago
parent
commit
5afd06f9d8
2 changed files with 107 additions and 15 deletions
  1. 48 15
      jsonpb/jsonpb.go
  2. 59 0
      jsonpb/jsonpb_test.go

+ 48 - 15
jsonpb/jsonpb.go

@@ -73,6 +73,31 @@ type Marshaler struct {
 
 
 	// Whether to use the original (.proto) name for fields.
 	// Whether to use the original (.proto) name for fields.
 	OrigName bool
 	OrigName bool
+
+	// A custom URL resolver to use when marshaling Any messages to JSON.
+	// If unset, the default resolution strategy is to extract the
+	// fully-qualified type name from the type URL and pass that to
+	// proto.MessageType(string).
+	AnyResolver AnyResolver
+}
+
+// AnyResolver takes a type URL, present in an Any message, and resolves it into
+// an instance of the associated message.
+type AnyResolver interface {
+	Resolve(typeUrl string) (proto.Message, error)
+}
+
+func defaultResolveAny(typeUrl string) (proto.Message, error) {
+	// Only the part of typeUrl after the last slash is relevant.
+	mname := typeUrl
+	if slash := strings.LastIndex(mname, "/"); slash >= 0 {
+		mname = mname[slash+1:]
+	}
+	mt := proto.MessageType(mname)
+	if mt == nil {
+		return nil, fmt.Errorf("unknown message type %q", mname)
+	}
+	return reflect.New(mt.Elem()).Interface().(proto.Message), nil
 }
 }
 
 
 // JSONPBMarshaler is implemented by protobuf messages that customize the
 // JSONPBMarshaler is implemented by protobuf messages that customize the
@@ -344,16 +369,17 @@ func (m *Marshaler) marshalAny(out *errWriter, any proto.Message, indent string)
 	turl := v.Field(0).String()
 	turl := v.Field(0).String()
 	val := v.Field(1).Bytes()
 	val := v.Field(1).Bytes()
 
 
-	// Only the part of type_url after the last slash is relevant.
-	mname := turl
-	if slash := strings.LastIndex(mname, "/"); slash >= 0 {
-		mname = mname[slash+1:]
+	var msg proto.Message
+	var err error
+	if m.AnyResolver != nil {
+		msg, err = m.AnyResolver.Resolve(turl)
+	} else {
+		msg, err = defaultResolveAny(turl)
 	}
 	}
-	mt := proto.MessageType(mname)
-	if mt == nil {
-		return fmt.Errorf("unknown message type %q", mname)
+	if err != nil {
+		return err
 	}
 	}
-	msg := reflect.New(mt.Elem()).Interface().(proto.Message)
+
 	if err := proto.Unmarshal(val, msg); err != nil {
 	if err := proto.Unmarshal(val, msg); err != nil {
 		return err
 		return err
 	}
 	}
@@ -590,6 +616,12 @@ type Unmarshaler struct {
 	// Whether to allow messages to contain unknown fields, as opposed to
 	// Whether to allow messages to contain unknown fields, as opposed to
 	// failing to unmarshal.
 	// failing to unmarshal.
 	AllowUnknownFields bool
 	AllowUnknownFields bool
+
+	// A custom URL resolver to use when unmarshaling Any messages from JSON.
+	// If unset, the default resolution strategy is to extract the
+	// fully-qualified type name from the type URL and pass that to
+	// proto.MessageType(string).
+	AnyResolver AnyResolver
 }
 }
 
 
 // UnmarshalNext unmarshals the next protocol buffer from a JSON object stream.
 // UnmarshalNext unmarshals the next protocol buffer from a JSON object stream.
@@ -679,16 +711,17 @@ func (u *Unmarshaler) unmarshalValue(target reflect.Value, inputValue json.RawMe
 			}
 			}
 			target.Field(0).SetString(turl)
 			target.Field(0).SetString(turl)
 
 
-			mname := turl
-			if slash := strings.LastIndex(mname, "/"); slash >= 0 {
-				mname = mname[slash+1:]
+			var m proto.Message
+			var err error
+			if u.AnyResolver != nil {
+				m, err = u.AnyResolver.Resolve(turl)
+			} else {
+				m, err = defaultResolveAny(turl)
 			}
 			}
-			mt := proto.MessageType(mname)
-			if mt == nil {
-				return fmt.Errorf("unknown message type %q", mname)
+			if err != nil {
+				return err
 			}
 			}
 
 
-			m := reflect.New(mt.Elem()).Interface().(proto.Message)
 			if _, ok := m.(wkt); ok {
 			if _, ok := m.(wkt); ok {
 				val, ok := jsonFields["value"]
 				val, ok := jsonFields["value"]
 				if !ok {
 				if !ok {

+ 59 - 0
jsonpb/jsonpb_test.go

@@ -721,6 +721,65 @@ func TestUnmarshalingBadInput(t *testing.T) {
 	}
 	}
 }
 }
 
 
+type funcResolver func(turl string) (proto.Message, error)
+
+func (fn funcResolver) Resolve(turl string) (proto.Message, error) {
+	return fn(turl)
+}
+
+func TestAnyWithCustomResolver(t *testing.T) {
+	var resolvedTypeUrls []string
+	resolver := funcResolver(func(turl string) (proto.Message, error) {
+		resolvedTypeUrls = append(resolvedTypeUrls, turl)
+		return new(pb.Simple), nil
+	})
+	msg := &pb.Simple{
+		OBytes:  []byte{1, 2, 3, 4},
+		OBool:   proto.Bool(true),
+		OString: proto.String("foobar"),
+		OInt64:  proto.Int64(1020304),
+	}
+	msgBytes, err := proto.Marshal(msg)
+	if err != nil {
+		t.Errorf("an unexpected error occurred when marshaling message: %v", err)
+	}
+	// make an Any with a type URL that won't resolve w/out custom resolver
+	any := &anypb.Any{
+		TypeUrl: "https://foobar.com/some.random.MessageKind",
+		Value:   msgBytes,
+	}
+
+	m := Marshaler{AnyResolver: resolver}
+	js, err := m.MarshalToString(any)
+	if err != nil {
+		t.Errorf("an unexpected error occurred when marshaling any to JSON: %v", err)
+	}
+	if len(resolvedTypeUrls) != 1 {
+		t.Errorf("custom resolver was not invoked during marshaling")
+	} else if resolvedTypeUrls[0] != "https://foobar.com/some.random.MessageKind" {
+		t.Errorf("custom resolver was invoked with wrong URL: got %q, wanted %q", resolvedTypeUrls[0], "https://foobar.com/some.random.MessageKind")
+	}
+	wanted := `{"@type":"https://foobar.com/some.random.MessageKind","oBool":true,"oInt64":"1020304","oString":"foobar","oBytes":"AQIDBA=="}`
+	if js != wanted {
+		t.Errorf("marshalling JSON produced incorrect output: got %s, wanted %s", js, wanted)
+	}
+
+	u := Unmarshaler{AnyResolver: resolver}
+	roundTrip := &anypb.Any{}
+	err = u.Unmarshal(bytes.NewReader([]byte(js)), roundTrip)
+	if err != nil {
+		t.Errorf("an unexpected error occurred when unmarshaling any from JSON: %v", err)
+	}
+	if len(resolvedTypeUrls) != 2 {
+		t.Errorf("custom resolver was not invoked during marshaling")
+	} else if resolvedTypeUrls[1] != "https://foobar.com/some.random.MessageKind" {
+		t.Errorf("custom resolver was invoked with wrong URL: got %q, wanted %q", resolvedTypeUrls[1], "https://foobar.com/some.random.MessageKind")
+	}
+	if !proto.Equal(any, roundTrip) {
+		t.Errorf("message contents not set correctly after unmarshalling JSON: got %s, wanted %s", roundTrip, any)
+	}
+}
+
 func TestUnmarshalJSONPBUnmarshaler(t *testing.T) {
 func TestUnmarshalJSONPBUnmarshaler(t *testing.T) {
 	rawJson := `{ "foo": "bar", "baz": [0, 1, 2, 3] }`
 	rawJson := `{ "foo": "bar", "baz": [0, 1, 2, 3] }`
 	var msg dynamicMessage
 	var msg dynamicMessage