Bläddra i källkod

CreateSession() will now attempt to resolve all host names (#889)

* CreateSession() will now attempt to resolve all host names before returning a DNS error

* Added unit test and improved error message
Derrick J. Wippler 8 år sedan
förälder
incheckning
223afc7aa9
2 ändrade filer med 68 tillägg och 4 borttagningar
  1. 56 0
      conn_test.go
  2. 12 4
      session.go

+ 56 - 0
conn_test.go

@@ -141,6 +141,62 @@ func newTestSession(addr string, proto protoVersion) (*Session, error) {
 	return testCluster(addr, proto).CreateSession()
 }
 
+func TestDNSLookupConnected(t *testing.T) {
+	log := &testLogger{}
+	Logger = log
+	defer func() {
+		Logger = &defaultLogger{}
+	}()
+
+	srv := NewTestServer(t, defaultProto, context.Background())
+	defer srv.Stop()
+
+	cluster := NewCluster("cassandra1.invalid", srv.Address, "cassandra2.invalid")
+	cluster.ProtoVersion = int(defaultProto)
+	cluster.disableControlConn = true
+
+	// CreateSession() should attempt to resolve the DNS name "cassandraX.invalid"
+	// and fail, but continue to connect via srv.Address
+	_, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal("CreateSession() should have connected")
+	}
+
+	if !strings.Contains(log.String(), "gocql: dns error") {
+		t.Fatalf("Expected to receive dns error log message  - got '%s' instead", log.String())
+	}
+}
+
+func TestDNSLookupError(t *testing.T) {
+	log := &testLogger{}
+	Logger = log
+	defer func() {
+		Logger = &defaultLogger{}
+	}()
+
+	srv := NewTestServer(t, defaultProto, context.Background())
+	defer srv.Stop()
+
+	cluster := NewCluster("cassandra1.invalid", "cassandra2.invalid")
+	cluster.ProtoVersion = int(defaultProto)
+	cluster.disableControlConn = true
+
+	// CreateSession() should attempt to resolve each DNS name "cassandraX.invalid"
+	// and fail since it could not resolve any dns entries
+	_, err := cluster.CreateSession()
+	if err == nil {
+		t.Fatal("CreateSession() should have returned an error")
+	}
+
+	if !strings.Contains(log.String(), "gocql: dns error") {
+		t.Fatalf("Expected to receive dns error log message  - got '%s' instead", log.String())
+	}
+
+	if err.Error() != "gocql: unable to create session: failed to resolve any of the provided hostnames" {
+		t.Fatalf("Expected CreateSession() to fail with message  - got '%s' instead", err.Error())
+	}
+}
+
 func TestStartupTimeout(t *testing.T) {
 	ctx, cancel := context.WithCancel(context.Background())
 	log := &testLogger{}

+ 12 - 4
session.go

@@ -10,6 +10,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"net"
 	"strings"
 	"sync"
 	"sync/atomic"
@@ -76,16 +77,23 @@ var queryPool = &sync.Pool{
 }
 
 func addrsToHosts(addrs []string, defaultPort int) ([]*HostInfo, error) {
-	hosts := make([]*HostInfo, len(addrs))
-	for i, hostport := range addrs {
+	var hosts []*HostInfo
+	for _, hostport := range addrs {
 		host, err := hostInfo(hostport, defaultPort)
 		if err != nil {
+			// Try other hosts if unable to resolve DNS name
+			if _, ok := err.(*net.DNSError); ok {
+				Logger.Printf("gocql: dns error: %v\n", err)
+				continue
+			}
 			return nil, err
 		}
 
-		hosts[i] = host
+		hosts = append(hosts, host)
+	}
+	if len(hosts) == 0 {
+		return nil, errors.New("failed to resolve any of the provided hostnames")
 	}
-
 	return hosts, nil
 }