Browse Source

ipv4: add {Read,Write}Batch methods to {Packet,Raw}Conn

This change provides message IO functionality that may support the
construction of modern datagram transport protocols.

The modern datagram transport protocols on a multihomed node basically
need to control each packet path for traffic engineering by using
information belongs to network- or link-layer implementation. In
addtion, it's desirable to be able to do simultaneous transmission
across multiple network- or link-layer adjacencies wihtout any
additional cost.

The ReadBatch and WriteBatch methods of PacketConn and RawConn can be
used to read and write an IO message that contains the information of
network- or link-layer implementation, and read and write a batch of
IO messages on Linux. The Marshal and Parse methods of ControlMessage
and Header can help to marshal and parse information contained in IO
messages.

Updates golang/go#3661.

Change-Id: Ia84a9d3bc51641406eaaf4258f2a3066945cc323
Reviewed-on: https://go-review.googlesource.com/38275
Run-TryBot: Mikio Hara <mikioh.mikioh@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Mikio Hara 8 years ago
parent
commit
b7a1f62a47

+ 191 - 0
ipv4/batch.go

@@ -0,0 +1,191 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.9
+
+package ipv4
+
+import (
+	"net"
+	"runtime"
+	"syscall"
+
+	"golang.org/x/net/internal/socket"
+)
+
+// BUG(mikio): On Windows, the ReadBatch and WriteBatch methods of
+// PacketConn are not implemented.
+
+// BUG(mikio): On Windows, the ReadBatch and WriteBatch methods of
+// RawConn are not implemented.
+
+// A Message represents an IO message.
+//
+//	type Message struct {
+//		Buffers [][]byte
+//		OOB     []byte
+//		Addr    net.Addr
+//		N       int
+//		NN      int
+//		Flags   int
+//	}
+//
+// The Buffers fields represents a list of contiguous buffers, which
+// can be used for vectored IO, for example, putting a header and a
+// payload in each slice.
+// When writing, the Buffers field must contain at least one byte to
+// write.
+// When reading, the Buffers field will always contain a byte to read.
+//
+// The OOB field contains protocol-specific control or miscellaneous
+// ancillary data known as out-of-band data.
+// It can be nil when not required.
+//
+// The Addr field specifies a destination address when writing.
+// It can be nil when the underlying protocol of the endpoint uses
+// connection-oriented communication.
+// After a successful read, it may contain the source address on the
+// received packet.
+//
+// The N field indicates the number of bytes read or written from/to
+// Buffers.
+//
+// The NN field indicates the number of bytes read or written from/to
+// OOB.
+//
+// The Flags field contains protocol-specific information on the
+// received message.
+type Message = socket.Message
+
+// ReadBatch reads a batch of messages.
+//
+// The provided flags is a set of platform-dependent flags, such as
+// syscall.MSG_PEEK.
+//
+// On a successful read it returns the number of messages received, up
+// to len(ms).
+//
+// On Linux, a batch read will be optimized.
+// On other platforms, this method will read only a single message.
+//
+// Unlike the ReadFrom method, it doesn't strip the IPv4 header
+// followed by option headers from the received IPv4 datagram when the
+// underlying transport is net.IPConn. Each Buffers field of Message
+// must be large enough to accommodate an IPv4 header and option
+// headers.
+func (c *payloadHandler) ReadBatch(ms []Message, flags int) (int, error) {
+	if !c.ok() {
+		return 0, syscall.EINVAL
+	}
+	switch runtime.GOOS {
+	case "linux":
+		n, err := c.RecvMsgs([]socket.Message(ms), flags)
+		if err != nil {
+			err = &net.OpError{Op: "read", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err}
+		}
+		return n, err
+	default:
+		n := 1
+		err := c.RecvMsg(&ms[0], flags)
+		if err != nil {
+			n = 0
+			err = &net.OpError{Op: "read", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err}
+		}
+		return n, err
+	}
+}
+
+// WriteBatch writes a batch of messages.
+//
+// The provided flags is a set of platform-dependent flags, such as
+// syscall.MSG_DONTROUTE.
+//
+// It returns the number of messages written on a successful write.
+//
+// On Linux, a batch write will be optimized.
+// On other platforms, this method will write only a single message.
+func (c *payloadHandler) WriteBatch(ms []Message, flags int) (int, error) {
+	if !c.ok() {
+		return 0, syscall.EINVAL
+	}
+	switch runtime.GOOS {
+	case "linux":
+		n, err := c.SendMsgs([]socket.Message(ms), flags)
+		if err != nil {
+			err = &net.OpError{Op: "write", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err}
+		}
+		return n, err
+	default:
+		n := 1
+		err := c.SendMsg(&ms[0], flags)
+		if err != nil {
+			n = 0
+			err = &net.OpError{Op: "write", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err}
+		}
+		return n, err
+	}
+}
+
+// ReadBatch reads a batch of messages.
+//
+// The provided flags is a set of platform-dependent flags, such as
+// syscall.MSG_PEEK.
+//
+// On a successful read it returns the number of messages received, up
+// to len(ms).
+//
+// On Linux, a batch read will be optimized.
+// On other platforms, this method will read only a single message.
+func (c *packetHandler) ReadBatch(ms []Message, flags int) (int, error) {
+	if !c.ok() {
+		return 0, syscall.EINVAL
+	}
+	switch runtime.GOOS {
+	case "linux":
+		n, err := c.RecvMsgs([]socket.Message(ms), flags)
+		if err != nil {
+			err = &net.OpError{Op: "read", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err}
+		}
+		return n, err
+	default:
+		n := 1
+		err := c.RecvMsg(&ms[0], flags)
+		if err != nil {
+			n = 0
+			err = &net.OpError{Op: "read", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err}
+		}
+		return n, err
+	}
+}
+
+// WriteBatch writes a batch of messages.
+//
+// The provided flags is a set of platform-dependent flags, such as
+// syscall.MSG_DONTROUTE.
+//
+// It returns the number of messages written on a successful write.
+//
+// On Linux, a batch write will be optimized.
+// On other platforms, this method will write only a single message.
+func (c *packetHandler) WriteBatch(ms []Message, flags int) (int, error) {
+	if !c.ok() {
+		return 0, syscall.EINVAL
+	}
+	switch runtime.GOOS {
+	case "linux":
+		n, err := c.SendMsgs([]socket.Message(ms), flags)
+		if err != nil {
+			err = &net.OpError{Op: "write", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err}
+		}
+		return n, err
+	default:
+		n := 1
+		err := c.SendMsg(&ms[0], flags)
+		if err != nil {
+			n = 0
+			err = &net.OpError{Op: "write", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err}
+		}
+		return n, err
+	}
+}

+ 74 - 0
ipv4/control.go

@@ -8,6 +8,9 @@ import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"sync"
 	"sync"
+
+	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/internal/socket"
 )
 )
 
 
 type rawOpt struct {
 type rawOpt struct {
@@ -51,6 +54,77 @@ func (cm *ControlMessage) String() string {
 	return fmt.Sprintf("ttl=%d src=%v dst=%v ifindex=%d", cm.TTL, cm.Src, cm.Dst, cm.IfIndex)
 	return fmt.Sprintf("ttl=%d src=%v dst=%v ifindex=%d", cm.TTL, cm.Src, cm.Dst, cm.IfIndex)
 }
 }
 
 
+// Marshal returns the binary encoding of cm.
+func (cm *ControlMessage) Marshal() []byte {
+	if cm == nil {
+		return nil
+	}
+	var m socket.ControlMessage
+	if ctlOpts[ctlPacketInfo].name > 0 && (cm.Src.To4() != nil || cm.IfIndex > 0) {
+		m = socket.NewControlMessage([]int{ctlOpts[ctlPacketInfo].length})
+	}
+	if len(m) > 0 {
+		ctlOpts[ctlPacketInfo].marshal(m, cm)
+	}
+	return m
+}
+
+// Parse parses b as a control message and stores the result in cm.
+func (cm *ControlMessage) Parse(b []byte) error {
+	ms, err := socket.ControlMessage(b).Parse()
+	if err != nil {
+		return err
+	}
+	for _, m := range ms {
+		lvl, typ, l, err := m.ParseHeader()
+		if err != nil {
+			return err
+		}
+		if lvl != iana.ProtocolIP {
+			continue
+		}
+		switch typ {
+		case ctlOpts[ctlTTL].name:
+			ctlOpts[ctlTTL].parse(cm, m.Data(l))
+		case ctlOpts[ctlDst].name:
+			ctlOpts[ctlDst].parse(cm, m.Data(l))
+		case ctlOpts[ctlInterface].name:
+			ctlOpts[ctlInterface].parse(cm, m.Data(l))
+		case ctlOpts[ctlPacketInfo].name:
+			ctlOpts[ctlPacketInfo].parse(cm, m.Data(l))
+		}
+	}
+	return nil
+}
+
+// NewControlMessage returns a new control message.
+//
+// The returned message is large enough for options specified by cf.
+func NewControlMessage(cf ControlFlags) []byte {
+	opt := rawOpt{cflags: cf}
+	var l int
+	if opt.isset(FlagTTL) && ctlOpts[ctlTTL].name > 0 {
+		l += socket.ControlMessageSpace(ctlOpts[ctlTTL].length)
+	}
+	if ctlOpts[ctlPacketInfo].name > 0 {
+		if opt.isset(FlagSrc | FlagDst | FlagInterface) {
+			l += socket.ControlMessageSpace(ctlOpts[ctlPacketInfo].length)
+		}
+	} else {
+		if opt.isset(FlagDst) && ctlOpts[ctlDst].name > 0 {
+			l += socket.ControlMessageSpace(ctlOpts[ctlDst].length)
+		}
+		if opt.isset(FlagInterface) && ctlOpts[ctlInterface].name > 0 {
+			l += socket.ControlMessageSpace(ctlOpts[ctlInterface].length)
+		}
+	}
+	var b []byte
+	if l > 0 {
+		b = make([]byte, l)
+	}
+	return b
+}
+
 // Ancillary data socket options
 // Ancillary data socket options
 const (
 const (
 	ctlTTL        = iota // header field
 	ctlTTL        = iota // header field

+ 11 - 11
ipv4/control_bsd.go

@@ -12,26 +12,26 @@ import (
 	"unsafe"
 	"unsafe"
 
 
 	"golang.org/x/net/internal/iana"
 	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/internal/socket"
 )
 )
 
 
 func marshalDst(b []byte, cm *ControlMessage) []byte {
 func marshalDst(b []byte, cm *ControlMessage) []byte {
-	m := (*syscall.Cmsghdr)(unsafe.Pointer(&b[0]))
-	m.Level = iana.ProtocolIP
-	m.Type = sysIP_RECVDSTADDR
-	m.SetLen(syscall.CmsgLen(net.IPv4len))
-	return b[syscall.CmsgSpace(net.IPv4len):]
+	m := socket.ControlMessage(b)
+	m.MarshalHeader(iana.ProtocolIP, sysIP_RECVDSTADDR, net.IPv4len)
+	return m.Next(net.IPv4len)
 }
 }
 
 
 func parseDst(cm *ControlMessage, b []byte) {
 func parseDst(cm *ControlMessage, b []byte) {
-	cm.Dst = b[:net.IPv4len]
+	if len(cm.Dst) < net.IPv4len {
+		cm.Dst = make(net.IP, net.IPv4len)
+	}
+	copy(cm.Dst, b[:net.IPv4len])
 }
 }
 
 
 func marshalInterface(b []byte, cm *ControlMessage) []byte {
 func marshalInterface(b []byte, cm *ControlMessage) []byte {
-	m := (*syscall.Cmsghdr)(unsafe.Pointer(&b[0]))
-	m.Level = iana.ProtocolIP
-	m.Type = sysIP_RECVIF
-	m.SetLen(syscall.CmsgLen(syscall.SizeofSockaddrDatalink))
-	return b[syscall.CmsgSpace(syscall.SizeofSockaddrDatalink):]
+	m := socket.ControlMessage(b)
+	m.MarshalHeader(iana.ProtocolIP, sysIP_RECVIF, syscall.SizeofSockaddrDatalink)
+	return m.Next(syscall.SizeofSockaddrDatalink)
 }
 }
 
 
 func parseInterface(cm *ControlMessage, b []byte) {
 func parseInterface(cm *ControlMessage, b []byte) {

+ 10 - 8
ipv4/control_pktinfo.go

@@ -7,19 +7,18 @@
 package ipv4
 package ipv4
 
 
 import (
 import (
-	"syscall"
+	"net"
 	"unsafe"
 	"unsafe"
 
 
 	"golang.org/x/net/internal/iana"
 	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/internal/socket"
 )
 )
 
 
 func marshalPacketInfo(b []byte, cm *ControlMessage) []byte {
 func marshalPacketInfo(b []byte, cm *ControlMessage) []byte {
-	m := (*syscall.Cmsghdr)(unsafe.Pointer(&b[0]))
-	m.Level = iana.ProtocolIP
-	m.Type = sysIP_PKTINFO
-	m.SetLen(syscall.CmsgLen(sizeofInetPktinfo))
+	m := socket.ControlMessage(b)
+	m.MarshalHeader(iana.ProtocolIP, sysIP_PKTINFO, sizeofInetPktinfo)
 	if cm != nil {
 	if cm != nil {
-		pi := (*inetPktinfo)(unsafe.Pointer(&b[syscall.CmsgLen(0)]))
+		pi := (*inetPktinfo)(unsafe.Pointer(&m.Data(sizeofInetPktinfo)[0]))
 		if ip := cm.Src.To4(); ip != nil {
 		if ip := cm.Src.To4(); ip != nil {
 			copy(pi.Spec_dst[:], ip)
 			copy(pi.Spec_dst[:], ip)
 		}
 		}
@@ -27,11 +26,14 @@ func marshalPacketInfo(b []byte, cm *ControlMessage) []byte {
 			pi.setIfindex(cm.IfIndex)
 			pi.setIfindex(cm.IfIndex)
 		}
 		}
 	}
 	}
-	return b[syscall.CmsgSpace(sizeofInetPktinfo):]
+	return m.Next(sizeofInetPktinfo)
 }
 }
 
 
 func parsePacketInfo(cm *ControlMessage, b []byte) {
 func parsePacketInfo(cm *ControlMessage, b []byte) {
 	pi := (*inetPktinfo)(unsafe.Pointer(&b[0]))
 	pi := (*inetPktinfo)(unsafe.Pointer(&b[0]))
 	cm.IfIndex = int(pi.Ifindex)
 	cm.IfIndex = int(pi.Ifindex)
-	cm.Dst = pi.Addr[:]
+	if len(cm.Dst) < net.IPv4len {
+		cm.Dst = make(net.IP, net.IPv4len)
+	}
+	copy(cm.Dst, pi.Addr[:])
 }
 }

+ 0 - 12
ipv4/control_stub.go

@@ -11,15 +11,3 @@ import "golang.org/x/net/internal/socket"
 func setControlMessage(c *socket.Conn, opt *rawOpt, cf ControlFlags, on bool) error {
 func setControlMessage(c *socket.Conn, opt *rawOpt, cf ControlFlags, on bool) error {
 	return errOpNoSupport
 	return errOpNoSupport
 }
 }
-
-func newControlMessage(opt *rawOpt) []byte {
-	return nil
-}
-
-func parseControlMessage(b []byte) (*ControlMessage, error) {
-	return nil, errOpNoSupport
-}
-
-func marshalControlMessage(cm *ControlMessage) []byte {
-	return nil
-}

+ 3 - 79
ipv4/control_unix.go

@@ -7,8 +7,6 @@
 package ipv4
 package ipv4
 
 
 import (
 import (
-	"os"
-	"syscall"
 	"unsafe"
 	"unsafe"
 
 
 	"golang.org/x/net/internal/iana"
 	"golang.org/x/net/internal/iana"
@@ -64,84 +62,10 @@ func setControlMessage(c *socket.Conn, opt *rawOpt, cf ControlFlags, on bool) er
 	return nil
 	return nil
 }
 }
 
 
-func newControlMessage(opt *rawOpt) (oob []byte) {
-	opt.RLock()
-	var l int
-	if opt.isset(FlagTTL) && ctlOpts[ctlTTL].name > 0 {
-		l += syscall.CmsgSpace(ctlOpts[ctlTTL].length)
-	}
-	if ctlOpts[ctlPacketInfo].name > 0 {
-		if opt.isset(FlagSrc | FlagDst | FlagInterface) {
-			l += syscall.CmsgSpace(ctlOpts[ctlPacketInfo].length)
-		}
-	} else {
-		if opt.isset(FlagDst) && ctlOpts[ctlDst].name > 0 {
-			l += syscall.CmsgSpace(ctlOpts[ctlDst].length)
-		}
-		if opt.isset(FlagInterface) && ctlOpts[ctlInterface].name > 0 {
-			l += syscall.CmsgSpace(ctlOpts[ctlInterface].length)
-		}
-	}
-	if l > 0 {
-		oob = make([]byte, l)
-	}
-	opt.RUnlock()
-	return
-}
-
-func parseControlMessage(b []byte) (*ControlMessage, error) {
-	if len(b) == 0 {
-		return nil, nil
-	}
-	cmsgs, err := syscall.ParseSocketControlMessage(b)
-	if err != nil {
-		return nil, os.NewSyscallError("parse socket control message", err)
-	}
-	cm := &ControlMessage{}
-	for _, m := range cmsgs {
-		if m.Header.Level != iana.ProtocolIP {
-			continue
-		}
-		switch int(m.Header.Type) {
-		case ctlOpts[ctlTTL].name:
-			ctlOpts[ctlTTL].parse(cm, m.Data[:])
-		case ctlOpts[ctlDst].name:
-			ctlOpts[ctlDst].parse(cm, m.Data[:])
-		case ctlOpts[ctlInterface].name:
-			ctlOpts[ctlInterface].parse(cm, m.Data[:])
-		case ctlOpts[ctlPacketInfo].name:
-			ctlOpts[ctlPacketInfo].parse(cm, m.Data[:])
-		}
-	}
-	return cm, nil
-}
-
-func marshalControlMessage(cm *ControlMessage) (oob []byte) {
-	if cm == nil {
-		return nil
-	}
-	var l int
-	pktinfo := false
-	if ctlOpts[ctlPacketInfo].name > 0 && (cm.Src.To4() != nil || cm.IfIndex > 0) {
-		pktinfo = true
-		l += syscall.CmsgSpace(ctlOpts[ctlPacketInfo].length)
-	}
-	if l > 0 {
-		oob = make([]byte, l)
-		b := oob
-		if pktinfo {
-			b = ctlOpts[ctlPacketInfo].marshal(b, cm)
-		}
-	}
-	return
-}
-
 func marshalTTL(b []byte, cm *ControlMessage) []byte {
 func marshalTTL(b []byte, cm *ControlMessage) []byte {
-	m := (*syscall.Cmsghdr)(unsafe.Pointer(&b[0]))
-	m.Level = iana.ProtocolIP
-	m.Type = sysIP_RECVTTL
-	m.SetLen(syscall.CmsgLen(1))
-	return b[syscall.CmsgSpace(1):]
+	m := socket.ControlMessage(b)
+	m.MarshalHeader(iana.ProtocolIP, sysIP_RECVTTL, 1)
+	return m.Next(1)
 }
 }
 
 
 func parseTTL(cm *ControlMessage, b []byte) {
 func parseTTL(cm *ControlMessage, b []byte) {

+ 0 - 15
ipv4/control_windows.go

@@ -14,18 +14,3 @@ func setControlMessage(c *socket.Conn, opt *rawOpt, cf ControlFlags, on bool) er
 	// TODO(mikio): implement this
 	// TODO(mikio): implement this
 	return syscall.EWINDOWS
 	return syscall.EWINDOWS
 }
 }
-
-func newControlMessage(opt *rawOpt) []byte {
-	// TODO(mikio): implement this
-	return nil
-}
-
-func parseControlMessage(b []byte) (*ControlMessage, error) {
-	// TODO(mikio): implement this
-	return nil, syscall.EWINDOWS
-}
-
-func marshalControlMessage(cm *ControlMessage) []byte {
-	// TODO(mikio): implement this
-	return nil
-}

+ 6 - 11
ipv4/endpoint.go

@@ -105,12 +105,7 @@ func NewPacketConn(c net.PacketConn) *PacketConn {
 	p := &PacketConn{
 	p := &PacketConn{
 		genericOpt:     genericOpt{Conn: cc},
 		genericOpt:     genericOpt{Conn: cc},
 		dgramOpt:       dgramOpt{Conn: cc},
 		dgramOpt:       dgramOpt{Conn: cc},
-		payloadHandler: payloadHandler{PacketConn: c},
-	}
-	if _, ok := c.(*net.IPConn); ok {
-		if so, ok := sockOpts[ssoStripHeader]; ok {
-			so.SetInt(p.dgramOpt.Conn, boolint(true))
-		}
+		payloadHandler: payloadHandler{PacketConn: c, Conn: cc},
 	}
 	}
 	return p
 	return p
 }
 }
@@ -140,7 +135,7 @@ func (c *RawConn) SetDeadline(t time.Time) error {
 	if !c.packetHandler.ok() {
 	if !c.packetHandler.ok() {
 		return syscall.EINVAL
 		return syscall.EINVAL
 	}
 	}
-	return c.packetHandler.c.SetDeadline(t)
+	return c.packetHandler.IPConn.SetDeadline(t)
 }
 }
 
 
 // SetReadDeadline sets the read deadline associated with the
 // SetReadDeadline sets the read deadline associated with the
@@ -149,7 +144,7 @@ func (c *RawConn) SetReadDeadline(t time.Time) error {
 	if !c.packetHandler.ok() {
 	if !c.packetHandler.ok() {
 		return syscall.EINVAL
 		return syscall.EINVAL
 	}
 	}
-	return c.packetHandler.c.SetReadDeadline(t)
+	return c.packetHandler.IPConn.SetReadDeadline(t)
 }
 }
 
 
 // SetWriteDeadline sets the write deadline associated with the
 // SetWriteDeadline sets the write deadline associated with the
@@ -158,7 +153,7 @@ func (c *RawConn) SetWriteDeadline(t time.Time) error {
 	if !c.packetHandler.ok() {
 	if !c.packetHandler.ok() {
 		return syscall.EINVAL
 		return syscall.EINVAL
 	}
 	}
-	return c.packetHandler.c.SetWriteDeadline(t)
+	return c.packetHandler.IPConn.SetWriteDeadline(t)
 }
 }
 
 
 // Close closes the endpoint.
 // Close closes the endpoint.
@@ -166,7 +161,7 @@ func (c *RawConn) Close() error {
 	if !c.packetHandler.ok() {
 	if !c.packetHandler.ok() {
 		return syscall.EINVAL
 		return syscall.EINVAL
 	}
 	}
-	return c.packetHandler.c.Close()
+	return c.packetHandler.IPConn.Close()
 }
 }
 
 
 // NewRawConn returns a new RawConn using c as its underlying
 // NewRawConn returns a new RawConn using c as its underlying
@@ -179,7 +174,7 @@ func NewRawConn(c net.PacketConn) (*RawConn, error) {
 	r := &RawConn{
 	r := &RawConn{
 		genericOpt:    genericOpt{Conn: cc},
 		genericOpt:    genericOpt{Conn: cc},
 		dgramOpt:      dgramOpt{Conn: cc},
 		dgramOpt:      dgramOpt{Conn: cc},
-		packetHandler: packetHandler{c: c.(*net.IPConn)},
+		packetHandler: packetHandler{IPConn: c.(*net.IPConn), Conn: cc},
 	}
 	}
 	so, ok := sockOpts[ssoHeaderPrepend]
 	so, ok := sockOpts[ssoHeaderPrepend]
 	if !ok {
 	if !ok {

+ 32 - 20
ipv4/header.go

@@ -51,7 +51,7 @@ func (h *Header) String() string {
 	return fmt.Sprintf("ver=%d hdrlen=%d tos=%#x totallen=%d id=%#x flags=%#x fragoff=%#x ttl=%d proto=%d cksum=%#x src=%v dst=%v", h.Version, h.Len, h.TOS, h.TotalLen, h.ID, h.Flags, h.FragOff, h.TTL, h.Protocol, h.Checksum, h.Src, h.Dst)
 	return fmt.Sprintf("ver=%d hdrlen=%d tos=%#x totallen=%d id=%#x flags=%#x fragoff=%#x ttl=%d proto=%d cksum=%#x src=%v dst=%v", h.Version, h.Len, h.TOS, h.TotalLen, h.ID, h.Flags, h.FragOff, h.TTL, h.Protocol, h.Checksum, h.Src, h.Dst)
 }
 }
 
 
-// Marshal returns the binary encoding of the IPv4 header h.
+// Marshal returns the binary encoding of h.
 func (h *Header) Marshal() ([]byte, error) {
 func (h *Header) Marshal() ([]byte, error) {
 	if h == nil {
 	if h == nil {
 		return nil, syscall.EINVAL
 		return nil, syscall.EINVAL
@@ -98,26 +98,24 @@ func (h *Header) Marshal() ([]byte, error) {
 	return b, nil
 	return b, nil
 }
 }
 
 
-// ParseHeader parses b as an IPv4 header.
-func ParseHeader(b []byte) (*Header, error) {
-	if len(b) < HeaderLen {
-		return nil, errHeaderTooShort
+// Parse parses b as an IPv4 header and sotres the result in h.
+func (h *Header) Parse(b []byte) error {
+	if h == nil || len(b) < HeaderLen {
+		return errHeaderTooShort
 	}
 	}
 	hdrlen := int(b[0]&0x0f) << 2
 	hdrlen := int(b[0]&0x0f) << 2
 	if hdrlen > len(b) {
 	if hdrlen > len(b) {
-		return nil, errBufferTooShort
-	}
-	h := &Header{
-		Version:  int(b[0] >> 4),
-		Len:      hdrlen,
-		TOS:      int(b[1]),
-		ID:       int(binary.BigEndian.Uint16(b[4:6])),
-		TTL:      int(b[8]),
-		Protocol: int(b[9]),
-		Checksum: int(binary.BigEndian.Uint16(b[10:12])),
-		Src:      net.IPv4(b[12], b[13], b[14], b[15]),
-		Dst:      net.IPv4(b[16], b[17], b[18], b[19]),
+		return errBufferTooShort
 	}
 	}
+	h.Version = int(b[0] >> 4)
+	h.Len = hdrlen
+	h.TOS = int(b[1])
+	h.ID = int(binary.BigEndian.Uint16(b[4:6]))
+	h.TTL = int(b[8])
+	h.Protocol = int(b[9])
+	h.Checksum = int(binary.BigEndian.Uint16(b[10:12]))
+	h.Src = net.IPv4(b[12], b[13], b[14], b[15])
+	h.Dst = net.IPv4(b[16], b[17], b[18], b[19])
 	switch runtime.GOOS {
 	switch runtime.GOOS {
 	case "darwin", "dragonfly", "netbsd":
 	case "darwin", "dragonfly", "netbsd":
 		h.TotalLen = int(socket.NativeEndian.Uint16(b[2:4])) + hdrlen
 		h.TotalLen = int(socket.NativeEndian.Uint16(b[2:4])) + hdrlen
@@ -139,9 +137,23 @@ func ParseHeader(b []byte) (*Header, error) {
 	}
 	}
 	h.Flags = HeaderFlags(h.FragOff&0xe000) >> 13
 	h.Flags = HeaderFlags(h.FragOff&0xe000) >> 13
 	h.FragOff = h.FragOff & 0x1fff
 	h.FragOff = h.FragOff & 0x1fff
-	if hdrlen-HeaderLen > 0 {
-		h.Options = make([]byte, hdrlen-HeaderLen)
-		copy(h.Options, b[HeaderLen:])
+	optlen := hdrlen - HeaderLen
+	if optlen > 0 && len(b) >= hdrlen {
+		if cap(h.Options) < optlen {
+			h.Options = make([]byte, optlen)
+		} else {
+			h.Options = h.Options[:optlen]
+		}
+		copy(h.Options, b[HeaderLen:hdrlen])
+	}
+	return nil
+}
+
+// ParseHeader parses b as an IPv4 header.
+func ParseHeader(b []byte) (*Header, error) {
+	h := new(Header)
+	if err := h.Parse(b); err != nil {
+		return nil, err
 	}
 	}
 	return h, nil
 	return h, nil
 }
 }

+ 181 - 107
ipv4/header_test.go

@@ -17,138 +17,212 @@ import (
 )
 )
 
 
 type headerTest struct {
 type headerTest struct {
-	wireHeaderFromKernel          [HeaderLen]byte
-	wireHeaderToKernel            [HeaderLen]byte
-	wireHeaderFromTradBSDKernel   [HeaderLen]byte
-	wireHeaderToTradBSDKernel     [HeaderLen]byte
-	wireHeaderFromFreeBSD10Kernel [HeaderLen]byte
-	wireHeaderToFreeBSD10Kernel   [HeaderLen]byte
+	wireHeaderFromKernel          []byte
+	wireHeaderToKernel            []byte
+	wireHeaderFromTradBSDKernel   []byte
+	wireHeaderToTradBSDKernel     []byte
+	wireHeaderFromFreeBSD10Kernel []byte
+	wireHeaderToFreeBSD10Kernel   []byte
 	*Header
 	*Header
 }
 }
 
 
-var headerLittleEndianTest = headerTest{
+var headerLittleEndianTests = []headerTest{
 	// TODO(mikio): Add platform dependent wire header formats when
 	// TODO(mikio): Add platform dependent wire header formats when
 	// we support new platforms.
 	// we support new platforms.
-	wireHeaderFromKernel: [HeaderLen]byte{
-		0x45, 0x01, 0xbe, 0xef,
-		0xca, 0xfe, 0x45, 0xdc,
-		0xff, 0x01, 0xde, 0xad,
-		172, 16, 254, 254,
-		192, 168, 0, 1,
+	{
+		wireHeaderFromKernel: []byte{
+			0x45, 0x01, 0xbe, 0xef,
+			0xca, 0xfe, 0x45, 0xdc,
+			0xff, 0x01, 0xde, 0xad,
+			172, 16, 254, 254,
+			192, 168, 0, 1,
+		},
+		wireHeaderToKernel: []byte{
+			0x45, 0x01, 0xbe, 0xef,
+			0xca, 0xfe, 0x45, 0xdc,
+			0xff, 0x01, 0xde, 0xad,
+			172, 16, 254, 254,
+			192, 168, 0, 1,
+		},
+		wireHeaderFromTradBSDKernel: []byte{
+			0x45, 0x01, 0xdb, 0xbe,
+			0xca, 0xfe, 0xdc, 0x45,
+			0xff, 0x01, 0xde, 0xad,
+			172, 16, 254, 254,
+			192, 168, 0, 1,
+		},
+		wireHeaderToTradBSDKernel: []byte{
+			0x45, 0x01, 0xef, 0xbe,
+			0xca, 0xfe, 0xdc, 0x45,
+			0xff, 0x01, 0xde, 0xad,
+			172, 16, 254, 254,
+			192, 168, 0, 1,
+		},
+		wireHeaderFromFreeBSD10Kernel: []byte{
+			0x45, 0x01, 0xef, 0xbe,
+			0xca, 0xfe, 0xdc, 0x45,
+			0xff, 0x01, 0xde, 0xad,
+			172, 16, 254, 254,
+			192, 168, 0, 1,
+		},
+		wireHeaderToFreeBSD10Kernel: []byte{
+			0x45, 0x01, 0xef, 0xbe,
+			0xca, 0xfe, 0xdc, 0x45,
+			0xff, 0x01, 0xde, 0xad,
+			172, 16, 254, 254,
+			192, 168, 0, 1,
+		},
+		Header: &Header{
+			Version:  Version,
+			Len:      HeaderLen,
+			TOS:      1,
+			TotalLen: 0xbeef,
+			ID:       0xcafe,
+			Flags:    DontFragment,
+			FragOff:  1500,
+			TTL:      255,
+			Protocol: 1,
+			Checksum: 0xdead,
+			Src:      net.IPv4(172, 16, 254, 254),
+			Dst:      net.IPv4(192, 168, 0, 1),
+		},
 	},
 	},
-	wireHeaderToKernel: [HeaderLen]byte{
-		0x45, 0x01, 0xbe, 0xef,
-		0xca, 0xfe, 0x45, 0xdc,
-		0xff, 0x01, 0xde, 0xad,
-		172, 16, 254, 254,
-		192, 168, 0, 1,
-	},
-	wireHeaderFromTradBSDKernel: [HeaderLen]byte{
-		0x45, 0x01, 0xdb, 0xbe,
-		0xca, 0xfe, 0xdc, 0x45,
-		0xff, 0x01, 0xde, 0xad,
-		172, 16, 254, 254,
-		192, 168, 0, 1,
-	},
-	wireHeaderToTradBSDKernel: [HeaderLen]byte{
-		0x45, 0x01, 0xef, 0xbe,
-		0xca, 0xfe, 0xdc, 0x45,
-		0xff, 0x01, 0xde, 0xad,
-		172, 16, 254, 254,
-		192, 168, 0, 1,
-	},
-	wireHeaderFromFreeBSD10Kernel: [HeaderLen]byte{
-		0x45, 0x01, 0xef, 0xbe,
-		0xca, 0xfe, 0xdc, 0x45,
-		0xff, 0x01, 0xde, 0xad,
-		172, 16, 254, 254,
-		192, 168, 0, 1,
-	},
-	wireHeaderToFreeBSD10Kernel: [HeaderLen]byte{
-		0x45, 0x01, 0xef, 0xbe,
-		0xca, 0xfe, 0xdc, 0x45,
-		0xff, 0x01, 0xde, 0xad,
-		172, 16, 254, 254,
-		192, 168, 0, 1,
-	},
-	Header: &Header{
-		Version:  Version,
-		Len:      HeaderLen,
-		TOS:      1,
-		TotalLen: 0xbeef,
-		ID:       0xcafe,
-		Flags:    DontFragment,
-		FragOff:  1500,
-		TTL:      255,
-		Protocol: 1,
-		Checksum: 0xdead,
-		Src:      net.IPv4(172, 16, 254, 254),
-		Dst:      net.IPv4(192, 168, 0, 1),
+
+	// with option headers
+	{
+		wireHeaderFromKernel: []byte{
+			0x46, 0x01, 0xbe, 0xf3,
+			0xca, 0xfe, 0x45, 0xdc,
+			0xff, 0x01, 0xde, 0xad,
+			172, 16, 254, 254,
+			192, 168, 0, 1,
+			0xff, 0xfe, 0xfe, 0xff,
+		},
+		wireHeaderToKernel: []byte{
+			0x46, 0x01, 0xbe, 0xf3,
+			0xca, 0xfe, 0x45, 0xdc,
+			0xff, 0x01, 0xde, 0xad,
+			172, 16, 254, 254,
+			192, 168, 0, 1,
+			0xff, 0xfe, 0xfe, 0xff,
+		},
+		wireHeaderFromTradBSDKernel: []byte{
+			0x46, 0x01, 0xdb, 0xbe,
+			0xca, 0xfe, 0xdc, 0x45,
+			0xff, 0x01, 0xde, 0xad,
+			172, 16, 254, 254,
+			192, 168, 0, 1,
+			0xff, 0xfe, 0xfe, 0xff,
+		},
+		wireHeaderToTradBSDKernel: []byte{
+			0x46, 0x01, 0xf3, 0xbe,
+			0xca, 0xfe, 0xdc, 0x45,
+			0xff, 0x01, 0xde, 0xad,
+			172, 16, 254, 254,
+			192, 168, 0, 1,
+			0xff, 0xfe, 0xfe, 0xff,
+		},
+		wireHeaderFromFreeBSD10Kernel: []byte{
+			0x46, 0x01, 0xf3, 0xbe,
+			0xca, 0xfe, 0xdc, 0x45,
+			0xff, 0x01, 0xde, 0xad,
+			172, 16, 254, 254,
+			192, 168, 0, 1,
+			0xff, 0xfe, 0xfe, 0xff,
+		},
+		wireHeaderToFreeBSD10Kernel: []byte{
+			0x46, 0x01, 0xf3, 0xbe,
+			0xca, 0xfe, 0xdc, 0x45,
+			0xff, 0x01, 0xde, 0xad,
+			172, 16, 254, 254,
+			192, 168, 0, 1,
+			0xff, 0xfe, 0xfe, 0xff,
+		},
+		Header: &Header{
+			Version:  Version,
+			Len:      HeaderLen + 4,
+			TOS:      1,
+			TotalLen: 0xbef3,
+			ID:       0xcafe,
+			Flags:    DontFragment,
+			FragOff:  1500,
+			TTL:      255,
+			Protocol: 1,
+			Checksum: 0xdead,
+			Src:      net.IPv4(172, 16, 254, 254),
+			Dst:      net.IPv4(192, 168, 0, 1),
+			Options:  []byte{0xff, 0xfe, 0xfe, 0xff},
+		},
 	},
 	},
 }
 }
 
 
 func TestMarshalHeader(t *testing.T) {
 func TestMarshalHeader(t *testing.T) {
-	tt := &headerLittleEndianTest
 	if socket.NativeEndian != binary.LittleEndian {
 	if socket.NativeEndian != binary.LittleEndian {
 		t.Skip("no test for non-little endian machine yet")
 		t.Skip("no test for non-little endian machine yet")
 	}
 	}
 
 
-	b, err := tt.Header.Marshal()
-	if err != nil {
-		t.Fatal(err)
-	}
-	var wh []byte
-	switch runtime.GOOS {
-	case "darwin", "dragonfly", "netbsd":
-		wh = tt.wireHeaderToTradBSDKernel[:]
-	case "freebsd":
-		switch {
-		case freebsdVersion < 1000000:
-			wh = tt.wireHeaderToTradBSDKernel[:]
-		case 1000000 <= freebsdVersion && freebsdVersion < 1100000:
-			wh = tt.wireHeaderToFreeBSD10Kernel[:]
+	for _, tt := range headerLittleEndianTests {
+		b, err := tt.Header.Marshal()
+		if err != nil {
+			t.Fatal(err)
+		}
+		var wh []byte
+		switch runtime.GOOS {
+		case "darwin", "dragonfly", "netbsd":
+			wh = tt.wireHeaderToTradBSDKernel
+		case "freebsd":
+			switch {
+			case freebsdVersion < 1000000:
+				wh = tt.wireHeaderToTradBSDKernel
+			case 1000000 <= freebsdVersion && freebsdVersion < 1100000:
+				wh = tt.wireHeaderToFreeBSD10Kernel
+			default:
+				wh = tt.wireHeaderToKernel
+			}
 		default:
 		default:
-			wh = tt.wireHeaderToKernel[:]
+			wh = tt.wireHeaderToKernel
+		}
+		if !bytes.Equal(b, wh) {
+			t.Fatalf("got %#v; want %#v", b, wh)
 		}
 		}
-	default:
-		wh = tt.wireHeaderToKernel[:]
-	}
-	if !bytes.Equal(b, wh) {
-		t.Fatalf("got %#v; want %#v", b, wh)
 	}
 	}
 }
 }
 
 
 func TestParseHeader(t *testing.T) {
 func TestParseHeader(t *testing.T) {
-	tt := &headerLittleEndianTest
 	if socket.NativeEndian != binary.LittleEndian {
 	if socket.NativeEndian != binary.LittleEndian {
 		t.Skip("no test for big endian machine yet")
 		t.Skip("no test for big endian machine yet")
 	}
 	}
 
 
-	var wh []byte
-	switch runtime.GOOS {
-	case "darwin", "dragonfly", "netbsd":
-		wh = tt.wireHeaderFromTradBSDKernel[:]
-	case "freebsd":
-		switch {
-		case freebsdVersion < 1000000:
-			wh = tt.wireHeaderFromTradBSDKernel[:]
-		case 1000000 <= freebsdVersion && freebsdVersion < 1100000:
-			wh = tt.wireHeaderFromFreeBSD10Kernel[:]
+	for _, tt := range headerLittleEndianTests {
+		var wh []byte
+		switch runtime.GOOS {
+		case "darwin", "dragonfly", "netbsd":
+			wh = tt.wireHeaderFromTradBSDKernel
+		case "freebsd":
+			switch {
+			case freebsdVersion < 1000000:
+				wh = tt.wireHeaderFromTradBSDKernel
+			case 1000000 <= freebsdVersion && freebsdVersion < 1100000:
+				wh = tt.wireHeaderFromFreeBSD10Kernel
+			default:
+				wh = tt.wireHeaderFromKernel
+			}
 		default:
 		default:
-			wh = tt.wireHeaderFromKernel[:]
+			wh = tt.wireHeaderFromKernel
+		}
+		h, err := ParseHeader(wh)
+		if err != nil {
+			t.Fatal(err)
+		}
+		if err := h.Parse(wh); err != nil {
+			t.Fatal(err)
+		}
+		if !reflect.DeepEqual(h, tt.Header) {
+			t.Fatalf("got %#v; want %#v", h, tt.Header)
+		}
+		s := h.String()
+		if strings.Contains(s, ",") {
+			t.Fatalf("should be space-separated values: %s", s)
 		}
 		}
-	default:
-		wh = tt.wireHeaderFromKernel[:]
-	}
-	h, err := ParseHeader(wh)
-	if err != nil {
-		t.Fatal(err)
-	}
-	if !reflect.DeepEqual(h, tt.Header) {
-		t.Fatalf("got %#v; want %#v", h, tt.Header)
-	}
-	s := h.String()
-	if strings.Contains(s, ",") {
-		t.Fatalf("should be space-separated values: %s", s)
 	}
 	}
 }
 }

+ 7 - 38
ipv4/packet.go

@@ -7,6 +7,8 @@ package ipv4
 import (
 import (
 	"net"
 	"net"
 	"syscall"
 	"syscall"
+
+	"golang.org/x/net/internal/socket"
 )
 )
 
 
 // BUG(mikio): On Windows, the ReadFrom and WriteTo methods of RawConn
 // BUG(mikio): On Windows, the ReadFrom and WriteTo methods of RawConn
@@ -14,11 +16,12 @@ import (
 
 
 // A packetHandler represents the IPv4 datagram handler.
 // A packetHandler represents the IPv4 datagram handler.
 type packetHandler struct {
 type packetHandler struct {
-	c *net.IPConn
+	*net.IPConn
+	*socket.Conn
 	rawOpt
 	rawOpt
 }
 }
 
 
-func (c *packetHandler) ok() bool { return c != nil && c.c != nil }
+func (c *packetHandler) ok() bool { return c != nil && c.IPConn != nil && c.Conn != nil }
 
 
 // ReadFrom reads an IPv4 datagram from the endpoint c, copying the
 // ReadFrom reads an IPv4 datagram from the endpoint c, copying the
 // datagram into b. It returns the received datagram as the IPv4
 // datagram into b. It returns the received datagram as the IPv4
@@ -27,25 +30,7 @@ func (c *packetHandler) ReadFrom(b []byte) (h *Header, p []byte, cm *ControlMess
 	if !c.ok() {
 	if !c.ok() {
 		return nil, nil, nil, syscall.EINVAL
 		return nil, nil, nil, syscall.EINVAL
 	}
 	}
-	oob := newControlMessage(&c.rawOpt)
-	n, oobn, _, src, err := c.c.ReadMsgIP(b, oob)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-	var hs []byte
-	if hs, p, err = slicePacket(b[:n]); err != nil {
-		return nil, nil, nil, err
-	}
-	if h, err = ParseHeader(hs); err != nil {
-		return nil, nil, nil, err
-	}
-	if cm, err = parseControlMessage(oob[:oobn]); err != nil {
-		return nil, nil, nil, err
-	}
-	if src != nil && cm != nil {
-		cm.Src = src.IP
-	}
-	return
+	return c.readFrom(b)
 }
 }
 
 
 func slicePacket(b []byte) (h, p []byte, err error) {
 func slicePacket(b []byte) (h, p []byte, err error) {
@@ -80,21 +65,5 @@ func (c *packetHandler) WriteTo(h *Header, p []byte, cm *ControlMessage) error {
 	if !c.ok() {
 	if !c.ok() {
 		return syscall.EINVAL
 		return syscall.EINVAL
 	}
 	}
-	oob := marshalControlMessage(cm)
-	wh, err := h.Marshal()
-	if err != nil {
-		return err
-	}
-	dst := &net.IPAddr{}
-	if cm != nil {
-		if ip := cm.Dst.To4(); ip != nil {
-			dst.IP = ip
-		}
-	}
-	if dst.IP == nil {
-		dst.IP = h.Dst
-	}
-	wh = append(wh, p...)
-	_, _, err = c.c.WriteMsgIP(wh, oob, dst)
-	return err
+	return c.writeTo(h, p, cm)
 }
 }

+ 56 - 0
ipv4/packet_go1_8.go

@@ -0,0 +1,56 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !go1.9
+
+package ipv4
+
+import "net"
+
+func (c *packetHandler) readFrom(b []byte) (h *Header, p []byte, cm *ControlMessage, err error) {
+	c.rawOpt.RLock()
+	oob := NewControlMessage(c.rawOpt.cflags)
+	c.rawOpt.RUnlock()
+	n, nn, _, src, err := c.ReadMsgIP(b, oob)
+	if err != nil {
+		return nil, nil, nil, err
+	}
+	var hs []byte
+	if hs, p, err = slicePacket(b[:n]); err != nil {
+		return nil, nil, nil, err
+	}
+	if h, err = ParseHeader(hs); err != nil {
+		return nil, nil, nil, err
+	}
+	if nn > 0 {
+		cm = new(ControlMessage)
+		if err := cm.Parse(oob[:nn]); err != nil {
+			return nil, nil, nil, err
+		}
+	}
+	if src != nil && cm != nil {
+		cm.Src = src.IP
+	}
+	return
+}
+
+func (c *packetHandler) writeTo(h *Header, p []byte, cm *ControlMessage) error {
+	oob := cm.Marshal()
+	wh, err := h.Marshal()
+	if err != nil {
+		return err
+	}
+	dst := new(net.IPAddr)
+	if cm != nil {
+		if ip := cm.Dst.To4(); ip != nil {
+			dst.IP = ip
+		}
+	}
+	if dst.IP == nil {
+		dst.IP = h.Dst
+	}
+	wh = append(wh, p...)
+	_, _, err = c.WriteMsgIP(wh, oob, dst)
+	return err
+}

+ 67 - 0
ipv4/packet_go1_9.go

@@ -0,0 +1,67 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.9
+
+package ipv4
+
+import (
+	"net"
+
+	"golang.org/x/net/internal/socket"
+)
+
+func (c *packetHandler) readFrom(b []byte) (h *Header, p []byte, cm *ControlMessage, err error) {
+	c.rawOpt.RLock()
+	m := socket.Message{
+		Buffers: [][]byte{b},
+		OOB:     NewControlMessage(c.rawOpt.cflags),
+	}
+	c.rawOpt.RUnlock()
+	if err := c.RecvMsg(&m, 0); err != nil {
+		return nil, nil, nil, &net.OpError{Op: "read", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err}
+	}
+	var hs []byte
+	if hs, p, err = slicePacket(b[:m.N]); err != nil {
+		return nil, nil, nil, &net.OpError{Op: "read", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err}
+	}
+	if h, err = ParseHeader(hs); err != nil {
+		return nil, nil, nil, &net.OpError{Op: "read", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err}
+	}
+	if m.NN > 0 {
+		cm = new(ControlMessage)
+		if err := cm.Parse(m.OOB[:m.NN]); err != nil {
+			return nil, nil, nil, &net.OpError{Op: "read", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err}
+		}
+	}
+	if src, ok := m.Addr.(*net.IPAddr); ok && cm != nil {
+		cm.Src = src.IP
+	}
+	return
+}
+
+func (c *packetHandler) writeTo(h *Header, p []byte, cm *ControlMessage) error {
+	m := socket.Message{
+		OOB: cm.Marshal(),
+	}
+	wh, err := h.Marshal()
+	if err != nil {
+		return err
+	}
+	m.Buffers = [][]byte{wh, p}
+	dst := new(net.IPAddr)
+	if cm != nil {
+		if ip := cm.Dst.To4(); ip != nil {
+			dst.IP = ip
+		}
+	}
+	if dst.IP == nil {
+		dst.IP = h.Dst
+	}
+	m.Addr = dst
+	if err := c.SendMsg(&m, 0); err != nil {
+		return &net.OpError{Op: "write", Net: c.IPConn.LocalAddr().Network(), Source: c.IPConn.LocalAddr(), Err: err}
+	}
+	return nil
+}

+ 7 - 2
ipv4/payload.go

@@ -4,7 +4,11 @@
 
 
 package ipv4
 package ipv4
 
 
-import "net"
+import (
+	"net"
+
+	"golang.org/x/net/internal/socket"
+)
 
 
 // BUG(mikio): On Windows, the ControlMessage for ReadFrom and WriteTo
 // BUG(mikio): On Windows, the ControlMessage for ReadFrom and WriteTo
 // methods of PacketConn is not implemented.
 // methods of PacketConn is not implemented.
@@ -12,7 +16,8 @@ import "net"
 // A payloadHandler represents the IPv4 datagram payload handler.
 // A payloadHandler represents the IPv4 datagram payload handler.
 type payloadHandler struct {
 type payloadHandler struct {
 	net.PacketConn
 	net.PacketConn
+	*socket.Conn
 	rawOpt
 	rawOpt
 }
 }
 
 
-func (c *payloadHandler) ok() bool { return c != nil && c.PacketConn != nil }
+func (c *payloadHandler) ok() bool { return c != nil && c.PacketConn != nil && c.Conn != nil }

+ 3 - 48
ipv4/payload_cmsg.go

@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 // license that can be found in the LICENSE file.
 
 
-// +build !plan9,!windows
+// +build !nacl,!plan9,!windows
 
 
 package ipv4
 package ipv4
 
 
@@ -19,37 +19,7 @@ func (c *payloadHandler) ReadFrom(b []byte) (n int, cm *ControlMessage, src net.
 	if !c.ok() {
 	if !c.ok() {
 		return 0, nil, nil, syscall.EINVAL
 		return 0, nil, nil, syscall.EINVAL
 	}
 	}
-	oob := newControlMessage(&c.rawOpt)
-	var oobn int
-	switch c := c.PacketConn.(type) {
-	case *net.UDPConn:
-		if n, oobn, _, src, err = c.ReadMsgUDP(b, oob); err != nil {
-			return 0, nil, nil, err
-		}
-	case *net.IPConn:
-		if _, ok := sockOpts[ssoStripHeader]; ok {
-			if n, oobn, _, src, err = c.ReadMsgIP(b, oob); err != nil {
-				return 0, nil, nil, err
-			}
-		} else {
-			nb := make([]byte, maxHeaderLen+len(b))
-			if n, oobn, _, src, err = c.ReadMsgIP(nb, oob); err != nil {
-				return 0, nil, nil, err
-			}
-			hdrlen := int(nb[0]&0x0f) << 2
-			copy(b, nb[hdrlen:])
-			n -= hdrlen
-		}
-	default:
-		return 0, nil, nil, errInvalidConnType
-	}
-	if cm, err = parseControlMessage(oob[:oobn]); err != nil {
-		return 0, nil, nil, err
-	}
-	if cm != nil {
-		cm.Src = netAddrToIP4(src)
-	}
-	return
+	return c.readFrom(b)
 }
 }
 
 
 // WriteTo writes a payload of the IPv4 datagram, to the destination
 // WriteTo writes a payload of the IPv4 datagram, to the destination
@@ -62,20 +32,5 @@ func (c *payloadHandler) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n
 	if !c.ok() {
 	if !c.ok() {
 		return 0, syscall.EINVAL
 		return 0, syscall.EINVAL
 	}
 	}
-	oob := marshalControlMessage(cm)
-	if dst == nil {
-		return 0, errMissingAddress
-	}
-	switch c := c.PacketConn.(type) {
-	case *net.UDPConn:
-		n, _, err = c.WriteMsgUDP(b, oob, dst.(*net.UDPAddr))
-	case *net.IPConn:
-		n, _, err = c.WriteMsgIP(b, oob, dst.(*net.IPAddr))
-	default:
-		return 0, errInvalidConnType
-	}
-	if err != nil {
-		return 0, err
-	}
-	return
+	return c.writeTo(b, cm, dst)
 }
 }

+ 59 - 0
ipv4/payload_cmsg_go1_8.go

@@ -0,0 +1,59 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !go1.9
+// +build !nacl,!plan9,!windows
+
+package ipv4
+
+import "net"
+
+func (c *payloadHandler) readFrom(b []byte) (n int, cm *ControlMessage, src net.Addr, err error) {
+	c.rawOpt.RLock()
+	oob := NewControlMessage(c.rawOpt.cflags)
+	c.rawOpt.RUnlock()
+	var nn int
+	switch c := c.PacketConn.(type) {
+	case *net.UDPConn:
+		if n, nn, _, src, err = c.ReadMsgUDP(b, oob); err != nil {
+			return 0, nil, nil, err
+		}
+	case *net.IPConn:
+		nb := make([]byte, maxHeaderLen+len(b))
+		if n, nn, _, src, err = c.ReadMsgIP(nb, oob); err != nil {
+			return 0, nil, nil, err
+		}
+		hdrlen := int(nb[0]&0x0f) << 2
+		copy(b, nb[hdrlen:])
+		n -= hdrlen
+	default:
+		return 0, nil, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Err: errInvalidConnType}
+	}
+	if nn > 0 {
+		cm = new(ControlMessage)
+		if err = cm.Parse(oob[:nn]); err != nil {
+			return 0, nil, nil, &net.OpError{Op: "read", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err}
+		}
+	}
+	if cm != nil {
+		cm.Src = netAddrToIP4(src)
+	}
+	return
+}
+
+func (c *payloadHandler) writeTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error) {
+	oob := cm.Marshal()
+	if dst == nil {
+		return 0, &net.OpError{Op: "write", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: errMissingAddress}
+	}
+	switch c := c.PacketConn.(type) {
+	case *net.UDPConn:
+		n, _, err = c.WriteMsgUDP(b, oob, dst.(*net.UDPAddr))
+	case *net.IPConn:
+		n, _, err = c.WriteMsgIP(b, oob, dst.(*net.IPAddr))
+	default:
+		return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Err: errInvalidConnType}
+	}
+	return
+}

+ 67 - 0
ipv4/payload_cmsg_go1_9.go

@@ -0,0 +1,67 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.9
+// +build !nacl,!plan9,!windows
+
+package ipv4
+
+import (
+	"net"
+
+	"golang.org/x/net/internal/socket"
+)
+
+func (c *payloadHandler) readFrom(b []byte) (int, *ControlMessage, net.Addr, error) {
+	c.rawOpt.RLock()
+	m := socket.Message{
+		OOB: NewControlMessage(c.rawOpt.cflags),
+	}
+	c.rawOpt.RUnlock()
+	switch c.PacketConn.(type) {
+	case *net.UDPConn:
+		m.Buffers = [][]byte{b}
+		if err := c.RecvMsg(&m, 0); err != nil {
+			return 0, nil, nil, &net.OpError{Op: "read", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err}
+		}
+	case *net.IPConn:
+		h := make([]byte, HeaderLen)
+		m.Buffers = [][]byte{h, b}
+		if err := c.RecvMsg(&m, 0); err != nil {
+			return 0, nil, nil, &net.OpError{Op: "read", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err}
+		}
+		hdrlen := int(h[0]&0x0f) << 2
+		if hdrlen > len(h) {
+			d := hdrlen - len(h)
+			copy(b, b[d:])
+			m.N -= d
+		} else {
+			m.N -= hdrlen
+		}
+	default:
+		return 0, nil, nil, &net.OpError{Op: "read", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: errInvalidConnType}
+	}
+	var cm *ControlMessage
+	if m.NN > 0 {
+		cm = new(ControlMessage)
+		if err := cm.Parse(m.OOB[:m.NN]); err != nil {
+			return 0, nil, nil, &net.OpError{Op: "read", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err}
+		}
+		cm.Src = netAddrToIP4(m.Addr)
+	}
+	return m.N, cm, m.Addr, nil
+}
+
+func (c *payloadHandler) writeTo(b []byte, cm *ControlMessage, dst net.Addr) (int, error) {
+	m := socket.Message{
+		Buffers: [][]byte{b},
+		OOB:     cm.Marshal(),
+		Addr:    dst,
+	}
+	err := c.SendMsg(&m, 0)
+	if err != nil {
+		err = &net.OpError{Op: "write", Net: c.PacketConn.LocalAddr().Network(), Source: c.PacketConn.LocalAddr(), Err: err}
+	}
+	return m.N, err
+}

+ 1 - 1
ipv4/payload_nocmsg.go

@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 // license that can be found in the LICENSE file.
 
 
-// +build plan9 windows
+// +build nacl plan9 windows
 
 
 package ipv4
 package ipv4
 
 

+ 248 - 0
ipv4/readwrite_go1_8_test.go

@@ -0,0 +1,248 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !go1.9
+
+package ipv4_test
+
+import (
+	"bytes"
+	"fmt"
+	"net"
+	"runtime"
+	"strings"
+	"sync"
+	"testing"
+
+	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/internal/nettest"
+	"golang.org/x/net/ipv4"
+)
+
+func BenchmarkPacketConnReadWriteUnicast(b *testing.B) {
+	switch runtime.GOOS {
+	case "nacl", "plan9", "windows":
+		b.Skipf("not supported on %s", runtime.GOOS)
+	}
+
+	payload := []byte("HELLO-R-U-THERE")
+	iph, err := (&ipv4.Header{
+		Version:  ipv4.Version,
+		Len:      ipv4.HeaderLen,
+		TotalLen: ipv4.HeaderLen + len(payload),
+		TTL:      1,
+		Protocol: iana.ProtocolReserved,
+		Src:      net.IPv4(192, 0, 2, 1),
+		Dst:      net.IPv4(192, 0, 2, 254),
+	}).Marshal()
+	if err != nil {
+		b.Fatal(err)
+	}
+	greh := []byte{0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00}
+	datagram := append(greh, append(iph, payload...)...)
+	bb := make([]byte, 128)
+	cm := ipv4.ControlMessage{
+		Src: net.IPv4(127, 0, 0, 1),
+	}
+	if ifi := nettest.RoutedInterface("ip4", net.FlagUp|net.FlagLoopback); ifi != nil {
+		cm.IfIndex = ifi.Index
+	}
+
+	b.Run("UDP", func(b *testing.B) {
+		c, err := nettest.NewLocalPacketListener("udp4")
+		if err != nil {
+			b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
+		}
+		defer c.Close()
+		p := ipv4.NewPacketConn(c)
+		dst := c.LocalAddr()
+		cf := ipv4.FlagTTL | ipv4.FlagInterface
+		if err := p.SetControlMessage(cf, true); err != nil {
+			b.Fatal(err)
+		}
+		b.Run("Net", func(b *testing.B) {
+			for i := 0; i < b.N; i++ {
+				if _, err := c.WriteTo(payload, dst); err != nil {
+					b.Fatal(err)
+				}
+				if _, _, err := c.ReadFrom(bb); err != nil {
+					b.Fatal(err)
+				}
+			}
+		})
+		b.Run("ToFrom", func(b *testing.B) {
+			for i := 0; i < b.N; i++ {
+				if _, err := p.WriteTo(payload, &cm, dst); err != nil {
+					b.Fatal(err)
+				}
+				if _, _, _, err := p.ReadFrom(bb); err != nil {
+					b.Fatal(err)
+				}
+			}
+		})
+	})
+	b.Run("IP", func(b *testing.B) {
+		switch runtime.GOOS {
+		case "netbsd":
+			b.Skip("need to configure gre on netbsd")
+		case "openbsd":
+			b.Skip("net.inet.gre.allow=0 by default on openbsd")
+		}
+
+		c, err := net.ListenPacket(fmt.Sprintf("ip4:%d", iana.ProtocolGRE), "127.0.0.1")
+		if err != nil {
+			b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
+		}
+		defer c.Close()
+		p := ipv4.NewPacketConn(c)
+		dst := c.LocalAddr()
+		cf := ipv4.FlagTTL | ipv4.FlagInterface
+		if err := p.SetControlMessage(cf, true); err != nil {
+			b.Fatal(err)
+		}
+		b.Run("Net", func(b *testing.B) {
+			for i := 0; i < b.N; i++ {
+				if _, err := c.WriteTo(datagram, dst); err != nil {
+					b.Fatal(err)
+				}
+				if _, _, err := c.ReadFrom(bb); err != nil {
+					b.Fatal(err)
+				}
+			}
+		})
+		b.Run("ToFrom", func(b *testing.B) {
+			for i := 0; i < b.N; i++ {
+				if _, err := p.WriteTo(datagram, &cm, dst); err != nil {
+					b.Fatal(err)
+				}
+				if _, _, _, err := p.ReadFrom(bb); err != nil {
+					b.Fatal(err)
+				}
+			}
+		})
+	})
+}
+
+func TestPacketConnConcurrentReadWriteUnicast(t *testing.T) {
+	switch runtime.GOOS {
+	case "nacl", "plan9", "windows":
+		t.Skipf("not supported on %s", runtime.GOOS)
+	}
+
+	payload := []byte("HELLO-R-U-THERE")
+	iph, err := (&ipv4.Header{
+		Version:  ipv4.Version,
+		Len:      ipv4.HeaderLen,
+		TotalLen: ipv4.HeaderLen + len(payload),
+		TTL:      1,
+		Protocol: iana.ProtocolReserved,
+		Src:      net.IPv4(192, 0, 2, 1),
+		Dst:      net.IPv4(192, 0, 2, 254),
+	}).Marshal()
+	if err != nil {
+		t.Fatal(err)
+	}
+	greh := []byte{0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00}
+	datagram := append(greh, append(iph, payload...)...)
+
+	t.Run("UDP", func(t *testing.T) {
+		c, err := nettest.NewLocalPacketListener("udp4")
+		if err != nil {
+			t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
+		}
+		defer c.Close()
+		p := ipv4.NewPacketConn(c)
+		t.Run("ToFrom", func(t *testing.T) {
+			testPacketConnConcurrentReadWriteUnicast(t, p, payload, c.LocalAddr())
+		})
+	})
+	t.Run("IP", func(t *testing.T) {
+		switch runtime.GOOS {
+		case "netbsd":
+			t.Skip("need to configure gre on netbsd")
+		case "openbsd":
+			t.Skip("net.inet.gre.allow=0 by default on openbsd")
+		}
+
+		c, err := net.ListenPacket(fmt.Sprintf("ip4:%d", iana.ProtocolGRE), "127.0.0.1")
+		if err != nil {
+			t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
+		}
+		defer c.Close()
+		p := ipv4.NewPacketConn(c)
+		t.Run("ToFrom", func(t *testing.T) {
+			testPacketConnConcurrentReadWriteUnicast(t, p, datagram, c.LocalAddr())
+		})
+	})
+}
+
+func testPacketConnConcurrentReadWriteUnicast(t *testing.T, p *ipv4.PacketConn, data []byte, dst net.Addr) {
+	ifi := nettest.RoutedInterface("ip4", net.FlagUp|net.FlagLoopback)
+	cf := ipv4.FlagTTL | ipv4.FlagSrc | ipv4.FlagDst | ipv4.FlagInterface
+
+	if err := p.SetControlMessage(cf, true); err != nil { // probe before test
+		if nettest.ProtocolNotSupported(err) {
+			t.Skipf("not supported on %s", runtime.GOOS)
+		}
+		t.Fatal(err)
+	}
+
+	var wg sync.WaitGroup
+	reader := func() {
+		defer wg.Done()
+		b := make([]byte, 128)
+		n, cm, _, err := p.ReadFrom(b)
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		if !bytes.Equal(b[:n], data) {
+			t.Errorf("got %#v; want %#v", b[:n], data)
+			return
+		}
+		s := cm.String()
+		if strings.Contains(s, ",") {
+			t.Errorf("should be space-separated values: %s", s)
+			return
+		}
+	}
+	writer := func(toggle bool) {
+		defer wg.Done()
+		cm := ipv4.ControlMessage{
+			Src: net.IPv4(127, 0, 0, 1),
+		}
+		if ifi != nil {
+			cm.IfIndex = ifi.Index
+		}
+		if err := p.SetControlMessage(cf, toggle); err != nil {
+			t.Error(err)
+			return
+		}
+		n, err := p.WriteTo(data, &cm, dst)
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		if n != len(data) {
+			t.Errorf("got %d; want %d", n, len(data))
+			return
+		}
+	}
+
+	const N = 10
+	wg.Add(N)
+	for i := 0; i < N; i++ {
+		go reader()
+	}
+	wg.Add(2 * N)
+	for i := 0; i < 2*N; i++ {
+		go writer(i%2 != 0)
+
+	}
+	wg.Add(N)
+	for i := 0; i < N; i++ {
+		go reader()
+	}
+	wg.Wait()
+}

+ 388 - 0
ipv4/readwrite_go1_9_test.go

@@ -0,0 +1,388 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.9
+
+package ipv4_test
+
+import (
+	"bytes"
+	"fmt"
+	"net"
+	"runtime"
+	"strings"
+	"sync"
+	"testing"
+
+	"golang.org/x/net/internal/iana"
+	"golang.org/x/net/internal/nettest"
+	"golang.org/x/net/ipv4"
+)
+
+func BenchmarkPacketConnReadWriteUnicast(b *testing.B) {
+	switch runtime.GOOS {
+	case "nacl", "plan9", "windows":
+		b.Skipf("not supported on %s", runtime.GOOS)
+	}
+
+	payload := []byte("HELLO-R-U-THERE")
+	iph, err := (&ipv4.Header{
+		Version:  ipv4.Version,
+		Len:      ipv4.HeaderLen,
+		TotalLen: ipv4.HeaderLen + len(payload),
+		TTL:      1,
+		Protocol: iana.ProtocolReserved,
+		Src:      net.IPv4(192, 0, 2, 1),
+		Dst:      net.IPv4(192, 0, 2, 254),
+	}).Marshal()
+	if err != nil {
+		b.Fatal(err)
+	}
+	greh := []byte{0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00}
+	datagram := append(greh, append(iph, payload...)...)
+	bb := make([]byte, 128)
+	cm := ipv4.ControlMessage{
+		Src: net.IPv4(127, 0, 0, 1),
+	}
+	if ifi := nettest.RoutedInterface("ip4", net.FlagUp|net.FlagLoopback); ifi != nil {
+		cm.IfIndex = ifi.Index
+	}
+
+	b.Run("UDP", func(b *testing.B) {
+		c, err := nettest.NewLocalPacketListener("udp4")
+		if err != nil {
+			b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
+		}
+		defer c.Close()
+		p := ipv4.NewPacketConn(c)
+		dst := c.LocalAddr()
+		cf := ipv4.FlagTTL | ipv4.FlagInterface
+		if err := p.SetControlMessage(cf, true); err != nil {
+			b.Fatal(err)
+		}
+		wms := []ipv4.Message{
+			{
+				Buffers: [][]byte{payload},
+				Addr:    dst,
+				OOB:     cm.Marshal(),
+			},
+		}
+		rms := []ipv4.Message{
+			{
+				Buffers: [][]byte{bb},
+				OOB:     ipv4.NewControlMessage(cf),
+			},
+		}
+		b.Run("Net", func(b *testing.B) {
+			for i := 0; i < b.N; i++ {
+				if _, err := c.WriteTo(payload, dst); err != nil {
+					b.Fatal(err)
+				}
+				if _, _, err := c.ReadFrom(bb); err != nil {
+					b.Fatal(err)
+				}
+			}
+		})
+		b.Run("ToFrom", func(b *testing.B) {
+			for i := 0; i < b.N; i++ {
+				if _, err := p.WriteTo(payload, &cm, dst); err != nil {
+					b.Fatal(err)
+				}
+				if _, _, _, err := p.ReadFrom(bb); err != nil {
+					b.Fatal(err)
+				}
+			}
+		})
+		b.Run("Batch", func(b *testing.B) {
+			for i := 0; i < b.N; i++ {
+				if _, err := p.WriteBatch(wms, 0); err != nil {
+					b.Fatal(err)
+				}
+				if _, err := p.ReadBatch(rms, 0); err != nil {
+					b.Fatal(err)
+				}
+			}
+		})
+	})
+	b.Run("IP", func(b *testing.B) {
+		switch runtime.GOOS {
+		case "netbsd":
+			b.Skip("need to configure gre on netbsd")
+		case "openbsd":
+			b.Skip("net.inet.gre.allow=0 by default on openbsd")
+		}
+
+		c, err := net.ListenPacket(fmt.Sprintf("ip4:%d", iana.ProtocolGRE), "127.0.0.1")
+		if err != nil {
+			b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
+		}
+		defer c.Close()
+		p := ipv4.NewPacketConn(c)
+		dst := c.LocalAddr()
+		cf := ipv4.FlagTTL | ipv4.FlagInterface
+		if err := p.SetControlMessage(cf, true); err != nil {
+			b.Fatal(err)
+		}
+		wms := []ipv4.Message{
+			{
+				Buffers: [][]byte{datagram},
+				Addr:    dst,
+				OOB:     cm.Marshal(),
+			},
+		}
+		rms := []ipv4.Message{
+			{
+				Buffers: [][]byte{bb},
+				OOB:     ipv4.NewControlMessage(cf),
+			},
+		}
+		b.Run("Net", func(b *testing.B) {
+			for i := 0; i < b.N; i++ {
+				if _, err := c.WriteTo(datagram, dst); err != nil {
+					b.Fatal(err)
+				}
+				if _, _, err := c.ReadFrom(bb); err != nil {
+					b.Fatal(err)
+				}
+			}
+		})
+		b.Run("ToFrom", func(b *testing.B) {
+			for i := 0; i < b.N; i++ {
+				if _, err := p.WriteTo(datagram, &cm, dst); err != nil {
+					b.Fatal(err)
+				}
+				if _, _, _, err := p.ReadFrom(bb); err != nil {
+					b.Fatal(err)
+				}
+			}
+		})
+		b.Run("Batch", func(b *testing.B) {
+			for i := 0; i < b.N; i++ {
+				if _, err := p.WriteBatch(wms, 0); err != nil {
+					b.Fatal(err)
+				}
+				if _, err := p.ReadBatch(rms, 0); err != nil {
+					b.Fatal(err)
+				}
+			}
+		})
+	})
+}
+
+func TestPacketConnConcurrentReadWriteUnicast(t *testing.T) {
+	switch runtime.GOOS {
+	case "nacl", "plan9", "windows":
+		t.Skipf("not supported on %s", runtime.GOOS)
+	}
+
+	payload := []byte("HELLO-R-U-THERE")
+	iph, err := (&ipv4.Header{
+		Version:  ipv4.Version,
+		Len:      ipv4.HeaderLen,
+		TotalLen: ipv4.HeaderLen + len(payload),
+		TTL:      1,
+		Protocol: iana.ProtocolReserved,
+		Src:      net.IPv4(192, 0, 2, 1),
+		Dst:      net.IPv4(192, 0, 2, 254),
+	}).Marshal()
+	if err != nil {
+		t.Fatal(err)
+	}
+	greh := []byte{0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00}
+	datagram := append(greh, append(iph, payload...)...)
+
+	t.Run("UDP", func(t *testing.T) {
+		c, err := nettest.NewLocalPacketListener("udp4")
+		if err != nil {
+			t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
+		}
+		defer c.Close()
+		p := ipv4.NewPacketConn(c)
+		t.Run("ToFrom", func(t *testing.T) {
+			testPacketConnConcurrentReadWriteUnicast(t, p, payload, c.LocalAddr(), false)
+		})
+		t.Run("Batch", func(t *testing.T) {
+			testPacketConnConcurrentReadWriteUnicast(t, p, payload, c.LocalAddr(), true)
+		})
+	})
+	t.Run("IP", func(t *testing.T) {
+		switch runtime.GOOS {
+		case "netbsd":
+			t.Skip("need to configure gre on netbsd")
+		case "openbsd":
+			t.Skip("net.inet.gre.allow=0 by default on openbsd")
+		}
+
+		c, err := net.ListenPacket(fmt.Sprintf("ip4:%d", iana.ProtocolGRE), "127.0.0.1")
+		if err != nil {
+			t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
+		}
+		defer c.Close()
+		p := ipv4.NewPacketConn(c)
+		t.Run("ToFrom", func(t *testing.T) {
+			testPacketConnConcurrentReadWriteUnicast(t, p, datagram, c.LocalAddr(), false)
+		})
+		t.Run("Batch", func(t *testing.T) {
+			testPacketConnConcurrentReadWriteUnicast(t, p, datagram, c.LocalAddr(), true)
+		})
+	})
+}
+
+func testPacketConnConcurrentReadWriteUnicast(t *testing.T, p *ipv4.PacketConn, data []byte, dst net.Addr, batch bool) {
+	ifi := nettest.RoutedInterface("ip4", net.FlagUp|net.FlagLoopback)
+	cf := ipv4.FlagTTL | ipv4.FlagSrc | ipv4.FlagDst | ipv4.FlagInterface
+
+	if err := p.SetControlMessage(cf, true); err != nil { // probe before test
+		if nettest.ProtocolNotSupported(err) {
+			t.Skipf("not supported on %s", runtime.GOOS)
+		}
+		t.Fatal(err)
+	}
+
+	var wg sync.WaitGroup
+	reader := func() {
+		defer wg.Done()
+		b := make([]byte, 128)
+		n, cm, _, err := p.ReadFrom(b)
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		if !bytes.Equal(b[:n], data) {
+			t.Errorf("got %#v; want %#v", b[:n], data)
+			return
+		}
+		s := cm.String()
+		if strings.Contains(s, ",") {
+			t.Errorf("should be space-separated values: %s", s)
+			return
+		}
+	}
+	batchReader := func() {
+		defer wg.Done()
+		ms := []ipv4.Message{
+			{
+				Buffers: [][]byte{make([]byte, 128)},
+				OOB:     ipv4.NewControlMessage(cf),
+			},
+		}
+		n, err := p.ReadBatch(ms, 0)
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		if n != len(ms) {
+			t.Errorf("got %d; want %d", n, len(ms))
+			return
+		}
+		var cm ipv4.ControlMessage
+		if err := cm.Parse(ms[0].OOB[:ms[0].NN]); err != nil {
+			t.Error(err)
+			return
+		}
+		var b []byte
+		if _, ok := dst.(*net.IPAddr); ok {
+			var h ipv4.Header
+			if err := h.Parse(ms[0].Buffers[0][:ms[0].N]); err != nil {
+				t.Error(err)
+				return
+			}
+			b = ms[0].Buffers[0][h.Len:ms[0].N]
+		} else {
+			b = ms[0].Buffers[0][:ms[0].N]
+		}
+		if !bytes.Equal(b, data) {
+			t.Errorf("got %#v; want %#v", b, data)
+			return
+		}
+		s := cm.String()
+		if strings.Contains(s, ",") {
+			t.Errorf("should be space-separated values: %s", s)
+			return
+		}
+	}
+	writer := func(toggle bool) {
+		defer wg.Done()
+		cm := ipv4.ControlMessage{
+			Src: net.IPv4(127, 0, 0, 1),
+		}
+		if ifi != nil {
+			cm.IfIndex = ifi.Index
+		}
+		if err := p.SetControlMessage(cf, toggle); err != nil {
+			t.Error(err)
+			return
+		}
+		n, err := p.WriteTo(data, &cm, dst)
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		if n != len(data) {
+			t.Errorf("got %d; want %d", n, len(data))
+			return
+		}
+	}
+	batchWriter := func(toggle bool) {
+		defer wg.Done()
+		cm := ipv4.ControlMessage{
+			Src: net.IPv4(127, 0, 0, 1),
+		}
+		if ifi != nil {
+			cm.IfIndex = ifi.Index
+		}
+		if err := p.SetControlMessage(cf, toggle); err != nil {
+			t.Error(err)
+			return
+		}
+		ms := []ipv4.Message{
+			{
+				Buffers: [][]byte{data},
+				OOB:     cm.Marshal(),
+				Addr:    dst,
+			},
+		}
+		n, err := p.WriteBatch(ms, 0)
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		if n != len(ms) {
+			t.Errorf("got %d; want %d", n, len(ms))
+			return
+		}
+		if ms[0].N != len(data) {
+			t.Errorf("got %d; want %d", ms[0].N, len(data))
+			return
+		}
+	}
+
+	const N = 10
+	wg.Add(N)
+	for i := 0; i < N; i++ {
+		if batch {
+			go batchReader()
+		} else {
+			go reader()
+		}
+	}
+	wg.Add(2 * N)
+	for i := 0; i < 2*N; i++ {
+		if batch {
+			go batchWriter(i%2 != 0)
+		} else {
+			go writer(i%2 != 0)
+		}
+
+	}
+	wg.Add(N)
+	for i := 0; i < N; i++ {
+		if batch {
+			go batchReader()
+		} else {
+			go reader()
+		}
+	}
+	wg.Wait()
+}