Browse Source

conn: check we can use compression in startup (#1215)

* conn: check we can use compression in startup

If the user asks to use compression check that the remote server
supports it by making an options call first.

Rejig conn startup to be self contained.

* simplify TestFrameHeaderObserver

* disable compression if the server doesnt support it
Chris Bannister 7 years ago
parent
commit
44e29ed5b8
2 changed files with 102 additions and 70 deletions
  1. 93 60
      conn.go
  2. 9 10
      conn_test.go

+ 93 - 60
conn.go

@@ -227,40 +227,14 @@ func (s *Session) dial(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHa
 	}
 	defer cancel()
 
-	frameTicker := make(chan struct{}, 1)
-	startupErr := make(chan error)
-	go func() {
-		for range frameTicker {
-			err := c.recv()
-			if err != nil {
-				select {
-				case startupErr <- err:
-				case <-ctx.Done():
-				}
-
-				return
-			}
-		}
-	}()
-
-	go func() {
-		defer close(frameTicker)
-		err := c.startup(ctx, frameTicker)
-		select {
-		case startupErr <- err:
-		case <-ctx.Done():
-		}
-	}()
+	startup := &startupCoordinator{
+		frameTicker: make(chan struct{}),
+		conn:        c,
+	}
 
-	select {
-	case err := <-startupErr:
-		if err != nil {
-			c.Close()
-			return nil, err
-		}
-	case <-ctx.Done():
-		c.Close()
-		return nil, errors.New("gocql: no response to connection startup within timeout")
+	if err := startup.setupConn(ctx); err != nil {
+		c.close()
+		return nil, err
 	}
 
 	// dont coalesce startup frames
@@ -300,27 +274,98 @@ func (c *Conn) Read(p []byte) (n int, err error) {
 	return
 }
 
-func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error {
-	m := map[string]string{
-		"CQL_VERSION": c.cfg.CQLVersion,
-	}
+type startupCoordinator struct {
+	conn        *Conn
+	frameTicker chan struct{}
+}
 
-	if c.compressor != nil {
-		m["COMPRESSION"] = c.compressor.Name()
+func (s *startupCoordinator) setupConn(ctx context.Context) error {
+	startupErr := make(chan error)
+	go func() {
+		for range s.frameTicker {
+			err := s.conn.recv()
+			if err != nil {
+				select {
+				case startupErr <- err:
+				case <-ctx.Done():
+				}
+
+				return
+			}
+		}
+	}()
+
+	go func() {
+		defer close(s.frameTicker)
+		err := s.options(ctx)
+		select {
+		case startupErr <- err:
+		case <-ctx.Done():
+		}
+	}()
+
+	select {
+	case err := <-startupErr:
+		if err != nil {
+			return err
+		}
+	case <-ctx.Done():
+		return errors.New("gocql: no response to connection startup within timeout")
 	}
 
+	return nil
+}
+
+func (s *startupCoordinator) write(ctx context.Context, frame frameWriter) (frame, error) {
 	select {
-	case frameTicker <- struct{}{}:
+	case s.frameTicker <- struct{}{}:
 	case <-ctx.Done():
-		return ctx.Err()
+		return nil, ctx.Err()
 	}
 
-	framer, err := c.exec(ctx, &writeStartupFrame{opts: m}, nil)
+	framer, err := s.conn.exec(ctx, frame, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	return framer.parseFrame()
+}
+
+func (s *startupCoordinator) options(ctx context.Context) error {
+	frame, err := s.write(ctx, &writeOptionsFrame{})
 	if err != nil {
 		return err
 	}
 
-	frame, err := framer.parseFrame()
+	supported, ok := frame.(*supportedFrame)
+	if !ok {
+		return NewErrProtocol("Unknown type of response to startup frame: %T", frame)
+	}
+
+	return s.startup(ctx, supported.supported)
+}
+
+func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string) error {
+	m := map[string]string{
+		"CQL_VERSION": s.conn.cfg.CQLVersion,
+	}
+
+	if s.conn.compressor != nil {
+		comp := supported["COMPRESSION"]
+		name := s.conn.compressor.Name()
+		for _, compressor := range comp {
+			if compressor == name {
+				m["COMPRESSION"] = compressor
+				break
+			}
+		}
+
+		if _, ok := m["COMPRESSION"]; !ok {
+			s.conn.compressor = nil
+		}
+	}
+
+	frame, err := s.write(ctx, &writeStartupFrame{opts: m})
 	if err != nil {
 		return err
 	}
@@ -331,37 +376,25 @@ func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error {
 	case *readyFrame:
 		return nil
 	case *authenticateFrame:
-		return c.authenticateHandshake(ctx, v, frameTicker)
+		return s.authenticateHandshake(ctx, v)
 	default:
 		return NewErrProtocol("Unknown type of response to startup frame: %s", v)
 	}
 }
 
-func (c *Conn) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, frameTicker chan struct{}) error {
-	if c.auth == nil {
+func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame) error {
+	if s.conn.auth == nil {
 		return fmt.Errorf("authentication required (using %q)", authFrame.class)
 	}
 
-	resp, challenger, err := c.auth.Challenge([]byte(authFrame.class))
+	resp, challenger, err := s.conn.auth.Challenge([]byte(authFrame.class))
 	if err != nil {
 		return err
 	}
 
 	req := &writeAuthResponseFrame{data: resp}
-
 	for {
-		select {
-		case frameTicker <- struct{}{}:
-		case <-ctx.Done():
-			return ctx.Err()
-		}
-
-		framer, err := c.exec(ctx, req, nil)
-		if err != nil {
-			return err
-		}
-
-		frame, err := framer.parseFrame()
+		frame, err := s.write(ctx, req)
 		if err != nil {
 			return err
 		}

+ 9 - 10
conn_test.go

@@ -940,18 +940,17 @@ func TestFrameHeaderObserver(t *testing.T) {
 	}
 
 	frames := observer.getFrames()
-
-	if len(frames) != 2 {
-		t.Fatalf("Expected to receive 2 frames, instead received %d", len(frames))
-	}
-	readyFrame := frames[0]
-	if readyFrame.Opcode != frameOp(opReady) {
-		t.Fatalf("Expected to receive ready frame, instead received frame of opcode %d", readyFrame.Opcode)
+	expFrames := []frameOp{opSupported, opReady, opResult}
+	if len(frames) != len(expFrames) {
+		t.Fatalf("Expected to receive %d frames, instead received %d", len(expFrames), len(frames))
 	}
-	voidResultFrame := frames[1]
-	if voidResultFrame.Opcode != frameOp(opResult) {
-		t.Fatalf("Expected to receive result frame, instead received frame of opcode %d", voidResultFrame.Opcode)
+
+	for i, op := range expFrames {
+		if op != frames[i].Opcode {
+			t.Fatalf("expected frame %d to be %v got %v", i, op, frames[i])
+		}
 	}
+	voidResultFrame := frames[2]
 	if voidResultFrame.Length != int32(4) {
 		t.Fatalf("Expected to receive frame with body length 4, instead received body length %d", voidResultFrame.Length)
 	}