Browse Source

auth will needs to read more than one startup response frame

Chris Bannister 9 years ago
parent
commit
1b219436fa
1 changed files with 19 additions and 7 deletions
  1. 19 7
      conn.go

+ 19 - 7
conn.go

@@ -207,17 +207,27 @@ func Connect(host *HostInfo, addr string, cfg *ConnConfig,
 		c.setKeepalive(cfg.Keepalive)
 		c.setKeepalive(cfg.Keepalive)
 	}
 	}
 
 
-	started := make(chan error, 1)
+	frameTicker := make(chan struct{}, 1)
+	startupErr := make(chan error, 1)
 	go func() {
 	go func() {
-		started <- c.recv()
+		for range frameTicker {
+			err := c.recv()
+			startupErr <- err
+			if err != nil {
+				return
+			}
+		}
 	}()
 	}()
 
 
-	if err := c.startup(); err != nil {
+	err = c.startup(frameTicker)
+	close(frameTicker)
+	if err != nil {
 		conn.Close()
 		conn.Close()
 		return nil, err
 		return nil, err
 	}
 	}
+
 	select {
 	select {
-	case err := <-started:
+	case err := <-startupErr:
 		if err != nil {
 		if err != nil {
 			log.Println(err)
 			log.Println(err)
 			c.Close()
 			c.Close()
@@ -264,7 +274,7 @@ func (c *Conn) Read(p []byte) (n int, err error) {
 	return
 	return
 }
 }
 
 
-func (c *Conn) startup() error {
+func (c *Conn) startup(frameTicker chan struct{}) error {
 	m := map[string]string{
 	m := map[string]string{
 		"CQL_VERSION": c.cfg.CQLVersion,
 		"CQL_VERSION": c.cfg.CQLVersion,
 	}
 	}
@@ -273,6 +283,7 @@ func (c *Conn) startup() error {
 		m["COMPRESSION"] = c.compressor.Name()
 		m["COMPRESSION"] = c.compressor.Name()
 	}
 	}
 
 
+	frameTicker <- struct{}{}
 	framer, err := c.exec(&writeStartupFrame{opts: m}, nil)
 	framer, err := c.exec(&writeStartupFrame{opts: m}, nil)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -289,13 +300,13 @@ func (c *Conn) startup() error {
 	case *readyFrame:
 	case *readyFrame:
 		return nil
 		return nil
 	case *authenticateFrame:
 	case *authenticateFrame:
-		return c.authenticateHandshake(v)
+		return c.authenticateHandshake(v, frameTicker)
 	default:
 	default:
 		return NewErrProtocol("Unknown type of response to startup frame: %s", v)
 		return NewErrProtocol("Unknown type of response to startup frame: %s", v)
 	}
 	}
 }
 }
 
 
-func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
+func (c *Conn) authenticateHandshake(authFrame *authenticateFrame, frameTicker chan struct{}) error {
 	if c.auth == nil {
 	if c.auth == nil {
 		return fmt.Errorf("authentication required (using %q)", authFrame.class)
 		return fmt.Errorf("authentication required (using %q)", authFrame.class)
 	}
 	}
@@ -308,6 +319,7 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
 	req := &writeAuthResponseFrame{data: resp}
 	req := &writeAuthResponseFrame{data: resp}
 
 
 	for {
 	for {
+		frameTicker <- struct{}{}
 		framer, err := c.exec(req, nil)
 		framer, err := c.exec(req, nil)
 		if err != nil {
 		if err != nil {
 			return err
 			return err