Ver Fonte

Revert "Reduce memory allocations in NextReader, NextWriter"

This reverts commit 8b209f63177a963547dc3cee89350a327ead0412.
Gary Burd há 9 anos atrás
pai
commit
50d660d6ac
2 ficheiros alterados com 83 adições e 82 exclusões
  1. 1 0
      .travis.yml
  2. 82 82
      conn.go

+ 1 - 0
.travis.yml

@@ -3,6 +3,7 @@ sudo: false
 
 matrix:
   include:
+    - go: 1.4
     - go: 1.5
     - go: 1.6
     - go: tip

+ 82 - 82
conn.go

@@ -238,15 +238,16 @@ type Conn struct {
 	writeBuf       []byte // frame is constructed in this buffer.
 	writePos       int    // end of data in writeBuf.
 	writeFrameType int    // type of the current frame.
+	writeSeq       int    // incremented to invalidate message writers.
 	writeDeadline  time.Time
-	isWriting      bool           // for best-effort concurrent write detection
-	messageWriter  *messageWriter // the current writer
+	isWriting      bool // for best-effort concurrent write detection
 
 	// Read fields
 	readErr       error
 	br            *bufio.Reader
 	readRemaining int64 // bytes remaining in current frame.
 	readFinal     bool  // true the current message has more frames.
+	readSeq       int   // incremented to invalidate message readers.
 	readLength    int64 // Message size.
 	readLimit     int64 // Maximum message size.
 	readMaskPos   int
@@ -254,7 +255,6 @@ type Conn struct {
 	handlePong    func(string) error
 	handlePing    func(string) error
 	readErrCount  int
-	messageReader *messageReader // the current reader
 }
 
 func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@@ -264,9 +264,6 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
 	if readBufferSize == 0 {
 		readBufferSize = defaultReadBufferSize
 	}
-	if readBufferSize < maxControlFramePayloadSize {
-		readBufferSize = maxControlFramePayloadSize
-	}
 	if writeBufferSize == 0 {
 		writeBufferSize = defaultWriteBufferSize
 	}
@@ -393,8 +390,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
 	return hideTempErr(err)
 }
 
-// NextWriter returns a writer for the next message to send. The writer's Close
-// method flushes the complete message to the network.
+// NextWriter returns a writer for the next message to send.  The writer's
+// Close method flushes the complete message to the network.
 //
 // There can be at most one open writer on a connection. NextWriter closes the
 // previous writer if the application has not already done so.
@@ -414,9 +411,7 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
 	}
 
 	c.writeFrameType = messageType
-	w := &messageWriter{c}
-	c.messageWriter = w
-	return w, nil
+	return messageWriter{c, c.writeSeq}, nil
 }
 
 func (c *Conn) flushFrame(final bool, extra []byte) error {
@@ -425,7 +420,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
 	// Check for invalid control frames.
 	if isControl(c.writeFrameType) &&
 		(!final || length > maxControlFramePayloadSize) {
-		c.messageWriter = nil
+		c.writeSeq++
 		c.writeFrameType = noFrame
 		c.writePos = maxFrameHeaderSize
 		return errInvalidControlFrame
@@ -493,17 +488,20 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
 	c.writePos = maxFrameHeaderSize
 	c.writeFrameType = continuationFrame
 	if final {
-		c.messageWriter = nil
+		c.writeSeq++
 		c.writeFrameType = noFrame
 	}
 	return c.writeErr
 }
 
-type messageWriter struct{ c *Conn }
+type messageWriter struct {
+	c   *Conn
+	seq int
+}
 
-func (w *messageWriter) err() error {
+func (w messageWriter) err() error {
 	c := w.c
-	if c.messageWriter != w {
+	if c.writeSeq != w.seq {
 		return errWriteClosed
 	}
 	if c.writeErr != nil {
@@ -512,7 +510,7 @@ func (w *messageWriter) err() error {
 	return nil
 }
 
-func (w *messageWriter) ncopy(max int) (int, error) {
+func (w messageWriter) ncopy(max int) (int, error) {
 	n := len(w.c.writeBuf) - w.c.writePos
 	if n <= 0 {
 		if err := w.c.flushFrame(false, nil); err != nil {
@@ -526,7 +524,7 @@ func (w *messageWriter) ncopy(max int) (int, error) {
 	return n, nil
 }
 
-func (w *messageWriter) write(final bool, p []byte) (int, error) {
+func (w messageWriter) write(final bool, p []byte) (int, error) {
 	if err := w.err(); err != nil {
 		return 0, err
 	}
@@ -553,11 +551,11 @@ func (w *messageWriter) write(final bool, p []byte) (int, error) {
 	return nn, nil
 }
 
-func (w *messageWriter) Write(p []byte) (int, error) {
+func (w messageWriter) Write(p []byte) (int, error) {
 	return w.write(false, p)
 }
 
-func (w *messageWriter) WriteString(p string) (int, error) {
+func (w messageWriter) WriteString(p string) (int, error) {
 	if err := w.err(); err != nil {
 		return 0, err
 	}
@@ -575,7 +573,7 @@ func (w *messageWriter) WriteString(p string) (int, error) {
 	return nn, nil
 }
 
-func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
+func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
 	if err := w.err(); err != nil {
 		return 0, err
 	}
@@ -600,7 +598,7 @@ func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
 	return nn, err
 }
 
-func (w *messageWriter) Close() error {
+func (w messageWriter) Close() error {
 	if err := w.err(); err != nil {
 		return err
 	}
@@ -610,22 +608,20 @@ func (w *messageWriter) Close() error {
 // WriteMessage is a helper method for getting a writer using NextWriter,
 // writing the message and closing the writer.
 func (c *Conn) WriteMessage(messageType int, data []byte) error {
-	w, err := c.NextWriter(messageType)
+	wr, err := c.NextWriter(messageType)
 	if err != nil {
 		return err
 	}
-	if _, ok := w.(*messageWriter); ok && c.isServer {
-		// Optimize write as a single frame.
-		n := copy(c.writeBuf[c.writePos:], data)
-		c.writePos += n
-		data = data[n:]
-		err = c.flushFrame(true, data)
+	w := wr.(messageWriter)
+	if _, err := w.write(true, data); err != nil {
 		return err
 	}
-	if _, err = w.Write(data); err != nil {
-		return err
+	if c.writeSeq == w.seq {
+		if err := c.flushFrame(true, nil); err != nil {
+			return err
+		}
 	}
-	return w.Close()
+	return nil
 }
 
 // SetWriteDeadline sets the write deadline on the underlying network
@@ -639,13 +635,20 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
 
 // Read methods
 
-func (c *Conn) read(n int) ([]byte, error) {
-	p, err := c.br.Peek(n)
-	if err == io.EOF {
+// readFull is like io.ReadFull except that io.EOF is never returned.
+func (c *Conn) readFull(p []byte) (err error) {
+	var n int
+	for n < len(p) && err == nil {
+		var nn int
+		nn, err = c.br.Read(p[n:])
+		n += nn
+	}
+	if n == len(p) {
+		err = nil
+	} else if err == io.EOF {
 		err = errUnexpectedEOF
 	}
-	c.br.Discard(len(p))
-	return p, err
+	return
 }
 
 func (c *Conn) advanceFrame() (int, error) {
@@ -660,16 +663,16 @@ func (c *Conn) advanceFrame() (int, error) {
 
 	// 2. Read and parse first two bytes of frame header.
 
-	p, err := c.read(2)
-	if err != nil {
+	var b [8]byte
+	if err := c.readFull(b[:2]); err != nil {
 		return noFrame, err
 	}
 
-	final := p[0]&finalBit != 0
-	frameType := int(p[0] & 0xf)
-	reserved := int((p[0] >> 4) & 0x7)
-	mask := p[1]&maskBit != 0
-	c.readRemaining = int64(p[1] & 0x7f)
+	final := b[0]&finalBit != 0
+	frameType := int(b[0] & 0xf)
+	reserved := int((b[0] >> 4) & 0x7)
+	mask := b[1]&maskBit != 0
+	c.readRemaining = int64(b[1] & 0x7f)
 
 	if reserved != 0 {
 		return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
@@ -701,17 +704,15 @@ func (c *Conn) advanceFrame() (int, error) {
 
 	switch c.readRemaining {
 	case 126:
-		p, err := c.read(2)
-		if err != nil {
+		if err := c.readFull(b[:2]); err != nil {
 			return noFrame, err
 		}
-		c.readRemaining = int64(binary.BigEndian.Uint16(p))
+		c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
 	case 127:
-		p, err := c.read(8)
-		if err != nil {
+		if err := c.readFull(b[:8]); err != nil {
 			return noFrame, err
 		}
-		c.readRemaining = int64(binary.BigEndian.Uint64(p))
+		c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
 	}
 
 	// 4. Handle frame masking.
@@ -722,11 +723,9 @@ func (c *Conn) advanceFrame() (int, error) {
 
 	if mask {
 		c.readMaskPos = 0
-		p, err := c.read(len(c.readMaskKey))
-		if err != nil {
+		if err := c.readFull(c.readMaskKey[:]); err != nil {
 			return noFrame, err
 		}
-		copy(c.readMaskKey[:], p)
 	}
 
 	// 5. For text and binary messages, enforce read limit and return.
@@ -746,9 +745,9 @@ func (c *Conn) advanceFrame() (int, error) {
 
 	var payload []byte
 	if c.readRemaining > 0 {
-		payload, err = c.read(int(c.readRemaining))
+		payload = make([]byte, c.readRemaining)
 		c.readRemaining = 0
-		if err != nil {
+		if err := c.readFull(payload); err != nil {
 			return noFrame, err
 		}
 		if c.isServer {
@@ -806,7 +805,7 @@ func (c *Conn) handleProtocolError(message string) error {
 // this method return the same error.
 func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
 
-	c.messageReader = nil
+	c.readSeq++
 	c.readLength = 0
 
 	for c.readErr == nil {
@@ -816,9 +815,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
 			break
 		}
 		if frameType == TextMessage || frameType == BinaryMessage {
-			r := &messageReader{c}
-			c.messageReader = r
-			return frameType, r, nil
+			return frameType, messageReader{c, c.readSeq}, nil
 		}
 	}
 
@@ -833,48 +830,51 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
 	return noFrame, nil, c.readErr
 }
 
-type messageReader struct{ c *Conn }
+type messageReader struct {
+	c   *Conn
+	seq int
+}
 
-func (r *messageReader) Read(b []byte) (int, error) {
-	c := r.c
-	if c.messageReader != r {
+func (r messageReader) Read(b []byte) (int, error) {
+
+	if r.seq != r.c.readSeq {
 		return 0, io.EOF
 	}
 
-	for c.readErr == nil {
+	for r.c.readErr == nil {
 
-		if c.readRemaining > 0 {
-			if int64(len(b)) > c.readRemaining {
-				b = b[:c.readRemaining]
+		if r.c.readRemaining > 0 {
+			if int64(len(b)) > r.c.readRemaining {
+				b = b[:r.c.readRemaining]
 			}
-			n, err := c.br.Read(b)
-			c.readErr = hideTempErr(err)
-			if c.isServer {
-				c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
+			n, err := r.c.br.Read(b)
+			r.c.readErr = hideTempErr(err)
+			if r.c.isServer {
+				r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
 			}
-			c.readRemaining -= int64(n)
-			if c.readRemaining > 0 && c.readErr == io.EOF {
-				c.readErr = errUnexpectedEOF
+			r.c.readRemaining -= int64(n)
+			if r.c.readRemaining > 0 && r.c.readErr == io.EOF {
+				r.c.readErr = errUnexpectedEOF
 			}
-			return n, c.readErr
+			return n, r.c.readErr
 		}
 
-		if c.readFinal {
-			c.messageReader = nil
+		if r.c.readFinal {
+			r.c.readSeq++
 			return 0, io.EOF
 		}
 
-		frameType, err := c.advanceFrame()
+		frameType, err := r.c.advanceFrame()
 		switch {
 		case err != nil:
-			c.readErr = hideTempErr(err)
+			r.c.readErr = hideTempErr(err)
 		case frameType == TextMessage || frameType == BinaryMessage:
-			c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
+			r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
 		}
 	}
 
-	err := c.readErr
-	if err == io.EOF && c.messageReader == r {
+	err := r.c.readErr
+	if err == io.EOF && r.seq == r.c.readSeq {
 		err = errUnexpectedEOF
 	}
 	return 0, err