Pārlūkot izejas kodu

x/net/ipv4: add support for source-specific multicast

This CL introduces methods for the manipulation of source-specifc
group into both PacketConn and RawConn as follows:

JoinSourceSpecificGroup(*net.Interface, net.Addr, net.Addr) error
LeaveSourceSpecificGroup(*net.Interface, net.Addr, net.Addr) error
ExcludeSourceSpecificGroup(*net.Interface, net.Addr, net.Addr) error
IncludeSourceSpecificGroup(*net.Interface, net.Addr, net.Addr) error

Fixes golang/go#8266.

LGTM=iant
R=iant
CC=golang-codereviews
https://golang.org/cl/174030043
Mikio Hara 11 gadi atpakaļ
vecāks
revīzija
a33e90a7ec
6 mainītis faili ar 471 papildinājumiem un 208 dzēšanām
  1. 95 1
      ipv4/dgramopt_posix.go
  2. 16 0
      ipv4/dgramopt_stub.go
  3. 2 2
      ipv4/doc.go
  4. 4 0
      ipv4/header.go
  5. 245 178
      ipv4/multicast_test.go
  6. 109 27
      ipv4/multicastsockopt_test.go

+ 95 - 1
ipv4/dgramopt_posix.go

@@ -93,7 +93,11 @@ func (c *dgramOpt) SetMulticastLoopback(on bool) error {
 	return setInt(fd, &sockOpts[ssoMulticastLoopback], boolint(on))
 }
 
-// JoinGroup joins the group address group on the interface ifi.
+// JoinGroup joins the group address group on the interface ifi. By
+// default all sources that can cast data to group are accepted. It's
+// possible to mute and unmute data transmission from a specific
+// source by using ExcludeSourceSpecificGroup and
+// IncludeSourceSpecificGroup.
 // It uses the system assigned multicast interface when ifi is nil,
 // although this is not recommended because the assignment depends on
 // platforms and sometimes it might require routing configuration.
@@ -113,6 +117,8 @@ func (c *dgramOpt) JoinGroup(ifi *net.Interface, group net.Addr) error {
 }
 
 // LeaveGroup leaves the group address group on the interface ifi.
+// It's allowed to leave the group which is formed by
+// JoinSourceSpecificGroup for convenience.
 func (c *dgramOpt) LeaveGroup(ifi *net.Interface, group net.Addr) error {
 	if !c.ok() {
 		return syscall.EINVAL
@@ -127,3 +133,91 @@ func (c *dgramOpt) LeaveGroup(ifi *net.Interface, group net.Addr) error {
 	}
 	return setGroup(fd, &sockOpts[ssoLeaveGroup], ifi, grp)
 }
+
+// JoinSourceSpecificGroup joins the source-specific group consisting
+// group and source on the interface ifi. It uses the system assigned
+// multicast interface when ifi is nil, although this is not
+// recommended because the assignment depends on platforms and
+// sometimes it might require routing configuration.
+func (c *dgramOpt) JoinSourceSpecificGroup(ifi *net.Interface, group, source net.Addr) error {
+	if !c.ok() {
+		return syscall.EINVAL
+	}
+	fd, err := c.sysfd()
+	if err != nil {
+		return err
+	}
+	grp := netAddrToIP4(group)
+	if grp == nil {
+		return errMissingAddress
+	}
+	src := netAddrToIP4(source)
+	if src == nil {
+		return errMissingAddress
+	}
+	return setSourceGroup(fd, &sockOpts[ssoJoinSourceGroup], ifi, grp, src)
+}
+
+// LeaveSourceSpecificGroup leaves the source-specific group on the
+// interface ifi.
+func (c *dgramOpt) LeaveSourceSpecificGroup(ifi *net.Interface, group, source net.Addr) error {
+	if !c.ok() {
+		return syscall.EINVAL
+	}
+	fd, err := c.sysfd()
+	if err != nil {
+		return err
+	}
+	grp := netAddrToIP4(group)
+	if grp == nil {
+		return errMissingAddress
+	}
+	src := netAddrToIP4(source)
+	if src == nil {
+		return errMissingAddress
+	}
+	return setSourceGroup(fd, &sockOpts[ssoLeaveSourceGroup], ifi, grp, src)
+}
+
+// ExcludeSourceSpecificGroup excludes the source-specific group from
+// the already joined groups by either JoinGroup or
+// JoinSourceSpecificGroup on the interface ifi.
+func (c *dgramOpt) ExcludeSourceSpecificGroup(ifi *net.Interface, group, source net.Addr) error {
+	if !c.ok() {
+		return syscall.EINVAL
+	}
+	fd, err := c.sysfd()
+	if err != nil {
+		return err
+	}
+	grp := netAddrToIP4(group)
+	if grp == nil {
+		return errMissingAddress
+	}
+	src := netAddrToIP4(source)
+	if src == nil {
+		return errMissingAddress
+	}
+	return setSourceGroup(fd, &sockOpts[ssoBlockSourceGroup], ifi, grp, src)
+}
+
+// IncludeSourceSpecificGroup includes the excluded source-specific
+// group by ExcludeSourceSpecificGroup again on the interface ifi.
+func (c *dgramOpt) IncludeSourceSpecificGroup(ifi *net.Interface, group, source net.Addr) error {
+	if !c.ok() {
+		return syscall.EINVAL
+	}
+	fd, err := c.sysfd()
+	if err != nil {
+		return err
+	}
+	grp := netAddrToIP4(group)
+	if grp == nil {
+		return errMissingAddress
+	}
+	src := netAddrToIP4(source)
+	if src == nil {
+		return errMissingAddress
+	}
+	return setSourceGroup(fd, &sockOpts[ssoUnblockSourceGroup], ifi, grp, src)
+}

+ 16 - 0
ipv4/dgramopt_stub.go

@@ -39,3 +39,19 @@ func (c *dgramOpt) JoinGroup(ifi *net.Interface, grp net.Addr) error {
 func (c *dgramOpt) LeaveGroup(ifi *net.Interface, grp net.Addr) error {
 	return errOpNoSupport
 }
+
+func (c *dgramOpt) JoinSourceSpecificGroup(ifi *net.Interface, group, source net.Addr) error {
+	return errOpNoSupport
+}
+
+func (c *dgramOpt) LeaveSourceSpecificGroup(ifi *net.Interface, group, source net.Addr) error {
+	return errOpNoSupport
+}
+
+func (c *dgramOpt) ExcludeSourceSpecificGroup(ifi *net.Interface, group, source net.Addr) error {
+	return errOpNoSupport
+}
+
+func (c *dgramOpt) IncludeSourceSpecificGroup(ifi *net.Interface, group, source net.Addr) error {
+	return errOpNoSupport
+}

+ 2 - 2
ipv4/doc.go

@@ -7,8 +7,8 @@
 //
 // The package provides IP-level socket options that allow
 // manipulation of IPv4 facilities.  The IPv4 and basic host
-// requirements for IPv4 are defined in RFC 791, RFC 1112 and RFC
-// 1122.
+// requirements for IPv4 are defined in RFC 791, RFC 1112, RFC 1122,
+// RFC 3678 and RFC 4607.
 //
 //
 // Unicasting

+ 4 - 0
ipv4/header.go

@@ -29,6 +29,10 @@ var (
 //	http://tools.ietf.org/html/rfc1112
 // RFC 1122  Requirements for Internet Hosts
 //	http://tools.ietf.org/html/rfc1122
+// RFC 3678  Socket Interface Extensions for Multicast Source Filters
+//	http://tools.ietf.org/html/rfc3678
+// RFC 4607  Source-Specific Multicast for IP
+//	http://tools.ietf.org/html/rfc4607
 
 const (
 	Version      = 4  // protocol version

+ 245 - 178
ipv4/multicast_test.go

@@ -5,6 +5,7 @@
 package ipv4_test
 
 import (
+	"bytes"
 	"net"
 	"os"
 	"runtime"
@@ -17,6 +18,15 @@ import (
 	"golang.org/x/net/ipv4"
 )
 
+var packetConnReadWriteMulticastUDPTests = []struct {
+	addr     string
+	grp, src *net.UDPAddr
+}{
+	{"224.0.0.0:0", &net.UDPAddr{IP: net.IPv4(224, 0, 0, 254)}, nil}, // see RFC 4727
+
+	{"232.0.1.0:0", &net.UDPAddr{IP: net.IPv4(232, 0, 1, 254)}, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}}, // see RFC 5771
+}
+
 func TestPacketConnReadWriteMulticastUDP(t *testing.T) {
 	switch runtime.GOOS {
 	case "nacl", "plan9", "solaris", "windows":
@@ -27,63 +37,86 @@ func TestPacketConnReadWriteMulticastUDP(t *testing.T) {
 		t.Skipf("not available on %q", runtime.GOOS)
 	}
 
-	c, err := net.ListenPacket("udp4", "224.0.0.0:0") // see RFC 4727
-	if err != nil {
-		t.Fatalf("net.ListenPacket failed: %v", err)
-	}
-	defer c.Close()
-
-	_, port, err := net.SplitHostPort(c.LocalAddr().String())
-	if err != nil {
-		t.Fatalf("net.SplitHostPort failed: %v", err)
-	}
-	dst, err := net.ResolveUDPAddr("udp4", "224.0.0.254:"+port) // see RFC 4727
-	if err != nil {
-		t.Fatalf("net.ResolveUDPAddr failed: %v", err)
-	}
-
-	p := ipv4.NewPacketConn(c)
-	defer p.Close()
-	if err := p.JoinGroup(ifi, dst); err != nil {
-		t.Fatalf("ipv4.PacketConn.JoinGroup on %v failed: %v", ifi, err)
-	}
-	if err := p.SetMulticastInterface(ifi); err != nil {
-		t.Fatalf("ipv4.PacketConn.SetMulticastInterface failed: %v", err)
-	}
-	if _, err := p.MulticastInterface(); err != nil {
-		t.Fatalf("ipv4.PacketConn.MulticastInterface failed: %v", err)
-	}
-	if err := p.SetMulticastLoopback(true); err != nil {
-		t.Fatalf("ipv4.PacketConn.SetMulticastLoopback failed: %v", err)
-	}
-	if _, err := p.MulticastLoopback(); err != nil {
-		t.Fatalf("ipv4.PacketConn.MulticastLoopback failed: %v", err)
-	}
-	cf := ipv4.FlagTTL | ipv4.FlagDst | ipv4.FlagInterface
+	for _, tt := range packetConnReadWriteMulticastUDPTests {
+		c, err := net.ListenPacket("udp4", tt.addr)
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer c.Close()
 
-	for i, toggle := range []bool{true, false, true} {
-		if err := p.SetControlMessage(cf, toggle); err != nil {
-			if nettest.ProtocolNotSupported(err) {
-				t.Skipf("not supported on %q", runtime.GOOS)
+		grp := *tt.grp
+		grp.Port = c.LocalAddr().(*net.UDPAddr).Port
+		p := ipv4.NewPacketConn(c)
+		defer p.Close()
+		if tt.src == nil {
+			if err := p.JoinGroup(ifi, &grp); err != nil {
+				t.Fatal(err)
 			}
-			t.Fatalf("ipv4.PacketConn.SetControlMessage failed: %v", err)
+			defer p.LeaveGroup(ifi, &grp)
+		} else {
+			if err := p.JoinSourceSpecificGroup(ifi, &grp, tt.src); err != nil {
+				switch runtime.GOOS {
+				case "freebsd", "linux":
+				default: // platforms that don't support IGMPv2/3 fail here
+					t.Logf("not supported on %q", runtime.GOOS)
+					continue
+				}
+				t.Fatal(err)
+			}
+			defer p.LeaveSourceSpecificGroup(ifi, &grp, tt.src)
 		}
-		if err := p.SetDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
-			t.Fatalf("ipv4.PacketConn.SetDeadline failed: %v", err)
+		if err := p.SetMulticastInterface(ifi); err != nil {
+			t.Fatal(err)
 		}
-		p.SetMulticastTTL(i + 1)
-		if _, err := p.WriteTo([]byte("HELLO-R-U-THERE"), nil, dst); err != nil {
-			t.Fatalf("ipv4.PacketConn.WriteTo failed: %v", err)
+		if _, err := p.MulticastInterface(); err != nil {
+			t.Fatal(err)
 		}
-		b := make([]byte, 128)
-		if _, cm, _, err := p.ReadFrom(b); err != nil {
-			t.Fatalf("ipv4.PacketConn.ReadFrom failed: %v", err)
-		} else {
-			t.Logf("rcvd cmsg: %v", cm)
+		if err := p.SetMulticastLoopback(true); err != nil {
+			t.Fatal(err)
+		}
+		if _, err := p.MulticastLoopback(); err != nil {
+			t.Fatal(err)
+		}
+		cf := ipv4.FlagTTL | ipv4.FlagDst | ipv4.FlagInterface
+		wb := []byte("HELLO-R-U-THERE")
+
+		for i, toggle := range []bool{true, false, true} {
+			if err := p.SetControlMessage(cf, toggle); err != nil {
+				if nettest.ProtocolNotSupported(err) {
+					t.Logf("not supported on %q", runtime.GOOS)
+					continue
+				}
+				t.Fatal(err)
+			}
+			if err := p.SetDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
+				t.Fatal(err)
+			}
+			p.SetMulticastTTL(i + 1)
+			if n, err := p.WriteTo(wb, nil, &grp); err != nil {
+				t.Fatal(err)
+			} else if n != len(wb) {
+				t.Fatalf("got %v; expected %v", n, len(wb))
+			}
+			rb := make([]byte, 128)
+			if n, cm, _, err := p.ReadFrom(rb); err != nil {
+				t.Fatal(err)
+			} else if !bytes.Equal(rb[:n], wb) {
+				t.Fatalf("got %v; expected %v", rb[:n], wb)
+			} else {
+				t.Logf("rcvd cmsg: %v", cm)
+			}
 		}
 	}
 }
 
+var packetConnReadWriteMulticastICMPTests = []struct {
+	grp, src *net.IPAddr
+}{
+	{&net.IPAddr{IP: net.IPv4(224, 0, 0, 254)}, nil}, // see RFC 4727
+
+	{&net.IPAddr{IP: net.IPv4(232, 0, 1, 254)}, &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}}, // see RFC 5771
+}
+
 func TestPacketConnReadWriteMulticastICMP(t *testing.T) {
 	switch runtime.GOOS {
 	case "nacl", "plan9", "solaris", "windows":
@@ -97,79 +130,101 @@ func TestPacketConnReadWriteMulticastICMP(t *testing.T) {
 		t.Skipf("not available on %q", runtime.GOOS)
 	}
 
-	c, err := net.ListenPacket("ip4:icmp", "0.0.0.0")
-	if err != nil {
-		t.Fatalf("net.ListenPacket failed: %v", err)
-	}
-	defer c.Close()
-
-	dst, err := net.ResolveIPAddr("ip4", "224.0.0.254") // see RFC 4727
-	if err != nil {
-		t.Fatalf("net.ResolveIPAddr failed: %v", err)
-	}
-
-	p := ipv4.NewPacketConn(c)
-	defer p.Close()
-	if err := p.JoinGroup(ifi, dst); err != nil {
-		t.Fatalf("ipv4.PacketConn.JoinGroup on %v failed: %v", ifi, err)
-	}
-	if err := p.SetMulticastInterface(ifi); err != nil {
-		t.Fatalf("ipv4.PacketConn.SetMulticastInterface failed: %v", err)
-	}
-	if _, err := p.MulticastInterface(); err != nil {
-		t.Fatalf("ipv4.PacketConn.MulticastInterface failed: %v", err)
-	}
-	if err := p.SetMulticastLoopback(true); err != nil {
-		t.Fatalf("ipv4.PacketConn.SetMulticastLoopback failed: %v", err)
-	}
-	if _, err := p.MulticastLoopback(); err != nil {
-		t.Fatalf("ipv4.PacketConn.MulticastLoopback failed: %v", err)
-	}
-	cf := ipv4.FlagTTL | ipv4.FlagDst | ipv4.FlagInterface
-
-	for i, toggle := range []bool{true, false, true} {
-		wb, err := (&icmp.Message{
-			Type: ipv4.ICMPTypeEcho, Code: 0,
-			Body: &icmp.Echo{
-				ID: os.Getpid() & 0xffff, Seq: i + 1,
-				Data: []byte("HELLO-R-U-THERE"),
-			},
-		}).Marshal(nil)
+	for _, tt := range packetConnReadWriteMulticastICMPTests {
+		c, err := net.ListenPacket("ip4:icmp", "0.0.0.0")
 		if err != nil {
-			t.Fatalf("icmp.Message.Marshal failed: %v", err)
+			t.Fatal(err)
 		}
-		if err := p.SetControlMessage(cf, toggle); err != nil {
-			if nettest.ProtocolNotSupported(err) {
-				t.Skipf("not supported on %q", runtime.GOOS)
+		defer c.Close()
+
+		p := ipv4.NewPacketConn(c)
+		defer p.Close()
+		if tt.src == nil {
+			if err := p.JoinGroup(ifi, tt.grp); err != nil {
+				t.Fatal(err)
 			}
-			t.Fatalf("ipv4.PacketConn.SetControlMessage failed: %v", err)
+			defer p.LeaveGroup(ifi, tt.grp)
+		} else {
+			if err := p.JoinSourceSpecificGroup(ifi, tt.grp, tt.src); err != nil {
+				switch runtime.GOOS {
+				case "freebsd", "linux":
+				default: // platforms that don't support IGMPv2/3 fail here
+					t.Logf("not supported on %q", runtime.GOOS)
+					continue
+				}
+				t.Fatal(err)
+			}
+			defer p.LeaveSourceSpecificGroup(ifi, tt.grp, tt.src)
 		}
-		if err := p.SetDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
-			t.Fatalf("ipv4.PacketConn.SetDeadline failed: %v", err)
+		if err := p.SetMulticastInterface(ifi); err != nil {
+			t.Fatal(err)
 		}
-		p.SetMulticastTTL(i + 1)
-		if _, err := p.WriteTo(wb, nil, dst); err != nil {
-			t.Fatalf("ipv4.PacketConn.WriteTo failed: %v", err)
+		if _, err := p.MulticastInterface(); err != nil {
+			t.Fatal(err)
 		}
-		b := make([]byte, 128)
-		if n, cm, _, err := p.ReadFrom(b); err != nil {
-			t.Fatalf("ipv4.PacketConn.ReadFrom failed: %v", err)
-		} else {
-			t.Logf("rcvd cmsg: %v", cm)
-			m, err := icmp.ParseMessage(iana.ProtocolICMP, b[:n])
+		if err := p.SetMulticastLoopback(true); err != nil {
+			t.Fatal(err)
+		}
+		if _, err := p.MulticastLoopback(); err != nil {
+			t.Fatal(err)
+		}
+		cf := ipv4.FlagTTL | ipv4.FlagDst | ipv4.FlagInterface
+
+		for i, toggle := range []bool{true, false, true} {
+			wb, err := (&icmp.Message{
+				Type: ipv4.ICMPTypeEcho, Code: 0,
+				Body: &icmp.Echo{
+					ID: os.Getpid() & 0xffff, Seq: i + 1,
+					Data: []byte("HELLO-R-U-THERE"),
+				},
+			}).Marshal(nil)
 			if err != nil {
-				t.Fatalf("icmp.ParseMessage failed: %v", err)
+				t.Fatal(err)
+			}
+			if err := p.SetControlMessage(cf, toggle); err != nil {
+				if nettest.ProtocolNotSupported(err) {
+					t.Logf("not supported on %q", runtime.GOOS)
+					continue
+				}
+				t.Fatal(err)
+			}
+			if err := p.SetDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
+				t.Fatal(err)
 			}
-			switch {
-			case m.Type == ipv4.ICMPTypeEchoReply && m.Code == 0: // net.inet.icmp.bmcastecho=1
-			case m.Type == ipv4.ICMPTypeEcho && m.Code == 0: // net.inet.icmp.bmcastecho=0
-			default:
-				t.Fatalf("got type=%v, code=%v; expected type=%v, code=%v", m.Type, m.Code, ipv4.ICMPTypeEchoReply, 0)
+			p.SetMulticastTTL(i + 1)
+			if n, err := p.WriteTo(wb, nil, tt.grp); err != nil {
+				t.Fatal(err)
+			} else if n != len(wb) {
+				t.Fatalf("got %v; expected %v", n, len(wb))
+			}
+			rb := make([]byte, 128)
+			if n, cm, _, err := p.ReadFrom(rb); err != nil {
+				t.Fatal(err)
+			} else {
+				t.Logf("rcvd cmsg: %v", cm)
+				m, err := icmp.ParseMessage(iana.ProtocolICMP, rb[:n])
+				if err != nil {
+					t.Fatal(err)
+				}
+				switch {
+				case m.Type == ipv4.ICMPTypeEchoReply && m.Code == 0: // net.inet.icmp.bmcastecho=1
+				case m.Type == ipv4.ICMPTypeEcho && m.Code == 0: // net.inet.icmp.bmcastecho=0
+				default:
+					t.Fatalf("got type=%v, code=%v; expected type=%v, code=%v", m.Type, m.Code, ipv4.ICMPTypeEchoReply, 0)
+				}
 			}
 		}
 	}
 }
 
+var rawConnReadWriteMulticastICMPTests = []struct {
+	grp, src *net.IPAddr
+}{
+	{&net.IPAddr{IP: net.IPv4(224, 0, 0, 254)}, nil}, // see RFC 4727
+
+	{&net.IPAddr{IP: net.IPv4(232, 0, 1, 254)}, &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}}, // see RFC 5771
+}
+
 func TestRawConnReadWriteMulticastICMP(t *testing.T) {
 	if testing.Short() {
 		t.Skip("to avoid external network")
@@ -182,85 +237,97 @@ func TestRawConnReadWriteMulticastICMP(t *testing.T) {
 		t.Skipf("not available on %q", runtime.GOOS)
 	}
 
-	c, err := net.ListenPacket("ip4:icmp", "0.0.0.0")
-	if err != nil {
-		t.Fatalf("net.ListenPacket failed: %v", err)
-	}
-	defer c.Close()
-
-	dst, err := net.ResolveIPAddr("ip4", "224.0.0.254") // see RFC 4727
-	if err != nil {
-		t.Fatalf("ResolveIPAddr failed: %v", err)
-	}
-
-	r, err := ipv4.NewRawConn(c)
-	if err != nil {
-		t.Fatalf("ipv4.NewRawConn failed: %v", err)
-	}
-	defer r.Close()
-	if err := r.JoinGroup(ifi, dst); err != nil {
-		t.Fatalf("ipv4.RawConn.JoinGroup on %v failed: %v", ifi, err)
-	}
-	if err := r.SetMulticastInterface(ifi); err != nil {
-		t.Fatalf("ipv4.RawConn.SetMulticastInterface failed: %v", err)
-	}
-	if _, err := r.MulticastInterface(); err != nil {
-		t.Fatalf("ipv4.RawConn.MulticastInterface failed: %v", err)
-	}
-	if err := r.SetMulticastLoopback(true); err != nil {
-		t.Fatalf("ipv4.RawConn.SetMulticastLoopback failed: %v", err)
-	}
-	if _, err := r.MulticastLoopback(); err != nil {
-		t.Fatalf("ipv4.RawConn.MulticastLoopback failed: %v", err)
-	}
-	cf := ipv4.FlagTTL | ipv4.FlagDst | ipv4.FlagInterface
-
-	for i, toggle := range []bool{true, false, true} {
-		wb, err := (&icmp.Message{
-			Type: ipv4.ICMPTypeEcho, Code: 0,
-			Body: &icmp.Echo{
-				ID: os.Getpid() & 0xffff, Seq: i + 1,
-				Data: []byte("HELLO-R-U-THERE"),
-			},
-		}).Marshal(nil)
+	for _, tt := range rawConnReadWriteMulticastICMPTests {
+		c, err := net.ListenPacket("ip4:icmp", "0.0.0.0")
 		if err != nil {
-			t.Fatalf("icmp.Message.Marshal failed: %v", err)
+			t.Fatal(err)
 		}
-		wh := &ipv4.Header{
-			Version:  ipv4.Version,
-			Len:      ipv4.HeaderLen,
-			TOS:      i + 1,
-			TotalLen: ipv4.HeaderLen + len(wb),
-			Protocol: 1,
-			Dst:      dst.IP,
+		defer c.Close()
+
+		r, err := ipv4.NewRawConn(c)
+		if err != nil {
+			t.Fatal(err)
 		}
-		if err := r.SetControlMessage(cf, toggle); err != nil {
-			if nettest.ProtocolNotSupported(err) {
-				t.Skipf("not supported on %q", runtime.GOOS)
+		defer r.Close()
+		if tt.src == nil {
+			if err := r.JoinGroup(ifi, tt.grp); err != nil {
+				t.Fatal(err)
+			}
+			defer r.LeaveGroup(ifi, tt.grp)
+		} else {
+			if err := r.JoinSourceSpecificGroup(ifi, tt.grp, tt.src); err != nil {
+				switch runtime.GOOS {
+				case "freebsd", "linux":
+				default: // platforms that don't support IGMPv2/3 fail here
+					t.Logf("not supported on %q", runtime.GOOS)
+					continue
+				}
+				t.Fatal(err)
 			}
-			t.Fatalf("ipv4.RawConn.SetControlMessage failed: %v", err)
+			defer r.LeaveSourceSpecificGroup(ifi, tt.grp, tt.src)
 		}
-		if err := r.SetDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
-			t.Fatalf("ipv4.RawConn.SetDeadline failed: %v", err)
+		if err := r.SetMulticastInterface(ifi); err != nil {
+			t.Fatal(err)
 		}
-		r.SetMulticastTTL(i + 1)
-		if err := r.WriteTo(wh, wb, nil); err != nil {
-			t.Fatalf("ipv4.RawConn.WriteTo failed: %v", err)
+		if _, err := r.MulticastInterface(); err != nil {
+			t.Fatal(err)
 		}
-		rb := make([]byte, ipv4.HeaderLen+128)
-		if rh, b, cm, err := r.ReadFrom(rb); err != nil {
-			t.Fatalf("ipv4.RawConn.ReadFrom failed: %v", err)
-		} else {
-			t.Logf("rcvd cmsg: %v", cm)
-			m, err := icmp.ParseMessage(iana.ProtocolICMP, b)
+		if err := r.SetMulticastLoopback(true); err != nil {
+			t.Fatal(err)
+		}
+		if _, err := r.MulticastLoopback(); err != nil {
+			t.Fatal(err)
+		}
+		cf := ipv4.FlagTTL | ipv4.FlagDst | ipv4.FlagInterface
+
+		for i, toggle := range []bool{true, false, true} {
+			wb, err := (&icmp.Message{
+				Type: ipv4.ICMPTypeEcho, Code: 0,
+				Body: &icmp.Echo{
+					ID: os.Getpid() & 0xffff, Seq: i + 1,
+					Data: []byte("HELLO-R-U-THERE"),
+				},
+			}).Marshal(nil)
 			if err != nil {
-				t.Fatalf("icmp.ParseMessage failed: %v", err)
+				t.Fatal(err)
+			}
+			wh := &ipv4.Header{
+				Version:  ipv4.Version,
+				Len:      ipv4.HeaderLen,
+				TOS:      i + 1,
+				TotalLen: ipv4.HeaderLen + len(wb),
+				Protocol: 1,
+				Dst:      tt.grp.IP,
+			}
+			if err := r.SetControlMessage(cf, toggle); err != nil {
+				if nettest.ProtocolNotSupported(err) {
+					t.Logf("not supported on %q", runtime.GOOS)
+					continue
+				}
+				t.Fatal(err)
+			}
+			if err := r.SetDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
+				t.Fatal(err)
+			}
+			r.SetMulticastTTL(i + 1)
+			if err := r.WriteTo(wh, wb, nil); err != nil {
+				t.Fatal(err)
 			}
-			switch {
-			case (rh.Dst.IsLoopback() || rh.Dst.IsLinkLocalUnicast() || rh.Dst.IsGlobalUnicast()) && m.Type == ipv4.ICMPTypeEchoReply && m.Code == 0: // net.inet.icmp.bmcastecho=1
-			case rh.Dst.IsMulticast() && m.Type == ipv4.ICMPTypeEcho && m.Code == 0: // net.inet.icmp.bmcastecho=0
-			default:
-				t.Fatalf("got type=%v, code=%v; expected type=%v, code=%v", m.Type, m.Code, ipv4.ICMPTypeEchoReply, 0)
+			rb := make([]byte, ipv4.HeaderLen+128)
+			if rh, b, cm, err := r.ReadFrom(rb); err != nil {
+				t.Fatal(err)
+			} else {
+				t.Logf("rcvd cmsg: %v", cm)
+				m, err := icmp.ParseMessage(iana.ProtocolICMP, b)
+				if err != nil {
+					t.Fatal(err)
+				}
+				switch {
+				case (rh.Dst.IsLoopback() || rh.Dst.IsLinkLocalUnicast() || rh.Dst.IsGlobalUnicast()) && m.Type == ipv4.ICMPTypeEchoReply && m.Code == 0: // net.inet.icmp.bmcastecho=1
+				case rh.Dst.IsMulticast() && m.Type == ipv4.ICMPTypeEcho && m.Code == 0: // net.inet.icmp.bmcastecho=0
+				default:
+					t.Fatalf("got type=%v, code=%v; expected type=%v, code=%v", m.Type, m.Code, ipv4.ICMPTypeEchoReply, 0)
+				}
 			}
 		}
 	}

+ 109 - 27
ipv4/multicastsockopt_test.go

@@ -16,10 +16,13 @@ import (
 
 var packetConnMulticastSocketOptionTests = []struct {
 	net, proto, addr string
-	gaddr            net.Addr
+	grp, src         net.Addr
 }{
-	{"udp4", "", "224.0.0.0:0", &net.UDPAddr{IP: net.IPv4(224, 0, 0, 249)}}, // see RFC 4727
-	{"ip4", ":icmp", "0.0.0.0", &net.IPAddr{IP: net.IPv4(224, 0, 0, 250)}},  // see RFC 4727
+	{"udp4", "", "224.0.0.0:0", &net.UDPAddr{IP: net.IPv4(224, 0, 0, 249)}, nil}, // see RFC 4727
+	{"ip4", ":icmp", "0.0.0.0", &net.IPAddr{IP: net.IPv4(224, 0, 0, 250)}, nil},  // see RFC 4727
+
+	{"udp4", "", "232.0.0.0:0", &net.UDPAddr{IP: net.IPv4(232, 0, 1, 249)}, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}}, // see RFC 5771
+	{"ip4", ":icmp", "0.0.0.0", &net.IPAddr{IP: net.IPv4(232, 0, 1, 250)}, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}},  // see RFC 5771
 }
 
 func TestPacketConnMulticastSocketOptions(t *testing.T) {
@@ -34,18 +37,33 @@ func TestPacketConnMulticastSocketOptions(t *testing.T) {
 
 	for _, tt := range packetConnMulticastSocketOptionTests {
 		if tt.net == "ip4" && os.Getuid() != 0 {
-			t.Skip("must be root")
+			t.Log("must be root")
+			continue
 		}
 		c, err := net.ListenPacket(tt.net+tt.proto, tt.addr)
 		if err != nil {
-			t.Fatalf("net.ListenPacket failed: %v", err)
+			t.Fatal(err)
 		}
 		defer c.Close()
+		p := ipv4.NewPacketConn(c)
+		defer p.Close()
 
-		testMulticastSocketOptions(t, ipv4.NewPacketConn(c), ifi, tt.gaddr)
+		if tt.src == nil {
+			testMulticastSocketOptions(t, p, ifi, tt.grp)
+		} else {
+			testSourceSpecificMulticastSocketOptions(t, p, ifi, tt.grp, tt.src)
+		}
 	}
 }
 
+var rawConnMulticastSocketOptionTests = []struct {
+	grp, src net.Addr
+}{
+	{&net.IPAddr{IP: net.IPv4(224, 0, 0, 250)}, nil}, // see RFC 4727
+
+	{&net.IPAddr{IP: net.IPv4(232, 0, 1, 250)}, &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}}, // see RFC 5771
+}
+
 func TestRawConnMulticastSocketOptions(t *testing.T) {
 	switch runtime.GOOS {
 	case "nacl", "plan9", "solaris":
@@ -59,18 +77,24 @@ func TestRawConnMulticastSocketOptions(t *testing.T) {
 		t.Skipf("not available on %q", runtime.GOOS)
 	}
 
-	c, err := net.ListenPacket("ip4:icmp", "0.0.0.0")
-	if err != nil {
-		t.Fatalf("net.ListenPacket failed: %v", err)
-	}
-	defer c.Close()
+	for _, tt := range rawConnMulticastSocketOptionTests {
+		c, err := net.ListenPacket("ip4:icmp", "0.0.0.0")
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer c.Close()
+		r, err := ipv4.NewRawConn(c)
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer r.Close()
 
-	r, err := ipv4.NewRawConn(c)
-	if err != nil {
-		t.Fatalf("ipv4.NewRawConn failed: %v", err)
+		if tt.src == nil {
+			testMulticastSocketOptions(t, r, ifi, tt.grp)
+		} else {
+			testSourceSpecificMulticastSocketOptions(t, r, ifi, tt.grp, tt.src)
+		}
 	}
-
-	testMulticastSocketOptions(t, r, ifi, &net.IPAddr{IP: net.IPv4(224, 0, 0, 250)}) /// see RFC 4727
 }
 
 type testIPv4MulticastConn interface {
@@ -80,34 +104,92 @@ type testIPv4MulticastConn interface {
 	SetMulticastLoopback(bool) error
 	JoinGroup(*net.Interface, net.Addr) error
 	LeaveGroup(*net.Interface, net.Addr) error
+	JoinSourceSpecificGroup(*net.Interface, net.Addr, net.Addr) error
+	LeaveSourceSpecificGroup(*net.Interface, net.Addr, net.Addr) error
+	ExcludeSourceSpecificGroup(*net.Interface, net.Addr, net.Addr) error
+	IncludeSourceSpecificGroup(*net.Interface, net.Addr, net.Addr) error
 }
 
-func testMulticastSocketOptions(t *testing.T, c testIPv4MulticastConn, ifi *net.Interface, gaddr net.Addr) {
+func testMulticastSocketOptions(t *testing.T, c testIPv4MulticastConn, ifi *net.Interface, grp net.Addr) {
 	const ttl = 255
 	if err := c.SetMulticastTTL(ttl); err != nil {
-		t.Fatalf("ipv4.PacketConn.SetMulticastTTL failed: %v", err)
+		t.Error(err)
+		return
 	}
 	if v, err := c.MulticastTTL(); err != nil {
-		t.Fatalf("ipv4.PacketConn.MulticastTTL failed: %v", err)
+		t.Error(err)
+		return
 	} else if v != ttl {
-		t.Fatalf("got unexpected multicast TTL value %v; expected %v", v, ttl)
+		t.Errorf("got unexpected multicast ttl %v; expected %v", v, ttl)
+		return
 	}
 
 	for _, toggle := range []bool{true, false} {
 		if err := c.SetMulticastLoopback(toggle); err != nil {
-			t.Fatalf("ipv4.PacketConn.SetMulticastLoopback failed: %v", err)
+			t.Error(err)
+			return
 		}
 		if v, err := c.MulticastLoopback(); err != nil {
-			t.Fatalf("ipv4.PacketConn.MulticastLoopback failed: %v", err)
+			t.Error(err)
+			return
 		} else if v != toggle {
-			t.Fatalf("got unexpected multicast loopback %v; expected %v", v, toggle)
+			t.Errorf("got unexpected multicast loopback %v; expected %v", v, toggle)
+			return
 		}
 	}
 
-	if err := c.JoinGroup(ifi, gaddr); err != nil {
-		t.Fatalf("ipv4.PacketConn.JoinGroup(%v, %v) failed: %v", ifi, gaddr, err)
+	if err := c.JoinGroup(ifi, grp); err != nil {
+		t.Error(err)
+		return
+	}
+	if err := c.LeaveGroup(ifi, grp); err != nil {
+		t.Error(err)
+		return
+	}
+}
+
+func testSourceSpecificMulticastSocketOptions(t *testing.T, c testIPv4MulticastConn, ifi *net.Interface, grp, src net.Addr) {
+	// MCAST_JOIN_GROUP -> MCAST_BLOCK_SOURCE -> MCAST_UNBLOCK_SOURCE -> MCAST_LEAVE_GROUP
+	if err := c.JoinGroup(ifi, grp); err != nil {
+		t.Error(err)
+		return
+	}
+	if err := c.ExcludeSourceSpecificGroup(ifi, grp, src); err != nil {
+		switch runtime.GOOS {
+		case "freebsd", "linux":
+		default: // platforms that don't support IGMPv2/3 fail here
+			t.Logf("not supported on %q", runtime.GOOS)
+			return
+		}
+		t.Error(err)
+		return
+	}
+	if err := c.IncludeSourceSpecificGroup(ifi, grp, src); err != nil {
+		t.Error(err)
+		return
+	}
+	if err := c.LeaveGroup(ifi, grp); err != nil {
+		t.Error(err)
+		return
+	}
+
+	// MCAST_JOIN_SOURCE_GROUP -> MCAST_LEAVE_SOURCE_GROUP
+	if err := c.JoinSourceSpecificGroup(ifi, grp, src); err != nil {
+		t.Error(err)
+		return
+	}
+	if err := c.LeaveSourceSpecificGroup(ifi, grp, src); err != nil {
+		t.Error(err)
+		return
+	}
+
+	// MCAST_JOIN_SOURCE_GROUP -> MCAST_LEAVE_GROUP
+	if err := c.JoinSourceSpecificGroup(ifi, grp, src); err != nil {
+		t.Error(err)
+		return
 	}
-	if err := c.LeaveGroup(ifi, gaddr); err != nil {
-		t.Fatalf("ipv4.PacketConn.LeaveGroup(%v, %v) failed: %v", ifi, gaddr, err)
+	if err := c.LeaveGroup(ifi, grp); err != nil {
+		t.Error(err)
+		return
 	}
 }