瀏覽代碼

conn: pass context through startup chain

The startup procoss is very concurrent, and can fail in different
places. To ensure we dont end up blocking somewhere indefinetly make use
a context.

Run the startup loop concurrently whilst waiting for an error from recv
so that we get the recv error instead of just timing out waiting for
startup to finish.
Chris Bannister 9 年之前
父節點
當前提交
4780c25d06
共有 1 個文件被更改,包括 37 次插入16 次删除
  1. 37 16
      conn.go

+ 37 - 16
conn.go

@@ -202,33 +202,44 @@ func Connect(host *HostInfo, addr string, cfg *ConnConfig,
 		c.setKeepalive(cfg.Keepalive)
 	}
 
+	ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
+	defer cancel()
+
 	frameTicker := make(chan struct{}, 1)
 	startupErr := make(chan error, 1)
 	go func() {
 		for range frameTicker {
 			err := c.recv()
-			startupErr <- err
-			if err != nil {
+			select {
+			case startupErr <- err:
+				if err != nil {
+					return
+				}
+			case <-ctx.Done():
 				return
 			}
 		}
 	}()
 
-	err = c.startup(frameTicker)
-	close(frameTicker)
-	if err != nil {
-		conn.Close()
-		return nil, err
-	}
+	go func() {
+		defer close(frameTicker)
+		err := c.startup(ctx, frameTicker)
+		if err != nil {
+			select {
+			case startupErr <- err:
+			case <-ctx.Done():
+				return
+			}
+		}
+	}()
 
 	select {
 	case err := <-startupErr:
 		if err != nil {
-			log.Println(err)
 			c.Close()
 			return nil, err
 		}
-	case <-time.After(c.timeout):
+	case <-ctx.Done():
 		c.Close()
 		return nil, errors.New("gocql: no response to connection startup within timeout")
 	}
@@ -269,7 +280,7 @@ func (c *Conn) Read(p []byte) (n int, err error) {
 	return
 }
 
-func (c *Conn) startup(frameTicker chan struct{}) error {
+func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error {
 	m := map[string]string{
 		"CQL_VERSION": c.cfg.CQLVersion,
 	}
@@ -278,8 +289,13 @@ func (c *Conn) startup(frameTicker chan struct{}) error {
 		m["COMPRESSION"] = c.compressor.Name()
 	}
 
-	frameTicker <- struct{}{}
-	framer, err := c.exec(context.Background(), &writeStartupFrame{opts: m}, nil)
+	select {
+	case frameTicker <- struct{}{}:
+	case <-ctx.Done():
+		return ctx.Err()
+	}
+
+	framer, err := c.exec(ctx, &writeStartupFrame{opts: m}, nil)
 	if err != nil {
 		return err
 	}
@@ -295,13 +311,13 @@ func (c *Conn) startup(frameTicker chan struct{}) error {
 	case *readyFrame:
 		return nil
 	case *authenticateFrame:
-		return c.authenticateHandshake(v, frameTicker)
+		return c.authenticateHandshake(ctx, v, frameTicker)
 	default:
 		return NewErrProtocol("Unknown type of response to startup frame: %s", v)
 	}
 }
 
-func (c *Conn) authenticateHandshake(authFrame *authenticateFrame, frameTicker chan struct{}) error {
+func (c *Conn) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, frameTicker chan struct{}) error {
 	if c.auth == nil {
 		return fmt.Errorf("authentication required (using %q)", authFrame.class)
 	}
@@ -314,7 +330,12 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame, frameTicker c
 	req := &writeAuthResponseFrame{data: resp}
 
 	for {
-		frameTicker <- struct{}{}
+		select {
+		case frameTicker <- struct{}{}:
+		case <-ctx.Done():
+			return ctx.Err()
+		}
+
 		framer, err := c.exec(context.Background(), req, nil)
 		if err != nil {
 			return err