socket_test.go 8.7 KB


  1. // Copyright 2017 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows
  5. package socket_test
  6. import (
  7. "bytes"
  8. "fmt"
  9. "io/ioutil"
  10. "net"
  11. "os"
  12. "os/exec"
  13. "path/filepath"
  14. "runtime"
  15. "strings"
  16. "syscall"
  17. "testing"
  18. "golang.org/x/net/internal/socket"
  19. "golang.org/x/net/nettest"
  20. )
  21. func TestSocket(t *testing.T) {
  22. t.Run("Option", func(t *testing.T) {
  23. testSocketOption(t, &socket.Option{Level: syscall.SOL_SOCKET, Name: syscall.SO_RCVBUF, Len: 4})
  24. })
  25. }
  26. func testSocketOption(t *testing.T, so *socket.Option) {
  27. c, err := nettest.NewLocalPacketListener("udp")
  28. if err != nil {
  29. t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
  30. }
  31. defer c.Close()
  32. cc, err := socket.NewConn(c.(net.Conn))
  33. if err != nil {
  34. t.Fatal(err)
  35. }
  36. const N = 2048
  37. if err := so.SetInt(cc, N); err != nil {
  38. t.Fatal(err)
  39. }
  40. n, err := so.GetInt(cc)
  41. if err != nil {
  42. t.Fatal(err)
  43. }
  44. if n < N {
  45. t.Fatalf("got %d; want greater than or equal to %d", n, N)
  46. }
  47. }
  48. type mockControl struct {
  49. Level int
  50. Type int
  51. Data []byte
  52. }
  53. func TestControlMessage(t *testing.T) {
  54. switch runtime.GOOS {
  55. case "windows":
  56. t.Skipf("not supported on %s", runtime.GOOS)
  57. }
  58. for _, tt := range []struct {
  59. cs []mockControl
  60. }{
  61. {
  62. []mockControl{
  63. {Level: 1, Type: 1},
  64. },
  65. },
  66. {
  67. []mockControl{
  68. {Level: 2, Type: 2, Data: []byte{0xfe}},
  69. },
  70. },
  71. {
  72. []mockControl{
  73. {Level: 3, Type: 3, Data: []byte{0xfe, 0xff, 0xff, 0xfe}},
  74. },
  75. },
  76. {
  77. []mockControl{
  78. {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
  79. },
  80. },
  81. {
  82. []mockControl{
  83. {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
  84. {Level: 2, Type: 2, Data: []byte{0xfe}},
  85. },
  86. },
  87. } {
  88. var w []byte
  89. var tailPadLen int
  90. mm := socket.NewControlMessage([]int{0})
  91. for i, c := range tt.cs {
  92. m := socket.NewControlMessage([]int{len(c.Data)})
  93. l := len(m) - len(mm)
  94. if i == len(tt.cs)-1 && l > len(c.Data) {
  95. tailPadLen = l - len(c.Data)
  96. }
  97. w = append(w, m...)
  98. }
  99. var err error
  100. ww := make([]byte, len(w))
  101. copy(ww, w)
  102. m := socket.ControlMessage(ww)
  103. for _, c := range tt.cs {
  104. if err = m.MarshalHeader(c.Level, c.Type, len(c.Data)); err != nil {
  105. t.Fatalf("(%v).MarshalHeader() = %v", tt.cs, err)
  106. }
  107. copy(m.Data(len(c.Data)), c.Data)
  108. m = m.Next(len(c.Data))
  109. }
  110. m = socket.ControlMessage(w)
  111. for _, c := range tt.cs {
  112. m, err = m.Marshal(c.Level, c.Type, c.Data)
  113. if err != nil {
  114. t.Fatalf("(%v).Marshal() = %v", tt.cs, err)
  115. }
  116. }
  117. if !bytes.Equal(ww, w) {
  118. t.Fatalf("got %#v; want %#v", ww, w)
  119. }
  120. ws := [][]byte{w}
  121. if tailPadLen > 0 {
  122. // Test a message with no tail padding.
  123. nopad := w[:len(w)-tailPadLen]
  124. ws = append(ws, [][]byte{nopad}...)
  125. }
  126. for _, w := range ws {
  127. ms, err := socket.ControlMessage(w).Parse()
  128. if err != nil {
  129. t.Fatalf("(%v).Parse() = %v", tt.cs, err)
  130. }
  131. for i, m := range ms {
  132. lvl, typ, dataLen, err := m.ParseHeader()
  133. if err != nil {
  134. t.Fatalf("(%v).ParseHeader() = %v", tt.cs, err)
  135. }
  136. if lvl != tt.cs[i].Level || typ != tt.cs[i].Type || dataLen != len(tt.cs[i].Data) {
  137. t.Fatalf("%v: got %d, %d, %d; want %d, %d, %d", tt.cs[i], lvl, typ, dataLen, tt.cs[i].Level, tt.cs[i].Type, len(tt.cs[i].Data))
  138. }
  139. }
  140. }
  141. }
  142. }
  143. func TestUDP(t *testing.T) {
  144. switch runtime.GOOS {
  145. case "windows":
  146. t.Skipf("not supported on %s", runtime.GOOS)
  147. }
  148. c, err := nettest.NewLocalPacketListener("udp")
  149. if err != nil {
  150. t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
  151. }
  152. defer c.Close()
  153. cc, err := socket.NewConn(c.(net.Conn))
  154. if err != nil {
  155. t.Fatal(err)
  156. }
  157. t.Run("Message", func(t *testing.T) {
  158. data := []byte("HELLO-R-U-THERE")
  159. wm := socket.Message{
  160. Buffers: bytes.SplitAfter(data, []byte("-")),
  161. Addr: c.LocalAddr(),
  162. }
  163. if err := cc.SendMsg(&wm, 0); err != nil {
  164. t.Fatal(err)
  165. }
  166. b := make([]byte, 32)
  167. rm := socket.Message{
  168. Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]},
  169. }
  170. if err := cc.RecvMsg(&rm, 0); err != nil {
  171. t.Fatal(err)
  172. }
  173. if !bytes.Equal(b[:rm.N], data) {
  174. t.Fatalf("got %#v; want %#v", b[:rm.N], data)
  175. }
  176. })
  177. switch runtime.GOOS {
  178. case "android", "linux":
  179. t.Run("Messages", func(t *testing.T) {
  180. data := []byte("HELLO-R-U-THERE")
  181. wmbs := bytes.SplitAfter(data, []byte("-"))
  182. wms := []socket.Message{
  183. {Buffers: wmbs[:1], Addr: c.LocalAddr()},
  184. {Buffers: wmbs[1:], Addr: c.LocalAddr()},
  185. }
  186. n, err := cc.SendMsgs(wms, 0)
  187. if err != nil {
  188. t.Fatal(err)
  189. }
  190. if n != len(wms) {
  191. t.Fatalf("got %d; want %d", n, len(wms))
  192. }
  193. b := make([]byte, 32)
  194. rmbs := [][][]byte{{b[:len(wmbs[0])]}, {b[len(wmbs[0]):]}}
  195. rms := []socket.Message{
  196. {Buffers: rmbs[0]},
  197. {Buffers: rmbs[1]},
  198. }
  199. n, err = cc.RecvMsgs(rms, 0)
  200. if err != nil {
  201. t.Fatal(err)
  202. }
  203. if n != len(rms) {
  204. t.Fatalf("got %d; want %d", n, len(rms))
  205. }
  206. nn := 0
  207. for i := 0; i < n; i++ {
  208. nn += rms[i].N
  209. }
  210. if !bytes.Equal(b[:nn], data) {
  211. t.Fatalf("got %#v; want %#v", b[:nn], data)
  212. }
  213. })
  214. }
  215. // The behavior of transmission for zero byte paylaod depends
  216. // on each platform implementation. Some may transmit only
  217. // protocol header and options, other may transmit nothing.
  218. // We test only that SendMsg and SendMsgs will not crash with
  219. // empty buffers.
  220. wm := socket.Message{
  221. Buffers: [][]byte{{}},
  222. Addr: c.LocalAddr(),
  223. }
  224. cc.SendMsg(&wm, 0)
  225. wms := []socket.Message{
  226. {Buffers: [][]byte{{}}, Addr: c.LocalAddr()},
  227. }
  228. cc.SendMsgs(wms, 0)
  229. }
  230. func BenchmarkUDP(b *testing.B) {
  231. c, err := nettest.NewLocalPacketListener("udp")
  232. if err != nil {
  233. b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
  234. }
  235. defer c.Close()
  236. cc, err := socket.NewConn(c.(net.Conn))
  237. if err != nil {
  238. b.Fatal(err)
  239. }
  240. data := []byte("HELLO-R-U-THERE")
  241. wm := socket.Message{
  242. Buffers: [][]byte{data},
  243. Addr: c.LocalAddr(),
  244. }
  245. rm := socket.Message{
  246. Buffers: [][]byte{make([]byte, 128)},
  247. OOB: make([]byte, 128),
  248. }
  249. for M := 1; M <= 1<<9; M = M << 1 {
  250. b.Run(fmt.Sprintf("Iter-%d", M), func(b *testing.B) {
  251. for i := 0; i < b.N; i++ {
  252. for j := 0; j < M; j++ {
  253. if err := cc.SendMsg(&wm, 0); err != nil {
  254. b.Fatal(err)
  255. }
  256. if err := cc.RecvMsg(&rm, 0); err != nil {
  257. b.Fatal(err)
  258. }
  259. }
  260. }
  261. })
  262. switch runtime.GOOS {
  263. case "android", "linux":
  264. wms := make([]socket.Message, M)
  265. for i := range wms {
  266. wms[i].Buffers = [][]byte{data}
  267. wms[i].Addr = c.LocalAddr()
  268. }
  269. rms := make([]socket.Message, M)
  270. for i := range rms {
  271. rms[i].Buffers = [][]byte{make([]byte, 128)}
  272. rms[i].OOB = make([]byte, 128)
  273. }
  274. b.Run(fmt.Sprintf("Batch-%d", M), func(b *testing.B) {
  275. for i := 0; i < b.N; i++ {
  276. if _, err := cc.SendMsgs(wms, 0); err != nil {
  277. b.Fatal(err)
  278. }
  279. if _, err := cc.RecvMsgs(rms, 0); err != nil {
  280. b.Fatal(err)
  281. }
  282. }
  283. })
  284. }
  285. }
  286. }
  287. func TestRace(t *testing.T) {
  288. tests := []string{
  289. `
  290. package main
  291. import "net"
  292. import "golang.org/x/net/ipv4"
  293. var g byte
  294. func main() {
  295. c, _ := net.ListenPacket("udp", "127.0.0.1:0")
  296. cc := ipv4.NewPacketConn(c)
  297. sync := make(chan bool)
  298. src := make([]byte, 1)
  299. dst := make([]byte, 1)
  300. go func() { cc.WriteTo(src, nil, c.LocalAddr()) }()
  301. go func() { cc.ReadFrom(dst); sync <- true }()
  302. g = dst[0]
  303. <- sync
  304. }
  305. `,
  306. `
  307. package main
  308. import "net"
  309. import "golang.org/x/net/ipv4"
  310. func main() {
  311. c, _ := net.ListenPacket("udp", "127.0.0.1:0")
  312. cc := ipv4.NewPacketConn(c)
  313. sync := make(chan bool)
  314. src := make([]byte, 1)
  315. dst := make([]byte, 1)
  316. go func() { cc.WriteTo(src, nil, c.LocalAddr()); sync <- true }()
  317. src[0] = 0
  318. go func() { cc.ReadFrom(dst) }()
  319. <- sync
  320. }
  321. `,
  322. }
  323. platforms := map[string]bool{
  324. "linux/amd64": true,
  325. "linux/ppc64le": true,
  326. "linux/arm64": true,
  327. }
  328. if !platforms[runtime.GOOS+"/"+runtime.GOARCH] {
  329. t.Skip("skipping test on non-race-enabled host.")
  330. }
  331. dir, err := ioutil.TempDir("", "testrace")
  332. if err != nil {
  333. t.Fatalf("failed to create temp directory: %v", err)
  334. }
  335. defer os.RemoveAll(dir)
  336. goBinary := filepath.Join(runtime.GOROOT(), "bin", "go")
  337. for i, test := range tests {
  338. t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
  339. src := filepath.Join(dir, fmt.Sprintf("test%d.go", i))
  340. if err := ioutil.WriteFile(src, []byte(test), 0644); err != nil {
  341. t.Fatalf("failed to write file: %v", err)
  342. }
  343. got, err := exec.Command(goBinary, "run", "-race", src).CombinedOutput()
  344. if strings.Contains(string(got), "-race requires cgo") {
  345. t.Log("CGO is not enabled so can't use -race")
  346. } else if !strings.Contains(string(got), "WARNING: DATA RACE") {
  347. t.Errorf("race not detected for test %d: err:%v out:%s", i, err, string(got))
  348. }
  349. })
  350. }
  351. }