Przeglądaj źródła

internal/socket: tell race detector about syscall reads and writes

The syscalls that send and receive messages write to buffers provided
by the user. The race detector can't see those reads and writes by
default (they are done by the kernel), so we need to tell the race
detector explicitly about them.

Fixes golang/go#35329

Change-Id: Ibf4ef1b937535c4834aa9eeb744722d91f669a27
Reviewed-on: https://go-review.googlesource.com/c/net/+/205461
Run-TryBot: Keith Randall <khr@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
Keith Randall 6 lat temu
rodzic
commit
2180aed223

+ 12 - 0
internal/socket/norace.go

@@ -0,0 +1,12 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !race
+
+package socket
+
+func (m *Message) raceRead() {
+}
+func (m *Message) raceWrite() {
+}

+ 37 - 0
internal/socket/race.go

@@ -0,0 +1,37 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build race
+
+package socket
+
+import (
+	"runtime"
+	"unsafe"
+)
+
+// This package reads and writes the Message buffers using a
+// direct system call, which the race detector can't see.
+// These functions tell the race detector what is going on during the syscall.
+
+func (m *Message) raceRead() {
+	for _, b := range m.Buffers {
+		if len(b) > 0 {
+			runtime.RaceReadRange(unsafe.Pointer(&b[0]), len(b))
+		}
+	}
+	if b := m.OOB; len(b) > 0 {
+		runtime.RaceReadRange(unsafe.Pointer(&b[0]), len(b))
+	}
+}
+func (m *Message) raceWrite() {
+	for _, b := range m.Buffers {
+		if len(b) > 0 {
+			runtime.RaceWriteRange(unsafe.Pointer(&b[0]), len(b))
+		}
+	}
+	if b := m.OOB; len(b) > 0 {
+		runtime.RaceWriteRange(unsafe.Pointer(&b[0]), len(b))
+	}
+}

+ 6 - 0
internal/socket/rawconn_mmsg.go

@@ -13,6 +13,9 @@ import (
 )
 
 func (c *Conn) recvMsgs(ms []Message, flags int) (int, error) {
+	for i := range ms {
+		ms[i].raceWrite()
+	}
 	hs := make(mmsghdrs, len(ms))
 	var parseFn func([]byte, string) (net.Addr, error)
 	if c.network != "tcp" {
@@ -43,6 +46,9 @@ func (c *Conn) recvMsgs(ms []Message, flags int) (int, error) {
 }
 
 func (c *Conn) sendMsgs(ms []Message, flags int) (int, error) {
+	for i := range ms {
+		ms[i].raceRead()
+	}
 	hs := make(mmsghdrs, len(ms))
 	var marshalFn func(net.Addr) []byte
 	if c.network != "tcp" {

+ 2 - 0
internal/socket/rawconn_msg.go

@@ -12,6 +12,7 @@ import (
 )
 
 func (c *Conn) recvMsg(m *Message, flags int) error {
+	m.raceWrite()
 	var h msghdr
 	vs := make([]iovec, len(m.Buffers))
 	var sa []byte
@@ -48,6 +49,7 @@ func (c *Conn) recvMsg(m *Message, flags int) error {
 }
 
 func (c *Conn) sendMsg(m *Message, flags int) error {
+	m.raceRead()
 	var h msghdr
 	vs := make([]iovec, len(m.Buffers))
 	var sa []byte

+ 69 - 0
internal/socket/socket_test.go

@@ -9,8 +9,13 @@ package socket_test
 import (
 	"bytes"
 	"fmt"
+	"io/ioutil"
 	"net"
+	"os"
+	"os/exec"
+	"path/filepath"
 	"runtime"
+	"strings"
 	"syscall"
 	"testing"
 
@@ -296,3 +301,67 @@ func BenchmarkUDP(b *testing.B) {
 		}
 	}
 }
+
+func TestRace(t *testing.T) {
+	tests := []string{
+		`
+package main
+import "net"
+import "golang.org/x/net/ipv4"
+var g byte
+func main() {
+	c, _ := net.ListenPacket("udp", "127.0.0.1:0")
+	cc := ipv4.NewPacketConn(c)
+	sync := make(chan bool)
+	src := make([]byte, 1)
+	dst := make([]byte, 1)
+	go func() { cc.WriteTo(src, nil, c.LocalAddr()) }()
+	go func() { cc.ReadFrom(dst); sync <- true }()
+	g = dst[0]
+	<- sync
+}
+`,
+		`
+package main
+import "net"
+import "golang.org/x/net/ipv4"
+func main() {
+	c, _ := net.ListenPacket("udp", "127.0.0.1:0")
+	cc := ipv4.NewPacketConn(c)
+	sync := make(chan bool)
+	src := make([]byte, 1)
+	dst := make([]byte, 1)
+	go func() { cc.WriteTo(src, nil, c.LocalAddr()); sync <- true }()
+	src[0] = 0
+	go func() { cc.ReadFrom(dst) }()
+	<- sync
+}
+`,
+	}
+	platforms := map[string]bool{
+		"linux/amd64":   true,
+		"linux/ppc64le": true,
+		"linux/arm64":   true,
+	}
+	if !platforms[runtime.GOOS+"/"+runtime.GOARCH] {
+		t.Skip("skipping test on non-race-enabled host.")
+	}
+	dir, err := ioutil.TempDir("", "testrace")
+	if err != nil {
+		t.Fatalf("failed to create temp directory: %v", err)
+	}
+	defer os.RemoveAll(dir)
+	goBinary := filepath.Join(runtime.GOROOT(), "bin", "go")
+	for i, test := range tests {
+		t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
+			src := filepath.Join(dir, fmt.Sprintf("test%d.go", i))
+			if err := ioutil.WriteFile(src, []byte(test), 0644); err != nil {
+				t.Fatalf("failed to write file: %v", err)
+			}
+			got, err := exec.Command(goBinary, "run", "-race", src).CombinedOutput()
+			if !strings.Contains(string(got), "WARNING: DATA RACE") {
+				t.Errorf("race not detected for test %d: err:%v out:%s", i, err, string(got))
+			}
+		})
+	}
+}