Browse Source

netutil: add dualstack to linux_route

in v3.1.0 netutil couldn't get default interface for ipv6only hosts

Fixes #7219
felixoid 9 years ago
parent
commit
0f53ad0b84
4 changed files with 126 additions and 60 deletions
  1. 2 2
      embed/config.go
  2. 17 10
      pkg/netutil/isolate_linux.go
  3. 97 46
      pkg/netutil/routes_linux.go
  4. 10 2
      pkg/netutil/routes_linux_test.go

+ 2 - 2
embed/config.go

@@ -68,8 +68,8 @@ func init() {
 		return
 		return
 	}
 	}
 	// found default host, advertise on it
 	// found default host, advertise on it
-	DefaultInitialAdvertisePeerURLs = "http://" + ip + ":2380"
-	DefaultAdvertiseClientURLs = "http://" + ip + ":2379"
+	DefaultInitialAdvertisePeerURLs = "http://" + net.JoinHostPort(ip, "2380")
+	DefaultAdvertiseClientURLs = "http://" + net.JoinHostPort(ip, "2379")
 	defaultHostname = ip
 	defaultHostname = ip
 }
 }
 
 

+ 17 - 10
pkg/netutil/isolate_linux.go

@@ -43,7 +43,7 @@ func RecoverPort(port int) error {
 
 
 // SetLatency adds latency in millisecond scale with random variations.
 // SetLatency adds latency in millisecond scale with random variations.
 func SetLatency(ms, rv int) error {
 func SetLatency(ms, rv int) error {
-	ifce, err := GetDefaultInterface()
+	ifces, err := GetDefaultInterfaces()
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -51,14 +51,16 @@ func SetLatency(ms, rv int) error {
 	if rv > ms {
 	if rv > ms {
 		rv = 1
 		rv = 1
 	}
 	}
-	cmdStr := fmt.Sprintf("sudo tc qdisc add dev %s root netem delay %dms %dms distribution normal", ifce, ms, rv)
-	_, err = exec.Command("/bin/sh", "-c", cmdStr).Output()
-	if err != nil {
-		// the rule has already been added. Overwrite it.
-		cmdStr = fmt.Sprintf("sudo tc qdisc change dev %s root netem delay %dms %dms distribution normal", ifce, ms, rv)
+	for ifce := range ifces {
+		cmdStr := fmt.Sprintf("sudo tc qdisc add dev %s root netem delay %dms %dms distribution normal", ifce, ms, rv)
 		_, err = exec.Command("/bin/sh", "-c", cmdStr).Output()
 		_, err = exec.Command("/bin/sh", "-c", cmdStr).Output()
 		if err != nil {
 		if err != nil {
-			return err
+			// the rule has already been added. Overwrite it.
+			cmdStr = fmt.Sprintf("sudo tc qdisc change dev %s root netem delay %dms %dms distribution normal", ifce, ms, rv)
+			_, err = exec.Command("/bin/sh", "-c", cmdStr).Output()
+			if err != nil {
+				return err
+			}
 		}
 		}
 	}
 	}
 	return nil
 	return nil
@@ -66,10 +68,15 @@ func SetLatency(ms, rv int) error {
 
 
 // RemoveLatency resets latency configurations.
 // RemoveLatency resets latency configurations.
 func RemoveLatency() error {
 func RemoveLatency() error {
-	ifce, err := GetDefaultInterface()
+	ifces, err := GetDefaultInterfaces()
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	_, err = exec.Command("/bin/sh", "-c", fmt.Sprintf("sudo tc qdisc del dev %s root netem", ifce)).Output()
-	return err
+	for ifce := range ifces {
+		_, err = exec.Command("/bin/sh", "-c", fmt.Sprintf("sudo tc qdisc del dev %s root netem", ifce)).Output()
+		if err != nil {
+			return err
+		}
+	}
+	return nil
 }
 }

+ 97 - 46
pkg/netutil/routes_linux.go

@@ -27,42 +27,49 @@ import (
 )
 )
 
 
 var errNoDefaultRoute = fmt.Errorf("could not find default route")
 var errNoDefaultRoute = fmt.Errorf("could not find default route")
+var errNoDefaultHost = fmt.Errorf("could not find default host")
+var errNoDefaultInterface = fmt.Errorf("could not find default interface")
 
 
+// GetDefaultHost obtains the first IP address of machine from the routing table and returns the IP address as string.
+// An IPv4 address is preferred to an IPv6 address for backward compatibility.
 func GetDefaultHost() (string, error) {
 func GetDefaultHost() (string, error) {
-	rmsg, rerr := getDefaultRoute()
+	rmsgs, rerr := getDefaultRoutes()
 	if rerr != nil {
 	if rerr != nil {
 		return "", rerr
 		return "", rerr
 	}
 	}
 
 
-	host, oif, err := parsePREFSRC(rmsg)
-	if err != nil {
-		return "", err
-	}
-	if host != "" {
-		return host, nil
-	}
+	for family, rmsg := range rmsgs {
+		host, oif, err := parsePREFSRC(rmsg)
+		if err != nil {
+			return "", err
+		}
+		if host != "" {
+			return host, nil
+		}
 
 
-	// prefsrc not detected, fall back to getting address from iface
-	ifmsg, ierr := getIface(oif)
-	if ierr != nil {
-		return "", ierr
-	}
+		// prefsrc not detected, fall back to getting address from iface
+		ifmsg, ierr := getIfaceAddr(oif, family)
+		if ierr != nil {
+			return "", ierr
+		}
 
 
-	attrs, aerr := syscall.ParseNetlinkRouteAttr(ifmsg)
-	if aerr != nil {
-		return "", aerr
-	}
+		attrs, aerr := syscall.ParseNetlinkRouteAttr(ifmsg)
+		if aerr != nil {
+			return "", aerr
+		}
 
 
-	for _, attr := range attrs {
-		if attr.Attr.Type == syscall.RTA_SRC {
-			return net.IP(attr.Value).String(), nil
+		for _, attr := range attrs {
+			// search for RTA_DST because ipv6 doesn't have RTA_SRC
+			if attr.Attr.Type == syscall.RTA_DST {
+				return net.IP(attr.Value).String(), nil
+			}
 		}
 		}
 	}
 	}
 
 
-	return "", errNoDefaultRoute
+	return "", errNoDefaultHost
 }
 }
 
 
-func getDefaultRoute() (*syscall.NetlinkMessage, error) {
+func getDefaultRoutes() (map[uint8]*syscall.NetlinkMessage, error) {
 	dat, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC)
 	dat, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -73,6 +80,7 @@ func getDefaultRoute() (*syscall.NetlinkMessage, error) {
 		return nil, msgErr
 		return nil, msgErr
 	}
 	}
 
 
+	routes := make(map[uint8]*syscall.NetlinkMessage)
 	rtmsg := syscall.RtMsg{}
 	rtmsg := syscall.RtMsg{}
 	for _, m := range msgs {
 	for _, m := range msgs {
 		if m.Header.Type != syscall.RTM_NEWROUTE {
 		if m.Header.Type != syscall.RTM_NEWROUTE {
@@ -82,17 +90,23 @@ func getDefaultRoute() (*syscall.NetlinkMessage, error) {
 		if rerr := binary.Read(buf, cpuutil.ByteOrder(), &rtmsg); rerr != nil {
 		if rerr := binary.Read(buf, cpuutil.ByteOrder(), &rtmsg); rerr != nil {
 			continue
 			continue
 		}
 		}
-		if rtmsg.Dst_len == 0 {
+		if rtmsg.Dst_len == 0 && rtmsg.Table == syscall.RT_TABLE_MAIN {
 			// zero-length Dst_len implies default route
 			// zero-length Dst_len implies default route
-			return &m, nil
+			msg := m
+			routes[rtmsg.Family] = &msg
 		}
 		}
 	}
 	}
 
 
+	if len(routes) > 0 {
+		return routes, nil
+	}
+
 	return nil, errNoDefaultRoute
 	return nil, errNoDefaultRoute
 }
 }
 
 
-func getIface(idx uint32) (*syscall.NetlinkMessage, error) {
-	dat, err := syscall.NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_UNSPEC)
+// Used to get an address of interface.
+func getIfaceAddr(idx uint32, family uint8) (*syscall.NetlinkMessage, error) {
+	dat, err := syscall.NetlinkRIB(syscall.RTM_GETADDR, int(family))
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -116,38 +130,75 @@ func getIface(idx uint32) (*syscall.NetlinkMessage, error) {
 		}
 		}
 	}
 	}
 
 
-	return nil, errNoDefaultRoute
-}
+	return nil, fmt.Errorf("could not find address for interface index %v", idx)
 
 
-var errNoDefaultInterface = fmt.Errorf("could not find default interface")
+}
 
 
-func GetDefaultInterface() (string, error) {
-	rmsg, rerr := getDefaultRoute()
-	if rerr != nil {
-		return "", rerr
+// Used to get a name of interface.
+func getIfaceLink(idx uint32) (*syscall.NetlinkMessage, error) {
+	dat, err := syscall.NetlinkRIB(syscall.RTM_GETLINK, syscall.AF_UNSPEC)
+	if err != nil {
+		return nil, err
 	}
 	}
 
 
-	_, oif, err := parsePREFSRC(rmsg)
-	if err != nil {
-		return "", err
+	msgs, msgErr := syscall.ParseNetlinkMessage(dat)
+	if msgErr != nil {
+		return nil, msgErr
 	}
 	}
 
 
-	ifmsg, ierr := getIface(oif)
-	if ierr != nil {
-		return "", ierr
+	ifinfomsg := syscall.IfInfomsg{}
+	for _, m := range msgs {
+		if m.Header.Type != syscall.RTM_NEWLINK {
+			continue
+		}
+		buf := bytes.NewBuffer(m.Data[:syscall.SizeofIfInfomsg])
+		if rerr := binary.Read(buf, cpuutil.ByteOrder(), &ifinfomsg); rerr != nil {
+			continue
+		}
+		if ifinfomsg.Index == int32(idx) {
+			return &m, nil
+		}
 	}
 	}
 
 
-	attrs, aerr := syscall.ParseNetlinkRouteAttr(ifmsg)
-	if aerr != nil {
-		return "", aerr
+	return nil, fmt.Errorf("could not find link for interface index %v", idx)
+}
+
+// GetDefaultInterfaces gets names of interfaces and returns a map[interface]families.
+func GetDefaultInterfaces() (map[string]uint8, error) {
+	interfaces := make(map[string]uint8)
+	rmsgs, rerr := getDefaultRoutes()
+	if rerr != nil {
+		return interfaces, rerr
 	}
 	}
 
 
-	for _, attr := range attrs {
-		if attr.Attr.Type == syscall.IFLA_IFNAME {
-			return string(attr.Value[:len(attr.Value)-1]), nil
+	for family, rmsg := range rmsgs {
+		_, oif, err := parsePREFSRC(rmsg)
+		if err != nil {
+			return interfaces, err
+		}
+
+		ifmsg, ierr := getIfaceLink(oif)
+		if ierr != nil {
+			return interfaces, ierr
 		}
 		}
+
+		attrs, aerr := syscall.ParseNetlinkRouteAttr(ifmsg)
+		if aerr != nil {
+			return interfaces, aerr
+		}
+
+		for _, attr := range attrs {
+			if attr.Attr.Type == syscall.IFLA_IFNAME {
+				// key is an interface name
+				// possible values: 2 - AF_INET, 10 - AF_INET6, 12 - dualstack
+				interfaces[string(attr.Value[:len(attr.Value)-1])] += family
+			}
+		}
+	}
+	if len(interfaces) > 0 {
+		return interfaces, nil
 	}
 	}
-	return "", errNoDefaultInterface
+	return interfaces, errNoDefaultInterface
 }
 }
 
 
 // parsePREFSRC returns preferred source address and output interface index (RTA_OIF).
 // parsePREFSRC returns preferred source address and output interface index (RTA_OIF).

+ 10 - 2
pkg/netutil/routes_linux_test.go

@@ -19,9 +19,17 @@ package netutil
 import "testing"
 import "testing"
 
 
 func TestGetDefaultInterface(t *testing.T) {
 func TestGetDefaultInterface(t *testing.T) {
-	ifc, err := GetDefaultInterface()
+	ifc, err := GetDefaultInterfaces()
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	t.Logf("default network interface: %q\n", ifc)
+	t.Logf("default network interfaces: %+v\n", ifc)
+}
+
+func TestGetDefaultHost(t *testing.T) {
+	ip, err := GetDefaultHost()
+	if err != nil {
+		t.Fatal(err)
+	}
+	t.Logf("default ip: %v", ip)
 }
 }