Browse Source

codec, codecgen: Support (M|Unm)arshalJSON exactly as encoding/json when using JSONHandle

To do this, we added support for explicitly tracking the bytes which have been read.
This way, when we see JSONUnmarshal, we can read the next set of bytes representing a value,
and pass that to the UnmarshalJSON method for custom decoding.

We also explicitly allow JSONMarshal to output JSON representation,
and expect it to be compliant.

This brings feature-parity with how JSON(M|Unm)arshal works,
which is very different from how encoding.Text(M|Unm)arshal works.

We also now only handle (M|Unm)arshalJSON when using JSONHandle,
and we check for it before encoding.Text(M|Unm)arshal.

This support applies to reflection-based and codecgen.

Fixes #97
Ugorji Nwoke 10 years ago
parent
commit
3199d19e57
5 changed files with 113 additions and 27 deletions
  1. 58 5
      codec/decode.go
  2. 13 8
      codec/encode.go
  3. 19 4
      codec/gen-helper.generated.go
  4. 17 4
      codec/gen-helper.go.tmpl
  5. 6 6
      codec/gen.go

+ 58 - 5
codec/decode.go

@@ -38,6 +38,9 @@ type decReader interface {
 	readb([]byte)
 	readn1() uint8
 	readn1eof() (v uint8, eof bool)
+
+	track()
+	stopTrack() []byte
 }
 
 type decReaderByteScanner interface {
@@ -183,6 +186,9 @@ type ioDecReader struct {
 	// shares buffer with Decoder, so we keep size of struct within 8 words.
 	x  *[scratchByteArrayLen]byte
 	bs ioDecByteScanner
+
+	tr  []byte // tracking bytes read
+	trb bool
 }
 
 func (z *ioDecReader) readx(n int) (bs []byte) {
@@ -197,6 +203,9 @@ func (z *ioDecReader) readx(n int) (bs []byte) {
 	if _, err := io.ReadAtLeast(z.br, bs, n); err != nil {
 		panic(err)
 	}
+	if z.trb {
+		z.tr = append(z.tr, bs...)
+	}
 	return
 }
 
@@ -207,6 +216,9 @@ func (z *ioDecReader) readb(bs []byte) {
 	if _, err := io.ReadAtLeast(z.br, bs, len(bs)); err != nil {
 		panic(err)
 	}
+	if z.trb {
+		z.tr = append(z.tr, bs...)
+	}
 }
 
 func (z *ioDecReader) readn1() (b uint8) {
@@ -214,6 +226,9 @@ func (z *ioDecReader) readn1() (b uint8) {
 	if err != nil {
 		panic(err)
 	}
+	if z.trb {
+		z.tr = append(z.tr, b)
+	}
 	return b
 }
 
@@ -225,6 +240,9 @@ func (z *ioDecReader) readn1eof() (b uint8, eof bool) {
 	} else {
 		panic(err)
 	}
+	if z.trb {
+		z.tr = append(z.tr, b)
+	}
 	return
 }
 
@@ -232,6 +250,23 @@ func (z *ioDecReader) unreadn1() {
 	if err := z.br.UnreadByte(); err != nil {
 		panic(err)
 	}
+	if z.trb {
+		if l := len(z.tr) - 1; l >= 0 {
+			z.tr = z.tr[:l]
+		}
+	}
+}
+
+func (z *ioDecReader) track() {
+	if z.tr != nil {
+		z.tr = z.tr[:0]
+	}
+	z.trb = true
+}
+
+func (z *ioDecReader) stopTrack() (bs []byte) {
+	z.trb = false
+	return z.tr
 }
 
 // ------------------------------------
@@ -243,6 +278,7 @@ type bytesDecReader struct {
 	b []byte // data
 	c int    // cursor
 	a int    // available
+	t int    // track start
 }
 
 func (z *bytesDecReader) unreadn1() {
@@ -298,6 +334,14 @@ func (z *bytesDecReader) readb(bs []byte) {
 	copy(bs, z.readx(len(bs)))
 }
 
+func (z *bytesDecReader) track() {
+	z.t = z.c
+}
+
+func (z *bytesDecReader) stopTrack() (bs []byte) {
+	return z.b[z.t:z.c]
+}
+
 // ------------------------------------
 
 type decFnInfoX struct {
@@ -388,7 +432,13 @@ func (f decFnInfo) textUnmarshal(rv reflect.Value) {
 
 func (f decFnInfo) jsonUnmarshal(rv reflect.Value) {
 	tm := f.getValueForUnmarshalInterface(rv, f.ti.junmIndir).(jsonUnmarshaler)
-	fnerr := tm.UnmarshalJSON(f.dd.DecodeBytes(f.d.b[:], true, true))
+	// bs := f.dd.DecodeBytes(f.d.b[:], true, true)
+	// grab the bytes to be read
+	f.d.r.track()
+	f.d.swallow()
+	bs := f.d.r.stopTrack()
+	// fmt.Printf(">>>>>> REFLECTION JSON: %s\n", bs)
+	fnerr := tm.UnmarshalJSON(bs)
 	if fnerr != nil {
 		panic(fnerr)
 	}
@@ -904,6 +954,7 @@ type Decoder struct {
 	hh    Handle
 	be    bool // is binary encoding
 	bytes bool // is bytes reader
+	js    bool // is json handle
 
 	ri ioDecReader
 	f  map[uintptr]decFn
@@ -927,6 +978,7 @@ func NewDecoder(r io.Reader, h Handle) (d *Decoder) {
 		d.ri.br = &d.ri.bs
 	}
 	d.r = &d.ri
+	_, d.js = h.(*JsonHandle)
 	d.d = h.newDecDriver(d)
 	return
 }
@@ -939,6 +991,7 @@ func NewDecoderBytes(in []byte, h Handle) (d *Decoder) {
 	d.rb.b = in
 	d.rb.a = len(in)
 	d.r = &d.rb
+	_, d.js = h.(*JsonHandle)
 	d.d = h.newDecDriver(d)
 	// d.d = h.newDecDriver(decReaderT{true, &d.rb, &d.ri})
 	return
@@ -1280,13 +1333,13 @@ func (d *Decoder) getDecFn(rt reflect.Type, checkFastpath, checkCodecSelfer bool
 	} else if supportMarshalInterfaces && d.be && ti.bunm {
 		fi.decFnInfoX = &decFnInfoX{d: d, ti: ti}
 		fn.f = (decFnInfo).binaryUnmarshal
+	} else if supportMarshalInterfaces && !d.be && d.js && ti.junm {
+		//If JSON, we should check JSONUnmarshal before textUnmarshal
+		fi.decFnInfoX = &decFnInfoX{d: d, ti: ti}
+		fn.f = (decFnInfo).jsonUnmarshal
 	} else if supportMarshalInterfaces && !d.be && ti.tunm {
 		fi.decFnInfoX = &decFnInfoX{d: d, ti: ti}
 		fn.f = (decFnInfo).textUnmarshal
-	} else if supportMarshalInterfaces && !d.be && ti.junm {
-		//TODO: This only works NOW, as JSON is the ONLY text format.
-		fi.decFnInfoX = &decFnInfoX{d: d, ti: ti}
-		fn.f = (decFnInfo).jsonUnmarshal
 	} else {
 		rk := rt.Kind()
 		if fastpathEnabled && checkFastpath && (rk == reflect.Map || rk == reflect.Slice) {

+ 13 - 8
codec/encode.go

@@ -337,7 +337,7 @@ func (f encFnInfo) selferMarshal(rv reflect.Value) {
 func (f encFnInfo) binaryMarshal(rv reflect.Value) {
 	if v, proceed := f.getValueForMarshalInterface(rv, f.ti.bmIndir); proceed {
 		bs, fnerr := v.(encoding.BinaryMarshaler).MarshalBinary()
-		f.e.marshal(bs, fnerr, c_RAW)
+		f.e.marshal(bs, fnerr, false, c_RAW)
 	}
 }
 
@@ -345,14 +345,14 @@ func (f encFnInfo) textMarshal(rv reflect.Value) {
 	if v, proceed := f.getValueForMarshalInterface(rv, f.ti.tmIndir); proceed {
 		// debugf(">>>> encoding.TextMarshaler: %T", rv.Interface())
 		bs, fnerr := v.(encoding.TextMarshaler).MarshalText()
-		f.e.marshal(bs, fnerr, c_UTF8)
+		f.e.marshal(bs, fnerr, false, c_UTF8)
 	}
 }
 
 func (f encFnInfo) jsonMarshal(rv reflect.Value) {
 	if v, proceed := f.getValueForMarshalInterface(rv, f.ti.jmIndir); proceed {
 		bs, fnerr := v.(jsonMarshaler).MarshalJSON()
-		f.e.marshal(bs, fnerr, c_UTF8)
+		f.e.marshal(bs, fnerr, true, c_UTF8)
 	}
 }
 
@@ -794,6 +794,7 @@ type Encoder struct {
 	w  encWriter
 	s  []rtidEncFn
 	be bool // is binary encoding
+	js bool // is json handle
 
 	wi ioEncWriter
 	wb bytesEncWriter
@@ -820,6 +821,7 @@ func NewEncoder(w io.Writer, h Handle) *Encoder {
 	}
 	e.wi.w = ww
 	e.w = &e.wi
+	_, e.js = h.(*JsonHandle)
 	e.e = h.newEncDriver(e)
 	return e
 }
@@ -837,6 +839,7 @@ func NewEncoderBytes(out *[]byte, h Handle) *Encoder {
 	}
 	e.wb.b, e.wb.out = in, out
 	e.w = &e.wb
+	_, e.js = h.(*JsonHandle)
 	e.e = h.newEncDriver(e)
 	return e
 }
@@ -1100,13 +1103,13 @@ func (e *Encoder) getEncFn(rtid uintptr, rt reflect.Type, checkFastpath, checkCo
 	} else if supportMarshalInterfaces && e.be && ti.bm {
 		fi.encFnInfoX = &encFnInfoX{e: e, ti: ti}
 		fn.f = (encFnInfo).binaryMarshal
+	} else if supportMarshalInterfaces && !e.be && e.js && ti.jm {
+		//If JSON, we should check JSONMarshal before textMarshal
+		fi.encFnInfoX = &encFnInfoX{e: e, ti: ti}
+		fn.f = (encFnInfo).jsonMarshal
 	} else if supportMarshalInterfaces && !e.be && ti.tm {
 		fi.encFnInfoX = &encFnInfoX{e: e, ti: ti}
 		fn.f = (encFnInfo).textMarshal
-	} else if supportMarshalInterfaces && !e.be && ti.jm {
-		//TODO: This only works NOW, as JSON is the ONLY text format.
-		fi.encFnInfoX = &encFnInfoX{e: e, ti: ti}
-		fn.f = (encFnInfo).jsonMarshal
 	} else {
 		rk := rt.Kind()
 		// if fastpathEnabled && checkFastpath && (rk == reflect.Map || rk == reflect.Slice) {
@@ -1193,12 +1196,14 @@ func (e *Encoder) getEncFn(rtid uintptr, rt reflect.Type, checkFastpath, checkCo
 	return
 }
 
-func (e *Encoder) marshal(bs []byte, fnerr error, c charEncoding) {
+func (e *Encoder) marshal(bs []byte, fnerr error, asis bool, c charEncoding) {
 	if fnerr != nil {
 		panic(fnerr)
 	}
 	if bs == nil {
 		e.e.EncodeNil()
+	} else if asis {
+		e.w.writeb(bs)
 	} else {
 		e.e.EncodeStringBytes(c, bs)
 	}

+ 19 - 4
codec/gen-helper.generated.go

@@ -68,19 +68,19 @@ func (f genHelperEncoder) EncFallback(iv interface{}) {
 // FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
 func (f genHelperEncoder) EncTextMarshal(iv encoding.TextMarshaler) {
 	bs, fnerr := iv.MarshalText()
-	f.e.marshal(bs, fnerr, c_UTF8)
+	f.e.marshal(bs, fnerr, false, c_UTF8)
 }
 
 // FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
 func (f genHelperEncoder) EncJSONMarshal(iv jsonMarshaler) {
 	bs, fnerr := iv.MarshalJSON()
-	f.e.marshal(bs, fnerr, c_UTF8)
+	f.e.marshal(bs, fnerr, true, c_UTF8)
 }
 
 // FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
 func (f genHelperEncoder) EncBinaryMarshal(iv encoding.BinaryMarshaler) {
 	bs, fnerr := iv.MarshalBinary()
-	f.e.marshal(bs, fnerr, c_RAW)
+	f.e.marshal(bs, fnerr, false, c_RAW)
 }
 
 // FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
@@ -91,6 +91,11 @@ func (f genHelperEncoder) TimeRtidIfBinc() uintptr {
 	return 0
 }
 
+// FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
+func (f genHelperEncoder) IsJSONHandle() bool {
+	return f.e.js
+}
+
 // FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
 func (f genHelperEncoder) HasExtensions() bool {
 	return len(f.e.h.extHandle) != 0
@@ -161,7 +166,12 @@ func (f genHelperDecoder) DecTextUnmarshal(tm encoding.TextUnmarshaler) {
 
 // FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
 func (f genHelperDecoder) DecJSONUnmarshal(tm jsonUnmarshaler) {
-	fnerr := tm.UnmarshalJSON(f.d.d.DecodeBytes(f.d.b[:], true, true))
+	// bs := f.dd.DecodeBytes(f.d.b[:], true, true)
+	f.d.r.track()
+	f.d.swallow()
+	bs := f.d.r.stopTrack()
+	// fmt.Printf(">>>>>> CODECGEN JSON: %s\n", bs)
+	fnerr := tm.UnmarshalJSON(bs)
 	if fnerr != nil {
 		panic(fnerr)
 	}
@@ -183,6 +193,11 @@ func (f genHelperDecoder) TimeRtidIfBinc() uintptr {
 	return 0
 }
 
+// FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
+func (f genHelperDecoder) IsJSONHandle() bool {
+	return f.d.js
+}
+
 // FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
 func (f genHelperDecoder) HasExtensions() bool {
 	return len(f.d.h.extHandle) != 0

+ 17 - 4
codec/gen-helper.go.tmpl

@@ -66,17 +66,17 @@ func (f genHelperEncoder) EncFallback(iv interface{}) {
 // FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
 func (f genHelperEncoder) EncTextMarshal(iv encoding.TextMarshaler) {
 	bs, fnerr := iv.MarshalText()
-	f.e.marshal(bs, fnerr, c_UTF8)
+	f.e.marshal(bs, fnerr, false, c_UTF8)
 }
 // FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
 func (f genHelperEncoder) EncJSONMarshal(iv jsonMarshaler) {
 	bs, fnerr := iv.MarshalJSON()
-	f.e.marshal(bs, fnerr, c_UTF8)
+	f.e.marshal(bs, fnerr, true, c_UTF8)
 }
 // FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
 func (f genHelperEncoder) EncBinaryMarshal(iv encoding.BinaryMarshaler) {
 	bs, fnerr := iv.MarshalBinary()
-	f.e.marshal(bs, fnerr, c_RAW)
+	f.e.marshal(bs, fnerr, false, c_RAW)
 }
 // FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
 func (f genHelperEncoder) TimeRtidIfBinc() uintptr {
@@ -86,6 +86,10 @@ func (f genHelperEncoder) TimeRtidIfBinc() uintptr {
 	return 0
 }
 // FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
+func (f genHelperEncoder) IsJSONHandle() bool {
+	return f.e.js
+}
+// FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
 func (f genHelperEncoder) HasExtensions() bool {
 	return len(f.e.h.extHandle) != 0
 }
@@ -145,7 +149,12 @@ func (f genHelperDecoder) DecTextUnmarshal(tm encoding.TextUnmarshaler) {
 }
 // FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
 func (f genHelperDecoder) DecJSONUnmarshal(tm jsonUnmarshaler) {
-	fnerr := tm.UnmarshalJSON(f.d.d.DecodeBytes(f.d.b[:], true, true))
+	// bs := f.dd.DecodeBytes(f.d.b[:], true, true)
+	f.d.r.track()
+	f.d.swallow()
+	bs := f.d.r.stopTrack()
+	// fmt.Printf(">>>>>> CODECGEN JSON: %s\n", bs)
+	fnerr := tm.UnmarshalJSON(bs)
 	if fnerr != nil {
 		panic(fnerr)
 	}
@@ -165,6 +174,10 @@ func (f genHelperDecoder) TimeRtidIfBinc() uintptr {
 	return 0
 }
 // FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
+func (f genHelperDecoder) IsJSONHandle() bool {
+	return f.d.js 
+}
+// FOR USE BY CODECGEN ONLY. IT *WILL* CHANGE WITHOUT NOTICE. *DO NOT USE*
 func (f genHelperDecoder) HasExtensions() bool {
 	return len(f.d.h.extHandle) != 0
 }

+ 6 - 6
codec/gen.go

@@ -620,10 +620,10 @@ func (x *genRunner) enc(varname string, t reflect.Type) {
 	if t.Implements(binaryMarshalerTyp) || tptr.Implements(binaryMarshalerTyp) {
 		x.linef("} else if %sm%s { z.EncBinaryMarshal(%v) ", genTempVarPfx, mi, varname)
 	}
-	if t.Implements(textMarshalerTyp) || tptr.Implements(textMarshalerTyp) {
+	if t.Implements(jsonMarshalerTyp) || tptr.Implements(jsonMarshalerTyp) {
+		x.linef("} else if !%sm%s && z.IsJSONHandle() { z.EncJSONMarshal(%v) ", genTempVarPfx, mi, varname)
+	} else if t.Implements(textMarshalerTyp) || tptr.Implements(textMarshalerTyp) {
 		x.linef("} else if !%sm%s { z.EncTextMarshal(%v) ", genTempVarPfx, mi, varname)
-	} else if t.Implements(jsonMarshalerTyp) || tptr.Implements(jsonMarshalerTyp) {
-		x.linef("} else if !%sm%s { z.EncJSONMarshal(%v) ", genTempVarPfx, mi, varname)
 	}
 
 	x.line("} else {")
@@ -1080,10 +1080,10 @@ func (x *genRunner) dec(varname string, t reflect.Type) {
 	if t.Implements(binaryUnmarshalerTyp) || tptr.Implements(binaryUnmarshalerTyp) {
 		x.linef("} else if %sm%s { z.DecBinaryUnmarshal(%v) ", genTempVarPfx, mi, varname)
 	}
-	if t.Implements(textUnmarshalerTyp) || tptr.Implements(textUnmarshalerTyp) {
+	if t.Implements(jsonUnmarshalerTyp) || tptr.Implements(jsonUnmarshalerTyp) {
+		x.linef("} else if !%sm%s && z.IsJSONHandle() { z.DecJSONUnmarshal(%v)", genTempVarPfx, mi, varname)
+	} else if t.Implements(textUnmarshalerTyp) || tptr.Implements(textUnmarshalerTyp) {
 		x.linef("} else if !%sm%s { z.DecTextUnmarshal(%v)", genTempVarPfx, mi, varname)
-	} else if t.Implements(jsonUnmarshalerTyp) || tptr.Implements(jsonUnmarshalerTyp) {
-		x.linef("} else if !%sm%s { z.DecJSONUnmarshal(%v)", genTempVarPfx, mi, varname)
 	}
 
 	x.line("} else {")