瀏覽代碼

Added ConnectTimeout for initial connections to nodes

Derrick J. Wippler 8 年之前
父節點
當前提交
baa79dc6f5
共有 6 個文件被更改,包括 97 次插入26 次删除
  1. 1 0
      AUTHORS
  2. 2 0
      cluster.go
  3. 11 10
      conn.go
  4. 59 6
      conn_test.go
  5. 10 9
      connectionpool.go
  6. 14 1
      logger.go

+ 1 - 0
AUTHORS

@@ -85,3 +85,4 @@ Nathan Davies <nathanjamesdavies@gmail.com>
 Bo Blanton <bo.blanton@gmail.com>
 Vincent Rischmann <me@vrischmann.me>
 Jesse Claven <jesse.claven@gmail.com>
+Derrick Wippler <thrawn01@gmail.com>

+ 2 - 0
cluster.go

@@ -46,6 +46,7 @@ type ClusterConfig struct {
 	// versions the protocol selected is not defined (ie, it can be any of the supported in the cluster)
 	ProtoVersion      int
 	Timeout           time.Duration     // connection timeout (default: 600ms)
+	ConnectTimeout    time.Duration     // initial connection timeout, used during initial dial to server (default: 600ms)
 	Port              int               // port (default: 9042)
 	Keyspace          string            // initial keyspace (optional)
 	NumConns          int               // number of connections per host (default: 2)
@@ -132,6 +133,7 @@ func NewCluster(hosts ...string) *ClusterConfig {
 		Hosts:                  hosts,
 		CQLVersion:             "3.0.0",
 		Timeout:                600 * time.Millisecond,
+		ConnectTimeout:         600 * time.Millisecond,
 		Port:                   9042,
 		NumConns:               2,
 		Consistency:            Quorum,

+ 11 - 10
conn.go

@@ -93,13 +93,14 @@ type SslOptions struct {
 }
 
 type ConnConfig struct {
-	ProtoVersion  int
-	CQLVersion    string
-	Timeout       time.Duration
-	Compressor    Compressor
-	Authenticator Authenticator
-	Keepalive     time.Duration
-	tlsConfig     *tls.Config
+	ProtoVersion   int
+	CQLVersion     string
+	Timeout        time.Duration
+	ConnectTimeout time.Duration
+	Compressor     Compressor
+	Authenticator  Authenticator
+	Keepalive      time.Duration
+	tlsConfig      *tls.Config
 }
 
 type ConnErrorHandler interface {
@@ -167,7 +168,7 @@ func Connect(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler, ses
 	)
 
 	dialer := &net.Dialer{
-		Timeout: cfg.Timeout,
+		Timeout: cfg.ConnectTimeout,
 	}
 
 	// TODO(zariel): handle ipv6 zone
@@ -212,8 +213,8 @@ func Connect(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler, ses
 		ctx    context.Context
 		cancel func()
 	)
-	if c.timeout > 0 {
-		ctx, cancel = context.WithTimeout(context.Background(), c.timeout)
+	if cfg.ConnectTimeout > 0 {
+		ctx, cancel = context.WithTimeout(context.Background(), cfg.ConnectTimeout)
 	} else {
 		ctx, cancel = context.WithCancel(context.Background())
 	}

+ 59 - 6
conn_test.go

@@ -140,6 +140,50 @@ func newTestSession(addr string, proto protoVersion) (*Session, error) {
 	return testCluster(addr, proto).CreateSession()
 }
 
+func TestStartupTimeout(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+	log := &testLogger{}
+	Logger = log
+	defer func() {
+		Logger = &defaultLogger{}
+	}()
+
+	srv := NewTestServer(t, defaultProto, ctx)
+	defer srv.Stop()
+
+	// Tell the server to never respond to Startup frame
+	atomic.StoreInt32(&srv.TimeoutOnStartup, 1)
+
+	startTime := time.Now()
+	cluster := NewCluster(srv.Address)
+	cluster.ProtoVersion = int(defaultProto)
+	cluster.disableControlConn = true
+	// Set very long query connection timeout
+	// so we know CreateSession() is using the ConnectTimeout
+	cluster.Timeout = time.Second * 5
+
+	// Create session should timeout during connect attempt
+	_, err := cluster.CreateSession()
+	if err == nil {
+		t.Fatal("CreateSession() should have returned a timeout error")
+	}
+
+	elapsed := time.Since(startTime)
+	if elapsed > time.Second*5 {
+		t.Fatal("ConnectTimeout is not respected")
+	}
+
+	if !strings.Contains(err.Error(), "no connections were made when creating the session") {
+		t.Fatalf("Expected to receive no connections error - got '%s'", err)
+	}
+
+	if !strings.Contains(log.String(), "no response to connection startup within timeout") {
+		t.Fatalf("Expected to receive timeout log message  - got '%s'", log.String())
+	}
+
+	cancel()
+}
+
 func TestTimeout(t *testing.T) {
 	ctx, cancel := context.WithCancel(context.Background())
 
@@ -619,12 +663,13 @@ func NewSSLTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestSe
 }
 
 type TestServer struct {
-	Address    string
-	t          testing.TB
-	nreq       uint64
-	listen     net.Listener
-	nKillReq   int64
-	compressor Compressor
+	Address          string
+	TimeoutOnStartup int32
+	t                testing.TB
+	nreq             uint64
+	listen           net.Listener
+	nKillReq         int64
+	compressor       Compressor
 
 	protocol   byte
 	headerSize int
@@ -738,6 +783,14 @@ func (srv *TestServer) process(f *framer) {
 
 	switch head.op {
 	case opStartup:
+		if atomic.LoadInt32(&srv.TimeoutOnStartup) > 0 {
+			// Do not respond to startup command
+			// wait until we get a cancel signal
+			select {
+			case <-srv.ctx.Done():
+				return
+			}
+		}
 		f.writeHeader(0, opReady, head.stream)
 	case opOptions:
 		f.writeHeader(0, opSupported, head.stream)

+ 10 - 9
connectionpool.go

@@ -85,13 +85,14 @@ func connConfig(cfg *ClusterConfig) (*ConnConfig, error) {
 	}
 
 	return &ConnConfig{
-		ProtoVersion:  cfg.ProtoVersion,
-		CQLVersion:    cfg.CQLVersion,
-		Timeout:       cfg.Timeout,
-		Compressor:    cfg.Compressor,
-		Authenticator: cfg.Authenticator,
-		Keepalive:     cfg.SocketKeepalive,
-		tlsConfig:     tlsConfig,
+		ProtoVersion:   cfg.ProtoVersion,
+		CQLVersion:     cfg.CQLVersion,
+		Timeout:        cfg.Timeout,
+		ConnectTimeout: cfg.ConnectTimeout,
+		Compressor:     cfg.Compressor,
+		Authenticator:  cfg.Authenticator,
+		Keepalive:      cfg.SocketKeepalive,
+		tlsConfig:      tlsConfig,
 	}, nil
 }
 
@@ -395,8 +396,8 @@ func (pool *hostConnPool) fill() {
 			// probably unreachable host
 			pool.fillingStopped(true)
 
-			// this is calle with the connetion pool mutex held, this call will
-			// then recursivly try to lock it again. FIXME
+			// this is call with the connection pool mutex held, this call will
+			// then recursively try to lock it again. FIXME
 			go pool.session.handleNodeDown(pool.host.Peer(), pool.port)
 			return
 		}

+ 14 - 1
logger.go

@@ -1,6 +1,10 @@
 package gocql
 
-import "log"
+import (
+	"bytes"
+	"fmt"
+	"log"
+)
 
 type StdLogger interface {
 	Print(v ...interface{})
@@ -8,6 +12,15 @@ type StdLogger interface {
 	Println(v ...interface{})
 }
 
+type testLogger struct {
+	capture bytes.Buffer
+}
+
+func (l *testLogger) Print(v ...interface{})                 { fmt.Fprint(&l.capture, v...) }
+func (l *testLogger) Printf(format string, v ...interface{}) { fmt.Fprintf(&l.capture, format, v...) }
+func (l *testLogger) Println(v ...interface{})               { fmt.Fprintln(&l.capture, v...) }
+func (l *testLogger) String() string                         { return l.capture.String() }
+
 type defaultLogger struct{}
 
 func (l *defaultLogger) Print(v ...interface{})                 { log.Print(v...) }