|
@@ -227,40 +227,14 @@ func (s *Session) dial(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHa
|
|
|
}
|
|
}
|
|
|
defer cancel()
|
|
defer cancel()
|
|
|
|
|
|
|
|
- frameTicker := make(chan struct{}, 1)
|
|
|
|
|
- startupErr := make(chan error)
|
|
|
|
|
- go func() {
|
|
|
|
|
- for range frameTicker {
|
|
|
|
|
- err := c.recv()
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- select {
|
|
|
|
|
- case startupErr <- err:
|
|
|
|
|
- case <-ctx.Done():
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- return
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }()
|
|
|
|
|
-
|
|
|
|
|
- go func() {
|
|
|
|
|
- defer close(frameTicker)
|
|
|
|
|
- err := c.startup(ctx, frameTicker)
|
|
|
|
|
- select {
|
|
|
|
|
- case startupErr <- err:
|
|
|
|
|
- case <-ctx.Done():
|
|
|
|
|
- }
|
|
|
|
|
- }()
|
|
|
|
|
|
|
+ startup := &startupCoordinator{
|
|
|
|
|
+ frameTicker: make(chan struct{}),
|
|
|
|
|
+ conn: c,
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- select {
|
|
|
|
|
- case err := <-startupErr:
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- c.Close()
|
|
|
|
|
- return nil, err
|
|
|
|
|
- }
|
|
|
|
|
- case <-ctx.Done():
|
|
|
|
|
- c.Close()
|
|
|
|
|
- return nil, errors.New("gocql: no response to connection startup within timeout")
|
|
|
|
|
|
|
+ if err := startup.setupConn(ctx); err != nil {
|
|
|
|
|
+ c.close()
|
|
|
|
|
+ return nil, err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// dont coalesce startup frames
|
|
// dont coalesce startup frames
|
|
@@ -300,27 +274,98 @@ func (c *Conn) Read(p []byte) (n int, err error) {
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error {
|
|
|
|
|
- m := map[string]string{
|
|
|
|
|
- "CQL_VERSION": c.cfg.CQLVersion,
|
|
|
|
|
- }
|
|
|
|
|
|
|
+type startupCoordinator struct {
|
|
|
|
|
+ conn *Conn
|
|
|
|
|
+ frameTicker chan struct{}
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
- if c.compressor != nil {
|
|
|
|
|
- m["COMPRESSION"] = c.compressor.Name()
|
|
|
|
|
|
|
+func (s *startupCoordinator) setupConn(ctx context.Context) error {
|
|
|
|
|
+ startupErr := make(chan error)
|
|
|
|
|
+ go func() {
|
|
|
|
|
+ for range s.frameTicker {
|
|
|
|
|
+ err := s.conn.recv()
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ select {
|
|
|
|
|
+ case startupErr <- err:
|
|
|
|
|
+ case <-ctx.Done():
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }()
|
|
|
|
|
+
|
|
|
|
|
+ go func() {
|
|
|
|
|
+ defer close(s.frameTicker)
|
|
|
|
|
+ err := s.options(ctx)
|
|
|
|
|
+ select {
|
|
|
|
|
+ case startupErr <- err:
|
|
|
|
|
+ case <-ctx.Done():
|
|
|
|
|
+ }
|
|
|
|
|
+ }()
|
|
|
|
|
+
|
|
|
|
|
+ select {
|
|
|
|
|
+ case err := <-startupErr:
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+ case <-ctx.Done():
|
|
|
|
|
+ return errors.New("gocql: no response to connection startup within timeout")
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ return nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (s *startupCoordinator) write(ctx context.Context, frame frameWriter) (frame, error) {
|
|
|
select {
|
|
select {
|
|
|
- case frameTicker <- struct{}{}:
|
|
|
|
|
|
|
+ case s.frameTicker <- struct{}{}:
|
|
|
case <-ctx.Done():
|
|
case <-ctx.Done():
|
|
|
- return ctx.Err()
|
|
|
|
|
|
|
+ return nil, ctx.Err()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- framer, err := c.exec(ctx, &writeStartupFrame{opts: m}, nil)
|
|
|
|
|
|
|
+ framer, err := s.conn.exec(ctx, frame, nil)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return framer.parseFrame()
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (s *startupCoordinator) options(ctx context.Context) error {
|
|
|
|
|
+ frame, err := s.write(ctx, &writeOptionsFrame{})
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- frame, err := framer.parseFrame()
|
|
|
|
|
|
|
+ supported, ok := frame.(*supportedFrame)
|
|
|
|
|
+ if !ok {
|
|
|
|
|
+ return NewErrProtocol("Unknown type of response to startup frame: %T", frame)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return s.startup(ctx, supported.supported)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string) error {
|
|
|
|
|
+ m := map[string]string{
|
|
|
|
|
+ "CQL_VERSION": s.conn.cfg.CQLVersion,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if s.conn.compressor != nil {
|
|
|
|
|
+ comp := supported["COMPRESSION"]
|
|
|
|
|
+ name := s.conn.compressor.Name()
|
|
|
|
|
+ for _, compressor := range comp {
|
|
|
|
|
+ if compressor == name {
|
|
|
|
|
+ m["COMPRESSION"] = compressor
|
|
|
|
|
+ break
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if _, ok := m["COMPRESSION"]; !ok {
|
|
|
|
|
+ s.conn.compressor = nil
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ frame, err := s.write(ctx, &writeStartupFrame{opts: m})
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
@@ -331,37 +376,25 @@ func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error {
|
|
|
case *readyFrame:
|
|
case *readyFrame:
|
|
|
return nil
|
|
return nil
|
|
|
case *authenticateFrame:
|
|
case *authenticateFrame:
|
|
|
- return c.authenticateHandshake(ctx, v, frameTicker)
|
|
|
|
|
|
|
+ return s.authenticateHandshake(ctx, v)
|
|
|
default:
|
|
default:
|
|
|
return NewErrProtocol("Unknown type of response to startup frame: %s", v)
|
|
return NewErrProtocol("Unknown type of response to startup frame: %s", v)
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (c *Conn) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, frameTicker chan struct{}) error {
|
|
|
|
|
- if c.auth == nil {
|
|
|
|
|
|
|
+func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame) error {
|
|
|
|
|
+ if s.conn.auth == nil {
|
|
|
return fmt.Errorf("authentication required (using %q)", authFrame.class)
|
|
return fmt.Errorf("authentication required (using %q)", authFrame.class)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- resp, challenger, err := c.auth.Challenge([]byte(authFrame.class))
|
|
|
|
|
|
|
+ resp, challenger, err := s.conn.auth.Challenge([]byte(authFrame.class))
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
req := &writeAuthResponseFrame{data: resp}
|
|
req := &writeAuthResponseFrame{data: resp}
|
|
|
-
|
|
|
|
|
for {
|
|
for {
|
|
|
- select {
|
|
|
|
|
- case frameTicker <- struct{}{}:
|
|
|
|
|
- case <-ctx.Done():
|
|
|
|
|
- return ctx.Err()
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- framer, err := c.exec(ctx, req, nil)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return err
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- frame, err := framer.parseFrame()
|
|
|
|
|
|
|
+ frame, err := s.write(ctx, req)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|