Browse Source

Merge pull request #6349 from gyuho/decode-length-limit

rafthttp: check decode size before buffer alloc
Gyu-Ho Lee 9 years ago
parent
commit
a66b1e7c60
2 changed files with 56 additions and 22 deletions
  1. 9 0
      rafthttp/msg_codec.go
  2. 47 22
      rafthttp/msg_codec_test.go

+ 9 - 0
rafthttp/msg_codec.go

@@ -16,6 +16,7 @@ package rafthttp
 
 
 import (
 import (
 	"encoding/binary"
 	"encoding/binary"
+	"errors"
 	"io"
 	"io"
 
 
 	"github.com/coreos/etcd/pkg/pbutil"
 	"github.com/coreos/etcd/pkg/pbutil"
@@ -41,12 +42,20 @@ type messageDecoder struct {
 	r io.Reader
 	r io.Reader
 }
 }
 
 
+var (
+	readBytesLimit     uint64 = 512 * 1024 // 512 MB
+	ErrExceedSizeLimit        = errors.New("rafthttp: error limit exceeded")
+)
+
 func (dec *messageDecoder) decode() (raftpb.Message, error) {
 func (dec *messageDecoder) decode() (raftpb.Message, error) {
 	var m raftpb.Message
 	var m raftpb.Message
 	var l uint64
 	var l uint64
 	if err := binary.Read(dec.r, binary.BigEndian, &l); err != nil {
 	if err := binary.Read(dec.r, binary.BigEndian, &l); err != nil {
 		return m, err
 		return m, err
 	}
 	}
+	if l > readBytesLimit {
+		return m, ErrExceedSizeLimit
+	}
 	buf := make([]byte, int(l))
 	buf := make([]byte, int(l))
 	if _, err := io.ReadFull(dec.r, buf); err != nil {
 	if _, err := io.ReadFull(dec.r, buf); err != nil {
 		return m, err
 		return m, err

+ 47 - 22
rafthttp/msg_codec_test.go

@@ -23,43 +23,68 @@ import (
 )
 )
 
 
 func TestMessage(t *testing.T) {
 func TestMessage(t *testing.T) {
-	tests := []raftpb.Message{
+	tests := []struct {
+		msg       raftpb.Message
+		encodeErr error
+		decodeErr error
+	}{
 		{
 		{
-			Type:    raftpb.MsgApp,
-			From:    1,
-			To:      2,
-			Term:    1,
-			LogTerm: 1,
-			Index:   3,
-			Entries: []raftpb.Entry{{Term: 1, Index: 4}},
+			raftpb.Message{
+				Type:    raftpb.MsgApp,
+				From:    1,
+				To:      2,
+				Term:    1,
+				LogTerm: 1,
+				Index:   3,
+				Entries: []raftpb.Entry{{Term: 1, Index: 4}},
+			},
+			nil,
+			nil,
+		},
+		{
+			raftpb.Message{
+				Type: raftpb.MsgProp,
+				From: 1,
+				To:   2,
+				Entries: []raftpb.Entry{
+					{Data: []byte("some data")},
+					{Data: []byte("some data")},
+					{Data: []byte("some data")},
+				},
+			},
+			nil,
+			nil,
 		},
 		},
 		{
 		{
-			Type: raftpb.MsgProp,
-			From: 1,
-			To:   2,
-			Entries: []raftpb.Entry{
-				{Data: []byte("some data")},
-				{Data: []byte("some data")},
-				{Data: []byte("some data")},
+			raftpb.Message{
+				Type: raftpb.MsgProp,
+				From: 1,
+				To:   2,
+				Entries: []raftpb.Entry{
+					{Data: bytes.Repeat([]byte("a"), int(readBytesLimit+10))},
+				},
 			},
 			},
+			nil,
+			ErrExceedSizeLimit,
 		},
 		},
-		linkHeartbeatMessage,
 	}
 	}
 	for i, tt := range tests {
 	for i, tt := range tests {
 		b := &bytes.Buffer{}
 		b := &bytes.Buffer{}
 		enc := &messageEncoder{w: b}
 		enc := &messageEncoder{w: b}
-		if err := enc.encode(&tt); err != nil {
-			t.Errorf("#%d: unexpected encode message error: %v", i, err)
+		if err := enc.encode(&tt.msg); err != tt.encodeErr {
+			t.Errorf("#%d: encode message error expected %v, got %v", i, tt.encodeErr, err)
 			continue
 			continue
 		}
 		}
 		dec := &messageDecoder{r: b}
 		dec := &messageDecoder{r: b}
 		m, err := dec.decode()
 		m, err := dec.decode()
-		if err != nil {
-			t.Errorf("#%d: unexpected decode message error: %v", i, err)
+		if err != tt.decodeErr {
+			t.Errorf("#%d: decode message error expected %v, got %v", i, tt.decodeErr, err)
 			continue
 			continue
 		}
 		}
-		if !reflect.DeepEqual(m, tt) {
-			t.Errorf("#%d: message = %+v, want %+v", i, m, tt)
+		if err == nil {
+			if !reflect.DeepEqual(m, tt.msg) {
+				t.Errorf("#%d: message = %+v, want %+v", i, m, tt.msg)
+			}
 		}
 		}
 	}
 	}
 }
 }