ソースを参照

codec: encWriterSwitch.end() panic's, and should not be called AS-IS during a final "recover"

When a Writer throws errors during Write, and a bufio writer consequently
throws an error during flush, we MUST handle this appropriately during
a final deferred function which recovers from panics. We cannot have that
function itself throw a panic, as the panic then propagates to the callee.

Fix this by making flush/flushErr and end/endErr pairs, and call the
appropriate one i.e. mustEncode, etc calls flush/end, while recover
block calls endErr.

Fixes #285
Ugorji Nwoke 6 年 前
コミット
c4a1c341dc
3 ファイル変更61 行追加14 行削除
  1. 30 0
      codec/codec_test.go
  2. 30 14
      codec/encode.go
  3. 1 0
      codec/z_all_test.go

+ 30 - 0
codec/codec_test.go

@@ -7,6 +7,7 @@ import (
 	"bufio"
 	"bytes"
 	"encoding/gob"
+	"errors"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -71,6 +72,14 @@ type testIntfMapT2 struct {
 
 func (x testIntfMapT2) GetIntfMapV() string { return x.IntfMapV }
 
+var testErrWriterErr = errors.New("testErrWriterErr")
+
+type testErrWriter struct{}
+
+func (x *testErrWriter) Write(p []byte) (int, error) {
+	return 0, testErrWriterErr
+}
+
 // ----
 
 type testVerifyFlag uint8
@@ -1419,6 +1428,22 @@ func doTestAnonCycle(t *testing.T, name string, h Handle) {
 	logT(t, "pti: %v", pti)
 }
 
+func doTestErrWriter(t *testing.T, name string, h Handle) {
+	var ew testErrWriter
+	w := bufio.NewWriterSize(&ew, 4)
+	enc := NewEncoder(w, h)
+	for i := 0; i < 4; i++ {
+		err := enc.Encode("ugorji")
+		if ev, ok := err.(encodeError); ok {
+			err = ev.Cause()
+		}
+		if err != testErrWriterErr {
+			logT(t, "%s: expecting err: %v, received: %v", name, testErrWriterErr, err)
+			failT(t)
+		}
+	}
+}
+
 func doTestJsonLargeInteger(t *testing.T, v interface{}, ias uint8) {
 	testOnce.Do(testInitAll)
 	logT(t, "Running doTestJsonLargeInteger: v: %#v, ias: %c", v, ias)
@@ -2902,6 +2927,11 @@ func TestAllAnonCycle(t *testing.T) {
 	doTestAnonCycle(t, "cbor", testCborH)
 }
 
+func TestAllErrWriter(t *testing.T) {
+	doTestErrWriter(t, "cbor", testCborH)
+	doTestErrWriter(t, "json", testJsonH)
+}
+
 // ----- RPC -----
 
 func TestBincRpcGo(t *testing.T) {

+ 30 - 14
codec/encode.go

@@ -318,16 +318,20 @@ func (z *bufioEncWriter) release() {
 }
 
 //go:noinline - flush only called intermittently
-func (z *bufioEncWriter) flush() {
+func (z *bufioEncWriter) flushErr() (err error) {
 	n, err := z.w.Write(z.buf[:z.n])
 	z.n -= n
 	if z.n > 0 && err == nil {
 		err = io.ErrShortWrite
 	}
-	if err != nil {
-		if n > 0 && z.n > 0 {
-			copy(z.buf, z.buf[n:z.n+n])
-		}
+	if n > 0 && z.n > 0 {
+		copy(z.buf, z.buf[n:z.n+n])
+	}
+	return err
+}
+
+func (z *bufioEncWriter) flush() {
+	if err := z.flushErr(); err != nil {
 		panic(err)
 	}
 }
@@ -374,10 +378,11 @@ func (z *bufioEncWriter) writen2(b1, b2 byte) {
 	z.n += 2
 }
 
-func (z *bufioEncWriter) end() {
+func (z *bufioEncWriter) endErr() (err error) {
 	if z.n > 0 {
-		z.flush()
+		err = z.flushErr()
 	}
+	return
 }
 
 // ---------------------------------------------
@@ -400,8 +405,9 @@ func (z *bytesEncAppender) writen1(b1 byte) {
 func (z *bytesEncAppender) writen2(b1, b2 byte) {
 	z.b = append(z.b, b1, b2)
 }
-func (z *bytesEncAppender) end() {
+func (z *bytesEncAppender) endErr() error {
 	*(z.out) = z.b
+	return nil
 }
 func (z *bytesEncAppender) reset(in []byte, out *[]byte) {
 	z.b = in[:0]
@@ -1119,11 +1125,16 @@ func (z *encWriterSwitch) writen2(b1, b2 byte) {
 		z.wf.writen2(b1, b2)
 	}
 }
-func (z *encWriterSwitch) end() {
+func (z *encWriterSwitch) endErr() error {
 	if z.bytes {
-		z.wb.end()
-	} else {
-		z.wf.end()
+		return z.wb.endErr()
+	}
+	return z.wf.endErr()
+}
+
+func (z *encWriterSwitch) end() {
+	if err := z.endErr(); err != nil {
+		panic(err)
 	}
 }
 
@@ -1456,8 +1467,13 @@ func (e *Encoder) Encode(v interface{}) (err error) {
 	}
 	if recoverPanicToErr {
 		defer func() {
-			e.w.end()
-			if x := recover(); x != nil {
+			// if error occurred during encoding, return that error;
+			// else if error occurred on end'ing (i.e. during flush), return that error.
+			err = e.w.endErr()
+			x := recover()
+			if x == nil {
+				e.err = err
+			} else {
 				panicValToErr(e, x, &e.err)
 				err = e.err
 			}

+ 1 - 0
codec/z_all_test.go

@@ -305,6 +305,7 @@ func testNonHandlesGroup(t *testing.T) {
 	t.Run("TestAllEncCircularRef", TestAllEncCircularRef)
 	t.Run("TestAllAnonCycle", TestAllAnonCycle)
 	t.Run("TestMultipleEncDec", TestMultipleEncDec)
+	t.Run("TestAllErrWriter", TestAllErrWriter)
 }
 
 func TestCodecSuite(t *testing.T) {