فهرست منبع

Reduce memory allocations in NextReader, NextWriter

Redo 8b209f63177a963547dc3cee89350a327ead0412 with support for old
versions of Go.
Gary Burd 10 سال پیش
والد
کامیت
be01041b66
3فایلهای تغییر یافته به همراه116 افزوده شده و 86 حذف شده
  1. 77 86
      conn.go
  2. 18 0
      conn_read.go
  3. 21 0
      conn_read_legacy.go

+ 77 - 86
conn.go

@@ -238,16 +238,15 @@ type Conn struct {
 	writeBuf       []byte // frame is constructed in this buffer.
 	writePos       int    // end of data in writeBuf.
 	writeFrameType int    // type of the current frame.
-	writeSeq       int    // incremented to invalidate message writers.
 	writeDeadline  time.Time
-	isWriting      bool // for best-effort concurrent write detection
+	isWriting      bool           // for best-effort concurrent write detection
+	messageWriter  *messageWriter // the current writer
 
 	// Read fields
 	readErr       error
 	br            *bufio.Reader
 	readRemaining int64 // bytes remaining in current frame.
 	readFinal     bool  // true the current message has more frames.
-	readSeq       int   // incremented to invalidate message readers.
 	readLength    int64 // Message size.
 	readLimit     int64 // Maximum message size.
 	readMaskPos   int
@@ -255,6 +254,7 @@ type Conn struct {
 	handlePong    func(string) error
 	handlePing    func(string) error
 	readErrCount  int
+	messageReader *messageReader // the current reader
 }
 
 func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@@ -264,6 +264,9 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
 	if readBufferSize == 0 {
 		readBufferSize = defaultReadBufferSize
 	}
+	if readBufferSize < maxControlFramePayloadSize {
+		readBufferSize = maxControlFramePayloadSize
+	}
 	if writeBufferSize == 0 {
 		writeBufferSize = defaultWriteBufferSize
 	}
@@ -390,8 +393,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
 	return hideTempErr(err)
 }
 
-// NextWriter returns a writer for the next message to send.  The writer's
-// Close method flushes the complete message to the network.
+// NextWriter returns a writer for the next message to send. The writer's Close
+// method flushes the complete message to the network.
 //
 // There can be at most one open writer on a connection. NextWriter closes the
 // previous writer if the application has not already done so.
@@ -411,7 +414,9 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
 	}
 
 	c.writeFrameType = messageType
-	return messageWriter{c, c.writeSeq}, nil
+	w := &messageWriter{c}
+	c.messageWriter = w
+	return w, nil
 }
 
 func (c *Conn) flushFrame(final bool, extra []byte) error {
@@ -420,7 +425,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
 	// Check for invalid control frames.
 	if isControl(c.writeFrameType) &&
 		(!final || length > maxControlFramePayloadSize) {
-		c.writeSeq++
+		c.messageWriter = nil
 		c.writeFrameType = noFrame
 		c.writePos = maxFrameHeaderSize
 		return errInvalidControlFrame
@@ -488,20 +493,17 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
 	c.writePos = maxFrameHeaderSize
 	c.writeFrameType = continuationFrame
 	if final {
-		c.writeSeq++
+		c.messageWriter = nil
 		c.writeFrameType = noFrame
 	}
 	return c.writeErr
 }
 
-type messageWriter struct {
-	c   *Conn
-	seq int
-}
+type messageWriter struct{ c *Conn }
 
-func (w messageWriter) err() error {
+func (w *messageWriter) err() error {
 	c := w.c
-	if c.writeSeq != w.seq {
+	if c.messageWriter != w {
 		return errWriteClosed
 	}
 	if c.writeErr != nil {
@@ -510,7 +512,7 @@ func (w messageWriter) err() error {
 	return nil
 }
 
-func (w messageWriter) ncopy(max int) (int, error) {
+func (w *messageWriter) ncopy(max int) (int, error) {
 	n := len(w.c.writeBuf) - w.c.writePos
 	if n <= 0 {
 		if err := w.c.flushFrame(false, nil); err != nil {
@@ -524,7 +526,7 @@ func (w messageWriter) ncopy(max int) (int, error) {
 	return n, nil
 }
 
-func (w messageWriter) write(final bool, p []byte) (int, error) {
+func (w *messageWriter) write(final bool, p []byte) (int, error) {
 	if err := w.err(); err != nil {
 		return 0, err
 	}
@@ -551,11 +553,11 @@ func (w messageWriter) write(final bool, p []byte) (int, error) {
 	return nn, nil
 }
 
-func (w messageWriter) Write(p []byte) (int, error) {
+func (w *messageWriter) Write(p []byte) (int, error) {
 	return w.write(false, p)
 }
 
-func (w messageWriter) WriteString(p string) (int, error) {
+func (w *messageWriter) WriteString(p string) (int, error) {
 	if err := w.err(); err != nil {
 		return 0, err
 	}
@@ -573,7 +575,7 @@ func (w messageWriter) WriteString(p string) (int, error) {
 	return nn, nil
 }
 
-func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
+func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
 	if err := w.err(); err != nil {
 		return 0, err
 	}
@@ -598,7 +600,7 @@ func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
 	return nn, err
 }
 
-func (w messageWriter) Close() error {
+func (w *messageWriter) Close() error {
 	if err := w.err(); err != nil {
 		return err
 	}
@@ -608,20 +610,22 @@ func (w messageWriter) Close() error {
 // WriteMessage is a helper method for getting a writer using NextWriter,
 // writing the message and closing the writer.
 func (c *Conn) WriteMessage(messageType int, data []byte) error {
-	wr, err := c.NextWriter(messageType)
+	w, err := c.NextWriter(messageType)
 	if err != nil {
 		return err
 	}
-	w := wr.(messageWriter)
-	if _, err := w.write(true, data); err != nil {
+	if _, ok := w.(*messageWriter); ok && c.isServer {
+		// Optimize write as a single frame.
+		n := copy(c.writeBuf[c.writePos:], data)
+		c.writePos += n
+		data = data[n:]
+		err = c.flushFrame(true, data)
 		return err
 	}
-	if c.writeSeq == w.seq {
-		if err := c.flushFrame(true, nil); err != nil {
-			return err
-		}
+	if _, err = w.Write(data); err != nil {
+		return err
 	}
-	return nil
+	return w.Close()
 }
 
 // SetWriteDeadline sets the write deadline on the underlying network
@@ -635,22 +639,6 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
 
 // Read methods
 
-// readFull is like io.ReadFull except that io.EOF is never returned.
-func (c *Conn) readFull(p []byte) (err error) {
-	var n int
-	for n < len(p) && err == nil {
-		var nn int
-		nn, err = c.br.Read(p[n:])
-		n += nn
-	}
-	if n == len(p) {
-		err = nil
-	} else if err == io.EOF {
-		err = errUnexpectedEOF
-	}
-	return
-}
-
 func (c *Conn) advanceFrame() (int, error) {
 
 	// 1. Skip remainder of previous frame.
@@ -663,16 +651,16 @@ func (c *Conn) advanceFrame() (int, error) {
 
 	// 2. Read and parse first two bytes of frame header.
 
-	var b [8]byte
-	if err := c.readFull(b[:2]); err != nil {
+	p, err := c.read(2)
+	if err != nil {
 		return noFrame, err
 	}
 
-	final := b[0]&finalBit != 0
-	frameType := int(b[0] & 0xf)
-	reserved := int((b[0] >> 4) & 0x7)
-	mask := b[1]&maskBit != 0
-	c.readRemaining = int64(b[1] & 0x7f)
+	final := p[0]&finalBit != 0
+	frameType := int(p[0] & 0xf)
+	reserved := int((p[0] >> 4) & 0x7)
+	mask := p[1]&maskBit != 0
+	c.readRemaining = int64(p[1] & 0x7f)
 
 	if reserved != 0 {
 		return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
@@ -704,15 +692,17 @@ func (c *Conn) advanceFrame() (int, error) {
 
 	switch c.readRemaining {
 	case 126:
-		if err := c.readFull(b[:2]); err != nil {
+		p, err := c.read(2)
+		if err != nil {
 			return noFrame, err
 		}
-		c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
+		c.readRemaining = int64(binary.BigEndian.Uint16(p))
 	case 127:
-		if err := c.readFull(b[:8]); err != nil {
+		p, err := c.read(8)
+		if err != nil {
 			return noFrame, err
 		}
-		c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
+		c.readRemaining = int64(binary.BigEndian.Uint64(p))
 	}
 
 	// 4. Handle frame masking.
@@ -723,9 +713,11 @@ func (c *Conn) advanceFrame() (int, error) {
 
 	if mask {
 		c.readMaskPos = 0
-		if err := c.readFull(c.readMaskKey[:]); err != nil {
+		p, err := c.read(len(c.readMaskKey))
+		if err != nil {
 			return noFrame, err
 		}
+		copy(c.readMaskKey[:], p)
 	}
 
 	// 5. For text and binary messages, enforce read limit and return.
@@ -745,9 +737,9 @@ func (c *Conn) advanceFrame() (int, error) {
 
 	var payload []byte
 	if c.readRemaining > 0 {
-		payload = make([]byte, c.readRemaining)
+		payload, err = c.read(int(c.readRemaining))
 		c.readRemaining = 0
-		if err := c.readFull(payload); err != nil {
+		if err != nil {
 			return noFrame, err
 		}
 		if c.isServer {
@@ -805,7 +797,7 @@ func (c *Conn) handleProtocolError(message string) error {
 // this method return the same error.
 func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
 
-	c.readSeq++
+	c.messageReader = nil
 	c.readLength = 0
 
 	for c.readErr == nil {
@@ -815,7 +807,9 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
 			break
 		}
 		if frameType == TextMessage || frameType == BinaryMessage {
-			return frameType, messageReader{c, c.readSeq}, nil
+			r := &messageReader{c}
+			c.messageReader = r
+			return frameType, r, nil
 		}
 	}
 
@@ -830,51 +824,48 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
 	return noFrame, nil, c.readErr
 }
 
-type messageReader struct {
-	c   *Conn
-	seq int
-}
-
-func (r messageReader) Read(b []byte) (int, error) {
+type messageReader struct{ c *Conn }
 
-	if r.seq != r.c.readSeq {
+func (r *messageReader) Read(b []byte) (int, error) {
+	c := r.c
+	if c.messageReader != r {
 		return 0, io.EOF
 	}
 
-	for r.c.readErr == nil {
+	for c.readErr == nil {
 
-		if r.c.readRemaining > 0 {
-			if int64(len(b)) > r.c.readRemaining {
-				b = b[:r.c.readRemaining]
+		if c.readRemaining > 0 {
+			if int64(len(b)) > c.readRemaining {
+				b = b[:c.readRemaining]
 			}
-			n, err := r.c.br.Read(b)
-			r.c.readErr = hideTempErr(err)
-			if r.c.isServer {
-				r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
+			n, err := c.br.Read(b)
+			c.readErr = hideTempErr(err)
+			if c.isServer {
+				c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
 			}
-			r.c.readRemaining -= int64(n)
-			if r.c.readRemaining > 0 && r.c.readErr == io.EOF {
-				r.c.readErr = errUnexpectedEOF
+			c.readRemaining -= int64(n)
+			if c.readRemaining > 0 && c.readErr == io.EOF {
+				c.readErr = errUnexpectedEOF
 			}
-			return n, r.c.readErr
+			return n, c.readErr
 		}
 
-		if r.c.readFinal {
-			r.c.readSeq++
+		if c.readFinal {
+			c.messageReader = nil
 			return 0, io.EOF
 		}
 
-		frameType, err := r.c.advanceFrame()
+		frameType, err := c.advanceFrame()
 		switch {
 		case err != nil:
-			r.c.readErr = hideTempErr(err)
+			c.readErr = hideTempErr(err)
 		case frameType == TextMessage || frameType == BinaryMessage:
-			r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
+			c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
 		}
 	}
 
-	err := r.c.readErr
-	if err == io.EOF && r.seq == r.c.readSeq {
+	err := c.readErr
+	if err == io.EOF && c.messageReader == r {
 		err = errUnexpectedEOF
 	}
 	return 0, err

+ 18 - 0
conn_read.go

@@ -0,0 +1,18 @@
+// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.5
+
+package websocket
+
+import "io"
+
+func (c *Conn) read(n int) ([]byte, error) {
+	p, err := c.br.Peek(n)
+	if err == io.EOF {
+		err = errUnexpectedEOF
+	}
+	c.br.Discard(len(p))
+	return p, err
+}

+ 21 - 0
conn_read_legacy.go

@@ -0,0 +1,21 @@
+// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !go1.5
+
+package websocket
+
+import "io"
+
+func (c *Conn) read(n int) ([]byte, error) {
+	p, err := c.br.Peek(n)
+	if err == io.EOF {
+		err = errUnexpectedEOF
+	}
+	if len(p) > 0 {
+		// advance over the bytes just read
+		io.ReadFull(c.br, p)
+	}
+	return p, err
+}