Browse Source

codec: Support RawExt and RpcCodecBuffered.

RawExt works like encoding/json.RawMessage.
It represents raw unprocessed extension data.
Users can have a RawExt in their struct, and the codec
will just put the extension data (Tag byte and []byte data)
as-is for potential further processing later.

RawExt is also returned if decoding into a nil interface{} and
no extension function has been registered.

RpcCodecBuffered interface allows access to the Buffered Reader
and Writer used internally by the rpc codec. This accomodates
use-cases where the connection should be used by rpc and non-rpc functions,
e.g. streaming a file after sending an rpc response.

Both GoRpc and MsgpackSpecRpc return codecs that implement RpcCodecBuffered.

In addition, some minor re-org and fixes were done:
- decoding numbers into a nil interface{} for Binc always uses 64-bit values
  (int64, uint64)

Updates issue #16 .
Ugorji Nwoke 12 years ago
parent
commit
4f88b73ce9
6 changed files with 171 additions and 109 deletions
  1. 11 8
      codec/binc.go
  2. 37 25
      codec/decode.go
  3. 31 16
      codec/encode.go
  4. 8 3
      codec/helper.go
  5. 47 43
      codec/msgpack.go
  6. 37 14
      codec/rpc.go

+ 11 - 8
codec/binc.go

@@ -733,11 +733,12 @@ func (d *bincDecDriver) decodeBytes(bs []byte) (bsOut []byte, changed bool) {
 	return
 	return
 }
 }
 
 
-func (d *bincDecDriver) decodeExt(tag byte) (xbs []byte) {
+func (d *bincDecDriver) decodeExt(verifyTag bool, tag byte) (xtag byte, xbs []byte) {
 	switch d.vd {
 	switch d.vd {
 	case bincVdCustomExt:
 	case bincVdCustomExt:
 		l := d.decLen()
 		l := d.decLen()
-		if xtag := d.r.readn1(); xtag != tag {
+		xtag = d.r.readn1()
+		if verifyTag && xtag != tag {
 			decErr("Wrong extension tag. Got %b. Expecting: %v", xtag, tag)
 			decErr("Wrong extension tag. Got %b. Expecting: %v", xtag, tag)
 		}
 		}
 		xbs = d.r.readn(l)
 		xbs = d.r.readn(l)
@@ -773,14 +774,14 @@ func (d *bincDecDriver) decodeNaked() (rv reflect.Value, ctx decodeNakedContext)
 		case bincSpZeroFloat:
 		case bincSpZeroFloat:
 			v = float64(0)
 			v = float64(0)
 		case bincSpZero:
 		case bincSpZero:
-			v = int8(0)
+			v = int64(0) // int8(0)
 		case bincSpNegOne:
 		case bincSpNegOne:
-			v = int8(-1)
+			v = int64(-1) // int8(-1)
 		default:
 		default:
 			decErr("decodeNaked: Unrecognized special value 0x%x", d.vs)
 			decErr("decodeNaked: Unrecognized special value 0x%x", d.vs)
 		}
 		}
 	case bincVdSmallInt:
 	case bincVdSmallInt:
-		v = int8(d.vs) + 1
+		v = int64(int8(d.vs)) + 1 // int8(d.vs) + 1
 	case bincVdUint:
 	case bincVdUint:
 		v = d.decUint()
 		v = d.decUint()
 	case bincVdInt:
 	case bincVdInt:
@@ -803,12 +804,14 @@ func (d *bincDecDriver) decodeNaked() (rv reflect.Value, ctx decodeNakedContext)
 		//ctx = dncExt
 		//ctx = dncExt
 		l := d.decLen()
 		l := d.decLen()
 		xtag := d.r.readn1()
 		xtag := d.r.readn1()
+		xbs := d.r.readn(l)
 		var bfn func(reflect.Value, []byte) error
 		var bfn func(reflect.Value, []byte) error
 		rv, bfn = d.h.getDecodeExtForTag(xtag)
 		rv, bfn = d.h.getDecodeExtForTag(xtag)
 		if bfn == nil {
 		if bfn == nil {
-			decErr("decodeNaked: Unable to find type mapped to extension tag: %v", xtag)
-		}
-		if fnerr := bfn(rv, d.r.readn(l)); fnerr != nil {
+			// decErr("decodeNaked: Unable to find type mapped to extension tag: %v", xtag)
+			re := RawExt { xtag, xbs }
+			rv = reflect.ValueOf(&re).Elem()
+		} else if fnerr := bfn(rv, xbs); fnerr != nil {
 			panic(fnerr)
 			panic(fnerr)
 		}
 		}
 	case bincVdArray:
 	case bincVdArray:

+ 37 - 25
codec/decode.go

@@ -70,7 +70,7 @@ type decDriver interface {
 	// decodeString can also decode symbols
 	// decodeString can also decode symbols
 	decodeString() (s string)
 	decodeString() (s string)
 	decodeBytes(bs []byte) (bsOut []byte, changed bool)
 	decodeBytes(bs []byte) (bsOut []byte, changed bool)
-	decodeExt(tag byte) []byte
+	decodeExt(verifyTag bool, tag byte) (xtag byte, xbs []byte)
 	readMapLen() int
 	readMapLen() int
 	readArrayLen() int
 	readArrayLen() int
 }
 }
@@ -101,22 +101,27 @@ type Decoder struct {
 }
 }
 
 
 func (f *decFnInfo) builtin(rv reflect.Value) {
 func (f *decFnInfo) builtin(rv reflect.Value) {
-	baseRv := rv
-	baseIndir := f.sis.baseIndir
-	for j := int8(0); j < baseIndir; j++ {
-		baseRv = baseRv.Elem()
+	for j, k := int8(0), f.sis.baseIndir; j < k; j++ {
+		rv = rv.Elem()
 	}
 	}
-	f.dd.decodeBuiltinType(f.sis.baseId, baseRv)
+	f.dd.decodeBuiltinType(f.sis.baseId, rv)
+}
+
+func (f *decFnInfo) rawExt(rv reflect.Value) {
+	xtag, xbs := f.dd.decodeExt(false, 0)
+	for j, k := int8(0), f.sis.baseIndir; j < k; j++ {
+		rv = rv.Elem()
+	}
+	rv.Field(0).SetUint(uint64(xtag))
+	rv.Field(1).SetBytes(xbs)
 }
 }
 
 
 func (f *decFnInfo) ext(rv reflect.Value) {
 func (f *decFnInfo) ext(rv reflect.Value) {
-	xbs := f.dd.decodeExt(f.xfTag)
-	baseRv := rv
-	baseIndir := f.sis.baseIndir
-	for j := int8(0); j < baseIndir; j++ {
-		baseRv = baseRv.Elem()
+	_, xbs := f.dd.decodeExt(true, f.xfTag)
+	for j, k := int8(0), f.sis.baseIndir; j < k; j++ {
+		rv = rv.Elem()
 	}
 	}
-	if fnerr := f.xfFn(baseRv, xbs); fnerr != nil {
+	if fnerr := f.xfFn(rv, xbs); fnerr != nil {
 		panic(fnerr)
 		panic(fnerr)
 	}
 	}
 }
 }
@@ -128,12 +133,10 @@ func (f *decFnInfo) binaryMarshal(rv reflect.Value) {
 	} else if f.sis.unmIndir == 0 {
 	} else if f.sis.unmIndir == 0 {
 		bm = rv.Interface().(binaryUnmarshaler)
 		bm = rv.Interface().(binaryUnmarshaler)
 	} else {
 	} else {
-		rv2 := rv
-		unmIndir := f.sis.unmIndir
-		for j := int8(0); j < unmIndir; j++ {
-			rv2 = rv.Elem()
+		for j, k := int8(0), f.sis.unmIndir; j < k; j++ {
+			rv = rv.Elem()
 		}
 		}
-		bm = rv2.Interface().(binaryUnmarshaler)
+		bm = rv.Interface().(binaryUnmarshaler)
 	}
 	}
 	xbs, _ := f.dd.decodeBytes(nil)
 	xbs, _ := f.dd.decodeBytes(nil)
 	if fnerr := bm.UnmarshalBinary(xbs); fnerr != nil {
 	if fnerr := bm.UnmarshalBinary(xbs); fnerr != nil {
@@ -235,7 +238,8 @@ func (f *decFnInfo) kStruct(rv reflect.Value) {
 				// f.d.decodeValue(sis.field(k, rv))
 				// f.d.decodeValue(sis.field(k, rv))
 			} else {
 			} else {
 				if f.d.h.errorIfNoField() {
 				if f.d.h.errorIfNoField() {
-					decErr("No matching struct field found when decoding stream map with key: %v", rvkencname)
+					decErr("No matching struct field found when decoding stream map with key: %v", 
+						rvkencname)
 				} else {
 				} else {
 					var nilintf0 interface{}
 					var nilintf0 interface{}
 					f.d.decodeValue(reflect.ValueOf(&nilintf0).Elem())
 					f.d.decodeValue(reflect.ValueOf(&nilintf0).Elem())
@@ -265,7 +269,8 @@ func (f *decFnInfo) kStruct(rv reflect.Value) {
 			}
 			}
 		}
 		}
 	} else {
 	} else {
-		decErr("Only encoded map or array can be decoded into a struct. (decodeEncodedType: %x)", currEncodedType)
+		decErr("Only encoded map or array can be decoded into a struct. (decodeEncodedType: %x)", 
+			currEncodedType)
 	}
 	}
 }
 }
 
 
@@ -300,7 +305,8 @@ func (f *decFnInfo) kSlice(rv reflect.Value) {
 			}
 			}
 			rv.Set(rvn)
 			rv.Set(rvn)
 		} else {
 		} else {
-			decErr("Cannot reset slice with less cap: %v than stream contents: %v", rvcap, containerLen)
+			decErr("Cannot reset slice with less cap: %v than stream contents: %v", 
+				rvcap, containerLen)
 		}
 		}
 	} else if containerLen > rvlen {
 	} else if containerLen > rvlen {
 		rv.SetLen(containerLen)
 		rv.SetLen(containerLen)
@@ -422,9 +428,13 @@ func NewDecoderBytes(in []byte, h Handle) *Decoder {
 //   err = dec.Decode(&v)
 //   err = dec.Decode(&v)
 // 
 // 
 // When decoding into a nil interface{}, we will decode into an appropriate value based
 // When decoding into a nil interface{}, we will decode into an appropriate value based
-// on the contents of the stream. Numbers are decoded as float64, int64 or uint64. Other values
-// are decoded appropriately (e.g. bool), and configurations exist on the Handle to override
-// defaults (e.g. for MapType, SliceType and how to decode raw bytes).
+// on the contents of the stream:
+//   - Numbers are decoded as float64, int64 or uint64. 
+//   - Other values are decoded appropriately depending on the encoding: 
+//     bool, string, []byte, time.Time, etc
+//   - Extensions are decoded as RawExt (if no ext function registered for the tag)
+// Configurations exist on the Handle to override defaults 
+// (e.g. for MapType, SliceType and how to decode raw bytes).
 // 
 // 
 // When decoding into a non-nil interface{} value, the mode of encoding is based on the 
 // When decoding into a non-nil interface{} value, the mode of encoding is based on the 
 // type of the value. When a value is seen:
 // type of the value. When a value is seen:
@@ -555,7 +565,9 @@ func (d *Decoder) decodeValue(rv reflect.Value) {
 		//
 		//
 		// If we are checking for builtin or ext type here, it means we didn't go through decodeNaked,
 		// If we are checking for builtin or ext type here, it means we didn't go through decodeNaked,
 		// Because decodeNaked would have handled it. It also means wasNilIntf = false.
 		// Because decodeNaked would have handled it. It also means wasNilIntf = false.
-		if d.d.isBuiltinType(fi.sis.baseId) {
+		if fi.sis.baseId == rawExtTypId {
+			fn = decFn { &fi, (*decFnInfo).rawExt }
+		} else if d.d.isBuiltinType(fi.sis.baseId) {
 			fn = decFn { &fi, (*decFnInfo).builtin }
 			fn = decFn { &fi, (*decFnInfo).builtin }
 		} else if xfTag, xfFn := d.h.getDecodeExt(fi.sis.baseId); xfFn != nil {
 		} else if xfTag, xfFn := d.h.getDecodeExt(fi.sis.baseId); xfFn != nil {
 			fi.xfTag, fi.xfFn = xfTag, xfFn
 			fi.xfTag, fi.xfFn = xfTag, xfFn
@@ -685,7 +697,7 @@ func (z *bytesDecReader) consume(n int) (oldcursor int) {
 		panic(io.EOF)
 		panic(io.EOF)
 	}
 	}
 	if n > z.a {
 	if n > z.a {
-		doPanic(msgTagDec, "Trying to read %v bytes. Only %v available", n, z.a)
+		decErr("Trying to read %v bytes. Only %v available", n, z.a)
 	}
 	}
 	// z.checkAvailable(n)
 	// z.checkAvailable(n)
 	oldcursor = z.c
 	oldcursor = z.c

+ 31 - 16
codec/encode.go

@@ -145,19 +145,34 @@ func (o *EncodeOptions) structToArray() bool {
 }
 }
 
 
 func (f *encFnInfo) builtin(rv reflect.Value) {
 func (f *encFnInfo) builtin(rv reflect.Value) {
-	baseRv := rv
-	for j := int8(0); j < f.sis.baseIndir; j++ {
-		baseRv = baseRv.Elem()
+	for j, k := int8(0), f.sis.baseIndir; j < k; j++ {
+		rv = rv.Elem()
+	}
+	f.ee.encodeBuiltinType(f.sis.baseId, rv)
+}
+
+func (f *encFnInfo) rawExt(rv reflect.Value) {
+	for j, k := int8(0), f.sis.baseIndir; j < k; j++ {
+		rv = rv.Elem()
+	}
+	re := rv.Interface().(RawExt)
+	if re.Data == nil {
+		f.ee.encodeNil()
+		return
+	}
+	if f.e.h.writeExt() {
+		f.ee.encodeExtPreamble(re.Tag, len(re.Data))
+		f.e.w.writeb(re.Data)
+	} else {
+		f.ee.encodeStringBytes(c_RAW, re.Data)
 	}
 	}
-	f.ee.encodeBuiltinType(f.sis.baseId, baseRv)
 }
 }
 
 
 func (f *encFnInfo) ext(rv reflect.Value) {
 func (f *encFnInfo) ext(rv reflect.Value) {
-	baseRv := rv
-	for j := int8(0); j < f.sis.baseIndir; j++ {
-		baseRv = baseRv.Elem()
+	for j, k := int8(0), f.sis.baseIndir; j < k; j++ {
+		rv = rv.Elem()
 	}
 	}
-	bs, fnerr := f.xfFn(baseRv)
+	bs, fnerr := f.xfFn(rv)
 	if fnerr != nil {
 	if fnerr != nil {
 		panic(fnerr)
 		panic(fnerr)
 	}
 	}
@@ -181,11 +196,10 @@ func (f *encFnInfo) binaryMarshal(rv reflect.Value) {
 	} else if f.sis.mIndir == -1 {
 	} else if f.sis.mIndir == -1 {
 		bm = rv.Addr().Interface().(binaryMarshaler)
 		bm = rv.Addr().Interface().(binaryMarshaler)
 	} else {
 	} else {
-		rv2 := rv
-		for j := int8(0); j < f.sis.mIndir; j++ {
-			rv2 = rv.Elem()
+		for j, k := int8(0), f.sis.mIndir; j < k; j++ {
+			rv = rv.Elem()
 		}
 		}
-		bm = rv2.Interface().(binaryMarshaler)
+		bm = rv.Interface().(binaryMarshaler)
 	}
 	}
 	// debugf(">>>> binaryMarshaler: %T", rv.Interface())
 	// debugf(">>>> binaryMarshaler: %T", rv.Interface())
 	bs, fnerr := bm.MarshalBinary()
 	bs, fnerr := bm.MarshalBinary()
@@ -197,7 +211,6 @@ func (f *encFnInfo) binaryMarshal(rv reflect.Value) {
 	} else {
 	} else {
 		f.ee.encodeStringBytes(c_RAW, bs)
 		f.ee.encodeStringBytes(c_RAW, bs)
 	}
 	}
-
 }
 }
 
 
 func (f *encFnInfo) kBool(rv reflect.Value) {
 func (f *encFnInfo) kBool(rv reflect.Value) {
@@ -523,7 +536,9 @@ func (e *Encoder) encodeValue(rv reflect.Value) {
 	if !ok {
 	if !ok {
 		// debugf("\tCreating new enc fn for type: %v\n", rt)
 		// debugf("\tCreating new enc fn for type: %v\n", rt)
 		fi := encFnInfo { sis:getTypeInfo(rtid, rt), e:e, ee:e.e, rt:rt, rtid:rtid }
 		fi := encFnInfo { sis:getTypeInfo(rtid, rt), e:e, ee:e.e, rt:rt, rtid:rtid }
-		if e.e.isBuiltinType(fi.sis.baseId) {
+		if fi.sis.baseId == rawExtTypId {
+			fn = encFn{ &fi, (*encFnInfo).rawExt }
+		} else if e.e.isBuiltinType(fi.sis.baseId) {
 			fn = encFn{ &fi, (*encFnInfo).builtin }
 			fn = encFn{ &fi, (*encFnInfo).builtin }
 		} else if xfTag, xfFn := e.h.getEncodeExt(fi.sis.baseId); xfFn != nil {
 		} else if xfTag, xfFn := e.h.getEncodeExt(fi.sis.baseId); xfFn != nil {
 			fi.xfTag, fi.xfFn = xfTag, xfFn
 			fi.xfTag, fi.xfFn = xfTag, xfFn
@@ -592,7 +607,7 @@ func (z *ioEncWriter) writeb(bs []byte) {
 		panic(err)
 		panic(err)
 	}
 	}
 	if n != len(bs) {
 	if n != len(bs) {
-		doPanic(msgTagEnc, "write: Incorrect num bytes written. Expecting: %v, Wrote: %v", len(bs), n)
+		encErr("write: Incorrect num bytes written. Expecting: %v, Wrote: %v", len(bs), n)
 	}
 	}
 }
 }
 
 
@@ -602,7 +617,7 @@ func (z *ioEncWriter) writestr(s string) {
 		panic(err)
 		panic(err)
 	}
 	}
 	if n != len(s) {
 	if n != len(s) {
-		doPanic(msgTagEnc, "write: Incorrect num bytes written. Expecting: %v, Wrote: %v", len(s), n)
+		encErr("write: Incorrect num bytes written. Expecting: %v, Wrote: %v", len(s), n)
 	}
 	}
 }
 }
 
 

+ 8 - 3
codec/helper.go

@@ -73,17 +73,16 @@ var (
 	intfSliceTyp     = reflect.TypeOf(nilIntfSlice)
 	intfSliceTyp     = reflect.TypeOf(nilIntfSlice)
 	intfTyp          = intfSliceTyp.Elem()
 	intfTyp          = intfSliceTyp.Elem()
 	byteSliceTyp     = reflect.TypeOf([]byte(nil))
 	byteSliceTyp     = reflect.TypeOf([]byte(nil))
-	ptrByteSliceTyp  = reflect.TypeOf((*[]byte)(nil))
 	mapStringIntfTyp = reflect.TypeOf(map[string]interface{}(nil))
 	mapStringIntfTyp = reflect.TypeOf(map[string]interface{}(nil))
 	mapIntfIntfTyp   = reflect.TypeOf(map[interface{}]interface{}(nil))
 	mapIntfIntfTyp   = reflect.TypeOf(map[interface{}]interface{}(nil))
 	
 	
 	timeTyp          = reflect.TypeOf(time.Time{})
 	timeTyp          = reflect.TypeOf(time.Time{})
-	ptrTimeTyp       = reflect.TypeOf((*time.Time)(nil))
 	int64SliceTyp    = reflect.TypeOf([]int64(nil))
 	int64SliceTyp    = reflect.TypeOf([]int64(nil))
+	rawExtTyp        = reflect.TypeOf(RawExt{})
 	
 	
 	timeTypId        = reflect.ValueOf(timeTyp).Pointer()
 	timeTypId        = reflect.ValueOf(timeTyp).Pointer()
-	ptrTimeTypId     = reflect.ValueOf(ptrTimeTyp).Pointer()
 	byteSliceTypId   = reflect.ValueOf(byteSliceTyp).Pointer()
 	byteSliceTypId   = reflect.ValueOf(byteSliceTyp).Pointer()
+	rawExtTypId      = reflect.ValueOf(rawExtTyp).Pointer()
 	
 	
 	binaryMarshalerTyp = reflect.TypeOf((*binaryMarshaler)(nil)).Elem()
 	binaryMarshalerTyp = reflect.TypeOf((*binaryMarshaler)(nil)).Elem()
 	binaryUnmarshalerTyp = reflect.TypeOf((*binaryUnmarshaler)(nil)).Elem()
 	binaryUnmarshalerTyp = reflect.TypeOf((*binaryUnmarshaler)(nil)).Elem()
@@ -98,6 +97,12 @@ var (
 	bsAll0xff = []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
 	bsAll0xff = []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
 )
 )
 
 
+// The RawExt type represents raw unprocessed extension data. 
+type RawExt struct {
+	Tag byte
+	Data []byte
+}
+
 // Handle is the interface for a specific encoding format.
 // Handle is the interface for a specific encoding format.
 // 
 // 
 // Typically, a Handle is pre-configured before first time use,
 // Typically, a Handle is pre-configured before first time use,

+ 47 - 43
codec/msgpack.go

@@ -65,7 +65,8 @@ const (
 )
 )
 
 
 // MsgpackSpecRpc implements Rpc using the communication protocol defined in
 // MsgpackSpecRpc implements Rpc using the communication protocol defined in
-// the msgpack spec at http://wiki.msgpack.org/display/MSGPACK/RPC+specification
+// the msgpack spec at http://wiki.msgpack.org/display/MSGPACK/RPC+specification .
+// It's methods (ServerCodec and ClientCodec) return values that implement RpcCodecBuffered.
 var MsgpackSpecRpc msgpackSpecRpc
 var MsgpackSpecRpc msgpackSpecRpc
 
 
 // MsgpackSpecRpcMultiArgs is a special type which signifies to the MsgpackSpecRpcCodec
 // MsgpackSpecRpcMultiArgs is a special type which signifies to the MsgpackSpecRpcCodec
@@ -371,12 +372,14 @@ func (d *msgpackDecDriver) decodeNaked() (rv reflect.Value, ctx decodeNakedConte
 			//ctx = dncExt
 			//ctx = dncExt
 			clen := d.readExtLen()
 			clen := d.readExtLen()
 			xtag := d.r.readn1()
 			xtag := d.r.readn1()
+			xbs := d.r.readn(clen)
 			var bfn func(reflect.Value, []byte) error
 			var bfn func(reflect.Value, []byte) error
 			rv, bfn = d.h.getDecodeExtForTag(xtag)
 			rv, bfn = d.h.getDecodeExtForTag(xtag)
 			if bfn == nil {
 			if bfn == nil {
-				decErr("Unable to find type mapped to extension tag: %v", xtag)
-			}
-			if fnerr := bfn(rv, d.r.readn(clen)); fnerr != nil {
+				// decErr("Unable to find type mapped to extension tag: %v", xtag)
+				re := RawExt { xtag, xbs }
+				rv = reflect.ValueOf(&re).Elem()
+			} else if fnerr := bfn(rv, xbs); fnerr != nil {
 				panic(fnerr)
 				panic(fnerr)
 			}
 			}
 		default:
 		default:
@@ -680,7 +683,7 @@ func (d *msgpackDecDriver) readExtLen() (clen int) {
 	return
 	return
 }
 }
 
 
-func (d *msgpackDecDriver) decodeExt(tag byte) (xbs []byte) {
+func (d *msgpackDecDriver) decodeExt(verifyTag bool, tag byte) (xtag byte, xbs []byte) {
 	xbd := d.bd
 	xbd := d.bd
 	switch {
 	switch {
 	case xbd == mpBin8, xbd == mpBin16, xbd == mpBin32: 
 	case xbd == mpBin8, xbd == mpBin16, xbd == mpBin32: 
@@ -690,7 +693,8 @@ func (d *msgpackDecDriver) decodeExt(tag byte) (xbs []byte) {
 		xbs = []byte(d.decodeString())
 		xbs = []byte(d.decodeString())
 	default:
 	default:
 		clen := d.readExtLen()
 		clen := d.readExtLen()
-		if xtag := d.r.readn1(); xtag != tag {
+		xtag = d.r.readn1()
+		if verifyTag && xtag != tag {
 			decErr("Wrong extension tag. Got %b. Expecting: %v", xtag, tag)
 			decErr("Wrong extension tag. Got %b. Expecting: %v", xtag, tag)
 		}
 		}
 		xbs = d.r.readn(clen)
 		xbs = d.r.readn(clen)
@@ -701,6 +705,43 @@ func (d *msgpackDecDriver) decodeExt(tag byte) (xbs []byte) {
 
 
 //--------------------------------------------------
 //--------------------------------------------------
 
 
+// TimeEncodeExt encodes a time.Time as a byte slice.
+// Configure this to support the Time Extension, e.g. using tag 1.
+func (_ *MsgpackHandle) TimeEncodeExt(rv reflect.Value) (bs []byte, err error) {
+	rvi := rv.Interface()
+	switch iv := rvi.(type) {
+	case time.Time:
+		bs = encodeTime(iv)
+	default:
+		err = fmt.Errorf("codec/msgpack: TimeEncodeExt expects a time.Time. Received %T", rvi)
+	}
+	return
+}
+
+// TimeDecodeExt decodes a time.Time from the byte slice parameter, and sets it into the reflect value.
+// Configure this to support the Time Extension, e.g. using tag 1.
+func (_ *MsgpackHandle) TimeDecodeExt(rv reflect.Value, bs []byte) (err error) {
+	tt, err := decodeTime(bs)
+	if err == nil {
+		rv.Set(reflect.ValueOf(tt))
+	}
+	return
+}
+
+func (h *MsgpackHandle) newEncDriver(w encWriter) encDriver {
+	return &msgpackEncDriver{w: w, h: h}
+}
+
+func (h *MsgpackHandle) newDecDriver(r decReader) decDriver {
+	return &msgpackDecDriver{r: r, h: h}
+}
+
+func (h *MsgpackHandle) writeExt() bool {
+	return h.WriteExt
+}
+
+//--------------------------------------------------
+
 func (x msgpackSpecRpc) ServerCodec(conn io.ReadWriteCloser, h Handle) rpc.ServerCodec {
 func (x msgpackSpecRpc) ServerCodec(conn io.ReadWriteCloser, h Handle) rpc.ServerCodec {
 	return &msgpackSpecRpcCodec{newRPCCodec(conn, h)}
 	return &msgpackSpecRpcCodec{newRPCCodec(conn, h)}
 }
 }
@@ -788,40 +829,3 @@ func (c *msgpackSpecRpcCodec) parseCustomHeader(expectTypeByte byte, msgid *uint
 	return
 	return
 }
 }
 
 
-//--------------------------------------------------
-
-// TimeEncodeExt encodes a time.Time as a byte slice.
-// Configure this to support the Time Extension, e.g. using tag 1.
-func (_ *MsgpackHandle) TimeEncodeExt(rv reflect.Value) (bs []byte, err error) {
-	rvi := rv.Interface()
-	switch iv := rvi.(type) {
-	case time.Time:
-		bs = encodeTime(iv)
-	default:
-		err = fmt.Errorf("codec/msgpack: TimeEncodeExt expects a time.Time. Received %T", rvi)
-	}
-	return
-}
-
-// TimeDecodeExt decodes a time.Time from the byte slice parameter, and sets it into the reflect value.
-// Configure this to support the Time Extension, e.g. using tag 1.
-func (_ *MsgpackHandle) TimeDecodeExt(rv reflect.Value, bs []byte) (err error) {
-	tt, err := decodeTime(bs)
-	if err == nil {
-		rv.Set(reflect.ValueOf(tt))
-	}
-	return
-}
-
-func (h *MsgpackHandle) newEncDriver(w encWriter) encDriver {
-	return &msgpackEncDriver{w: w, h: h}
-}
-
-func (h *MsgpackHandle) newDecDriver(r decReader) decDriver {
-	return &msgpackDecDriver{r: r, h: h}
-}
-
-func (h *MsgpackHandle) writeExt() bool {
-	return h.WriteExt
-}
-

+ 37 - 14
codec/rpc.go

@@ -16,19 +16,31 @@ import (
 )
 )
 
 
 // GoRpc implements Rpc using the communication protocol defined in net/rpc package.
 // GoRpc implements Rpc using the communication protocol defined in net/rpc package.
+// It's methods (ServerCodec and ClientCodec) return values that implement RpcCodecBuffered.
 var GoRpc goRpc
 var GoRpc goRpc
 
 
-// Rpc interface provides a rpc Server or Client Codec for rpc communication.
+// Rpc provides a rpc Server or Client Codec for rpc communication.
 type Rpc interface {
 type Rpc interface {
 	ServerCodec(conn io.ReadWriteCloser, h Handle) rpc.ServerCodec
 	ServerCodec(conn io.ReadWriteCloser, h Handle) rpc.ServerCodec
 	ClientCodec(conn io.ReadWriteCloser, h Handle) rpc.ClientCodec
 	ClientCodec(conn io.ReadWriteCloser, h Handle) rpc.ClientCodec
 }
 }
 
 
+// RpcCodecBuffered allows access to the underlying bufio.Reader/Writer
+// used by the rpc connection. It accomodates use-cases where the connection
+// should be used by rpc and non-rpc functions, e.g. streaming a file after
+// sending an rpc response.
+type RpcCodecBuffered interface {
+	BufferedReader() *bufio.Reader
+	BufferedWriter() *bufio.Writer
+}
+
+// rpcCodec defines the struct members and common methods.
 type rpcCodec struct {
 type rpcCodec struct {
 	rwc io.ReadWriteCloser
 	rwc io.ReadWriteCloser
 	dec *Decoder
 	dec *Decoder
 	enc *Encoder
 	enc *Encoder
-	encbuf *bufio.Writer 
+	bw *bufio.Writer
+	br *bufio.Reader
 }
 }
 
 
 type goRpcCodec struct {
 type goRpcCodec struct {
@@ -48,18 +60,26 @@ func (x goRpc) ClientCodec(conn io.ReadWriteCloser, h Handle) rpc.ClientCodec {
 }
 }
 
 
 func newRPCCodec(conn io.ReadWriteCloser, h Handle) rpcCodec {
 func newRPCCodec(conn io.ReadWriteCloser, h Handle) rpcCodec {
-	encbuf := bufio.NewWriter(conn)
+	bw := bufio.NewWriter(conn)
+	br := bufio.NewReader(conn)
 	return rpcCodec{
 	return rpcCodec{
 		rwc: conn,
 		rwc: conn,
-		encbuf: encbuf,
-		enc: NewEncoder(encbuf, h),
-		dec: NewDecoder(bufio.NewReader(conn), h),
-		//enc: NewEncoder(conn, h),
-		//dec: NewDecoder(conn, h),
+		bw: bw,
+		br: br,
+		enc: NewEncoder(bw, h),
+		dec: NewDecoder(br, h),
 	}
 	}
 }
 }
 
 
 // /////////////// RPC Codec Shared Methods ///////////////////
 // /////////////// RPC Codec Shared Methods ///////////////////
+func (c *rpcCodec) BufferedReader() *bufio.Reader {
+	return c.br
+}
+
+func (c *rpcCodec) BufferedWriter() *bufio.Writer {
+	return c.bw
+}
+
 func (c *rpcCodec) write(obj1, obj2 interface{}, writeObj2, doFlush bool) (err error) {
 func (c *rpcCodec) write(obj1, obj2 interface{}, writeObj2, doFlush bool) (err error) {
 	if err = c.enc.Encode(obj1); err != nil {
 	if err = c.enc.Encode(obj1); err != nil {
 		return
 		return
@@ -69,9 +89,9 @@ func (c *rpcCodec) write(obj1, obj2 interface{}, writeObj2, doFlush bool) (err e
 			return
 			return
 		}
 		}
 	}
 	}
-	if doFlush && c.encbuf != nil {
+	if doFlush && c.bw != nil {
 		//println("rpc flushing")
 		//println("rpc flushing")
-		return c.encbuf.Flush()
+		return c.bw.Flush()
 	}
 	}
 	return
 	return
 }
 }
@@ -94,10 +114,6 @@ func (c *rpcCodec) ReadResponseBody(body interface{}) error {
 	return c.read(body)
 	return c.read(body)
 }
 }
 
 
-func (c *rpcCodec) ReadRequestBody(body interface{}) error {
-	return c.read(body)
-}
-
 // /////////////// Go RPC Codec ///////////////////
 // /////////////// Go RPC Codec ///////////////////
 func (c *goRpcCodec) WriteRequest(r *rpc.Request, body interface{}) error {
 func (c *goRpcCodec) WriteRequest(r *rpc.Request, body interface{}) error {
 	return c.write(r, body, true, true)
 	return c.write(r, body, true, true)
@@ -114,3 +130,10 @@ func (c *goRpcCodec) ReadResponseHeader(r *rpc.Response) error {
 func (c *goRpcCodec) ReadRequestHeader(r *rpc.Request) error {
 func (c *goRpcCodec) ReadRequestHeader(r *rpc.Request) error {
 	return c.read(r)
 	return c.read(r)
 }
 }
+
+func (c *goRpcCodec) ReadRequestBody(body interface{}) error {
+	return c.read(body)
+}
+
+var _ RpcCodecBuffered = (*rpcCodec)(nil) // ensure *rpcCodec implements RpcCodecBuffered
+