Browse Source

codec: remove buffering from encoder/decoder. Add buffering to rpc. fix binc floatZero bug.

This codec library does not internally do buffering during encode or decode, leaving
that to the caller of the NewEncoder or NewDecoder function to pass a buffered stream
if desired.

However, RPC owns the connection and is the caller of the NewEncoder/NewDecoder. The RPC
support now internally passes buffered streams to the NewEncoder/NewDecoder calls.

Fixes #9 .

In addition, there was a floatZero bug in binc, where floats with zero value could
mess up the decoder state. This is now fixed.
Ugorji Nwoke 12 years ago
parent
commit
8f3b3ef741
7 changed files with 123 additions and 88 deletions
  1. 12 6
      codec/binc.go
  2. 3 2
      codec/decode.go
  3. 58 45
      codec/encode.go
  4. 3 0
      codec/helper.go
  5. 10 3
      codec/msgpack.go
  6. 31 25
      codec/rpc.go
  7. 6 7
      codec/time.go

+ 12 - 6
codec/binc.go

@@ -126,6 +126,7 @@ func (e *bincEncDriver) encodeBool(b bool) {
 }
 
 func (e *bincEncDriver) encodeFloat32(f float32) {
+	//println("encodeFloat32")
 	if f == 0 {
 		e.w.writen1(bincVdSpecial<<4 | bincSpZeroFloat)
 		return
@@ -386,14 +387,17 @@ func (d *bincDecDriver) decFloatPre(vs, defaultLen byte) {
 }
 
 func (d *bincDecDriver) decFloat() (f float64) {
+	//println("decFloat")
 	//if true { f = math.Float64frombits(d.r.readUint64()); break; }
 	switch vs := d.vs; vs & 0x7 {
 	case bincFlBin32:
+		//println("decodeFloat32")
 		d.decFloatPre(vs, 4)
 		f = float64(math.Float32frombits(bigen.Uint32(d.b[0:4])))
 	case bincFlBin64:
+		//println("decodeFloat64")
 		d.decFloatPre(vs, 8)
-		f = math.Float64frombits(bigen.Uint64(d.b[:]))
+		f = math.Float64frombits(bigen.Uint64(d.b[0:8]))
 	default:
 		decErr("only float32 and float64 are supported. d.vd: 0x%x, d.vs: 0x%x", d.vd, d.vs)
 	}
@@ -530,19 +534,21 @@ func (d *bincDecDriver) decodeUint(bitsize uint8) (ui uint64) {
 }
 
 func (d *bincDecDriver) decodeFloat(chkOverflow32 bool) (f float64) {
-	if d.vd == bincVdSpecial {
+	switch d.vd {
+	case bincVdSpecial:
+		d.bdRead = false
 		switch d.vs {
 		case bincSpNan:
 			return math.NaN()
 		case bincSpPosInf:
 			return math.Inf(1)
-		case bincSpZeroFloat:
+		case bincSpZeroFloat, bincSpZero:
 			return
 		case bincSpNegInf:
 			return math.Inf(-1)
-		}
-	}
-	switch d.vd {
+		default:
+			decErr("Invalid d.vs decoding float where d.vd=bincVdSpecial: %v", d.vs)
+		}		
 	case bincVdFloat:
 		f = d.decFloat()
 	case bincVdUint:

+ 3 - 2
codec/decode.go

@@ -6,8 +6,6 @@ package codec
 import (
 	"io"
 	"reflect"
-	//"math"
-	//"fmt"
 )
 
 // Some tagging information for error messages.
@@ -174,6 +172,9 @@ func (o *decHandle) getDecodeExt(rt reflect.Type) (tag byte, fn func(reflect.Val
 }
 
 // NewDecoder returns a Decoder for decoding a stream of bytes from an io.Reader.
+// 
+// For efficiency, Users are encouraged to pass in a memory buffered writer
+// (eg bufio.Reader, bytes.Buffer). 
 func NewDecoder(r io.Reader, h Handle) *Decoder {
 	z := ioDecReader{
 		r: r,

+ 58 - 45
codec/encode.go

@@ -4,7 +4,7 @@
 package codec
 
 import (
-	"bufio"
+	//"bufio"
 	"io"
 	"reflect"
 	//"fmt"
@@ -27,9 +27,9 @@ type encWriter interface {
 	writestr(string)
 	writen1(byte)
 	writen2(byte, byte)
-	writen3(byte, byte, byte)
-	writen4(byte, byte, byte, byte)
-	flush()
+	//writen3(byte, byte, byte)
+	//writen4(byte, byte, byte, byte)
+	atEndOfEncode()
 }
 
 type encDriver interface {
@@ -69,8 +69,14 @@ type ioEncWriterWriter interface {
 	Write(p []byte) (n int, err error)
 }
 
-type ioEncWriterFlusher interface {
-	Flush() error
+type ioEncStringWriter interface {
+	WriteString(s string) (n int, err error)
+}
+
+type simpleIoEncWriterWriter struct {
+	w io.Writer 
+	bw io.ByteWriter
+	sw ioEncStringWriter
 }
 
 // ioEncWriter implements encWriter and can write to an io.Writer implementation
@@ -84,7 +90,7 @@ type ioEncWriter struct {
 type bytesEncWriter struct {
 	b   []byte
 	c   int     // cursor
-	out *[]byte // write out on flush
+	out *[]byte // write out on atEndOfEncode
 }
 
 type encExtTagFn struct {
@@ -103,6 +109,26 @@ type encHandle struct {
 	exts     []encExtTypeTagFn
 }
 
+func (o *simpleIoEncWriterWriter) WriteByte(c byte) (err error) {
+	if o.bw != nil {
+		return o.bw.WriteByte(c)
+	}
+	_, err = o.w.Write([]byte{c})
+	return
+}
+
+func (o *simpleIoEncWriterWriter) WriteString(s string) (n int, err error) {
+	if o.sw != nil {
+		return o.sw.WriteString(s)
+	}
+	return o.w.Write([]byte(s))
+}
+
+func (o *simpleIoEncWriterWriter) Write(p []byte) (n int, err error) {
+	return o.w.Write(p)
+}
+
+
 // addEncodeExt registers a function to handle encoding a given type as an extension
 // with a specific specific tag byte.
 // To remove an extension, pass fn=nil.
@@ -148,12 +174,17 @@ func (o *encHandle) getEncodeExt(rt reflect.Type) (tag byte, fn func(reflect.Val
 }
 
 // NewEncoder returns an Encoder for encoding into an io.Writer.
+// 
 // For efficiency, Users are encouraged to pass in a memory buffered writer
-// (eg bufio.Writer, bytes.Buffer). This implementation *may* use one internally.
+// (eg bufio.Writer, bytes.Buffer). 
 func NewEncoder(w io.Writer, h Handle) *Encoder {
 	ww, ok := w.(ioEncWriterWriter)
 	if !ok {
-		ww = bufio.NewWriterSize(w, defEncByteBufSize)
+		sww := simpleIoEncWriterWriter{w: w}
+		sww.bw, _ = w.(io.ByteWriter)
+		sww.sw, _ = w.(ioEncStringWriter)
+		ww = &sww
+		//ww = bufio.NewWriterSize(w, defEncByteBufSize)
 	}
 	z := ioEncWriter{
 		w: ww,
@@ -216,7 +247,7 @@ func NewEncoderBytes(out *[]byte, h Handle) *Encoder {
 func (e *Encoder) Encode(v interface{}) (err error) {
 	defer panicToErr(&err)
 	e.encode(v)
-	e.w.flush()
+	e.w.atEndOfEncode()
 	return
 }
 
@@ -318,7 +349,7 @@ func (e *Encoder) encodeValue(rv reflect.Value) {
 		}
 		return
 	}
-
+	
 	// ensure more common cases appear early in switch.
 	rk := rv.Kind()
 	switch rk {
@@ -475,26 +506,7 @@ func (z *ioEncWriter) writen2(b1 byte, b2 byte) {
 	z.writen1(b2)
 }
 
-func (z *ioEncWriter) writen3(b1, b2, b3 byte) {
-	z.writen1(b1)
-	z.writen1(b2)
-	z.writen1(b3)
-}
-
-func (z *ioEncWriter) writen4(b1, b2, b3, b4 byte) {
-	z.writen1(b1)
-	z.writen1(b2)
-	z.writen1(b3)
-	z.writen1(b4)
-}
-
-func (z *ioEncWriter) flush() {
-	if f, ok := z.w.(ioEncWriterFlusher); ok {
-		if err := f.Flush(); err != nil {
-			panic(err)
-		}
-	}
-}
+func (z *ioEncWriter) atEndOfEncode() { }
 
 // ----------------------------------------
 
@@ -545,22 +557,22 @@ func (z *bytesEncWriter) writen2(b1 byte, b2 byte) {
 	z.b[c+1] = b2
 }
 
-func (z *bytesEncWriter) writen3(b1 byte, b2 byte, b3 byte) {
-	c := z.grow(3)
-	z.b[c] = b1
-	z.b[c+1] = b2
-	z.b[c+2] = b3
-}
+// func (z *bytesEncWriter) writen3(b1 byte, b2 byte, b3 byte) {
+// 	c := z.grow(3)
+// 	z.b[c] = b1
+// 	z.b[c+1] = b2
+// 	z.b[c+2] = b3
+// }
 
-func (z *bytesEncWriter) writen4(b1 byte, b2 byte, b3 byte, b4 byte) {
-	c := z.grow(4)
-	z.b[c] = b1
-	z.b[c+1] = b2
-	z.b[c+2] = b3
-	z.b[c+3] = b4
-}
+// func (z *bytesEncWriter) writen4(b1 byte, b2 byte, b3 byte, b4 byte) {
+// 	c := z.grow(4)
+// 	z.b[c] = b1
+// 	z.b[c+1] = b2
+// 	z.b[c+2] = b3
+// 	z.b[c+3] = b4
+// }
 
-func (z *bytesEncWriter) flush() {
+func (z *bytesEncWriter) atEndOfEncode() {
 	*(z.out) = z.b[:z.c]
 }
 
@@ -585,3 +597,4 @@ func (z *bytesEncWriter) grow(n int) (oldcursor int) {
 func encErr(format string, params ...interface{}) {
 	doPanic(msgTagEnc, format, params...)
 }
+

+ 3 - 0
codec/helper.go

@@ -55,6 +55,9 @@ var (
 
 	intBitsize  uint8 = uint8(reflect.TypeOf(int(0)).Bits())
 	uintBitsize uint8 = uint8(reflect.TypeOf(uint(0)).Bits())
+
+	bsAll0x00 = []byte{0, 0, 0, 0, 0, 0, 0, 0}
+	bsAll0xff = []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
 )
 
 type encdecHandle struct {

+ 10 - 3
codec/msgpack.go

@@ -200,7 +200,8 @@ func (e *msgpackEncDriver) encodeExtPreamble(xtag byte, l int) {
 	case l <= 8:
 		e.w.writen2(0xc0|byte(l), xtag)
 	case l < 256:
-		e.w.writen3(mpXv4Fixext5, xtag, byte(l))
+		e.w.writen2(mpXv4Fixext5, xtag)
+		e.w.writen1(byte(l))
 	case l < 65536:
 		e.w.writen2(mpXv4Ext16, xtag)
 		e.w.writeUint16(uint16(l))
@@ -662,13 +663,19 @@ func (c msgpackSpecRpcCodec) parseCustomHeader(expectTypeByte byte, msgid *uint6
 		return
 	}
 	var b byte
-	if err = c.read(&b, msgid, methodOrError); err != nil {
+	if err = c.read(&b); err != nil {
 		return
 	}
 	if b != expectTypeByte {
 		err = fmt.Errorf("Unexpected byte descriptor in header. Expecting %v. Received %v", expectTypeByte, b)
 		return
 	}
+	if err = c.read(msgid); err != nil {
+		return
+	}
+	if err = c.read(methodOrError); err != nil {
+		return
+	}
 	return
 }
 
@@ -684,7 +691,7 @@ func (c msgpackSpecRpcCodec) writeCustomBody(typeByte byte, msgid uint64, method
 		}
 	}
 	r2 := []interface{}{typeByte, uint32(msgid), moe, body}
-	return c.enc.Encode(r2)
+	return c.write(r2, nil, false, true)
 }
 
 //--------------------------------------------------

+ 31 - 25
codec/rpc.go

@@ -11,6 +11,7 @@ package codec
 
 import (
 	"io"
+	"bufio"
 	"net/rpc"
 )
 
@@ -27,6 +28,7 @@ type rpcCodec struct {
 	rwc io.ReadWriteCloser
 	dec *Decoder
 	enc *Encoder
+	encbuf *bufio.Writer 
 }
 
 type goRpcCodec struct {
@@ -46,45 +48,50 @@ func (x goRpc) ClientCodec(conn io.ReadWriteCloser, h Handle) rpc.ClientCodec {
 }
 
 func newRPCCodec(conn io.ReadWriteCloser, h Handle) rpcCodec {
+	encbuf := bufio.NewWriter(conn)
 	return rpcCodec{
 		rwc: conn,
-		dec: NewDecoder(conn, h),
-		enc: NewEncoder(conn, h),
+		encbuf: encbuf,
+		enc: NewEncoder(encbuf, h),
+		dec: NewDecoder(bufio.NewReader(conn), h),
+		//enc: NewEncoder(conn, h),
+		//dec: NewDecoder(conn, h),
 	}
 }
 
 // /////////////// RPC Codec Shared Methods ///////////////////
-func (c rpcCodec) write(objs ...interface{}) (err error) {
-	for _, obj := range objs {
-		if err = c.enc.Encode(obj); err != nil {
+func (c rpcCodec) write(obj1, obj2 interface{}, writeObj2, doFlush bool) (err error) {
+	if err = c.enc.Encode(obj1); err != nil {
+		return
+	}
+	if writeObj2 {
+		if err = c.enc.Encode(obj2); err != nil {
 			return
 		}
 	}
+	if doFlush && c.encbuf != nil {
+		//println("rpc flushing")
+		return c.encbuf.Flush()
+	}
 	return
 }
 
-func (c rpcCodec) read(objs ...interface{}) (err error) {
-	for _, obj := range objs {
-		//If nil is passed in, we should still attempt to read content to nowhere.
-		if obj == nil {
-			//obj = &obj //This bombs/uses all memory up. Dunno why (maybe because obj is not addressable???).
-			var n interface{}
-			obj = &n
-		}
-		if err = c.dec.Decode(obj); err != nil {
-			return
-		}
+
+func (c rpcCodec) read(obj interface{}) (err error) {
+	//If nil is passed in, we should still attempt to read content to nowhere.
+	if obj == nil {
+		var obj2 interface{}
+		return c.dec.Decode(&obj2)
 	}
-	return
+	return c.dec.Decode(obj)
 }
 
 func (c rpcCodec) Close() error {
 	return c.rwc.Close()
 }
 
-func (c rpcCodec) ReadResponseBody(body interface{}) (err error) {
-	err = c.read(body)
-	return
+func (c rpcCodec) ReadResponseBody(body interface{}) error {
+	return c.read(body)
 }
 
 func (c rpcCodec) ReadRequestBody(body interface{}) error {
@@ -93,16 +100,15 @@ func (c rpcCodec) ReadRequestBody(body interface{}) error {
 
 // /////////////// Go RPC Codec ///////////////////
 func (c goRpcCodec) WriteRequest(r *rpc.Request, body interface{}) error {
-	return c.write(r, body)
+	return c.write(r, body, true, true)
 }
 
 func (c goRpcCodec) WriteResponse(r *rpc.Response, body interface{}) error {
-	return c.write(r, body)
+	return c.write(r, body, true, true)
 }
 
-func (c goRpcCodec) ReadResponseHeader(r *rpc.Response) (err error) {
-	err = c.read(r)
-	return
+func (c goRpcCodec) ReadResponseHeader(r *rpc.Response) error {
+	return c.read(r)
 }
 
 func (c goRpcCodec) ReadRequestHeader(r *rpc.Request) error {

+ 6 - 7
codec/time.go

@@ -8,8 +8,7 @@ import (
 )
 
 var (
-	timeBs0xff = []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
-	digits = [...]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
+	timeDigits = [...]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
 )
 
 // encodeTime encodes a time.Time as a []byte, including
@@ -79,7 +78,7 @@ func decodeTime(bs []byte) (tt time.Time, err error) {
 		copy(btmp[8-n:], bs[i:i2])
 		//if first bit of bs[i] is set, then fill btmp[0..8-n] with 0xff (ie sign extend it)
 		if bs[i] & (1 << 7) != 0 {
-			copy(btmp[0:8-n], timeBs0xff)
+			copy(btmp[0:8-n], bsAll0xff)
 			//for j,k := byte(0), 8-n; j < k; j++ {	btmp[j] = 0xff }
 		}
 		i = i2
@@ -122,10 +121,10 @@ func decodeTime(bs []byte) (tt time.Time, err error) {
 	} else {
 		tzhr, tzmin = tzint/60, tzint%60
 	}
-	tzname[4] = digits[tzhr/10]
-	tzname[5] = digits[tzhr%10]
-	tzname[7] = digits[tzmin/10]
-	tzname[8] = digits[tzmin%10]
+	tzname[4] = timeDigits[tzhr/10]
+	tzname[5] = timeDigits[tzhr%10]
+	tzname[7] = timeDigits[tzmin/10]
+	tzname[8] = timeDigits[tzmin%10]
 
 	tt = time.Unix(tsec, int64(tnsec)).In(time.FixedZone(string(tzname), int(tzint)*60))
 	return