Преглед изворни кода

conn: pass context through startup chain

The startup procoss is very concurrent, and can fail in different
places. To ensure we dont end up blocking somewhere indefinetly make use
a context.

Run the startup loop concurrently whilst waiting for an error from recv
so that we get the recv error instead of just timing out waiting for
startup to finish.
Chris Bannister пре 10 година
родитељ
комит
4780c25d06
1 измењених фајлова са 37 додато и 16 уклоњено
  1. 37 16
      conn.go

+ 37 - 16
conn.go

@@ -202,33 +202,44 @@ func Connect(host *HostInfo, addr string, cfg *ConnConfig,
 		c.setKeepalive(cfg.Keepalive)
 		c.setKeepalive(cfg.Keepalive)
 	}
 	}
 
 
+	ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
+	defer cancel()
+
 	frameTicker := make(chan struct{}, 1)
 	frameTicker := make(chan struct{}, 1)
 	startupErr := make(chan error, 1)
 	startupErr := make(chan error, 1)
 	go func() {
 	go func() {
 		for range frameTicker {
 		for range frameTicker {
 			err := c.recv()
 			err := c.recv()
-			startupErr <- err
-			if err != nil {
+			select {
+			case startupErr <- err:
+				if err != nil {
+					return
+				}
+			case <-ctx.Done():
 				return
 				return
 			}
 			}
 		}
 		}
 	}()
 	}()
 
 
-	err = c.startup(frameTicker)
-	close(frameTicker)
-	if err != nil {
-		conn.Close()
-		return nil, err
-	}
+	go func() {
+		defer close(frameTicker)
+		err := c.startup(ctx, frameTicker)
+		if err != nil {
+			select {
+			case startupErr <- err:
+			case <-ctx.Done():
+				return
+			}
+		}
+	}()
 
 
 	select {
 	select {
 	case err := <-startupErr:
 	case err := <-startupErr:
 		if err != nil {
 		if err != nil {
-			log.Println(err)
 			c.Close()
 			c.Close()
 			return nil, err
 			return nil, err
 		}
 		}
-	case <-time.After(c.timeout):
+	case <-ctx.Done():
 		c.Close()
 		c.Close()
 		return nil, errors.New("gocql: no response to connection startup within timeout")
 		return nil, errors.New("gocql: no response to connection startup within timeout")
 	}
 	}
@@ -269,7 +280,7 @@ func (c *Conn) Read(p []byte) (n int, err error) {
 	return
 	return
 }
 }
 
 
-func (c *Conn) startup(frameTicker chan struct{}) error {
+func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error {
 	m := map[string]string{
 	m := map[string]string{
 		"CQL_VERSION": c.cfg.CQLVersion,
 		"CQL_VERSION": c.cfg.CQLVersion,
 	}
 	}
@@ -278,8 +289,13 @@ func (c *Conn) startup(frameTicker chan struct{}) error {
 		m["COMPRESSION"] = c.compressor.Name()
 		m["COMPRESSION"] = c.compressor.Name()
 	}
 	}
 
 
-	frameTicker <- struct{}{}
-	framer, err := c.exec(context.Background(), &writeStartupFrame{opts: m}, nil)
+	select {
+	case frameTicker <- struct{}{}:
+	case <-ctx.Done():
+		return ctx.Err()
+	}
+
+	framer, err := c.exec(ctx, &writeStartupFrame{opts: m}, nil)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -295,13 +311,13 @@ func (c *Conn) startup(frameTicker chan struct{}) error {
 	case *readyFrame:
 	case *readyFrame:
 		return nil
 		return nil
 	case *authenticateFrame:
 	case *authenticateFrame:
-		return c.authenticateHandshake(v, frameTicker)
+		return c.authenticateHandshake(ctx, v, frameTicker)
 	default:
 	default:
 		return NewErrProtocol("Unknown type of response to startup frame: %s", v)
 		return NewErrProtocol("Unknown type of response to startup frame: %s", v)
 	}
 	}
 }
 }
 
 
-func (c *Conn) authenticateHandshake(authFrame *authenticateFrame, frameTicker chan struct{}) error {
+func (c *Conn) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, frameTicker chan struct{}) error {
 	if c.auth == nil {
 	if c.auth == nil {
 		return fmt.Errorf("authentication required (using %q)", authFrame.class)
 		return fmt.Errorf("authentication required (using %q)", authFrame.class)
 	}
 	}
@@ -314,7 +330,12 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame, frameTicker c
 	req := &writeAuthResponseFrame{data: resp}
 	req := &writeAuthResponseFrame{data: resp}
 
 
 	for {
 	for {
-		frameTicker <- struct{}{}
+		select {
+		case frameTicker <- struct{}{}:
+		case <-ctx.Done():
+			return ctx.Err()
+		}
+
 		framer, err := c.exec(context.Background(), req, nil)
 		framer, err := c.exec(context.Background(), req, nil)
 		if err != nil {
 		if err != nil {
 			return err
 			return err