Browse Source

discovery: add a test case for srv

During srv discovery, it should try to match local member with
resolved addr and return unresolved hostnames for the cluster.
Xiang Li 10 years ago
parent
commit
f5d4c86153
2 changed files with 36 additions and 5 deletions
  1. 4 3
      discovery/srv.go
  2. 32 2
      discovery/srv_test.go

+ 4 - 3
discovery/srv.go

@@ -25,7 +25,8 @@ import (
 
 var (
 	// indirection for testing
-	lookupSRV = net.LookupSRV
+	lookupSRV      = net.LookupSRV
+	resolveTCPAddr = net.ResolveTCPAddr
 )
 
 // TODO(barakmich): Currently ignores priority and weight (as they don't make as much sense for a bootstrap)
@@ -38,7 +39,7 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
 
 	// First, resolve the apurls
 	for _, url := range apurls {
-		tcpAddr, err := net.ResolveTCPAddr("tcp", url.Host)
+		tcpAddr, err := resolveTCPAddr("tcp", url.Host)
 		if err != nil {
 			log.Printf("discovery: Couldn't resolve host %s during SRV discovery", url.Host)
 			return "", "", err
@@ -53,7 +54,7 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
 		}
 		for _, srv := range addrs {
 			host := net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port))
-			tcpAddr, err := net.ResolveTCPAddr("tcp", host)
+			tcpAddr, err := resolveTCPAddr("tcp", host)
 			if err != nil {
 				log.Printf("discovery: Couldn't resolve host %s during SRV discovery", host)
 				continue

+ 32 - 2
discovery/srv_test.go

@@ -23,19 +23,26 @@ import (
 )
 
 func TestSRVGetCluster(t *testing.T) {
-	defer func() { lookupSRV = net.LookupSRV }()
+	defer func() {
+		lookupSRV = net.LookupSRV
+		resolveTCPAddr = net.ResolveTCPAddr
+	}()
 
 	name := "dnsClusterTest"
 	tests := []struct {
 		withSSL    []*net.SRV
 		withoutSSL []*net.SRV
 		urls       []string
-		expected   string
+		dns        map[string]string
+
+		expected string
 	}{
 		{
 			[]*net.SRV{},
 			[]*net.SRV{},
 			nil,
+			nil,
+
 			"",
 		},
 		{
@@ -46,6 +53,8 @@ func TestSRVGetCluster(t *testing.T) {
 			},
 			[]*net.SRV{},
 			nil,
+			nil,
+
 			"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480",
 		},
 		{
@@ -58,6 +67,7 @@ func TestSRVGetCluster(t *testing.T) {
 				&net.SRV{Target: "10.0.0.1", Port: 2380},
 			},
 			nil,
+			nil,
 			"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480,3=http://10.0.0.1:2380",
 		},
 		{
@@ -70,8 +80,22 @@ func TestSRVGetCluster(t *testing.T) {
 				&net.SRV{Target: "10.0.0.1", Port: 2380},
 			},
 			[]string{"https://10.0.0.1:2480"},
+			nil,
 			"dnsClusterTest=https://10.0.0.1:2480,0=https://10.0.0.2:2480,1=https://10.0.0.3:2480,2=http://10.0.0.1:2380",
 		},
+		// matching local member with resolved addr and return unresolved hostnames
+		{
+			[]*net.SRV{
+				&net.SRV{Target: "1.example.com.", Port: 2480},
+				&net.SRV{Target: "2.example.com.", Port: 2480},
+				&net.SRV{Target: "3.example.com.", Port: 2480},
+			},
+			nil,
+			[]string{"https://10.0.0.1:2480"},
+			map[string]string{"1.example.com:2480": "10.0.0.1:2480", "2.example.com:2480": "10.0.0.2:2480", "3.example.com:2480": "10.0.0.3:2480"},
+
+			"dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480",
+		},
 	}
 
 	for i, tt := range tests {
@@ -84,6 +108,12 @@ func TestSRVGetCluster(t *testing.T) {
 			}
 			return "", nil, errors.New("Unkown service in mock")
 		}
+		resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) {
+			if tt.dns == nil || tt.dns[addr] == "" {
+				return net.ResolveTCPAddr(network, addr)
+			}
+			return net.ResolveTCPAddr(network, tt.dns[addr])
+		}
 		urls := testutil.MustNewURLs(t, tt.urls)
 		str, token, err := SRVGetCluster(name, "example.com", "token", urls)
 		if err != nil {