Browse Source

discovery: add a test case for srv

During srv discovery, it should try to match local member with
resolved addr and return unresolved hostnames for the cluster.
Xiang Li 10 years ago
parent
commit
f5d4c86153
2 changed files with 36 additions and 5 deletions
  1. 4 3
      discovery/srv.go
  2. 32 2
      discovery/srv_test.go

+ 4 - 3
discovery/srv.go

@@ -25,7 +25,8 @@ import (
 
 
 var (
 var (
 	// indirection for testing
 	// indirection for testing
-	lookupSRV = net.LookupSRV
+	lookupSRV      = net.LookupSRV
+	resolveTCPAddr = net.ResolveTCPAddr
 )
 )
 
 
 // TODO(barakmich): Currently ignores priority and weight (as they don't make as much sense for a bootstrap)
 // TODO(barakmich): Currently ignores priority and weight (as they don't make as much sense for a bootstrap)
@@ -38,7 +39,7 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
 
 
 	// First, resolve the apurls
 	// First, resolve the apurls
 	for _, url := range apurls {
 	for _, url := range apurls {
-		tcpAddr, err := net.ResolveTCPAddr("tcp", url.Host)
+		tcpAddr, err := resolveTCPAddr("tcp", url.Host)
 		if err != nil {
 		if err != nil {
 			log.Printf("discovery: Couldn't resolve host %s during SRV discovery", url.Host)
 			log.Printf("discovery: Couldn't resolve host %s during SRV discovery", url.Host)
 			return "", "", err
 			return "", "", err
@@ -53,7 +54,7 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
 		}
 		}
 		for _, srv := range addrs {
 		for _, srv := range addrs {
 			host := net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port))
 			host := net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port))
-			tcpAddr, err := net.ResolveTCPAddr("tcp", host)
+			tcpAddr, err := resolveTCPAddr("tcp", host)
 			if err != nil {
 			if err != nil {
 				log.Printf("discovery: Couldn't resolve host %s during SRV discovery", host)
 				log.Printf("discovery: Couldn't resolve host %s during SRV discovery", host)
 				continue
 				continue

+ 32 - 2
discovery/srv_test.go

@@ -23,19 +23,26 @@ import (
 )
 )
 
 
 func TestSRVGetCluster(t *testing.T) {
 func TestSRVGetCluster(t *testing.T) {
-	defer func() { lookupSRV = net.LookupSRV }()
+	defer func() {
+		lookupSRV = net.LookupSRV
+		resolveTCPAddr = net.ResolveTCPAddr
+	}()
 
 
 	name := "dnsClusterTest"
 	name := "dnsClusterTest"
 	tests := []struct {
 	tests := []struct {
 		withSSL    []*net.SRV
 		withSSL    []*net.SRV
 		withoutSSL []*net.SRV
 		withoutSSL []*net.SRV
 		urls       []string
 		urls       []string
-		expected   string
+		dns        map[string]string
+
+		expected string
 	}{
 	}{
 		{
 		{
 			[]*net.SRV{},
 			[]*net.SRV{},
 			[]*net.SRV{},
 			[]*net.SRV{},
 			nil,
 			nil,
+			nil,
+
 			"",
 			"",
 		},
 		},
 		{
 		{
@@ -46,6 +53,8 @@ func TestSRVGetCluster(t *testing.T) {
 			},
 			},
 			[]*net.SRV{},
 			[]*net.SRV{},
 			nil,
 			nil,
+			nil,
+
 			"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480",
 			"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480",
 		},
 		},
 		{
 		{
@@ -58,6 +67,7 @@ func TestSRVGetCluster(t *testing.T) {
 				&net.SRV{Target: "10.0.0.1", Port: 2380},
 				&net.SRV{Target: "10.0.0.1", Port: 2380},
 			},
 			},
 			nil,
 			nil,
+			nil,
 			"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480,3=http://10.0.0.1:2380",
 			"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480,3=http://10.0.0.1:2380",
 		},
 		},
 		{
 		{
@@ -70,8 +80,22 @@ func TestSRVGetCluster(t *testing.T) {
 				&net.SRV{Target: "10.0.0.1", Port: 2380},
 				&net.SRV{Target: "10.0.0.1", Port: 2380},
 			},
 			},
 			[]string{"https://10.0.0.1:2480"},
 			[]string{"https://10.0.0.1:2480"},
+			nil,
 			"dnsClusterTest=https://10.0.0.1:2480,0=https://10.0.0.2:2480,1=https://10.0.0.3:2480,2=http://10.0.0.1:2380",
 			"dnsClusterTest=https://10.0.0.1:2480,0=https://10.0.0.2:2480,1=https://10.0.0.3:2480,2=http://10.0.0.1:2380",
 		},
 		},
+		// matching local member with resolved addr and return unresolved hostnames
+		{
+			[]*net.SRV{
+				&net.SRV{Target: "1.example.com.", Port: 2480},
+				&net.SRV{Target: "2.example.com.", Port: 2480},
+				&net.SRV{Target: "3.example.com.", Port: 2480},
+			},
+			nil,
+			[]string{"https://10.0.0.1:2480"},
+			map[string]string{"1.example.com:2480": "10.0.0.1:2480", "2.example.com:2480": "10.0.0.2:2480", "3.example.com:2480": "10.0.0.3:2480"},
+
+			"dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480",
+		},
 	}
 	}
 
 
 	for i, tt := range tests {
 	for i, tt := range tests {
@@ -84,6 +108,12 @@ func TestSRVGetCluster(t *testing.T) {
 			}
 			}
 			return "", nil, errors.New("Unkown service in mock")
 			return "", nil, errors.New("Unkown service in mock")
 		}
 		}
+		resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) {
+			if tt.dns == nil || tt.dns[addr] == "" {
+				return net.ResolveTCPAddr(network, addr)
+			}
+			return net.ResolveTCPAddr(network, tt.dns[addr])
+		}
 		urls := testutil.MustNewURLs(t, tt.urls)
 		urls := testutil.MustNewURLs(t, tt.urls)
 		str, token, err := SRVGetCluster(name, "example.com", "token", urls)
 		str, token, err := SRVGetCluster(name, "example.com", "token", urls)
 		if err != nil {
 		if err != nil {