فهرست منبع

Improve write error handling

- Do not fail NextWriter when close of previous writer fails.
- Replace closeSent field with mutex protected writeErr. Set writeErr on
  any error writing to underlying network connection. Check and return
  writeErr before attempting to write to network connection. Check
  writeErr in NextWriter so application can detect failed connection
  before attempting to write.
- Do not close underlying network connection on error.
- Move message writing state and method flushFrame from Conn to
  messageWriter. This makes error code paths (and the code in general)
  easier to understand.
- Add messageWriter field err to latch errors in messageWriter.

Bonus: Improve test coverage.
Gary Burd 9 سال پیش
والد
کامیت
80a0029a65
2فایلهای تغییر یافته به همراه195 افزوده شده و 117 حذف شده
  1. 123 103
      conn.go
  2. 72 14
      conn_test.go

+ 123 - 103
conn.go

@@ -12,6 +12,7 @@ import (
 	"io/ioutil"
 	"io/ioutil"
 	"net"
 	"net"
 	"strconv"
 	"strconv"
+	"sync"
 	"time"
 	"time"
 	"unicode/utf8"
 	"unicode/utf8"
 )
 )
@@ -223,19 +224,16 @@ type Conn struct {
 	subprotocol string
 	subprotocol string
 
 
 	// Write fields
 	// Write fields
-	mu             chan bool // used as mutex to protect write to conn and closeSent
-	closeSent      bool      // whether close message was sent
-	writeErr       error
-	writeBuf       []byte // frame is constructed in this buffer.
-	writePos       int    // end of data in writeBuf.
-	writeFrameType int    // type of the current frame.
-	writeDeadline  time.Time
-	messageWriter  *messageWriter // the current low-level message writer
-	writer         io.WriteCloser // the current writer returned to the application
-	isWriting      bool           // for best-effort concurrent write detection
+	mu            chan bool // used as mutex to protect write to conn
+	writeBuf      []byte    // frame is constructed in this buffer.
+	writeDeadline time.Time
+	writer        io.WriteCloser // the current writer returned to the application
+	isWriting     bool           // for best-effort concurrent write detection
+
+	writeErrMu sync.Mutex
+	writeErr   error
 
 
 	enableWriteCompression bool
 	enableWriteCompression bool
-	writeCompress          bool // whether next call to flushFrame should set RSV1
 	newCompressionWriter   func(io.WriteCloser) (io.WriteCloser, error)
 	newCompressionWriter   func(io.WriteCloser) (io.WriteCloser, error)
 
 
 	// Read fields
 	// Read fields
@@ -277,8 +275,6 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
 		mu:                     mu,
 		mu:                     mu,
 		readFinal:              true,
 		readFinal:              true,
 		writeBuf:               make([]byte, writeBufferSize+maxFrameHeaderSize),
 		writeBuf:               make([]byte, writeBufferSize+maxFrameHeaderSize),
-		writeFrameType:         noFrame,
-		writePos:               maxFrameHeaderSize,
 		enableWriteCompression: true,
 		enableWriteCompression: true,
 	}
 	}
 	c.SetPingHandler(nil)
 	c.SetPingHandler(nil)
@@ -308,29 +304,40 @@ func (c *Conn) RemoteAddr() net.Addr {
 
 
 // Write methods
 // Write methods
 
 
+func (c *Conn) writeFatal(err error) error {
+	err = hideTempErr(err)
+	c.writeErrMu.Lock()
+	if c.writeErr == nil {
+		c.writeErr = err
+	}
+	c.writeErrMu.Unlock()
+	return err
+}
+
 func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
 func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
 	<-c.mu
 	<-c.mu
 	defer func() { c.mu <- true }()
 	defer func() { c.mu <- true }()
 
 
-	if c.closeSent {
-		return ErrCloseSent
-	} else if frameType == CloseMessage {
-		c.closeSent = true
+	c.writeErrMu.Lock()
+	err := c.writeErr
+	c.writeErrMu.Unlock()
+	if err != nil {
+		return err
 	}
 	}
 
 
 	c.conn.SetWriteDeadline(deadline)
 	c.conn.SetWriteDeadline(deadline)
 	for _, buf := range bufs {
 	for _, buf := range bufs {
 		if len(buf) > 0 {
 		if len(buf) > 0 {
-			n, err := c.conn.Write(buf)
-			if n != len(buf) {
-				// Close on partial write.
-				c.conn.Close()
-			}
+			_, err := c.conn.Write(buf)
 			if err != nil {
 			if err != nil {
-				return err
+				return c.writeFatal(err)
 			}
 			}
 		}
 		}
 	}
 	}
+
+	if frameType == CloseMessage {
+		c.writeFatal(ErrCloseSent)
+	}
 	return nil
 	return nil
 }
 }
 
 
@@ -379,18 +386,22 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
 	}
 	}
 	defer func() { c.mu <- true }()
 	defer func() { c.mu <- true }()
 
 
-	if c.closeSent {
-		return ErrCloseSent
-	} else if messageType == CloseMessage {
-		c.closeSent = true
+	c.writeErrMu.Lock()
+	err := c.writeErr
+	c.writeErrMu.Unlock()
+	if err != nil {
+		return err
 	}
 	}
 
 
 	c.conn.SetWriteDeadline(deadline)
 	c.conn.SetWriteDeadline(deadline)
-	n, err := c.conn.Write(buf)
-	if n != 0 && n != len(buf) {
-		c.conn.Close()
+	_, err = c.conn.Write(buf)
+	if err != nil {
+		return c.writeFatal(err)
+	}
+	if messageType == CloseMessage {
+		c.writeFatal(ErrCloseSent)
 	}
 	}
-	return hideTempErr(err)
+	return err
 }
 }
 
 
 // NextWriter returns a writer for the next message to send. The writer's Close
 // NextWriter returns a writer for the next message to send. The writer's Close
@@ -399,64 +410,79 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
 // There can be at most one open writer on a connection. NextWriter closes the
 // There can be at most one open writer on a connection. NextWriter closes the
 // previous writer if the application has not already done so.
 // previous writer if the application has not already done so.
 func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
 func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
-	if c.writeErr != nil {
-		return nil, c.writeErr
-	}
-
 	// Close previous writer if not already closed by the application. It's
 	// Close previous writer if not already closed by the application. It's
 	// probably better to return an error in this situation, but we cannot
 	// probably better to return an error in this situation, but we cannot
 	// change this without breaking existing applications.
 	// change this without breaking existing applications.
 	if c.writer != nil {
 	if c.writer != nil {
-		err := c.writer.Close()
-		if err != nil {
-			return nil, err
-		}
+		c.writer.Close()
+		c.writer = nil
 	}
 	}
 
 
 	if !isControl(messageType) && !isData(messageType) {
 	if !isControl(messageType) && !isData(messageType) {
 		return nil, errBadWriteOpCode
 		return nil, errBadWriteOpCode
 	}
 	}
 
 
-	c.writeFrameType = messageType
-	c.messageWriter = &messageWriter{c}
+	c.writeErrMu.Lock()
+	err := c.writeErr
+	c.writeErrMu.Unlock()
+	if err != nil {
+		return nil, err
+	}
 
 
-	var w io.WriteCloser = c.messageWriter
+	mw := &messageWriter{
+		c:         c,
+		frameType: messageType,
+		pos:       maxFrameHeaderSize,
+	}
+	c.writer = mw
 	if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
 	if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
-		c.writeCompress = true
-		var err error
-		w, err = c.newCompressionWriter(w)
+		w, err := c.newCompressionWriter(c.writer)
 		if err != nil {
 		if err != nil {
-			c.writer.Close()
+			c.writer = nil
 			return nil, err
 			return nil, err
 		}
 		}
+		mw.compress = true
+		c.writer = w
 	}
 	}
+	return c.writer, nil
+}
 
 
-	return w, nil
+type messageWriter struct {
+	c         *Conn
+	compress  bool // whether next call to flushFrame should set RSV1
+	pos       int  // end of data in writeBuf.
+	frameType int  // type of the current frame.
+	err       error
+}
+
+func (w *messageWriter) fatal(err error) error {
+	if w.err != nil {
+		w.err = err
+		w.c.writer = nil
+	}
+	return err
 }
 }
 
 
 // flushFrame writes buffered data and extra as a frame to the network. The
 // flushFrame writes buffered data and extra as a frame to the network. The
 // final argument indicates that this is the last frame in the message.
 // final argument indicates that this is the last frame in the message.
-func (c *Conn) flushFrame(final bool, extra []byte) error {
-	length := c.writePos - maxFrameHeaderSize + len(extra)
+func (w *messageWriter) flushFrame(final bool, extra []byte) error {
+	c := w.c
+	length := w.pos - maxFrameHeaderSize + len(extra)
 
 
 	// Check for invalid control frames.
 	// Check for invalid control frames.
-	if isControl(c.writeFrameType) &&
+	if isControl(w.frameType) &&
 		(!final || length > maxControlFramePayloadSize) {
 		(!final || length > maxControlFramePayloadSize) {
-		c.messageWriter = nil
-		c.writer = nil
-		c.writeFrameType = noFrame
-		c.writePos = maxFrameHeaderSize
-		return errInvalidControlFrame
+		return w.fatal(errInvalidControlFrame)
 	}
 	}
 
 
-	b0 := byte(c.writeFrameType)
+	b0 := byte(w.frameType)
 	if final {
 	if final {
 		b0 |= finalBit
 		b0 |= finalBit
 	}
 	}
-	if c.writeCompress {
+	if w.compress {
 		b0 |= rsv1Bit
 		b0 |= rsv1Bit
 	}
 	}
-	c.writeCompress = false
+	w.compress = false
 
 
 	b1 := byte(0)
 	b1 := byte(0)
 	if !c.isServer {
 	if !c.isServer {
@@ -489,10 +515,9 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
 	if !c.isServer {
 	if !c.isServer {
 		key := newMaskKey()
 		key := newMaskKey()
 		copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
 		copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
-		maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos])
+		maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
 		if len(extra) > 0 {
 		if len(extra) > 0 {
-			c.writeErr = errors.New("websocket: internal error, extra used in client mode")
-			return c.writeErr
+			return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
 		}
 		}
 	}
 	}
 
 
@@ -505,44 +530,35 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
 	}
 	}
 	c.isWriting = true
 	c.isWriting = true
 
 
-	c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra)
+	err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra)
 
 
 	if !c.isWriting {
 	if !c.isWriting {
 		panic("concurrent write to websocket connection")
 		panic("concurrent write to websocket connection")
 	}
 	}
 	c.isWriting = false
 	c.isWriting = false
 
 
-	// Setup for next frame.
-	c.writePos = maxFrameHeaderSize
-	c.writeFrameType = continuationFrame
+	if err != nil {
+		return w.fatal(err)
+	}
+
 	if final {
 	if final {
-		c.messageWriter = nil
 		c.writer = nil
 		c.writer = nil
-		c.writeFrameType = noFrame
+		return nil
 	}
 	}
-	return c.writeErr
-}
-
-type messageWriter struct{ c *Conn }
 
 
-func (w *messageWriter) err() error {
-	c := w.c
-	if c.messageWriter != w {
-		return errWriteClosed
-	}
-	if c.writeErr != nil {
-		return c.writeErr
-	}
+	// Setup for next frame.
+	w.pos = maxFrameHeaderSize
+	w.frameType = continuationFrame
 	return nil
 	return nil
 }
 }
 
 
 func (w *messageWriter) ncopy(max int) (int, error) {
 func (w *messageWriter) ncopy(max int) (int, error) {
-	n := len(w.c.writeBuf) - w.c.writePos
+	n := len(w.c.writeBuf) - w.pos
 	if n <= 0 {
 	if n <= 0 {
-		if err := w.c.flushFrame(false, nil); err != nil {
+		if err := w.flushFrame(false, nil); err != nil {
 			return 0, err
 			return 0, err
 		}
 		}
-		n = len(w.c.writeBuf) - w.c.writePos
+		n = len(w.c.writeBuf) - w.pos
 	}
 	}
 	if n > max {
 	if n > max {
 		n = max
 		n = max
@@ -551,13 +567,13 @@ func (w *messageWriter) ncopy(max int) (int, error) {
 }
 }
 
 
 func (w *messageWriter) Write(p []byte) (int, error) {
 func (w *messageWriter) Write(p []byte) (int, error) {
-	if err := w.err(); err != nil {
-		return 0, err
+	if w.err != nil {
+		return 0, w.err
 	}
 	}
 
 
 	if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
 	if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
 		// Don't buffer large messages.
 		// Don't buffer large messages.
-		err := w.c.flushFrame(false, p)
+		err := w.flushFrame(false, p)
 		if err != nil {
 		if err != nil {
 			return 0, err
 			return 0, err
 		}
 		}
@@ -570,16 +586,16 @@ func (w *messageWriter) Write(p []byte) (int, error) {
 		if err != nil {
 		if err != nil {
 			return 0, err
 			return 0, err
 		}
 		}
-		copy(w.c.writeBuf[w.c.writePos:], p[:n])
-		w.c.writePos += n
+		copy(w.c.writeBuf[w.pos:], p[:n])
+		w.pos += n
 		p = p[n:]
 		p = p[n:]
 	}
 	}
 	return nn, nil
 	return nn, nil
 }
 }
 
 
 func (w *messageWriter) WriteString(p string) (int, error) {
 func (w *messageWriter) WriteString(p string) (int, error) {
-	if err := w.err(); err != nil {
-		return 0, err
+	if w.err != nil {
+		return 0, w.err
 	}
 	}
 
 
 	nn := len(p)
 	nn := len(p)
@@ -588,27 +604,27 @@ func (w *messageWriter) WriteString(p string) (int, error) {
 		if err != nil {
 		if err != nil {
 			return 0, err
 			return 0, err
 		}
 		}
-		copy(w.c.writeBuf[w.c.writePos:], p[:n])
-		w.c.writePos += n
+		copy(w.c.writeBuf[w.pos:], p[:n])
+		w.pos += n
 		p = p[n:]
 		p = p[n:]
 	}
 	}
 	return nn, nil
 	return nn, nil
 }
 }
 
 
 func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
 func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
-	if err := w.err(); err != nil {
-		return 0, err
+	if w.err != nil {
+		return 0, w.err
 	}
 	}
 	for {
 	for {
-		if w.c.writePos == len(w.c.writeBuf) {
-			err = w.c.flushFrame(false, nil)
+		if w.pos == len(w.c.writeBuf) {
+			err = w.flushFrame(false, nil)
 			if err != nil {
 			if err != nil {
 				break
 				break
 			}
 			}
 		}
 		}
 		var n int
 		var n int
-		n, err = r.Read(w.c.writeBuf[w.c.writePos:])
-		w.c.writePos += n
+		n, err = r.Read(w.c.writeBuf[w.pos:])
+		w.pos += n
 		nn += int64(n)
 		nn += int64(n)
 		if err != nil {
 		if err != nil {
 			if err == io.EOF {
 			if err == io.EOF {
@@ -621,10 +637,14 @@ func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
 }
 }
 
 
 func (w *messageWriter) Close() error {
 func (w *messageWriter) Close() error {
-	if err := w.err(); err != nil {
+	if w.err != nil {
+		return w.err
+	}
+	if err := w.flushFrame(true, nil); err != nil {
 		return err
 		return err
 	}
 	}
-	return w.c.flushFrame(true, nil)
+	w.err = errWriteClosed
+	return nil
 }
 }
 
 
 // WriteMessage is a helper method for getting a writer using NextWriter,
 // WriteMessage is a helper method for getting a writer using NextWriter,
@@ -634,12 +654,12 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	if _, ok := w.(*messageWriter); ok && c.isServer {
+	if mw, ok := w.(*messageWriter); ok && c.isServer {
 		// Optimize write as a single frame.
 		// Optimize write as a single frame.
-		n := copy(c.writeBuf[c.writePos:], data)
-		c.writePos += n
+		n := copy(c.writeBuf[mw.pos:], data)
+		mw.pos += n
 		data = data[n:]
 		data = data[n:]
-		err = c.flushFrame(true, data)
+		err = mw.flushFrame(true, data)
 		return err
 		return err
 	}
 	}
 	if _, err = w.Write(data); err != nil {
 	if _, err = w.Write(data); err != nil {

+ 72 - 14
conn_test.go

@@ -26,12 +26,27 @@ type fakeNetConn struct {
 }
 }
 
 
 func (c fakeNetConn) Close() error                       { return nil }
 func (c fakeNetConn) Close() error                       { return nil }
-func (c fakeNetConn) LocalAddr() net.Addr                { return nil }
-func (c fakeNetConn) RemoteAddr() net.Addr               { return nil }
+func (c fakeNetConn) LocalAddr() net.Addr                { return localAddr }
+func (c fakeNetConn) RemoteAddr() net.Addr               { return remoteAddr }
 func (c fakeNetConn) SetDeadline(t time.Time) error      { return nil }
 func (c fakeNetConn) SetDeadline(t time.Time) error      { return nil }
 func (c fakeNetConn) SetReadDeadline(t time.Time) error  { return nil }
 func (c fakeNetConn) SetReadDeadline(t time.Time) error  { return nil }
 func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
 func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
 
 
+type fakeAddr int
+
+var (
+	localAddr  = fakeAddr(1)
+	remoteAddr = fakeAddr(2)
+)
+
+func (a fakeAddr) Network() string {
+	return "net"
+}
+
+func (a fakeAddr) String() string {
+	return "str"
+}
+
 func TestFraming(t *testing.T) {
 func TestFraming(t *testing.T) {
 	frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
 	frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
 	var readChunkers = []struct {
 	var readChunkers = []struct {
@@ -42,11 +57,25 @@ func TestFraming(t *testing.T) {
 		{"one", iotest.OneByteReader},
 		{"one", iotest.OneByteReader},
 		{"asis", func(r io.Reader) io.Reader { return r }},
 		{"asis", func(r io.Reader) io.Reader { return r }},
 	}
 	}
-
 	writeBuf := make([]byte, 65537)
 	writeBuf := make([]byte, 65537)
 	for i := range writeBuf {
 	for i := range writeBuf {
 		writeBuf[i] = byte(i)
 		writeBuf[i] = byte(i)
 	}
 	}
+	var writers = []struct {
+		name string
+		f    func(w io.Writer, n int) (int, error)
+	}{
+		{"iocopy", func(w io.Writer, n int) (int, error) {
+			nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
+			return int(nn), err
+		}},
+		{"write", func(w io.Writer, n int) (int, error) {
+			return w.Write(writeBuf[:n])
+		}},
+		{"string", func(w io.Writer, n int) (int, error) {
+			return io.WriteString(w, string(writeBuf[:n]))
+		}},
+	}
 
 
 	for _, compress := range []bool{false, true} {
 	for _, compress := range []bool{false, true} {
 		for _, isServer := range []bool{true, false} {
 		for _, isServer := range []bool{true, false} {
@@ -60,22 +89,15 @@ func TestFraming(t *testing.T) {
 					rc.newDecompressionReader = decompressNoContextTakeover
 					rc.newDecompressionReader = decompressNoContextTakeover
 				}
 				}
 				for _, n := range frameSizes {
 				for _, n := range frameSizes {
-					for _, iocopy := range []bool{true, false} {
-						name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d c:%v", compress, isServer, chunker.name, n, iocopy)
+					for _, writer := range writers {
+						name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
 
 
 						w, err := wc.NextWriter(TextMessage)
 						w, err := wc.NextWriter(TextMessage)
 						if err != nil {
 						if err != nil {
 							t.Errorf("%s: wc.NextWriter() returned %v", name, err)
 							t.Errorf("%s: wc.NextWriter() returned %v", name, err)
 							continue
 							continue
 						}
 						}
-						var nn int
-						if iocopy {
-							var n64 int64
-							n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n]))
-							nn = int(n64)
-						} else {
-							nn, err = w.Write(writeBuf[:n])
-						}
+						nn, err := writer.f(w, n)
 						if err != nil || nn != n {
 						if err != nil || nn != n {
 							t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
 							t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
 							continue
 							continue
@@ -151,7 +173,7 @@ func TestControl(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func TestCloseBeforeFinalFrame(t *testing.T) {
+func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
 	const bufSize = 512
 	const bufSize = 512
 
 
 	expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
 	expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
@@ -238,6 +260,32 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestWriteAfterMessageWriterClose(t *testing.T) {
+	wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024)
+	w, _ := wc.NextWriter(BinaryMessage)
+	io.WriteString(w, "hello")
+	if err := w.Close(); err != nil {
+		t.Fatalf("unxpected error closing message writer, %v", err)
+	}
+
+	if _, err := io.WriteString(w, "world"); err == nil {
+		t.Fatalf("no error writing after close")
+	}
+
+	w, _ = wc.NextWriter(BinaryMessage)
+	io.WriteString(w, "hello")
+
+	// close w by getting next writer
+	_, err := wc.NextWriter(BinaryMessage)
+	if err != nil {
+		t.Fatalf("unexpected error getting next writer, %v", err)
+	}
+
+	if _, err := io.WriteString(w, "world"); err == nil {
+		t.Fatalf("no error writing after close")
+	}
+}
+
 func TestReadLimit(t *testing.T) {
 func TestReadLimit(t *testing.T) {
 
 
 	const readLimit = 512
 	const readLimit = 512
@@ -272,6 +320,16 @@ func TestReadLimit(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestAddrs(t *testing.T) {
+	c := newConn(&fakeNetConn{}, true, 1024, 1024)
+	if c.LocalAddr() != localAddr {
+		t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
+	}
+	if c.RemoteAddr() != remoteAddr {
+		t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
+	}
+}
+
 func TestUnderlyingConn(t *testing.T) {
 func TestUnderlyingConn(t *testing.T) {
 	var b1, b2 bytes.Buffer
 	var b1, b2 bytes.Buffer
 	fc := fakeNetConn{Reader: &b1, Writer: &b2}
 	fc := fakeNetConn{Reader: &b1, Writer: &b2}