瀏覽代碼

discovery: add a test case for srv

During srv discovery, it should try to match local member with
resolved addr and return unresolved hostnames for the cluster.
Xiang Li 11 年之前
父節點
當前提交
f5d4c86153
共有 2 個文件被更改,包括 36 次插入5 次删除
  1. 4 3
      discovery/srv.go
  2. 32 2
      discovery/srv_test.go

+ 4 - 3
discovery/srv.go

@@ -25,7 +25,8 @@ import (
 
 var (
 	// indirection for testing
-	lookupSRV = net.LookupSRV
+	lookupSRV      = net.LookupSRV
+	resolveTCPAddr = net.ResolveTCPAddr
 )
 
 // TODO(barakmich): Currently ignores priority and weight (as they don't make as much sense for a bootstrap)
@@ -38,7 +39,7 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
 
 	// First, resolve the apurls
 	for _, url := range apurls {
-		tcpAddr, err := net.ResolveTCPAddr("tcp", url.Host)
+		tcpAddr, err := resolveTCPAddr("tcp", url.Host)
 		if err != nil {
 			log.Printf("discovery: Couldn't resolve host %s during SRV discovery", url.Host)
 			return "", "", err
@@ -53,7 +54,7 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
 		}
 		for _, srv := range addrs {
 			host := net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port))
-			tcpAddr, err := net.ResolveTCPAddr("tcp", host)
+			tcpAddr, err := resolveTCPAddr("tcp", host)
 			if err != nil {
 				log.Printf("discovery: Couldn't resolve host %s during SRV discovery", host)
 				continue

+ 32 - 2
discovery/srv_test.go

@@ -23,19 +23,26 @@ import (
 )
 
 func TestSRVGetCluster(t *testing.T) {
-	defer func() { lookupSRV = net.LookupSRV }()
+	defer func() {
+		lookupSRV = net.LookupSRV
+		resolveTCPAddr = net.ResolveTCPAddr
+	}()
 
 	name := "dnsClusterTest"
 	tests := []struct {
 		withSSL    []*net.SRV
 		withoutSSL []*net.SRV
 		urls       []string
-		expected   string
+		dns        map[string]string
+
+		expected string
 	}{
 		{
 			[]*net.SRV{},
 			[]*net.SRV{},
 			nil,
+			nil,
+
 			"",
 		},
 		{
@@ -46,6 +53,8 @@ func TestSRVGetCluster(t *testing.T) {
 			},
 			[]*net.SRV{},
 			nil,
+			nil,
+
 			"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480",
 		},
 		{
@@ -58,6 +67,7 @@ func TestSRVGetCluster(t *testing.T) {
 				&net.SRV{Target: "10.0.0.1", Port: 2380},
 			},
 			nil,
+			nil,
 			"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480,3=http://10.0.0.1:2380",
 		},
 		{
@@ -70,8 +80,22 @@ func TestSRVGetCluster(t *testing.T) {
 				&net.SRV{Target: "10.0.0.1", Port: 2380},
 			},
 			[]string{"https://10.0.0.1:2480"},
+			nil,
 			"dnsClusterTest=https://10.0.0.1:2480,0=https://10.0.0.2:2480,1=https://10.0.0.3:2480,2=http://10.0.0.1:2380",
 		},
+		// matching local member with resolved addr and return unresolved hostnames
+		{
+			[]*net.SRV{
+				&net.SRV{Target: "1.example.com.", Port: 2480},
+				&net.SRV{Target: "2.example.com.", Port: 2480},
+				&net.SRV{Target: "3.example.com.", Port: 2480},
+			},
+			nil,
+			[]string{"https://10.0.0.1:2480"},
+			map[string]string{"1.example.com:2480": "10.0.0.1:2480", "2.example.com:2480": "10.0.0.2:2480", "3.example.com:2480": "10.0.0.3:2480"},
+
+			"dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480",
+		},
 	}
 
 	for i, tt := range tests {
@@ -84,6 +108,12 @@ func TestSRVGetCluster(t *testing.T) {
 			}
 			return "", nil, errors.New("Unkown service in mock")
 		}
+		resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) {
+			if tt.dns == nil || tt.dns[addr] == "" {
+				return net.ResolveTCPAddr(network, addr)
+			}
+			return net.ResolveTCPAddr(network, tt.dns[addr])
+		}
 		urls := testutil.MustNewURLs(t, tt.urls)
 		str, token, err := SRVGetCluster(name, "example.com", "token", urls)
 		if err != nil {