Browse Source

conn: pass context through startup chain

The startup procoss is very concurrent, and can fail in different
places. To ensure we dont end up blocking somewhere indefinetly make use
a context.

Run the startup loop concurrently whilst waiting for an error from recv
so that we get the recv error instead of just timing out waiting for
startup to finish.
Chris Bannister 9 years ago
parent
commit
4780c25d06
1 changed files with 37 additions and 16 deletions
  1. 37 16
      conn.go

+ 37 - 16
conn.go

@@ -202,33 +202,44 @@ func Connect(host *HostInfo, addr string, cfg *ConnConfig,
 		c.setKeepalive(cfg.Keepalive)
 		c.setKeepalive(cfg.Keepalive)
 	}
 	}
 
 
+	ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
+	defer cancel()
+
 	frameTicker := make(chan struct{}, 1)
 	frameTicker := make(chan struct{}, 1)
 	startupErr := make(chan error, 1)
 	startupErr := make(chan error, 1)
 	go func() {
 	go func() {
 		for range frameTicker {
 		for range frameTicker {
 			err := c.recv()
 			err := c.recv()
-			startupErr <- err
-			if err != nil {
+			select {
+			case startupErr <- err:
+				if err != nil {
+					return
+				}
+			case <-ctx.Done():
 				return
 				return
 			}
 			}
 		}
 		}
 	}()
 	}()
 
 
-	err = c.startup(frameTicker)
-	close(frameTicker)
-	if err != nil {
-		conn.Close()
-		return nil, err
-	}
+	go func() {
+		defer close(frameTicker)
+		err := c.startup(ctx, frameTicker)
+		if err != nil {
+			select {
+			case startupErr <- err:
+			case <-ctx.Done():
+				return
+			}
+		}
+	}()
 
 
 	select {
 	select {
 	case err := <-startupErr:
 	case err := <-startupErr:
 		if err != nil {
 		if err != nil {
-			log.Println(err)
 			c.Close()
 			c.Close()
 			return nil, err
 			return nil, err
 		}
 		}
-	case <-time.After(c.timeout):
+	case <-ctx.Done():
 		c.Close()
 		c.Close()
 		return nil, errors.New("gocql: no response to connection startup within timeout")
 		return nil, errors.New("gocql: no response to connection startup within timeout")
 	}
 	}
@@ -269,7 +280,7 @@ func (c *Conn) Read(p []byte) (n int, err error) {
 	return
 	return
 }
 }
 
 
-func (c *Conn) startup(frameTicker chan struct{}) error {
+func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error {
 	m := map[string]string{
 	m := map[string]string{
 		"CQL_VERSION": c.cfg.CQLVersion,
 		"CQL_VERSION": c.cfg.CQLVersion,
 	}
 	}
@@ -278,8 +289,13 @@ func (c *Conn) startup(frameTicker chan struct{}) error {
 		m["COMPRESSION"] = c.compressor.Name()
 		m["COMPRESSION"] = c.compressor.Name()
 	}
 	}
 
 
-	frameTicker <- struct{}{}
-	framer, err := c.exec(context.Background(), &writeStartupFrame{opts: m}, nil)
+	select {
+	case frameTicker <- struct{}{}:
+	case <-ctx.Done():
+		return ctx.Err()
+	}
+
+	framer, err := c.exec(ctx, &writeStartupFrame{opts: m}, nil)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -295,13 +311,13 @@ func (c *Conn) startup(frameTicker chan struct{}) error {
 	case *readyFrame:
 	case *readyFrame:
 		return nil
 		return nil
 	case *authenticateFrame:
 	case *authenticateFrame:
-		return c.authenticateHandshake(v, frameTicker)
+		return c.authenticateHandshake(ctx, v, frameTicker)
 	default:
 	default:
 		return NewErrProtocol("Unknown type of response to startup frame: %s", v)
 		return NewErrProtocol("Unknown type of response to startup frame: %s", v)
 	}
 	}
 }
 }
 
 
-func (c *Conn) authenticateHandshake(authFrame *authenticateFrame, frameTicker chan struct{}) error {
+func (c *Conn) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, frameTicker chan struct{}) error {
 	if c.auth == nil {
 	if c.auth == nil {
 		return fmt.Errorf("authentication required (using %q)", authFrame.class)
 		return fmt.Errorf("authentication required (using %q)", authFrame.class)
 	}
 	}
@@ -314,7 +330,12 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame, frameTicker c
 	req := &writeAuthResponseFrame{data: resp}
 	req := &writeAuthResponseFrame{data: resp}
 
 
 	for {
 	for {
-		frameTicker <- struct{}{}
+		select {
+		case frameTicker <- struct{}{}:
+		case <-ctx.Done():
+			return ctx.Err()
+		}
+
 		framer, err := c.exec(context.Background(), req, nil)
 		framer, err := c.exec(context.Background(), req, nil)
 		if err != nil {
 		if err != nil {
 			return err
 			return err