浏览代码

codec: Selfer implementation should not call itself.

This effectively calls itself recursively and leads to an expected stackoverflow.
To fix, check if the variable name == the constant variable name of the type in the
Selfer implementation (x). If it is, then do not attempt to call the Selfer method.

In addition, the following fixes were applied:

When encoding extensions, use an address if the kind is not Struct or Array.
This way, we don't copy unnecessarily.

When encoding based on the known interfaces which may be implemented
i.e. encoding.(Binary|Text)(M|Unm)arshaler, json.(M|Unm)arshaler, codec.Selfer,
the value to encode may not be addressable but the implementation is on a pointer.
In that situation, copy the value to an addressable value and call the implementation there.

Updates #100
Ugorji Nwoke 10 年之前
父节点
当前提交
28f38c8e36
共有 2 个文件被更改,包括 54 次插入35 次删除
  1. 12 3
      codec/encode.go
  2. 42 32
      codec/gen.go

+ 12 - 3
codec/encode.go

@@ -303,8 +303,8 @@ func (f encFnInfo) rawExt(rv reflect.Value) {
 }
 }
 
 
 func (f encFnInfo) ext(rv reflect.Value) {
 func (f encFnInfo) ext(rv reflect.Value) {
-	// if this is a struct and it was addressable, then pass the address directly (not the value)
-	if rv.CanAddr() && rv.Kind() == reflect.Struct {
+	// if this is a struct|array and it was addressable, then pass the address directly (not the value)
+	if k := rv.Kind(); (k == reflect.Struct || k == reflect.Array) && rv.CanAddr() {
 		rv = rv.Addr()
 		rv = rv.Addr()
 	}
 	}
 	f.ee.EncodeExt(rv.Interface(), f.xfTag, f.xfFn, f.e)
 	f.ee.EncodeExt(rv.Interface(), f.xfTag, f.xfFn, f.e)
@@ -314,7 +314,16 @@ func (f encFnInfo) getValueForMarshalInterface(rv reflect.Value, indir int8) (v
 	if indir == 0 {
 	if indir == 0 {
 		v = rv.Interface()
 		v = rv.Interface()
 	} else if indir == -1 {
 	} else if indir == -1 {
-		v = rv.Addr().Interface()
+		// If a non-pointer was passed to Encode(), then that value is not addressable.
+		// Take addr if addresable, else copy value to an addressable value.
+		if rv.CanAddr() {
+			v = rv.Addr().Interface()
+		} else {
+			rv2 := reflect.New(rv.Type())
+			rv2.Elem().Set(rv)
+			v = rv2.Interface()
+			// fmt.Printf("rv.Type: %v, rv2.Type: %v, v: %v\n", rv.Type(), rv2.Type(), v)
+		}
 	} else {
 	} else {
 		for j := int8(0); j < indir; j++ {
 		for j := int8(0); j < indir; j++ {
 			if rv.IsNil() {
 			if rv.IsNil() {

+ 42 - 32
codec/gen.go

@@ -80,8 +80,9 @@ import (
 const GenVersion = 2 // increment this value each time codecgen changes fundamentally.
 const GenVersion = 2 // increment this value each time codecgen changes fundamentally.
 
 
 const (
 const (
-	genCodecPkg   = "codec1978"
-	genTempVarPfx = "yy"
+	genCodecPkg        = "codec1978"
+	genTempVarPfx      = "yy"
+	genTopLevelVarName = "x"
 
 
 	// ignore canBeNil parameter, and always set to true.
 	// ignore canBeNil parameter, and always set to true.
 	// This is because nil can appear anywhere, so we should always check.
 	// This is because nil can appear anywhere, so we should always check.
@@ -289,6 +290,12 @@ func Gen(w io.Writer, buildTags, pkgName string, useUnsafe bool, typ ...reflect.
 	x.line("")
 	x.line("")
 }
 }
 
 
+func (x *genRunner) checkForSelfer(t reflect.Type, varname string) bool {
+	// return varname != genTopLevelVarName && t != x.tc
+	// the only time we checkForSelfer is if we are not at the TOP of the generated code.
+	return varname != genTopLevelVarName
+}
+
 func (x *genRunner) arr2str(t reflect.Type, s string) string {
 func (x *genRunner) arr2str(t reflect.Type, s string) string {
 	if t.Kind() == reflect.Array {
 	if t.Kind() == reflect.Array {
 		return s
 		return s
@@ -450,16 +457,16 @@ func (x *genRunner) selfer(encode bool) {
 	if encode {
 	if encode {
 		x.line(") CodecEncodeSelf(e *" + x.cpfx + "Encoder) {")
 		x.line(") CodecEncodeSelf(e *" + x.cpfx + "Encoder) {")
 		x.genRequiredMethodVars(true)
 		x.genRequiredMethodVars(true)
-		// x.enc("x", t)
-		x.encVar("x", t)
+		// x.enc(genTopLevelVarName, t)
+		x.encVar(genTopLevelVarName, t)
 	} else {
 	} else {
 		x.line(") CodecDecodeSelf(d *" + x.cpfx + "Decoder) {")
 		x.line(") CodecDecodeSelf(d *" + x.cpfx + "Decoder) {")
 		x.genRequiredMethodVars(false)
 		x.genRequiredMethodVars(false)
 		// do not use decVar, as there is no need to check TryDecodeAsNil
 		// do not use decVar, as there is no need to check TryDecodeAsNil
 		// or way to elegantly handle that, and also setting it to a
 		// or way to elegantly handle that, and also setting it to a
 		// non-nil value doesn't affect the pointer passed.
 		// non-nil value doesn't affect the pointer passed.
-		// x.decVar("x", t, false)
-		x.dec("x", t0)
+		// x.decVar(genTopLevelVarName, t, false)
+		x.dec(genTopLevelVarName, t0)
 	}
 	}
 	x.line("}")
 	x.line("}")
 	x.line("")
 	x.line("")
@@ -473,21 +480,21 @@ func (x *genRunner) selfer(encode bool) {
 		x.out(fnSigPfx)
 		x.out(fnSigPfx)
 		x.line(") codecDecodeSelfFromMap(l int, d *" + x.cpfx + "Decoder) {")
 		x.line(") codecDecodeSelfFromMap(l int, d *" + x.cpfx + "Decoder) {")
 		x.genRequiredMethodVars(false)
 		x.genRequiredMethodVars(false)
-		x.decStructMap("x", "l", reflect.ValueOf(t0).Pointer(), t0, 0)
+		x.decStructMap(genTopLevelVarName, "l", reflect.ValueOf(t0).Pointer(), t0, 0)
 		x.line("}")
 		x.line("}")
 		x.line("")
 		x.line("")
 	} else {
 	} else {
 		x.out(fnSigPfx)
 		x.out(fnSigPfx)
 		x.line(") codecDecodeSelfFromMapLenPrefix(l int, d *" + x.cpfx + "Decoder) {")
 		x.line(") codecDecodeSelfFromMapLenPrefix(l int, d *" + x.cpfx + "Decoder) {")
 		x.genRequiredMethodVars(false)
 		x.genRequiredMethodVars(false)
-		x.decStructMap("x", "l", reflect.ValueOf(t0).Pointer(), t0, 1)
+		x.decStructMap(genTopLevelVarName, "l", reflect.ValueOf(t0).Pointer(), t0, 1)
 		x.line("}")
 		x.line("}")
 		x.line("")
 		x.line("")
 
 
 		x.out(fnSigPfx)
 		x.out(fnSigPfx)
 		x.line(") codecDecodeSelfFromMapCheckBreak(l int, d *" + x.cpfx + "Decoder) {")
 		x.line(") codecDecodeSelfFromMapCheckBreak(l int, d *" + x.cpfx + "Decoder) {")
 		x.genRequiredMethodVars(false)
 		x.genRequiredMethodVars(false)
-		x.decStructMap("x", "l", reflect.ValueOf(t0).Pointer(), t0, 2)
+		x.decStructMap(genTopLevelVarName, "l", reflect.ValueOf(t0).Pointer(), t0, 2)
 		x.line("}")
 		x.line("}")
 		x.line("")
 		x.line("")
 	}
 	}
@@ -496,7 +503,7 @@ func (x *genRunner) selfer(encode bool) {
 	x.out(fnSigPfx)
 	x.out(fnSigPfx)
 	x.line(") codecDecodeSelfFromArray(l int, d *" + x.cpfx + "Decoder) {")
 	x.line(") codecDecodeSelfFromArray(l int, d *" + x.cpfx + "Decoder) {")
 	x.genRequiredMethodVars(false)
 	x.genRequiredMethodVars(false)
-	x.decStructArray("x", "l", "return", reflect.ValueOf(t0).Pointer(), t0)
+	x.decStructArray(genTopLevelVarName, "l", "return", reflect.ValueOf(t0).Pointer(), t0)
 	x.line("}")
 	x.line("}")
 	x.line("")
 	x.line("")
 
 
@@ -517,6 +524,7 @@ func (x *genRunner) xtraSM(varname string, encode bool, t reflect.Type) {
 // encVar will encode a variable.
 // encVar will encode a variable.
 // The parameter, t, is the reflect.Type of the variable itself
 // The parameter, t, is the reflect.Type of the variable itself
 func (x *genRunner) encVar(varname string, t reflect.Type) {
 func (x *genRunner) encVar(varname string, t reflect.Type) {
+	// fmt.Printf(">>>>>> varname: %s, t: %v\n", varname, t)
 	var checkNil bool
 	var checkNil bool
 	switch t.Kind() {
 	switch t.Kind() {
 	case reflect.Ptr, reflect.Interface, reflect.Slice, reflect.Map, reflect.Chan:
 	case reflect.Ptr, reflect.Interface, reflect.Slice, reflect.Map, reflect.Chan:
@@ -560,31 +568,31 @@ func (x *genRunner) enc(varname string, t reflect.Type) {
 	//   - the type is in the list of the ones we will generate for, but it is not currently being generated
 	//   - the type is in the list of the ones we will generate for, but it is not currently being generated
 
 
 	tptr := reflect.PtrTo(t)
 	tptr := reflect.PtrTo(t)
-	if t.Implements(selferTyp) {
-		x.line(varname + ".CodecEncodeSelf(e)")
-		return
-	}
-	// if t.Kind() == reflect.Struct && tptr.Implements(selferTyp) { //TODO: verify that no need to check struct
-	if tptr.Implements(selferTyp) {
-		x.line(varname + ".CodecEncodeSelf(e)")
-		return
-	}
-	if _, ok := x.te[rtid]; ok {
-		x.line(varname + ".CodecEncodeSelf(e)")
-		return
+	tk := t.Kind()
+	if x.checkForSelfer(t, varname) {
+		if t.Implements(selferTyp) || (tptr.Implements(selferTyp) && (tk == reflect.Array || tk == reflect.Struct)) {
+			x.line(varname + ".CodecEncodeSelf(e)")
+			return
+		}
+
+		if _, ok := x.te[rtid]; ok {
+			x.line(varname + ".CodecEncodeSelf(e)")
+			return
+		}
 	}
 	}
 
 
 	inlist := false
 	inlist := false
 	for _, t0 := range x.t {
 	for _, t0 := range x.t {
 		if t == t0 {
 		if t == t0 {
 			inlist = true
 			inlist = true
-			if t != x.tc {
+			if x.checkForSelfer(t, varname) {
 				x.line(varname + ".CodecEncodeSelf(e)")
 				x.line(varname + ".CodecEncodeSelf(e)")
 				return
 				return
 			}
 			}
 			break
 			break
 		}
 		}
 	}
 	}
+
 	var rtidAdded bool
 	var rtidAdded bool
 	if t == x.tc {
 	if t == x.tc {
 		x.te[rtid] = true
 		x.te[rtid] = true
@@ -1022,27 +1030,29 @@ func (x *genRunner) dec(varname string, t reflect.Type) {
 	//   - t is always a baseType T (not a *T, etc).
 	//   - t is always a baseType T (not a *T, etc).
 	rtid := reflect.ValueOf(t).Pointer()
 	rtid := reflect.ValueOf(t).Pointer()
 	tptr := reflect.PtrTo(t)
 	tptr := reflect.PtrTo(t)
-	if t.Implements(selferTyp) || (t.Kind() == reflect.Struct &&
-		reflect.PtrTo(t).Implements(selferTyp)) {
-		x.line(varname + ".CodecDecodeSelf(d)")
-		return
-	}
-	if _, ok := x.td[rtid]; ok {
-		x.line(varname + ".CodecDecodeSelf(d)")
-		return
+	if x.checkForSelfer(t, varname) {
+		if t.Implements(selferTyp) || tptr.Implements(selferTyp) {
+			x.line(varname + ".CodecDecodeSelf(d)")
+			return
+		}
+		if _, ok := x.td[rtid]; ok {
+			x.line(varname + ".CodecDecodeSelf(d)")
+			return
+		}
 	}
 	}
 
 
 	inlist := false
 	inlist := false
 	for _, t0 := range x.t {
 	for _, t0 := range x.t {
 		if t == t0 {
 		if t == t0 {
 			inlist = true
 			inlist = true
-			if t != x.tc {
+			if x.checkForSelfer(t, varname) {
 				x.line(varname + ".CodecDecodeSelf(d)")
 				x.line(varname + ".CodecDecodeSelf(d)")
 				return
 				return
 			}
 			}
 			break
 			break
 		}
 		}
 	}
 	}
+
 	var rtidAdded bool
 	var rtidAdded bool
 	if t == x.tc {
 	if t == x.tc {
 		x.td[rtid] = true
 		x.td[rtid] = true