|
|
@@ -12,6 +12,7 @@ import (
|
|
|
"io/ioutil"
|
|
|
"net"
|
|
|
"strconv"
|
|
|
+ "sync"
|
|
|
"time"
|
|
|
"unicode/utf8"
|
|
|
)
|
|
|
@@ -223,19 +224,16 @@ type Conn struct {
|
|
|
subprotocol string
|
|
|
|
|
|
// Write fields
|
|
|
- mu chan bool // used as mutex to protect write to conn and closeSent
|
|
|
- closeSent bool // whether close message was sent
|
|
|
- writeErr error
|
|
|
- writeBuf []byte // frame is constructed in this buffer.
|
|
|
- writePos int // end of data in writeBuf.
|
|
|
- writeFrameType int // type of the current frame.
|
|
|
- writeDeadline time.Time
|
|
|
- messageWriter *messageWriter // the current low-level message writer
|
|
|
- writer io.WriteCloser // the current writer returned to the application
|
|
|
- isWriting bool // for best-effort concurrent write detection
|
|
|
+ mu chan bool // used as mutex to protect write to conn
|
|
|
+ writeBuf []byte // frame is constructed in this buffer.
|
|
|
+ writeDeadline time.Time
|
|
|
+ writer io.WriteCloser // the current writer returned to the application
|
|
|
+ isWriting bool // for best-effort concurrent write detection
|
|
|
+
|
|
|
+ writeErrMu sync.Mutex
|
|
|
+ writeErr error
|
|
|
|
|
|
enableWriteCompression bool
|
|
|
- writeCompress bool // whether next call to flushFrame should set RSV1
|
|
|
newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)
|
|
|
|
|
|
// Read fields
|
|
|
@@ -277,8 +275,6 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
|
|
|
mu: mu,
|
|
|
readFinal: true,
|
|
|
writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
|
|
|
- writeFrameType: noFrame,
|
|
|
- writePos: maxFrameHeaderSize,
|
|
|
enableWriteCompression: true,
|
|
|
}
|
|
|
c.SetPingHandler(nil)
|
|
|
@@ -308,29 +304,40 @@ func (c *Conn) RemoteAddr() net.Addr {
|
|
|
|
|
|
// Write methods
|
|
|
|
|
|
+func (c *Conn) writeFatal(err error) error {
|
|
|
+ err = hideTempErr(err)
|
|
|
+ c.writeErrMu.Lock()
|
|
|
+ if c.writeErr == nil {
|
|
|
+ c.writeErr = err
|
|
|
+ }
|
|
|
+ c.writeErrMu.Unlock()
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
|
|
|
<-c.mu
|
|
|
defer func() { c.mu <- true }()
|
|
|
|
|
|
- if c.closeSent {
|
|
|
- return ErrCloseSent
|
|
|
- } else if frameType == CloseMessage {
|
|
|
- c.closeSent = true
|
|
|
+ c.writeErrMu.Lock()
|
|
|
+ err := c.writeErr
|
|
|
+ c.writeErrMu.Unlock()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
c.conn.SetWriteDeadline(deadline)
|
|
|
for _, buf := range bufs {
|
|
|
if len(buf) > 0 {
|
|
|
- n, err := c.conn.Write(buf)
|
|
|
- if n != len(buf) {
|
|
|
- // Close on partial write.
|
|
|
- c.conn.Close()
|
|
|
- }
|
|
|
+ _, err := c.conn.Write(buf)
|
|
|
if err != nil {
|
|
|
- return err
|
|
|
+ return c.writeFatal(err)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ if frameType == CloseMessage {
|
|
|
+ c.writeFatal(ErrCloseSent)
|
|
|
+ }
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
@@ -379,18 +386,22 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
|
|
}
|
|
|
defer func() { c.mu <- true }()
|
|
|
|
|
|
- if c.closeSent {
|
|
|
- return ErrCloseSent
|
|
|
- } else if messageType == CloseMessage {
|
|
|
- c.closeSent = true
|
|
|
+ c.writeErrMu.Lock()
|
|
|
+ err := c.writeErr
|
|
|
+ c.writeErrMu.Unlock()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
c.conn.SetWriteDeadline(deadline)
|
|
|
- n, err := c.conn.Write(buf)
|
|
|
- if n != 0 && n != len(buf) {
|
|
|
- c.conn.Close()
|
|
|
+ _, err = c.conn.Write(buf)
|
|
|
+ if err != nil {
|
|
|
+ return c.writeFatal(err)
|
|
|
+ }
|
|
|
+ if messageType == CloseMessage {
|
|
|
+ c.writeFatal(ErrCloseSent)
|
|
|
}
|
|
|
- return hideTempErr(err)
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
// NextWriter returns a writer for the next message to send. The writer's Close
|
|
|
@@ -399,64 +410,79 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
|
|
// There can be at most one open writer on a connection. NextWriter closes the
|
|
|
// previous writer if the application has not already done so.
|
|
|
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
|
|
- if c.writeErr != nil {
|
|
|
- return nil, c.writeErr
|
|
|
- }
|
|
|
-
|
|
|
// Close previous writer if not already closed by the application. It's
|
|
|
// probably better to return an error in this situation, but we cannot
|
|
|
// change this without breaking existing applications.
|
|
|
if c.writer != nil {
|
|
|
- err := c.writer.Close()
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
+ c.writer.Close()
|
|
|
+ c.writer = nil
|
|
|
}
|
|
|
|
|
|
if !isControl(messageType) && !isData(messageType) {
|
|
|
return nil, errBadWriteOpCode
|
|
|
}
|
|
|
|
|
|
- c.writeFrameType = messageType
|
|
|
- c.messageWriter = &messageWriter{c}
|
|
|
+ c.writeErrMu.Lock()
|
|
|
+ err := c.writeErr
|
|
|
+ c.writeErrMu.Unlock()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
|
|
|
- var w io.WriteCloser = c.messageWriter
|
|
|
+ mw := &messageWriter{
|
|
|
+ c: c,
|
|
|
+ frameType: messageType,
|
|
|
+ pos: maxFrameHeaderSize,
|
|
|
+ }
|
|
|
+ c.writer = mw
|
|
|
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
|
|
|
- c.writeCompress = true
|
|
|
- var err error
|
|
|
- w, err = c.newCompressionWriter(w)
|
|
|
+ w, err := c.newCompressionWriter(c.writer)
|
|
|
if err != nil {
|
|
|
- c.writer.Close()
|
|
|
+ c.writer = nil
|
|
|
return nil, err
|
|
|
}
|
|
|
+ mw.compress = true
|
|
|
+ c.writer = w
|
|
|
}
|
|
|
+ return c.writer, nil
|
|
|
+}
|
|
|
|
|
|
- return w, nil
|
|
|
+type messageWriter struct {
|
|
|
+ c *Conn
|
|
|
+ compress bool // whether next call to flushFrame should set RSV1
|
|
|
+ pos int // end of data in writeBuf.
|
|
|
+ frameType int // type of the current frame.
|
|
|
+ err error
|
|
|
+}
|
|
|
+
|
|
|
+func (w *messageWriter) fatal(err error) error {
|
|
|
+ if w.err != nil {
|
|
|
+ w.err = err
|
|
|
+ w.c.writer = nil
|
|
|
+ }
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
// flushFrame writes buffered data and extra as a frame to the network. The
|
|
|
// final argument indicates that this is the last frame in the message.
|
|
|
-func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|
|
- length := c.writePos - maxFrameHeaderSize + len(extra)
|
|
|
+func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
|
|
+ c := w.c
|
|
|
+ length := w.pos - maxFrameHeaderSize + len(extra)
|
|
|
|
|
|
// Check for invalid control frames.
|
|
|
- if isControl(c.writeFrameType) &&
|
|
|
+ if isControl(w.frameType) &&
|
|
|
(!final || length > maxControlFramePayloadSize) {
|
|
|
- c.messageWriter = nil
|
|
|
- c.writer = nil
|
|
|
- c.writeFrameType = noFrame
|
|
|
- c.writePos = maxFrameHeaderSize
|
|
|
- return errInvalidControlFrame
|
|
|
+ return w.fatal(errInvalidControlFrame)
|
|
|
}
|
|
|
|
|
|
- b0 := byte(c.writeFrameType)
|
|
|
+ b0 := byte(w.frameType)
|
|
|
if final {
|
|
|
b0 |= finalBit
|
|
|
}
|
|
|
- if c.writeCompress {
|
|
|
+ if w.compress {
|
|
|
b0 |= rsv1Bit
|
|
|
}
|
|
|
- c.writeCompress = false
|
|
|
+ w.compress = false
|
|
|
|
|
|
b1 := byte(0)
|
|
|
if !c.isServer {
|
|
|
@@ -489,10 +515,9 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|
|
if !c.isServer {
|
|
|
key := newMaskKey()
|
|
|
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
|
|
|
- maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos])
|
|
|
+ maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
|
|
|
if len(extra) > 0 {
|
|
|
- c.writeErr = errors.New("websocket: internal error, extra used in client mode")
|
|
|
- return c.writeErr
|
|
|
+ return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -505,44 +530,35 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|
|
}
|
|
|
c.isWriting = true
|
|
|
|
|
|
- c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra)
|
|
|
+ err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra)
|
|
|
|
|
|
if !c.isWriting {
|
|
|
panic("concurrent write to websocket connection")
|
|
|
}
|
|
|
c.isWriting = false
|
|
|
|
|
|
- // Setup for next frame.
|
|
|
- c.writePos = maxFrameHeaderSize
|
|
|
- c.writeFrameType = continuationFrame
|
|
|
+ if err != nil {
|
|
|
+ return w.fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
if final {
|
|
|
- c.messageWriter = nil
|
|
|
c.writer = nil
|
|
|
- c.writeFrameType = noFrame
|
|
|
+ return nil
|
|
|
}
|
|
|
- return c.writeErr
|
|
|
-}
|
|
|
-
|
|
|
-type messageWriter struct{ c *Conn }
|
|
|
|
|
|
-func (w *messageWriter) err() error {
|
|
|
- c := w.c
|
|
|
- if c.messageWriter != w {
|
|
|
- return errWriteClosed
|
|
|
- }
|
|
|
- if c.writeErr != nil {
|
|
|
- return c.writeErr
|
|
|
- }
|
|
|
+ // Setup for next frame.
|
|
|
+ w.pos = maxFrameHeaderSize
|
|
|
+ w.frameType = continuationFrame
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
func (w *messageWriter) ncopy(max int) (int, error) {
|
|
|
- n := len(w.c.writeBuf) - w.c.writePos
|
|
|
+ n := len(w.c.writeBuf) - w.pos
|
|
|
if n <= 0 {
|
|
|
- if err := w.c.flushFrame(false, nil); err != nil {
|
|
|
+ if err := w.flushFrame(false, nil); err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
- n = len(w.c.writeBuf) - w.c.writePos
|
|
|
+ n = len(w.c.writeBuf) - w.pos
|
|
|
}
|
|
|
if n > max {
|
|
|
n = max
|
|
|
@@ -551,13 +567,13 @@ func (w *messageWriter) ncopy(max int) (int, error) {
|
|
|
}
|
|
|
|
|
|
func (w *messageWriter) Write(p []byte) (int, error) {
|
|
|
- if err := w.err(); err != nil {
|
|
|
- return 0, err
|
|
|
+ if w.err != nil {
|
|
|
+ return 0, w.err
|
|
|
}
|
|
|
|
|
|
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
|
|
|
// Don't buffer large messages.
|
|
|
- err := w.c.flushFrame(false, p)
|
|
|
+ err := w.flushFrame(false, p)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
@@ -570,16 +586,16 @@ func (w *messageWriter) Write(p []byte) (int, error) {
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
- copy(w.c.writeBuf[w.c.writePos:], p[:n])
|
|
|
- w.c.writePos += n
|
|
|
+ copy(w.c.writeBuf[w.pos:], p[:n])
|
|
|
+ w.pos += n
|
|
|
p = p[n:]
|
|
|
}
|
|
|
return nn, nil
|
|
|
}
|
|
|
|
|
|
func (w *messageWriter) WriteString(p string) (int, error) {
|
|
|
- if err := w.err(); err != nil {
|
|
|
- return 0, err
|
|
|
+ if w.err != nil {
|
|
|
+ return 0, w.err
|
|
|
}
|
|
|
|
|
|
nn := len(p)
|
|
|
@@ -588,27 +604,27 @@ func (w *messageWriter) WriteString(p string) (int, error) {
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
- copy(w.c.writeBuf[w.c.writePos:], p[:n])
|
|
|
- w.c.writePos += n
|
|
|
+ copy(w.c.writeBuf[w.pos:], p[:n])
|
|
|
+ w.pos += n
|
|
|
p = p[n:]
|
|
|
}
|
|
|
return nn, nil
|
|
|
}
|
|
|
|
|
|
func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
|
|
|
- if err := w.err(); err != nil {
|
|
|
- return 0, err
|
|
|
+ if w.err != nil {
|
|
|
+ return 0, w.err
|
|
|
}
|
|
|
for {
|
|
|
- if w.c.writePos == len(w.c.writeBuf) {
|
|
|
- err = w.c.flushFrame(false, nil)
|
|
|
+ if w.pos == len(w.c.writeBuf) {
|
|
|
+ err = w.flushFrame(false, nil)
|
|
|
if err != nil {
|
|
|
break
|
|
|
}
|
|
|
}
|
|
|
var n int
|
|
|
- n, err = r.Read(w.c.writeBuf[w.c.writePos:])
|
|
|
- w.c.writePos += n
|
|
|
+ n, err = r.Read(w.c.writeBuf[w.pos:])
|
|
|
+ w.pos += n
|
|
|
nn += int64(n)
|
|
|
if err != nil {
|
|
|
if err == io.EOF {
|
|
|
@@ -621,10 +637,14 @@ func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
|
|
|
}
|
|
|
|
|
|
func (w *messageWriter) Close() error {
|
|
|
- if err := w.err(); err != nil {
|
|
|
+ if w.err != nil {
|
|
|
+ return w.err
|
|
|
+ }
|
|
|
+ if err := w.flushFrame(true, nil); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- return w.c.flushFrame(true, nil)
|
|
|
+ w.err = errWriteClosed
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
// WriteMessage is a helper method for getting a writer using NextWriter,
|
|
|
@@ -634,12 +654,12 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- if _, ok := w.(*messageWriter); ok && c.isServer {
|
|
|
+ if mw, ok := w.(*messageWriter); ok && c.isServer {
|
|
|
// Optimize write as a single frame.
|
|
|
- n := copy(c.writeBuf[c.writePos:], data)
|
|
|
- c.writePos += n
|
|
|
+ n := copy(c.writeBuf[mw.pos:], data)
|
|
|
+ mw.pos += n
|
|
|
data = data[n:]
|
|
|
- err = c.flushFrame(true, data)
|
|
|
+ err = mw.flushFrame(true, data)
|
|
|
return err
|
|
|
}
|
|
|
if _, err = w.Write(data); err != nil {
|