Browse Source

Merge pull request #6253 from heyitsanthony/srv-arec

discovery: reject IP address records in SRVGetCluster
Anthony Romano 9 years ago
parent
commit
1c989edb47
2 changed files with 43 additions and 45 deletions
  1. 7 4
      discovery/srv.go
  2. 36 41
      discovery/srv_test.go

+ 7 - 4
discovery/srv.go

@@ -55,8 +55,8 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
 			return err
 		}
 		for _, srv := range addrs {
-			target := strings.TrimSuffix(srv.Target, ".")
-			host := net.JoinHostPort(target, fmt.Sprintf("%d", srv.Port))
+			port := fmt.Sprintf("%d", srv.Port)
+			host := net.JoinHostPort(srv.Target, port)
 			tcpAddr, err := resolveTCPAddr("tcp", host)
 			if err != nil {
 				plog.Warningf("couldn't resolve host %s during SRV discovery", host)
@@ -72,8 +72,11 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
 				n = fmt.Sprintf("%d", tempName)
 				tempName++
 			}
-			stringParts = append(stringParts, fmt.Sprintf("%s=%s%s", n, prefix, host))
-			plog.Noticef("got bootstrap from DNS for %s at %s%s", service, prefix, host)
+			// SRV records have a trailing dot but URL shouldn't.
+			shortHost := strings.TrimSuffix(srv.Target, ".")
+			urlHost := net.JoinHostPort(shortHost, port)
+			stringParts = append(stringParts, fmt.Sprintf("%s=%s%s", n, prefix, urlHost))
+			plog.Noticef("got bootstrap from DNS for %s at %s%s", service, prefix, urlHost)
 		}
 		return nil
 	}

+ 36 - 41
discovery/srv_test.go

@@ -17,6 +17,7 @@ package discovery
 import (
 	"errors"
 	"net"
+	"strings"
 	"testing"
 
 	"github.com/coreos/etcd/pkg/testutil"
@@ -29,11 +30,22 @@ func TestSRVGetCluster(t *testing.T) {
 	}()
 
 	name := "dnsClusterTest"
+	dns := map[string]string{
+		"1.example.com.:2480": "10.0.0.1:2480",
+		"2.example.com.:2480": "10.0.0.2:2480",
+		"3.example.com.:2480": "10.0.0.3:2480",
+		"4.example.com.:2380": "10.0.0.3:2380",
+	}
+	srvAll := []*net.SRV{
+		{Target: "1.example.com.", Port: 2480},
+		{Target: "2.example.com.", Port: 2480},
+		{Target: "3.example.com.", Port: 2480},
+	}
+
 	tests := []struct {
 		withSSL    []*net.SRV
 		withoutSSL []*net.SRV
 		urls       []string
-		dns        map[string]string
 
 		expected string
 	}{
@@ -41,61 +53,50 @@ func TestSRVGetCluster(t *testing.T) {
 			[]*net.SRV{},
 			[]*net.SRV{},
 			nil,
-			nil,
 
 			"",
 		},
 		{
-			[]*net.SRV{
-				{Target: "10.0.0.1", Port: 2480},
-				{Target: "10.0.0.2", Port: 2480},
-				{Target: "10.0.0.3", Port: 2480},
-			},
+			srvAll,
 			[]*net.SRV{},
 			nil,
-			nil,
 
-			"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480",
+			"0=https://1.example.com:2480,1=https://2.example.com:2480,2=https://3.example.com:2480",
 		},
 		{
-			[]*net.SRV{
-				{Target: "10.0.0.1", Port: 2480},
-				{Target: "10.0.0.2", Port: 2480},
-				{Target: "10.0.0.3", Port: 2480},
-			},
-			[]*net.SRV{
-				{Target: "10.0.0.1", Port: 2380},
-			},
-			nil,
+			srvAll,
+			[]*net.SRV{{Target: "4.example.com.", Port: 2380}},
 			nil,
-			"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480,3=http://10.0.0.1:2380",
+
+			"0=https://1.example.com:2480,1=https://2.example.com:2480,2=https://3.example.com:2480,3=http://4.example.com:2380",
 		},
 		{
-			[]*net.SRV{
-				{Target: "10.0.0.1", Port: 2480},
-				{Target: "10.0.0.2", Port: 2480},
-				{Target: "10.0.0.3", Port: 2480},
-			},
-			[]*net.SRV{
-				{Target: "10.0.0.1", Port: 2380},
-			},
+			srvAll,
+			[]*net.SRV{{Target: "4.example.com.", Port: 2380}},
 			[]string{"https://10.0.0.1:2480"},
-			nil,
-			"dnsClusterTest=https://10.0.0.1:2480,0=https://10.0.0.2:2480,1=https://10.0.0.3:2480,2=http://10.0.0.1:2380",
+
+			"dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480,2=http://4.example.com:2380",
 		},
 		// matching local member with resolved addr and return unresolved hostnames
 		{
-			[]*net.SRV{
-				{Target: "1.example.com.", Port: 2480},
-				{Target: "2.example.com.", Port: 2480},
-				{Target: "3.example.com.", Port: 2480},
-			},
+			srvAll,
 			nil,
 			[]string{"https://10.0.0.1:2480"},
-			map[string]string{"1.example.com:2480": "10.0.0.1:2480", "2.example.com:2480": "10.0.0.2:2480", "3.example.com:2480": "10.0.0.3:2480"},
 
 			"dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480",
 		},
+		// invalid
+	}
+
+	resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) {
+		if strings.Contains(addr, "10.0.0.") {
+			// accept IP addresses when resolving apurls
+			return net.ResolveTCPAddr(network, addr)
+		}
+		if dns[addr] == "" {
+			return nil, errors.New("missing dns record")
+		}
+		return net.ResolveTCPAddr(network, dns[addr])
 	}
 
 	for i, tt := range tests {
@@ -108,12 +109,6 @@ func TestSRVGetCluster(t *testing.T) {
 			}
 			return "", nil, errors.New("Unknown service in mock")
 		}
-		resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) {
-			if tt.dns == nil || tt.dns[addr] == "" {
-				return net.ResolveTCPAddr(network, addr)
-			}
-			return net.ResolveTCPAddr(network, tt.dns[addr])
-		}
 		urls := testutil.MustNewURLs(t, tt.urls)
 		str, token, err := SRVGetCluster(name, "example.com", "token", urls)
 		if err != nil {