|
@@ -166,7 +166,10 @@ package proto
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
"fmt"
|
|
"fmt"
|
|
|
|
|
+ "log"
|
|
|
|
|
+ "reflect"
|
|
|
"strconv"
|
|
"strconv"
|
|
|
|
|
+ "sync"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
// Stats records allocation details about the protocol buffer encoders
|
|
// Stats records allocation details about the protocol buffer encoders
|
|
@@ -535,3 +538,233 @@ out:
|
|
|
o.buf = obuf
|
|
o.buf = obuf
|
|
|
o.index = index
|
|
o.index = index
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+// SetDefaults sets unset protocol buffer fields to their default values.
|
|
|
|
|
+// It only modifies fields that are both unset and have defined defaults.
|
|
|
|
|
+// It recursively sets default values in any non-nil sub-messages.
|
|
|
|
|
+func SetDefaults(pb interface{}) {
|
|
|
|
|
+ v := reflect.ValueOf(pb)
|
|
|
|
|
+ if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
|
|
|
|
|
+ log.Printf("proto: hit non-pointer-to-struct %v", v)
|
|
|
|
|
+ }
|
|
|
|
|
+ setDefaults(v, true, false)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// v is a pointer to a struct.
|
|
|
|
|
+func setDefaults(v reflect.Value, recur, zeros bool) {
|
|
|
|
|
+ v = v.Elem()
|
|
|
|
|
+
|
|
|
|
|
+ defaultMu.Lock()
|
|
|
|
|
+ dm, ok := defaults[v.Type()]
|
|
|
|
|
+ defaultMu.Unlock()
|
|
|
|
|
+ if !ok {
|
|
|
|
|
+ dm = buildDefaultMessage(v.Type())
|
|
|
|
|
+ defaultMu.Lock()
|
|
|
|
|
+ defaults[v.Type()] = dm
|
|
|
|
|
+ defaultMu.Unlock()
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for _, sf := range dm.scalars {
|
|
|
|
|
+ f := v.Field(sf.index)
|
|
|
|
|
+ if !f.IsNil() {
|
|
|
|
|
+ // field already set
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+ dv := sf.value
|
|
|
|
|
+ if dv == nil && !zeros {
|
|
|
|
|
+ // no explicit default, and don't want to set zeros
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+ fptr := f.Addr().Interface() // **T
|
|
|
|
|
+ // TODO: Consider batching the allocations we do here.
|
|
|
|
|
+ switch sf.kind {
|
|
|
|
|
+ case reflect.Bool:
|
|
|
|
|
+ b := new(bool)
|
|
|
|
|
+ if dv != nil {
|
|
|
|
|
+ *b = dv.(bool)
|
|
|
|
|
+ }
|
|
|
|
|
+ *(fptr.(**bool)) = b
|
|
|
|
|
+ case reflect.Float32:
|
|
|
|
|
+ f := new(float32)
|
|
|
|
|
+ if dv != nil {
|
|
|
|
|
+ *f = dv.(float32)
|
|
|
|
|
+ }
|
|
|
|
|
+ *(fptr.(**float32)) = f
|
|
|
|
|
+ case reflect.Float64:
|
|
|
|
|
+ f := new(float64)
|
|
|
|
|
+ if dv != nil {
|
|
|
|
|
+ *f = dv.(float64)
|
|
|
|
|
+ }
|
|
|
|
|
+ *(fptr.(**float64)) = f
|
|
|
|
|
+ case reflect.Int32:
|
|
|
|
|
+ // might be an enum
|
|
|
|
|
+ if ft := f.Type(); ft != int32PtrType {
|
|
|
|
|
+ // enum
|
|
|
|
|
+ f.Set(reflect.New(ft.Elem()))
|
|
|
|
|
+ if dv != nil {
|
|
|
|
|
+ f.Elem().SetInt(int64(dv.(int32)))
|
|
|
|
|
+ }
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // int32 field
|
|
|
|
|
+ i := new(int32)
|
|
|
|
|
+ if dv != nil {
|
|
|
|
|
+ *i = dv.(int32)
|
|
|
|
|
+ }
|
|
|
|
|
+ *(fptr.(**int32)) = i
|
|
|
|
|
+ }
|
|
|
|
|
+ case reflect.Int64:
|
|
|
|
|
+ i := new(int64)
|
|
|
|
|
+ if dv != nil {
|
|
|
|
|
+ *i = dv.(int64)
|
|
|
|
|
+ }
|
|
|
|
|
+ *(fptr.(**int64)) = i
|
|
|
|
|
+ case reflect.String:
|
|
|
|
|
+ s := new(string)
|
|
|
|
|
+ if dv != nil {
|
|
|
|
|
+ *s = dv.(string)
|
|
|
|
|
+ }
|
|
|
|
|
+ *(fptr.(**string)) = s
|
|
|
|
|
+ case reflect.Uint8:
|
|
|
|
|
+ // exceptional case: []byte
|
|
|
|
|
+ var b []byte
|
|
|
|
|
+ if dv != nil {
|
|
|
|
|
+ db := dv.([]byte)
|
|
|
|
|
+ b = make([]byte, len(db))
|
|
|
|
|
+ copy(b, db)
|
|
|
|
|
+ } else {
|
|
|
|
|
+ b = []byte{}
|
|
|
|
|
+ }
|
|
|
|
|
+ *(fptr.(*[]byte)) = b
|
|
|
|
|
+ case reflect.Uint32:
|
|
|
|
|
+ u := new(uint32)
|
|
|
|
|
+ if dv != nil {
|
|
|
|
|
+ *u = dv.(uint32)
|
|
|
|
|
+ }
|
|
|
|
|
+ *(fptr.(**uint32)) = u
|
|
|
|
|
+ case reflect.Uint64:
|
|
|
|
|
+ u := new(uint64)
|
|
|
|
|
+ if dv != nil {
|
|
|
|
|
+ *u = dv.(uint64)
|
|
|
|
|
+ }
|
|
|
|
|
+ *(fptr.(**uint64)) = u
|
|
|
|
|
+ default:
|
|
|
|
|
+ log.Printf("proto: can't set default for field %v (sf.kind=%v)", f, sf.kind)
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for _, ni := range dm.nested {
|
|
|
|
|
+ setDefaults(v.Field(ni), recur, zeros)
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+var (
|
|
|
|
|
+ // defaults maps a protocol buffer struct type to a slice of the fields,
|
|
|
|
|
+ // with its scalar fields set to their proto-declared non-zero default values.
|
|
|
|
|
+ defaultMu sync.Mutex
|
|
|
|
|
+ defaults = make(map[reflect.Type]defaultMessage)
|
|
|
|
|
+
|
|
|
|
|
+ int32PtrType = reflect.TypeOf((*int32)(nil))
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+// defaultMessage represents information about the default values of a message.
|
|
|
|
|
+type defaultMessage struct {
|
|
|
|
|
+ scalars []scalarField
|
|
|
|
|
+ nested []int // struct field index of nested messages
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+type scalarField struct {
|
|
|
|
|
+ index int // struct field index
|
|
|
|
|
+ kind reflect.Kind // element type (the T in *T or []T)
|
|
|
|
|
+ value interface{} // the proto-declared default value, or nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// t is a struct type.
|
|
|
|
|
+func buildDefaultMessage(t reflect.Type) (dm defaultMessage) {
|
|
|
|
|
+ sprop := GetProperties(t)
|
|
|
|
|
+ for _, prop := range sprop.Prop {
|
|
|
|
|
+ fi := sprop.tags[prop.Tag]
|
|
|
|
|
+ ft := t.Field(fi).Type
|
|
|
|
|
+
|
|
|
|
|
+ // nested messages
|
|
|
|
|
+ if ft.Kind() == reflect.Ptr && ft.Elem().Kind() == reflect.Struct {
|
|
|
|
|
+ dm.nested = append(dm.nested, fi)
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ sf := scalarField{
|
|
|
|
|
+ index: fi,
|
|
|
|
|
+ kind: ft.Elem().Kind(),
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // scalar fields without defaults
|
|
|
|
|
+ if prop.Default == "" {
|
|
|
|
|
+ dm.scalars = append(dm.scalars, sf)
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // a scalar field: either *T or []byte
|
|
|
|
|
+ switch ft.Elem().Kind() {
|
|
|
|
|
+ case reflect.Bool:
|
|
|
|
|
+ x, err := strconv.Atob(prop.Default)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ log.Printf("proto: bad default bool %q: %v", prop.Default, err)
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+ sf.value = x
|
|
|
|
|
+ case reflect.Float32:
|
|
|
|
|
+ x, err := strconv.Atof32(prop.Default)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ log.Printf("proto: bad default float32 %q: %v", prop.Default, err)
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+ sf.value = x
|
|
|
|
|
+ case reflect.Float64:
|
|
|
|
|
+ x, err := strconv.Atof64(prop.Default)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ log.Printf("proto: bad default float64 %q: %v", prop.Default, err)
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+ sf.value = x
|
|
|
|
|
+ case reflect.Int32:
|
|
|
|
|
+ x, err := strconv.Atoi64(prop.Default)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ log.Printf("proto: bad default int32 %q: %v", prop.Default, err)
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+ sf.value = int32(x)
|
|
|
|
|
+ case reflect.Int64:
|
|
|
|
|
+ x, err := strconv.Atoi64(prop.Default)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ log.Printf("proto: bad default int64 %q: %v", prop.Default, err)
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+ sf.value = x
|
|
|
|
|
+ case reflect.String:
|
|
|
|
|
+ sf.value = prop.Default
|
|
|
|
|
+ case reflect.Uint8:
|
|
|
|
|
+ // []byte (not *uint8)
|
|
|
|
|
+ sf.value = []byte(prop.Default)
|
|
|
|
|
+ case reflect.Uint32:
|
|
|
|
|
+ x, err := strconv.Atoui64(prop.Default)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ log.Printf("proto: bad default uint32 %q: %v", prop.Default, err)
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+ sf.value = uint32(x)
|
|
|
|
|
+ case reflect.Uint64:
|
|
|
|
|
+ x, err := strconv.Atoui64(prop.Default)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ log.Printf("proto: bad default uint64 %q: %v", prop.Default, err)
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+ sf.value = x
|
|
|
|
|
+ default:
|
|
|
|
|
+ log.Printf("proto: unhandled def kind %v", ft.Elem().Kind())
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ dm.scalars = append(dm.scalars, sf)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return dm
|
|
|
|
|
+}
|