|
|
@@ -451,7 +451,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
-func (c *Conn) prepWrite(messageType int) error {
|
|
|
+// beginMessage prepares a connection and message writer for a new message.
|
|
|
+func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
|
|
|
// Close previous writer if not already closed by the application. It's
|
|
|
// probably better to return an error in this situation, but we cannot
|
|
|
// change this without breaking existing applications.
|
|
|
@@ -471,6 +472,10 @@ func (c *Conn) prepWrite(messageType int) error {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
+ mw.c = c
|
|
|
+ mw.frameType = messageType
|
|
|
+ mw.pos = maxFrameHeaderSize
|
|
|
+
|
|
|
if c.writeBuf == nil {
|
|
|
wpd, ok := c.writePool.Get().(writePoolData)
|
|
|
if ok {
|
|
|
@@ -491,16 +496,11 @@ func (c *Conn) prepWrite(messageType int) error {
|
|
|
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
|
|
|
// PongMessage) are supported.
|
|
|
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
|
|
- if err := c.prepWrite(messageType); err != nil {
|
|
|
+ var mw messageWriter
|
|
|
+ if err := c.beginMessage(&mw, messageType); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
-
|
|
|
- mw := &messageWriter{
|
|
|
- c: c,
|
|
|
- frameType: messageType,
|
|
|
- pos: maxFrameHeaderSize,
|
|
|
- }
|
|
|
- c.writer = mw
|
|
|
+ c.writer = &mw
|
|
|
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
|
|
|
w := c.newCompressionWriter(c.writer, c.compressionLevel)
|
|
|
mw.compress = true
|
|
|
@@ -517,10 +517,16 @@ type messageWriter struct {
|
|
|
err error
|
|
|
}
|
|
|
|
|
|
-func (w *messageWriter) fatal(err error) error {
|
|
|
+func (w *messageWriter) endMessage(err error) error {
|
|
|
if w.err != nil {
|
|
|
- w.err = err
|
|
|
- w.c.writer = nil
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ c := w.c
|
|
|
+ w.err = err
|
|
|
+ c.writer = nil
|
|
|
+ if c.writePool != nil {
|
|
|
+ c.writePool.Put(writePoolData{buf: c.writeBuf})
|
|
|
+ c.writeBuf = nil
|
|
|
}
|
|
|
return err
|
|
|
}
|
|
|
@@ -534,7 +540,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
|
|
// Check for invalid control frames.
|
|
|
if isControl(w.frameType) &&
|
|
|
(!final || length > maxControlFramePayloadSize) {
|
|
|
- return w.fatal(errInvalidControlFrame)
|
|
|
+ return w.endMessage(errInvalidControlFrame)
|
|
|
}
|
|
|
|
|
|
b0 := byte(w.frameType)
|
|
|
@@ -579,7 +585,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
|
|
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
|
|
|
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
|
|
|
if len(extra) > 0 {
|
|
|
- return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
|
|
|
+ return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -600,15 +606,11 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
|
|
c.isWriting = false
|
|
|
|
|
|
if err != nil {
|
|
|
- return w.fatal(err)
|
|
|
+ return w.endMessage(err)
|
|
|
}
|
|
|
|
|
|
if final {
|
|
|
- c.writer = nil
|
|
|
- if c.writePool != nil {
|
|
|
- c.writePool.Put(writePoolData{buf: c.writeBuf})
|
|
|
- c.writeBuf = nil
|
|
|
- }
|
|
|
+ w.endMessage(errWriteClosed)
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
@@ -709,7 +711,6 @@ func (w *messageWriter) Close() error {
|
|
|
if err := w.flushFrame(true, nil); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- w.err = errWriteClosed
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
@@ -742,10 +743,10 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
|
|
|
if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
|
|
|
// Fast path with no allocations and single frame.
|
|
|
|
|
|
- if err := c.prepWrite(messageType); err != nil {
|
|
|
+ var mw messageWriter
|
|
|
+ if err := c.beginMessage(&mw, messageType); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize}
|
|
|
n := copy(c.writeBuf[mw.pos:], data)
|
|
|
mw.pos += n
|
|
|
data = data[n:]
|