Quellcode durchsuchen

internal/socket: don't crash with empty Message.Buffers

Fixes golang/go#22117.

Change-Id: I0d9c3e126aaf97cd297c84e064e9a521ddac626f
Reviewed-on: https://go-review.googlesource.com/67750
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Mikio Hara vor 8 Jahren
Ursprung
Commit
4f8c76a975

+ 5 - 1
internal/socket/iovec_32bit.go

@@ -10,6 +10,10 @@ package socket
 import "unsafe"
 
 func (v *iovec) set(b []byte) {
+	l := len(b)
+	if l == 0 {
+		return
+	}
 	v.Base = (*byte)(unsafe.Pointer(&b[0]))
-	v.Len = uint32(len(b))
+	v.Len = uint32(l)
 }

+ 5 - 1
internal/socket/iovec_64bit.go

@@ -10,6 +10,10 @@ package socket
 import "unsafe"
 
 func (v *iovec) set(b []byte) {
+	l := len(b)
+	if l == 0 {
+		return
+	}
 	v.Base = (*byte)(unsafe.Pointer(&b[0]))
-	v.Len = uint64(len(b))
+	v.Len = uint64(l)
 }

+ 5 - 1
internal/socket/iovec_solaris_64bit.go

@@ -10,6 +10,10 @@ package socket
 import "unsafe"
 
 func (v *iovec) set(b []byte) {
+	l := len(b)
+	if l == 0 {
+		return
+	}
 	v.Base = (*int8)(unsafe.Pointer(&b[0]))
-	v.Len = uint64(len(b))
+	v.Len = uint64(l)
 }

+ 5 - 1
internal/socket/msghdr_bsdvar.go

@@ -7,6 +7,10 @@
 package socket
 
 func (h *msghdr) setIov(vs []iovec) {
+	l := len(vs)
+	if l == 0 {
+		return
+	}
 	h.Iov = &vs[0]
-	h.Iovlen = int32(len(vs))
+	h.Iovlen = int32(l)
 }

+ 5 - 1
internal/socket/msghdr_linux_32bit.go

@@ -10,8 +10,12 @@ package socket
 import "unsafe"
 
 func (h *msghdr) setIov(vs []iovec) {
+	l := len(vs)
+	if l == 0 {
+		return
+	}
 	h.Iov = &vs[0]
-	h.Iovlen = uint32(len(vs))
+	h.Iovlen = uint32(l)
 }
 
 func (h *msghdr) setControl(b []byte) {

+ 5 - 1
internal/socket/msghdr_linux_64bit.go

@@ -10,8 +10,12 @@ package socket
 import "unsafe"
 
 func (h *msghdr) setIov(vs []iovec) {
+	l := len(vs)
+	if l == 0 {
+		return
+	}
 	h.Iov = &vs[0]
-	h.Iovlen = uint64(len(vs))
+	h.Iovlen = uint64(l)
 }
 
 func (h *msghdr) setControl(b []byte) {

+ 5 - 1
internal/socket/msghdr_openbsd.go

@@ -5,6 +5,10 @@
 package socket
 
 func (h *msghdr) setIov(vs []iovec) {
+	l := len(vs)
+	if l == 0 {
+		return
+	}
 	h.Iov = &vs[0]
-	h.Iovlen = uint32(len(vs))
+	h.Iovlen = uint32(l)
 }

+ 4 - 2
internal/socket/msghdr_solaris_64bit.go

@@ -13,8 +13,10 @@ func (h *msghdr) pack(vs []iovec, bs [][]byte, oob []byte, sa []byte) {
 	for i := range vs {
 		vs[i].set(bs[i])
 	}
-	h.Iov = &vs[0]
-	h.Iovlen = int32(len(vs))
+	if len(vs) > 0 {
+		h.Iov = &vs[0]
+		h.Iovlen = int32(len(vs))
+	}
 	if len(oob) > 0 {
 		h.Accrights = (*int8)(unsafe.Pointer(&oob[0]))
 		h.Accrightslen = int32(len(oob))

+ 64 - 61
internal/socket/socket_go1_9_test.go

@@ -119,81 +119,84 @@ func TestUDP(t *testing.T) {
 		t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
 	}
 	defer c.Close()
+	cc, err := socket.NewConn(c.(net.Conn))
+	if err != nil {
+		t.Fatal(err)
+	}
 
 	t.Run("Message", func(t *testing.T) {
-		testUDPMessage(t, c.(net.Conn))
+		data := []byte("HELLO-R-U-THERE")
+		wm := socket.Message{
+			Buffers: bytes.SplitAfter(data, []byte("-")),
+			Addr:    c.LocalAddr(),
+		}
+		if err := cc.SendMsg(&wm, 0); err != nil {
+			t.Fatal(err)
+		}
+		b := make([]byte, 32)
+		rm := socket.Message{
+			Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]},
+		}
+		if err := cc.RecvMsg(&rm, 0); err != nil {
+			t.Fatal(err)
+		}
+		if !bytes.Equal(b[:rm.N], data) {
+			t.Fatalf("got %#v; want %#v", b[:rm.N], data)
+		}
 	})
 	switch runtime.GOOS {
 	case "linux":
 		t.Run("Messages", func(t *testing.T) {
-			testUDPMessages(t, c.(net.Conn))
+			data := []byte("HELLO-R-U-THERE")
+			wmbs := bytes.SplitAfter(data, []byte("-"))
+			wms := []socket.Message{
+				{Buffers: wmbs[:1], Addr: c.LocalAddr()},
+				{Buffers: wmbs[1:], Addr: c.LocalAddr()},
+			}
+			n, err := cc.SendMsgs(wms, 0)
+			if err != nil {
+				t.Fatal(err)
+			}
+			if n != len(wms) {
+				t.Fatalf("got %d; want %d", n, len(wms))
+			}
+			b := make([]byte, 32)
+			rmbs := [][][]byte{{b[:len(wmbs[0])]}, {b[len(wmbs[0]):]}}
+			rms := []socket.Message{
+				{Buffers: rmbs[0]},
+				{Buffers: rmbs[1]},
+			}
+			n, err = cc.RecvMsgs(rms, 0)
+			if err != nil {
+				t.Fatal(err)
+			}
+			if n != len(rms) {
+				t.Fatalf("got %d; want %d", n, len(rms))
+			}
+			nn := 0
+			for i := 0; i < n; i++ {
+				nn += rms[i].N
+			}
+			if !bytes.Equal(b[:nn], data) {
+				t.Fatalf("got %#v; want %#v", b[:nn], data)
+			}
 		})
 	}
-}
 
-func testUDPMessage(t *testing.T, c net.Conn) {
-	cc, err := socket.NewConn(c)
-	if err != nil {
-		t.Fatal(err)
-	}
-	data := []byte("HELLO-R-U-THERE")
+	// The behavior of transmission for zero byte paylaod depends
+	// on each platform implementation. Some may transmit only
+	// protocol header and options, other may transmit nothing.
+	// We test only that SendMsg and SendMsgs will not crash with
+	// empty buffers.
 	wm := socket.Message{
-		Buffers: bytes.SplitAfter(data, []byte("-")),
+		Buffers: [][]byte{{}},
 		Addr:    c.LocalAddr(),
 	}
-	if err := cc.SendMsg(&wm, 0); err != nil {
-		t.Fatal(err)
-	}
-	b := make([]byte, 32)
-	rm := socket.Message{
-		Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]},
-	}
-	if err := cc.RecvMsg(&rm, 0); err != nil {
-		t.Fatal(err)
-	}
-	if !bytes.Equal(b[:rm.N], data) {
-		t.Fatalf("got %#v; want %#v", b[:rm.N], data)
-	}
-}
-
-func testUDPMessages(t *testing.T, c net.Conn) {
-	cc, err := socket.NewConn(c)
-	if err != nil {
-		t.Fatal(err)
-	}
-	data := []byte("HELLO-R-U-THERE")
-	wmbs := bytes.SplitAfter(data, []byte("-"))
+	cc.SendMsg(&wm, 0)
 	wms := []socket.Message{
-		{Buffers: wmbs[:1], Addr: c.LocalAddr()},
-		{Buffers: wmbs[1:], Addr: c.LocalAddr()},
-	}
-	n, err := cc.SendMsgs(wms, 0)
-	if err != nil {
-		t.Fatal(err)
-	}
-	if n != len(wms) {
-		t.Fatalf("got %d; want %d", n, len(wms))
-	}
-	b := make([]byte, 32)
-	rmbs := [][][]byte{{b[:len(wmbs[0])]}, {b[len(wmbs[0]):]}}
-	rms := []socket.Message{
-		{Buffers: rmbs[0]},
-		{Buffers: rmbs[1]},
-	}
-	n, err = cc.RecvMsgs(rms, 0)
-	if err != nil {
-		t.Fatal(err)
-	}
-	if n != len(rms) {
-		t.Fatalf("got %d; want %d", n, len(rms))
-	}
-	nn := 0
-	for i := 0; i < n; i++ {
-		nn += rms[i].N
-	}
-	if !bytes.Equal(b[:nn], data) {
-		t.Fatalf("got %#v; want %#v", b[:nn], data)
+		{Buffers: [][]byte{{}}, Addr: c.LocalAddr()},
 	}
+	cc.SendMsgs(wms, 0)
 }
 
 func BenchmarkUDP(b *testing.B) {