|
|
@@ -202,33 +202,44 @@ func Connect(host *HostInfo, addr string, cfg *ConnConfig,
|
|
|
c.setKeepalive(cfg.Keepalive)
|
|
|
}
|
|
|
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
frameTicker := make(chan struct{}, 1)
|
|
|
startupErr := make(chan error, 1)
|
|
|
go func() {
|
|
|
for range frameTicker {
|
|
|
err := c.recv()
|
|
|
- startupErr <- err
|
|
|
- if err != nil {
|
|
|
+ select {
|
|
|
+ case startupErr <- err:
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ case <-ctx.Done():
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
- err = c.startup(frameTicker)
|
|
|
- close(frameTicker)
|
|
|
- if err != nil {
|
|
|
- conn.Close()
|
|
|
- return nil, err
|
|
|
- }
|
|
|
+ go func() {
|
|
|
+ defer close(frameTicker)
|
|
|
+ err := c.startup(ctx, frameTicker)
|
|
|
+ if err != nil {
|
|
|
+ select {
|
|
|
+ case startupErr <- err:
|
|
|
+ case <-ctx.Done():
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }()
|
|
|
|
|
|
select {
|
|
|
case err := <-startupErr:
|
|
|
if err != nil {
|
|
|
- log.Println(err)
|
|
|
c.Close()
|
|
|
return nil, err
|
|
|
}
|
|
|
- case <-time.After(c.timeout):
|
|
|
+ case <-ctx.Done():
|
|
|
c.Close()
|
|
|
return nil, errors.New("gocql: no response to connection startup within timeout")
|
|
|
}
|
|
|
@@ -269,7 +280,7 @@ func (c *Conn) Read(p []byte) (n int, err error) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func (c *Conn) startup(frameTicker chan struct{}) error {
|
|
|
+func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error {
|
|
|
m := map[string]string{
|
|
|
"CQL_VERSION": c.cfg.CQLVersion,
|
|
|
}
|
|
|
@@ -278,8 +289,13 @@ func (c *Conn) startup(frameTicker chan struct{}) error {
|
|
|
m["COMPRESSION"] = c.compressor.Name()
|
|
|
}
|
|
|
|
|
|
- frameTicker <- struct{}{}
|
|
|
- framer, err := c.exec(context.Background(), &writeStartupFrame{opts: m}, nil)
|
|
|
+ select {
|
|
|
+ case frameTicker <- struct{}{}:
|
|
|
+ case <-ctx.Done():
|
|
|
+ return ctx.Err()
|
|
|
+ }
|
|
|
+
|
|
|
+ framer, err := c.exec(ctx, &writeStartupFrame{opts: m}, nil)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -295,13 +311,13 @@ func (c *Conn) startup(frameTicker chan struct{}) error {
|
|
|
case *readyFrame:
|
|
|
return nil
|
|
|
case *authenticateFrame:
|
|
|
- return c.authenticateHandshake(v, frameTicker)
|
|
|
+ return c.authenticateHandshake(ctx, v, frameTicker)
|
|
|
default:
|
|
|
return NewErrProtocol("Unknown type of response to startup frame: %s", v)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (c *Conn) authenticateHandshake(authFrame *authenticateFrame, frameTicker chan struct{}) error {
|
|
|
+func (c *Conn) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, frameTicker chan struct{}) error {
|
|
|
if c.auth == nil {
|
|
|
return fmt.Errorf("authentication required (using %q)", authFrame.class)
|
|
|
}
|
|
|
@@ -314,7 +330,12 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame, frameTicker c
|
|
|
req := &writeAuthResponseFrame{data: resp}
|
|
|
|
|
|
for {
|
|
|
- frameTicker <- struct{}{}
|
|
|
+ select {
|
|
|
+ case frameTicker <- struct{}{}:
|
|
|
+ case <-ctx.Done():
|
|
|
+ return ctx.Err()
|
|
|
+ }
|
|
|
+
|
|
|
framer, err := c.exec(context.Background(), req, nil)
|
|
|
if err != nil {
|
|
|
return err
|