浏览代码

goprotobuf: Fix a couple of instances where bad input could cause a panic.

R=r
CC=golang-dev
http://codereview.appspot.com/5554045
David Symonds 14 年之前
父节点
当前提交
22ac150239
共有 3 个文件被更改,包括 52 次插入4 次删除
  1. 1 1
      compiler/testdata/main.go
  2. 46 0
      proto/all_test.go
  3. 5 3
      proto/decode.go

+ 1 - 1
compiler/testdata/main.go

@@ -34,8 +34,8 @@
 package main
 package main
 
 
 import (
 import (
-	"./test.pb"
 	"./multi1.pb"
 	"./multi1.pb"
+	"./test.pb"
 )
 )
 
 
 func main() {
 func main() {

+ 46 - 0
proto/all_test.go

@@ -36,9 +36,12 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
 	"math"
 	"math"
+	"math/rand"
 	"reflect"
 	"reflect"
+	"runtime/debug"
 	"strings"
 	"strings"
 	"testing"
 	"testing"
+	"time"
 
 
 	. "./testdata/_obj/test_proto"
 	. "./testdata/_obj/test_proto"
 	. "code.google.com/p/goprotobuf/proto"
 	. "code.google.com/p/goprotobuf/proto"
@@ -1319,6 +1322,49 @@ func TestJSON(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestBadWireType(t *testing.T) {
+	b := []byte{7<<3 | 6} // field 7, wire type 6
+	pb := new(OtherMessage)
+	if err := Unmarshal(b, pb); err == nil {
+		t.Errorf("Unmarshal did not fail")
+	} else if !strings.Contains(err.Error(), "unknown wire type") {
+		t.Errorf("wrong error: %v", err)
+	}
+}
+
+func TestBytesWithInvalidLength(t *testing.T) {
+	// If a byte sequence has an invalid (negative) length, Unmarshal should not panic.
+	b := []byte{2<<3 | WireBytes, 0xff, 0xff, 0xff, 0xff, 0xff, 0}
+	Unmarshal(b, new(MyMessage))
+}
+
+func TestUnmarshalFuzz(t *testing.T) {
+	const N = 1000
+	seed := time.Now().UnixNano()
+	t.Logf("RNG seed is %d", seed)
+	rng := rand.New(rand.NewSource(seed))
+	buf := make([]byte, 20)
+	for i := 0; i < N; i++ {
+		for j := range buf {
+			buf[j] = byte(rng.Intn(256))
+		}
+		fuzzUnmarshal(t, buf)
+	}
+}
+
+func fuzzUnmarshal(t *testing.T, data []byte) {
+	defer func() {
+		if e := recover(); e != nil {
+			t.Errorf("These bytes caused a panic: %+v", data)
+			t.Logf("Stack:\n%s", debug.Stack())
+			t.FailNow()
+		}
+	}()
+
+	pb := new(MyMessage)
+	Unmarshal(data, pb)
+}
+
 func BenchmarkMarshal(b *testing.B) {
 func BenchmarkMarshal(b *testing.B) {
 	b.StopTimer()
 	b.StopTimer()
 
 

+ 5 - 3
proto/decode.go

@@ -181,9 +181,11 @@ func (p *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) {
 	}
 	}
 
 
 	nb := int(n)
 	nb := int(n)
+	if nb < 0 {
+		return nil, fmt.Errorf("proto: bad byte length %d", nb)
+	}
 	if p.index+nb > len(p.buf) {
 	if p.index+nb > len(p.buf) {
-		err = io.ErrUnexpectedEOF
-		return
+		return nil, io.ErrUnexpectedEOF
 	}
 	}
 
 
 	if !alloc {
 	if !alloc {
@@ -279,7 +281,7 @@ func (o *Buffer) skip(t reflect.Type, tag, wire int) error {
 			}
 			}
 		}
 		}
 	default:
 	default:
-		fmt.Fprintf(os.Stderr, "proto: can't skip wire type %d for %s\n", wire, t)
+		err = fmt.Errorf("proto: can't skip unknown wire type %d for %s", wire, t)
 	}
 	}
 	return err
 	return err
 }
 }