123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369 |
- // Copyright 2017 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows
- package socket_test
- import (
- "bytes"
- "fmt"
- "io/ioutil"
- "net"
- "os"
- "os/exec"
- "path/filepath"
- "runtime"
- "strings"
- "syscall"
- "testing"
- "golang.org/x/net/internal/socket"
- "golang.org/x/net/nettest"
- )
- func TestSocket(t *testing.T) {
- t.Run("Option", func(t *testing.T) {
- testSocketOption(t, &socket.Option{Level: syscall.SOL_SOCKET, Name: syscall.SO_RCVBUF, Len: 4})
- })
- }
- func testSocketOption(t *testing.T, so *socket.Option) {
- c, err := nettest.NewLocalPacketListener("udp")
- if err != nil {
- t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
- }
- defer c.Close()
- cc, err := socket.NewConn(c.(net.Conn))
- if err != nil {
- t.Fatal(err)
- }
- const N = 2048
- if err := so.SetInt(cc, N); err != nil {
- t.Fatal(err)
- }
- n, err := so.GetInt(cc)
- if err != nil {
- t.Fatal(err)
- }
- if n < N {
- t.Fatalf("got %d; want greater than or equal to %d", n, N)
- }
- }
- type mockControl struct {
- Level int
- Type int
- Data []byte
- }
- func TestControlMessage(t *testing.T) {
- switch runtime.GOOS {
- case "windows":
- t.Skipf("not supported on %s", runtime.GOOS)
- }
- for _, tt := range []struct {
- cs []mockControl
- }{
- {
- []mockControl{
- {Level: 1, Type: 1},
- },
- },
- {
- []mockControl{
- {Level: 2, Type: 2, Data: []byte{0xfe}},
- },
- },
- {
- []mockControl{
- {Level: 3, Type: 3, Data: []byte{0xfe, 0xff, 0xff, 0xfe}},
- },
- },
- {
- []mockControl{
- {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
- },
- },
- {
- []mockControl{
- {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
- {Level: 2, Type: 2, Data: []byte{0xfe}},
- },
- },
- } {
- var w []byte
- var tailPadLen int
- mm := socket.NewControlMessage([]int{0})
- for i, c := range tt.cs {
- m := socket.NewControlMessage([]int{len(c.Data)})
- l := len(m) - len(mm)
- if i == len(tt.cs)-1 && l > len(c.Data) {
- tailPadLen = l - len(c.Data)
- }
- w = append(w, m...)
- }
- var err error
- ww := make([]byte, len(w))
- copy(ww, w)
- m := socket.ControlMessage(ww)
- for _, c := range tt.cs {
- if err = m.MarshalHeader(c.Level, c.Type, len(c.Data)); err != nil {
- t.Fatalf("(%v).MarshalHeader() = %v", tt.cs, err)
- }
- copy(m.Data(len(c.Data)), c.Data)
- m = m.Next(len(c.Data))
- }
- m = socket.ControlMessage(w)
- for _, c := range tt.cs {
- m, err = m.Marshal(c.Level, c.Type, c.Data)
- if err != nil {
- t.Fatalf("(%v).Marshal() = %v", tt.cs, err)
- }
- }
- if !bytes.Equal(ww, w) {
- t.Fatalf("got %#v; want %#v", ww, w)
- }
- ws := [][]byte{w}
- if tailPadLen > 0 {
- // Test a message with no tail padding.
- nopad := w[:len(w)-tailPadLen]
- ws = append(ws, [][]byte{nopad}...)
- }
- for _, w := range ws {
- ms, err := socket.ControlMessage(w).Parse()
- if err != nil {
- t.Fatalf("(%v).Parse() = %v", tt.cs, err)
- }
- for i, m := range ms {
- lvl, typ, dataLen, err := m.ParseHeader()
- if err != nil {
- t.Fatalf("(%v).ParseHeader() = %v", tt.cs, err)
- }
- if lvl != tt.cs[i].Level || typ != tt.cs[i].Type || dataLen != len(tt.cs[i].Data) {
- t.Fatalf("%v: got %d, %d, %d; want %d, %d, %d", tt.cs[i], lvl, typ, dataLen, tt.cs[i].Level, tt.cs[i].Type, len(tt.cs[i].Data))
- }
- }
- }
- }
- }
- func TestUDP(t *testing.T) {
- switch runtime.GOOS {
- case "windows":
- t.Skipf("not supported on %s", runtime.GOOS)
- }
- c, err := nettest.NewLocalPacketListener("udp")
- if err != nil {
- t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
- }
- defer c.Close()
- cc, err := socket.NewConn(c.(net.Conn))
- if err != nil {
- t.Fatal(err)
- }
- t.Run("Message", func(t *testing.T) {
- data := []byte("HELLO-R-U-THERE")
- wm := socket.Message{
- Buffers: bytes.SplitAfter(data, []byte("-")),
- Addr: c.LocalAddr(),
- }
- if err := cc.SendMsg(&wm, 0); err != nil {
- t.Fatal(err)
- }
- b := make([]byte, 32)
- rm := socket.Message{
- Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]},
- }
- if err := cc.RecvMsg(&rm, 0); err != nil {
- t.Fatal(err)
- }
- if !bytes.Equal(b[:rm.N], data) {
- t.Fatalf("got %#v; want %#v", b[:rm.N], data)
- }
- })
- switch runtime.GOOS {
- case "android", "linux":
- t.Run("Messages", func(t *testing.T) {
- data := []byte("HELLO-R-U-THERE")
- wmbs := bytes.SplitAfter(data, []byte("-"))
- wms := []socket.Message{
- {Buffers: wmbs[:1], Addr: c.LocalAddr()},
- {Buffers: wmbs[1:], Addr: c.LocalAddr()},
- }
- n, err := cc.SendMsgs(wms, 0)
- if err != nil {
- t.Fatal(err)
- }
- if n != len(wms) {
- t.Fatalf("got %d; want %d", n, len(wms))
- }
- b := make([]byte, 32)
- rmbs := [][][]byte{{b[:len(wmbs[0])]}, {b[len(wmbs[0]):]}}
- rms := []socket.Message{
- {Buffers: rmbs[0]},
- {Buffers: rmbs[1]},
- }
- n, err = cc.RecvMsgs(rms, 0)
- if err != nil {
- t.Fatal(err)
- }
- if n != len(rms) {
- t.Fatalf("got %d; want %d", n, len(rms))
- }
- nn := 0
- for i := 0; i < n; i++ {
- nn += rms[i].N
- }
- if !bytes.Equal(b[:nn], data) {
- t.Fatalf("got %#v; want %#v", b[:nn], data)
- }
- })
- }
- // The behavior of transmission for zero byte paylaod depends
- // on each platform implementation. Some may transmit only
- // protocol header and options, other may transmit nothing.
- // We test only that SendMsg and SendMsgs will not crash with
- // empty buffers.
- wm := socket.Message{
- Buffers: [][]byte{{}},
- Addr: c.LocalAddr(),
- }
- cc.SendMsg(&wm, 0)
- wms := []socket.Message{
- {Buffers: [][]byte{{}}, Addr: c.LocalAddr()},
- }
- cc.SendMsgs(wms, 0)
- }
- func BenchmarkUDP(b *testing.B) {
- c, err := nettest.NewLocalPacketListener("udp")
- if err != nil {
- b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
- }
- defer c.Close()
- cc, err := socket.NewConn(c.(net.Conn))
- if err != nil {
- b.Fatal(err)
- }
- data := []byte("HELLO-R-U-THERE")
- wm := socket.Message{
- Buffers: [][]byte{data},
- Addr: c.LocalAddr(),
- }
- rm := socket.Message{
- Buffers: [][]byte{make([]byte, 128)},
- OOB: make([]byte, 128),
- }
- for M := 1; M <= 1<<9; M = M << 1 {
- b.Run(fmt.Sprintf("Iter-%d", M), func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- for j := 0; j < M; j++ {
- if err := cc.SendMsg(&wm, 0); err != nil {
- b.Fatal(err)
- }
- if err := cc.RecvMsg(&rm, 0); err != nil {
- b.Fatal(err)
- }
- }
- }
- })
- switch runtime.GOOS {
- case "android", "linux":
- wms := make([]socket.Message, M)
- for i := range wms {
- wms[i].Buffers = [][]byte{data}
- wms[i].Addr = c.LocalAddr()
- }
- rms := make([]socket.Message, M)
- for i := range rms {
- rms[i].Buffers = [][]byte{make([]byte, 128)}
- rms[i].OOB = make([]byte, 128)
- }
- b.Run(fmt.Sprintf("Batch-%d", M), func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- if _, err := cc.SendMsgs(wms, 0); err != nil {
- b.Fatal(err)
- }
- if _, err := cc.RecvMsgs(rms, 0); err != nil {
- b.Fatal(err)
- }
- }
- })
- }
- }
- }
- func TestRace(t *testing.T) {
- tests := []string{
- `
- package main
- import "net"
- import "golang.org/x/net/ipv4"
- var g byte
- func main() {
- c, _ := net.ListenPacket("udp", "127.0.0.1:0")
- cc := ipv4.NewPacketConn(c)
- sync := make(chan bool)
- src := make([]byte, 1)
- dst := make([]byte, 1)
- go func() { cc.WriteTo(src, nil, c.LocalAddr()) }()
- go func() { cc.ReadFrom(dst); sync <- true }()
- g = dst[0]
- <- sync
- }
- `,
- `
- package main
- import "net"
- import "golang.org/x/net/ipv4"
- func main() {
- c, _ := net.ListenPacket("udp", "127.0.0.1:0")
- cc := ipv4.NewPacketConn(c)
- sync := make(chan bool)
- src := make([]byte, 1)
- dst := make([]byte, 1)
- go func() { cc.WriteTo(src, nil, c.LocalAddr()); sync <- true }()
- src[0] = 0
- go func() { cc.ReadFrom(dst) }()
- <- sync
- }
- `,
- }
- platforms := map[string]bool{
- "linux/amd64": true,
- "linux/ppc64le": true,
- "linux/arm64": true,
- }
- if !platforms[runtime.GOOS+"/"+runtime.GOARCH] {
- t.Skip("skipping test on non-race-enabled host.")
- }
- dir, err := ioutil.TempDir("", "testrace")
- if err != nil {
- t.Fatalf("failed to create temp directory: %v", err)
- }
- defer os.RemoveAll(dir)
- goBinary := filepath.Join(runtime.GOROOT(), "bin", "go")
- for i, test := range tests {
- t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
- src := filepath.Join(dir, fmt.Sprintf("test%d.go", i))
- if err := ioutil.WriteFile(src, []byte(test), 0644); err != nil {
- t.Fatalf("failed to write file: %v", err)
- }
- got, err := exec.Command(goBinary, "run", "-race", src).CombinedOutput()
- if strings.Contains(string(got), "-race requires cgo") {
- t.Log("CGO is not enabled so can't use -race")
- } else if !strings.Contains(string(got), "WARNING: DATA RACE") {
- t.Errorf("race not detected for test %d: err:%v out:%s", i, err, string(got))
- }
- })
- }
- }
|