Просмотр исходного кода

Read Limit Fix (#537)

This fix addresses a potential denial-of-service (DoS) vector that can cause an integer overflow in the presence of malicious WebSocket frames.

The fix adds additional checks against the remaining bytes on a connection, as well as a test to prevent regression.

Credit to Max Justicz (https://justi.cz/) for discovering and reporting this, as well as providing a robust PoC and review.

* build: go.mod to go1.12
* bugfix: fix DoS vector caused by readLimit bypass
* test: update TestReadLimit sub-test
* bugfix: payload length 127 should read bytes as uint64
* bugfix: defend against readLength overflows
Matt Silverlock 6 лет назад
Родитель
Сommit
5b740c2926
3 измененных файлов с 138 добавлено и 37 удалено
  1. 48 10
      conn.go
  2. 88 27
      conn_test.go
  3. 2 0
      go.mod

+ 48 - 10
conn.go

@@ -260,10 +260,12 @@ type Conn struct {
 	newCompressionWriter   func(io.WriteCloser, int) io.WriteCloser
 
 	// Read fields
-	reader        io.ReadCloser // the current reader returned to the application
-	readErr       error
-	br            *bufio.Reader
-	readRemaining int64 // bytes remaining in current frame.
+	reader  io.ReadCloser // the current reader returned to the application
+	readErr error
+	br      *bufio.Reader
+	// bytes remaining in current frame.
+	// set setReadRemaining to safely update this value and prevent overflow
+	readRemaining int64
 	readFinal     bool  // true the current message has more frames.
 	readLength    int64 // Message size.
 	readLimit     int64 // Maximum message size.
@@ -320,6 +322,17 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
 	return c
 }
 
+// setReadRemaining tracks the number of bytes remaining on the connection. If n
+// overflows, an ErrReadLimit is returned.
+func (c *Conn) setReadRemaining(n int64) error {
+	if n < 0 {
+		return ErrReadLimit
+	}
+
+	c.readRemaining = n
+	return nil
+}
+
 // Subprotocol returns the negotiated protocol for the connection.
 func (c *Conn) Subprotocol() string {
 	return c.subprotocol
@@ -790,7 +803,7 @@ func (c *Conn) advanceFrame() (int, error) {
 	final := p[0]&finalBit != 0
 	frameType := int(p[0] & 0xf)
 	mask := p[1]&maskBit != 0
-	c.readRemaining = int64(p[1] & 0x7f)
+	c.setReadRemaining(int64(p[1] & 0x7f))
 
 	c.readDecompress = false
 	if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
@@ -824,7 +837,17 @@ func (c *Conn) advanceFrame() (int, error) {
 		return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
 	}
 
-	// 3. Read and parse frame length.
+	// 3. Read and parse frame length as per
+	// https://tools.ietf.org/html/rfc6455#section-5.2
+	//
+	// The length of the "Payload data", in bytes: if 0-125, that is the payload
+	// length.
+	// - If 126, the following 2 bytes interpreted as a 16-bit unsigned
+	// integer are the payload length.
+	// - If 127, the following 8 bytes interpreted as
+	// a 64-bit unsigned integer (the most significant bit MUST be 0) are the
+	// payload length. Multibyte length quantities are expressed in network byte
+	// order.
 
 	switch c.readRemaining {
 	case 126:
@@ -832,13 +855,19 @@ func (c *Conn) advanceFrame() (int, error) {
 		if err != nil {
 			return noFrame, err
 		}
-		c.readRemaining = int64(binary.BigEndian.Uint16(p))
+
+		if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
+			return noFrame, err
+		}
 	case 127:
 		p, err := c.read(8)
 		if err != nil {
 			return noFrame, err
 		}
-		c.readRemaining = int64(binary.BigEndian.Uint64(p))
+
+		if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
+			return noFrame, err
+		}
 	}
 
 	// 4. Handle frame masking.
@@ -861,6 +890,12 @@ func (c *Conn) advanceFrame() (int, error) {
 	if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
 
 		c.readLength += c.readRemaining
+		// Don't allow readLength to overflow in the presence of a large readRemaining
+		// counter.
+		if c.readLength < 0 {
+			return noFrame, ErrReadLimit
+		}
+
 		if c.readLimit > 0 && c.readLength > c.readLimit {
 			c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
 			return noFrame, ErrReadLimit
@@ -874,7 +909,7 @@ func (c *Conn) advanceFrame() (int, error) {
 	var payload []byte
 	if c.readRemaining > 0 {
 		payload, err = c.read(int(c.readRemaining))
-		c.readRemaining = 0
+		c.setReadRemaining(0)
 		if err != nil {
 			return noFrame, err
 		}
@@ -947,6 +982,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
 			c.readErr = hideTempErr(err)
 			break
 		}
+
 		if frameType == TextMessage || frameType == BinaryMessage {
 			c.messageReader = &messageReader{c}
 			c.reader = c.messageReader
@@ -987,7 +1023,9 @@ func (r *messageReader) Read(b []byte) (int, error) {
 			if c.isServer {
 				c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
 			}
-			c.readRemaining -= int64(n)
+			rem := c.readRemaining
+			rem -= int64(n)
+			c.setReadRemaining(rem)
 			if c.readRemaining > 0 && c.readErr == io.EOF {
 				c.readErr = errUnexpectedEOF
 			}

+ 88 - 27
conn_test.go

@@ -55,7 +55,10 @@ func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
 }
 
 func TestFraming(t *testing.T) {
-	frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
+	frameSizes := []int{
+		0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
+		// 65536, 65537
+	}
 	var readChunkers = []struct {
 		name string
 		f    func(io.Reader) io.Reader
@@ -120,6 +123,8 @@ func TestFraming(t *testing.T) {
 							t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
 							continue
 						}
+
+						t.Logf("frame size: %d", n)
 						rbuf, err := ioutil.ReadAll(r)
 						if err != nil {
 							t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
@@ -458,37 +463,93 @@ func TestWriteAfterMessageWriterClose(t *testing.T) {
 }
 
 func TestReadLimit(t *testing.T) {
+	t.Run("Test ReadLimit is enforced", func(t *testing.T) {
+		const readLimit = 512
+		message := make([]byte, readLimit+1)
 
-	const readLimit = 512
-	message := make([]byte, readLimit+1)
+		var b1, b2 bytes.Buffer
+		wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
+		rc := newTestConn(&b1, &b2, true)
+		rc.SetReadLimit(readLimit)
 
-	var b1, b2 bytes.Buffer
-	wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
-	rc := newTestConn(&b1, &b2, true)
-	rc.SetReadLimit(readLimit)
+		// Send message at the limit with interleaved pong.
+		w, _ := wc.NextWriter(BinaryMessage)
+		w.Write(message[:readLimit-1])
+		wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
+		w.Write(message[:1])
+		w.Close()
 
-	// Send message at the limit with interleaved pong.
-	w, _ := wc.NextWriter(BinaryMessage)
-	w.Write(message[:readLimit-1])
-	wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
-	w.Write(message[:1])
-	w.Close()
+		// Send message larger than the limit.
+		wc.WriteMessage(BinaryMessage, message[:readLimit+1])
 
-	// Send message larger than the limit.
-	wc.WriteMessage(BinaryMessage, message[:readLimit+1])
+		op, _, err := rc.NextReader()
+		if op != BinaryMessage || err != nil {
+			t.Fatalf("1: NextReader() returned %d, %v", op, err)
+		}
+		op, r, err := rc.NextReader()
+		if op != BinaryMessage || err != nil {
+			t.Fatalf("2: NextReader() returned %d, %v", op, err)
+		}
+		_, err = io.Copy(ioutil.Discard, r)
+		if err != ErrReadLimit {
+			t.Fatalf("io.Copy() returned %v", err)
+		}
+	})
 
-	op, _, err := rc.NextReader()
-	if op != BinaryMessage || err != nil {
-		t.Fatalf("1: NextReader() returned %d, %v", op, err)
-	}
-	op, r, err := rc.NextReader()
-	if op != BinaryMessage || err != nil {
-		t.Fatalf("2: NextReader() returned %d, %v", op, err)
-	}
-	_, err = io.Copy(ioutil.Discard, r)
-	if err != ErrReadLimit {
-		t.Fatalf("io.Copy() returned %v", err)
-	}
+	t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) {
+		const readLimit = 1
+
+		var b1, b2 bytes.Buffer
+		rc := newTestConn(&b1, &b2, true)
+		rc.SetReadLimit(readLimit)
+
+		// First, send a non-final binary message
+		b1.Write([]byte("\x02\x81"))
+
+		// Mask key
+		b1.Write([]byte("\x00\x00\x00\x00"))
+
+		// First payload
+		b1.Write([]byte("A"))
+
+		// Next, send a negative-length, non-final continuation frame
+		b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00"))
+
+		// Mask key
+		b1.Write([]byte("\x00\x00\x00\x00"))
+
+		// Next, send a too long, final continuation frame
+		b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05"))
+
+		// Mask key
+		b1.Write([]byte("\x00\x00\x00\x00"))
+
+		// Too-long payload
+		b1.Write([]byte("BCDEF"))
+
+		op, r, err := rc.NextReader()
+		if op != BinaryMessage || err != nil {
+			t.Fatalf("1: NextReader() returned %d, %v", op, err)
+		}
+
+		var buf [10]byte
+		var read int
+		n, err := r.Read(buf[:])
+		if err != nil && err != ErrReadLimit {
+			t.Fatalf("unexpected error testing read limit: %v", err)
+		}
+		read += n
+
+		n, err = r.Read(buf[:])
+		if err != nil && err != ErrReadLimit {
+			t.Fatalf("unexpected error testing read limit: %v", err)
+		}
+		read += n
+
+		if err == nil && read > readLimit {
+			t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read)
+		}
+	})
 }
 
 func TestAddrs(t *testing.T) {

+ 2 - 0
go.mod

@@ -1 +1,3 @@
 module github.com/gorilla/websocket
+
+go 1.12