Browse Source

codec: remove buffering from encoder/decoder. Add buffering to rpc. fix binc floatZero bug.

This codec library does not internally do buffering during encode or decode, leaving
that to the caller of the NewEncoder or NewDecoder function to pass a buffered stream
if desired.

However, RPC owns the connection and is the caller of the NewEncoder/NewDecoder. The RPC
support now internally passes buffered streams to the NewEncoder/NewDecoder calls.

Fixes #9 .

In addition, there was a floatZero bug in binc, where floats with zero value could
mess up the decoder state. This is now fixed.
Ugorji Nwoke 12 years ago
parent
commit
8f3b3ef741
7 changed files with 123 additions and 88 deletions
  1. 12 6
      codec/binc.go
  2. 3 2
      codec/decode.go
  3. 58 45
      codec/encode.go
  4. 3 0
      codec/helper.go
  5. 10 3
      codec/msgpack.go
  6. 31 25
      codec/rpc.go
  7. 6 7
      codec/time.go

+ 12 - 6
codec/binc.go

@@ -126,6 +126,7 @@ func (e *bincEncDriver) encodeBool(b bool) {
 }
 }
 
 
 func (e *bincEncDriver) encodeFloat32(f float32) {
 func (e *bincEncDriver) encodeFloat32(f float32) {
+	//println("encodeFloat32")
 	if f == 0 {
 	if f == 0 {
 		e.w.writen1(bincVdSpecial<<4 | bincSpZeroFloat)
 		e.w.writen1(bincVdSpecial<<4 | bincSpZeroFloat)
 		return
 		return
@@ -386,14 +387,17 @@ func (d *bincDecDriver) decFloatPre(vs, defaultLen byte) {
 }
 }
 
 
 func (d *bincDecDriver) decFloat() (f float64) {
 func (d *bincDecDriver) decFloat() (f float64) {
+	//println("decFloat")
 	//if true { f = math.Float64frombits(d.r.readUint64()); break; }
 	//if true { f = math.Float64frombits(d.r.readUint64()); break; }
 	switch vs := d.vs; vs & 0x7 {
 	switch vs := d.vs; vs & 0x7 {
 	case bincFlBin32:
 	case bincFlBin32:
+		//println("decodeFloat32")
 		d.decFloatPre(vs, 4)
 		d.decFloatPre(vs, 4)
 		f = float64(math.Float32frombits(bigen.Uint32(d.b[0:4])))
 		f = float64(math.Float32frombits(bigen.Uint32(d.b[0:4])))
 	case bincFlBin64:
 	case bincFlBin64:
+		//println("decodeFloat64")
 		d.decFloatPre(vs, 8)
 		d.decFloatPre(vs, 8)
-		f = math.Float64frombits(bigen.Uint64(d.b[:]))
+		f = math.Float64frombits(bigen.Uint64(d.b[0:8]))
 	default:
 	default:
 		decErr("only float32 and float64 are supported. d.vd: 0x%x, d.vs: 0x%x", d.vd, d.vs)
 		decErr("only float32 and float64 are supported. d.vd: 0x%x, d.vs: 0x%x", d.vd, d.vs)
 	}
 	}
@@ -530,19 +534,21 @@ func (d *bincDecDriver) decodeUint(bitsize uint8) (ui uint64) {
 }
 }
 
 
 func (d *bincDecDriver) decodeFloat(chkOverflow32 bool) (f float64) {
 func (d *bincDecDriver) decodeFloat(chkOverflow32 bool) (f float64) {
-	if d.vd == bincVdSpecial {
+	switch d.vd {
+	case bincVdSpecial:
+		d.bdRead = false
 		switch d.vs {
 		switch d.vs {
 		case bincSpNan:
 		case bincSpNan:
 			return math.NaN()
 			return math.NaN()
 		case bincSpPosInf:
 		case bincSpPosInf:
 			return math.Inf(1)
 			return math.Inf(1)
-		case bincSpZeroFloat:
+		case bincSpZeroFloat, bincSpZero:
 			return
 			return
 		case bincSpNegInf:
 		case bincSpNegInf:
 			return math.Inf(-1)
 			return math.Inf(-1)
-		}
-	}
-	switch d.vd {
+		default:
+			decErr("Invalid d.vs decoding float where d.vd=bincVdSpecial: %v", d.vs)
+		}		
 	case bincVdFloat:
 	case bincVdFloat:
 		f = d.decFloat()
 		f = d.decFloat()
 	case bincVdUint:
 	case bincVdUint:

+ 3 - 2
codec/decode.go

@@ -6,8 +6,6 @@ package codec
 import (
 import (
 	"io"
 	"io"
 	"reflect"
 	"reflect"
-	//"math"
-	//"fmt"
 )
 )
 
 
 // Some tagging information for error messages.
 // Some tagging information for error messages.
@@ -174,6 +172,9 @@ func (o *decHandle) getDecodeExt(rt reflect.Type) (tag byte, fn func(reflect.Val
 }
 }
 
 
 // NewDecoder returns a Decoder for decoding a stream of bytes from an io.Reader.
 // NewDecoder returns a Decoder for decoding a stream of bytes from an io.Reader.
+// 
+// For efficiency, Users are encouraged to pass in a memory buffered writer
+// (eg bufio.Reader, bytes.Buffer). 
 func NewDecoder(r io.Reader, h Handle) *Decoder {
 func NewDecoder(r io.Reader, h Handle) *Decoder {
 	z := ioDecReader{
 	z := ioDecReader{
 		r: r,
 		r: r,

+ 58 - 45
codec/encode.go

@@ -4,7 +4,7 @@
 package codec
 package codec
 
 
 import (
 import (
-	"bufio"
+	//"bufio"
 	"io"
 	"io"
 	"reflect"
 	"reflect"
 	//"fmt"
 	//"fmt"
@@ -27,9 +27,9 @@ type encWriter interface {
 	writestr(string)
 	writestr(string)
 	writen1(byte)
 	writen1(byte)
 	writen2(byte, byte)
 	writen2(byte, byte)
-	writen3(byte, byte, byte)
-	writen4(byte, byte, byte, byte)
-	flush()
+	//writen3(byte, byte, byte)
+	//writen4(byte, byte, byte, byte)
+	atEndOfEncode()
 }
 }
 
 
 type encDriver interface {
 type encDriver interface {
@@ -69,8 +69,14 @@ type ioEncWriterWriter interface {
 	Write(p []byte) (n int, err error)
 	Write(p []byte) (n int, err error)
 }
 }
 
 
-type ioEncWriterFlusher interface {
-	Flush() error
+type ioEncStringWriter interface {
+	WriteString(s string) (n int, err error)
+}
+
+type simpleIoEncWriterWriter struct {
+	w io.Writer 
+	bw io.ByteWriter
+	sw ioEncStringWriter
 }
 }
 
 
 // ioEncWriter implements encWriter and can write to an io.Writer implementation
 // ioEncWriter implements encWriter and can write to an io.Writer implementation
@@ -84,7 +90,7 @@ type ioEncWriter struct {
 type bytesEncWriter struct {
 type bytesEncWriter struct {
 	b   []byte
 	b   []byte
 	c   int     // cursor
 	c   int     // cursor
-	out *[]byte // write out on flush
+	out *[]byte // write out on atEndOfEncode
 }
 }
 
 
 type encExtTagFn struct {
 type encExtTagFn struct {
@@ -103,6 +109,26 @@ type encHandle struct {
 	exts     []encExtTypeTagFn
 	exts     []encExtTypeTagFn
 }
 }
 
 
+func (o *simpleIoEncWriterWriter) WriteByte(c byte) (err error) {
+	if o.bw != nil {
+		return o.bw.WriteByte(c)
+	}
+	_, err = o.w.Write([]byte{c})
+	return
+}
+
+func (o *simpleIoEncWriterWriter) WriteString(s string) (n int, err error) {
+	if o.sw != nil {
+		return o.sw.WriteString(s)
+	}
+	return o.w.Write([]byte(s))
+}
+
+func (o *simpleIoEncWriterWriter) Write(p []byte) (n int, err error) {
+	return o.w.Write(p)
+}
+
+
 // addEncodeExt registers a function to handle encoding a given type as an extension
 // addEncodeExt registers a function to handle encoding a given type as an extension
 // with a specific specific tag byte.
 // with a specific specific tag byte.
 // To remove an extension, pass fn=nil.
 // To remove an extension, pass fn=nil.
@@ -148,12 +174,17 @@ func (o *encHandle) getEncodeExt(rt reflect.Type) (tag byte, fn func(reflect.Val
 }
 }
 
 
 // NewEncoder returns an Encoder for encoding into an io.Writer.
 // NewEncoder returns an Encoder for encoding into an io.Writer.
+// 
 // For efficiency, Users are encouraged to pass in a memory buffered writer
 // For efficiency, Users are encouraged to pass in a memory buffered writer
-// (eg bufio.Writer, bytes.Buffer). This implementation *may* use one internally.
+// (eg bufio.Writer, bytes.Buffer). 
 func NewEncoder(w io.Writer, h Handle) *Encoder {
 func NewEncoder(w io.Writer, h Handle) *Encoder {
 	ww, ok := w.(ioEncWriterWriter)
 	ww, ok := w.(ioEncWriterWriter)
 	if !ok {
 	if !ok {
-		ww = bufio.NewWriterSize(w, defEncByteBufSize)
+		sww := simpleIoEncWriterWriter{w: w}
+		sww.bw, _ = w.(io.ByteWriter)
+		sww.sw, _ = w.(ioEncStringWriter)
+		ww = &sww
+		//ww = bufio.NewWriterSize(w, defEncByteBufSize)
 	}
 	}
 	z := ioEncWriter{
 	z := ioEncWriter{
 		w: ww,
 		w: ww,
@@ -216,7 +247,7 @@ func NewEncoderBytes(out *[]byte, h Handle) *Encoder {
 func (e *Encoder) Encode(v interface{}) (err error) {
 func (e *Encoder) Encode(v interface{}) (err error) {
 	defer panicToErr(&err)
 	defer panicToErr(&err)
 	e.encode(v)
 	e.encode(v)
-	e.w.flush()
+	e.w.atEndOfEncode()
 	return
 	return
 }
 }
 
 
@@ -318,7 +349,7 @@ func (e *Encoder) encodeValue(rv reflect.Value) {
 		}
 		}
 		return
 		return
 	}
 	}
-
+	
 	// ensure more common cases appear early in switch.
 	// ensure more common cases appear early in switch.
 	rk := rv.Kind()
 	rk := rv.Kind()
 	switch rk {
 	switch rk {
@@ -475,26 +506,7 @@ func (z *ioEncWriter) writen2(b1 byte, b2 byte) {
 	z.writen1(b2)
 	z.writen1(b2)
 }
 }
 
 
-func (z *ioEncWriter) writen3(b1, b2, b3 byte) {
-	z.writen1(b1)
-	z.writen1(b2)
-	z.writen1(b3)
-}
-
-func (z *ioEncWriter) writen4(b1, b2, b3, b4 byte) {
-	z.writen1(b1)
-	z.writen1(b2)
-	z.writen1(b3)
-	z.writen1(b4)
-}
-
-func (z *ioEncWriter) flush() {
-	if f, ok := z.w.(ioEncWriterFlusher); ok {
-		if err := f.Flush(); err != nil {
-			panic(err)
-		}
-	}
-}
+func (z *ioEncWriter) atEndOfEncode() { }
 
 
 // ----------------------------------------
 // ----------------------------------------
 
 
@@ -545,22 +557,22 @@ func (z *bytesEncWriter) writen2(b1 byte, b2 byte) {
 	z.b[c+1] = b2
 	z.b[c+1] = b2
 }
 }
 
 
-func (z *bytesEncWriter) writen3(b1 byte, b2 byte, b3 byte) {
-	c := z.grow(3)
-	z.b[c] = b1
-	z.b[c+1] = b2
-	z.b[c+2] = b3
-}
+// func (z *bytesEncWriter) writen3(b1 byte, b2 byte, b3 byte) {
+// 	c := z.grow(3)
+// 	z.b[c] = b1
+// 	z.b[c+1] = b2
+// 	z.b[c+2] = b3
+// }
 
 
-func (z *bytesEncWriter) writen4(b1 byte, b2 byte, b3 byte, b4 byte) {
-	c := z.grow(4)
-	z.b[c] = b1
-	z.b[c+1] = b2
-	z.b[c+2] = b3
-	z.b[c+3] = b4
-}
+// func (z *bytesEncWriter) writen4(b1 byte, b2 byte, b3 byte, b4 byte) {
+// 	c := z.grow(4)
+// 	z.b[c] = b1
+// 	z.b[c+1] = b2
+// 	z.b[c+2] = b3
+// 	z.b[c+3] = b4
+// }
 
 
-func (z *bytesEncWriter) flush() {
+func (z *bytesEncWriter) atEndOfEncode() {
 	*(z.out) = z.b[:z.c]
 	*(z.out) = z.b[:z.c]
 }
 }
 
 
@@ -585,3 +597,4 @@ func (z *bytesEncWriter) grow(n int) (oldcursor int) {
 func encErr(format string, params ...interface{}) {
 func encErr(format string, params ...interface{}) {
 	doPanic(msgTagEnc, format, params...)
 	doPanic(msgTagEnc, format, params...)
 }
 }
+

+ 3 - 0
codec/helper.go

@@ -55,6 +55,9 @@ var (
 
 
 	intBitsize  uint8 = uint8(reflect.TypeOf(int(0)).Bits())
 	intBitsize  uint8 = uint8(reflect.TypeOf(int(0)).Bits())
 	uintBitsize uint8 = uint8(reflect.TypeOf(uint(0)).Bits())
 	uintBitsize uint8 = uint8(reflect.TypeOf(uint(0)).Bits())
+
+	bsAll0x00 = []byte{0, 0, 0, 0, 0, 0, 0, 0}
+	bsAll0xff = []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
 )
 )
 
 
 type encdecHandle struct {
 type encdecHandle struct {

+ 10 - 3
codec/msgpack.go

@@ -200,7 +200,8 @@ func (e *msgpackEncDriver) encodeExtPreamble(xtag byte, l int) {
 	case l <= 8:
 	case l <= 8:
 		e.w.writen2(0xc0|byte(l), xtag)
 		e.w.writen2(0xc0|byte(l), xtag)
 	case l < 256:
 	case l < 256:
-		e.w.writen3(mpXv4Fixext5, xtag, byte(l))
+		e.w.writen2(mpXv4Fixext5, xtag)
+		e.w.writen1(byte(l))
 	case l < 65536:
 	case l < 65536:
 		e.w.writen2(mpXv4Ext16, xtag)
 		e.w.writen2(mpXv4Ext16, xtag)
 		e.w.writeUint16(uint16(l))
 		e.w.writeUint16(uint16(l))
@@ -662,13 +663,19 @@ func (c msgpackSpecRpcCodec) parseCustomHeader(expectTypeByte byte, msgid *uint6
 		return
 		return
 	}
 	}
 	var b byte
 	var b byte
-	if err = c.read(&b, msgid, methodOrError); err != nil {
+	if err = c.read(&b); err != nil {
 		return
 		return
 	}
 	}
 	if b != expectTypeByte {
 	if b != expectTypeByte {
 		err = fmt.Errorf("Unexpected byte descriptor in header. Expecting %v. Received %v", expectTypeByte, b)
 		err = fmt.Errorf("Unexpected byte descriptor in header. Expecting %v. Received %v", expectTypeByte, b)
 		return
 		return
 	}
 	}
+	if err = c.read(msgid); err != nil {
+		return
+	}
+	if err = c.read(methodOrError); err != nil {
+		return
+	}
 	return
 	return
 }
 }
 
 
@@ -684,7 +691,7 @@ func (c msgpackSpecRpcCodec) writeCustomBody(typeByte byte, msgid uint64, method
 		}
 		}
 	}
 	}
 	r2 := []interface{}{typeByte, uint32(msgid), moe, body}
 	r2 := []interface{}{typeByte, uint32(msgid), moe, body}
-	return c.enc.Encode(r2)
+	return c.write(r2, nil, false, true)
 }
 }
 
 
 //--------------------------------------------------
 //--------------------------------------------------

+ 31 - 25
codec/rpc.go

@@ -11,6 +11,7 @@ package codec
 
 
 import (
 import (
 	"io"
 	"io"
+	"bufio"
 	"net/rpc"
 	"net/rpc"
 )
 )
 
 
@@ -27,6 +28,7 @@ type rpcCodec struct {
 	rwc io.ReadWriteCloser
 	rwc io.ReadWriteCloser
 	dec *Decoder
 	dec *Decoder
 	enc *Encoder
 	enc *Encoder
+	encbuf *bufio.Writer 
 }
 }
 
 
 type goRpcCodec struct {
 type goRpcCodec struct {
@@ -46,45 +48,50 @@ func (x goRpc) ClientCodec(conn io.ReadWriteCloser, h Handle) rpc.ClientCodec {
 }
 }
 
 
 func newRPCCodec(conn io.ReadWriteCloser, h Handle) rpcCodec {
 func newRPCCodec(conn io.ReadWriteCloser, h Handle) rpcCodec {
+	encbuf := bufio.NewWriter(conn)
 	return rpcCodec{
 	return rpcCodec{
 		rwc: conn,
 		rwc: conn,
-		dec: NewDecoder(conn, h),
-		enc: NewEncoder(conn, h),
+		encbuf: encbuf,
+		enc: NewEncoder(encbuf, h),
+		dec: NewDecoder(bufio.NewReader(conn), h),
+		//enc: NewEncoder(conn, h),
+		//dec: NewDecoder(conn, h),
 	}
 	}
 }
 }
 
 
 // /////////////// RPC Codec Shared Methods ///////////////////
 // /////////////// RPC Codec Shared Methods ///////////////////
-func (c rpcCodec) write(objs ...interface{}) (err error) {
-	for _, obj := range objs {
-		if err = c.enc.Encode(obj); err != nil {
+func (c rpcCodec) write(obj1, obj2 interface{}, writeObj2, doFlush bool) (err error) {
+	if err = c.enc.Encode(obj1); err != nil {
+		return
+	}
+	if writeObj2 {
+		if err = c.enc.Encode(obj2); err != nil {
 			return
 			return
 		}
 		}
 	}
 	}
+	if doFlush && c.encbuf != nil {
+		//println("rpc flushing")
+		return c.encbuf.Flush()
+	}
 	return
 	return
 }
 }
 
 
-func (c rpcCodec) read(objs ...interface{}) (err error) {
-	for _, obj := range objs {
-		//If nil is passed in, we should still attempt to read content to nowhere.
-		if obj == nil {
-			//obj = &obj //This bombs/uses all memory up. Dunno why (maybe because obj is not addressable???).
-			var n interface{}
-			obj = &n
-		}
-		if err = c.dec.Decode(obj); err != nil {
-			return
-		}
+
+func (c rpcCodec) read(obj interface{}) (err error) {
+	//If nil is passed in, we should still attempt to read content to nowhere.
+	if obj == nil {
+		var obj2 interface{}
+		return c.dec.Decode(&obj2)
 	}
 	}
-	return
+	return c.dec.Decode(obj)
 }
 }
 
 
 func (c rpcCodec) Close() error {
 func (c rpcCodec) Close() error {
 	return c.rwc.Close()
 	return c.rwc.Close()
 }
 }
 
 
-func (c rpcCodec) ReadResponseBody(body interface{}) (err error) {
-	err = c.read(body)
-	return
+func (c rpcCodec) ReadResponseBody(body interface{}) error {
+	return c.read(body)
 }
 }
 
 
 func (c rpcCodec) ReadRequestBody(body interface{}) error {
 func (c rpcCodec) ReadRequestBody(body interface{}) error {
@@ -93,16 +100,15 @@ func (c rpcCodec) ReadRequestBody(body interface{}) error {
 
 
 // /////////////// Go RPC Codec ///////////////////
 // /////////////// Go RPC Codec ///////////////////
 func (c goRpcCodec) WriteRequest(r *rpc.Request, body interface{}) error {
 func (c goRpcCodec) WriteRequest(r *rpc.Request, body interface{}) error {
-	return c.write(r, body)
+	return c.write(r, body, true, true)
 }
 }
 
 
 func (c goRpcCodec) WriteResponse(r *rpc.Response, body interface{}) error {
 func (c goRpcCodec) WriteResponse(r *rpc.Response, body interface{}) error {
-	return c.write(r, body)
+	return c.write(r, body, true, true)
 }
 }
 
 
-func (c goRpcCodec) ReadResponseHeader(r *rpc.Response) (err error) {
-	err = c.read(r)
-	return
+func (c goRpcCodec) ReadResponseHeader(r *rpc.Response) error {
+	return c.read(r)
 }
 }
 
 
 func (c goRpcCodec) ReadRequestHeader(r *rpc.Request) error {
 func (c goRpcCodec) ReadRequestHeader(r *rpc.Request) error {

+ 6 - 7
codec/time.go

@@ -8,8 +8,7 @@ import (
 )
 )
 
 
 var (
 var (
-	timeBs0xff = []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
-	digits = [...]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
+	timeDigits = [...]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
 )
 )
 
 
 // encodeTime encodes a time.Time as a []byte, including
 // encodeTime encodes a time.Time as a []byte, including
@@ -79,7 +78,7 @@ func decodeTime(bs []byte) (tt time.Time, err error) {
 		copy(btmp[8-n:], bs[i:i2])
 		copy(btmp[8-n:], bs[i:i2])
 		//if first bit of bs[i] is set, then fill btmp[0..8-n] with 0xff (ie sign extend it)
 		//if first bit of bs[i] is set, then fill btmp[0..8-n] with 0xff (ie sign extend it)
 		if bs[i] & (1 << 7) != 0 {
 		if bs[i] & (1 << 7) != 0 {
-			copy(btmp[0:8-n], timeBs0xff)
+			copy(btmp[0:8-n], bsAll0xff)
 			//for j,k := byte(0), 8-n; j < k; j++ {	btmp[j] = 0xff }
 			//for j,k := byte(0), 8-n; j < k; j++ {	btmp[j] = 0xff }
 		}
 		}
 		i = i2
 		i = i2
@@ -122,10 +121,10 @@ func decodeTime(bs []byte) (tt time.Time, err error) {
 	} else {
 	} else {
 		tzhr, tzmin = tzint/60, tzint%60
 		tzhr, tzmin = tzint/60, tzint%60
 	}
 	}
-	tzname[4] = digits[tzhr/10]
-	tzname[5] = digits[tzhr%10]
-	tzname[7] = digits[tzmin/10]
-	tzname[8] = digits[tzmin%10]
+	tzname[4] = timeDigits[tzhr/10]
+	tzname[5] = timeDigits[tzhr%10]
+	tzname[7] = timeDigits[tzmin/10]
+	tzname[8] = timeDigits[tzmin%10]
 
 
 	tt = time.Unix(tsec, int64(tnsec)).In(time.FixedZone(string(tzname), int(tzint)*60))
 	tt = time.Unix(tsec, int64(tnsec)).In(time.FixedZone(string(tzname), int(tzint)*60))
 	return
 	return