Bläddra i källkod

store the hostInfo on the connection

Chris Bannister 9 år sedan
förälder
incheckning
2d49157aa7
6 ändrade filer med 40 tillägg och 10 borttagningar
  1. 6 1
      conn.go
  2. 2 1
      conn_test.go
  3. 2 2
      connectionpool.go
  4. 27 4
      control.go
  5. 1 0
      events.go
  6. 2 2
      session.go

+ 6 - 1
conn.go

@@ -139,6 +139,8 @@ type Conn struct {
 	currentKeyspace string
 	started         bool
 
+	host *HostInfo
+
 	session *Session
 
 	closed int32
@@ -148,7 +150,9 @@ type Conn struct {
 }
 
 // Connect establishes a connection to a Cassandra node.
-func Connect(addr string, cfg *ConnConfig, errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
+func Connect(host *HostInfo, addr string, cfg *ConnConfig,
+	errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
+
 	var (
 		err  error
 		conn net.Conn
@@ -196,6 +200,7 @@ func Connect(addr string, cfg *ConnConfig, errorHandler ConnErrorHandler, sessio
 		quit:         make(chan struct{}),
 		session:      session,
 		streams:      streams.New(cfg.ProtoVersion),
+		host:         host,
 	}
 
 	if cfg.Keepalive > 0 {

+ 2 - 1
conn_test.go

@@ -416,7 +416,8 @@ func TestStream0(t *testing.T) {
 		}
 	})
 
-	conn, err := Connect(srv.Address, &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
+	host := &HostInfo{peer: srv.Address}
+	conn, err := Connect(host, srv.Address, &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
 	if err != nil {
 		t.Fatal(err)
 	}

+ 2 - 2
connectionpool.go

@@ -246,7 +246,7 @@ func (p *policyConnPool) addHost(host *HostInfo) {
 		pool = newHostConnPool(
 			p.session,
 			host,
-			host.Port(),
+			host.Port(), // TODO: if port == 0 use pool.port?
 			p.numConns,
 			p.keyspace,
 			p.connPolicy(),
@@ -506,7 +506,7 @@ func (pool *hostConnPool) connect() (err error) {
 	// try to connect
 	var conn *Conn
 	for i := 0; i < maxAttempts; i++ {
-		conn, err = pool.session.connect(pool.addr, pool)
+		conn, err = pool.session.connect(pool.addr, pool, pool.host)
 		if err == nil {
 			break
 		}

+ 27 - 4
control.go

@@ -99,9 +99,28 @@ func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) {
 	// shuffle endpoints so not all drivers will connect to the same initial
 	// node.
 	for _, addr := range shuffled {
-		conn, err = c.session.connect(JoinHostPort(addr, c.session.cfg.Port), c)
+		if addr == "" {
+			return nil, fmt.Errorf("control: invalid address: %q", addr)
+		}
+
+		port := c.session.cfg.Port
+		addr = JoinHostPort(addr, port)
+		host, portStr, err := net.SplitHostPort(addr)
+		if err != nil {
+			host = addr
+			port = c.session.cfg.Port
+			err = nil
+		} else {
+			port, err = strconv.Atoi(portStr)
+			if err != nil {
+				return nil, err
+			}
+		}
+
+		hostInfo, _ := c.session.ring.addHostIfMissing(&HostInfo{peer: host, port: port})
+		conn, err = c.session.connect(addr, c, hostInfo)
 		if err == nil {
-			return
+			return conn, err
 		}
 
 		log.Printf("gocql: unable to dial control conn %v: %v\n", addr, err)
@@ -111,6 +130,10 @@ func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) {
 }
 
 func (c *controlConn) connect(endpoints []string) error {
+	if len(endpoints) == 0 {
+		return errors.New("control: no endpoints specified")
+	}
+
 	conn, err := c.shuffleDial(endpoints)
 	if err != nil {
 		return fmt.Errorf("control: unable to connect: %v", err)
@@ -200,7 +223,7 @@ func (c *controlConn) reconnect(refreshring bool) {
 	var newConn *Conn
 	if addr != "" {
 		// try to connect to the old host
-		conn, err := c.session.connect(addr, c)
+		conn, err := c.session.connect(addr, c, oldConn.host)
 		if err != nil {
 			// host is dead
 			// TODO: this is replicated in a few places
@@ -222,7 +245,7 @@ func (c *controlConn) reconnect(refreshring bool) {
 		}
 
 		var err error
-		newConn, err = c.session.connect(conn.addr, c)
+		newConn, err = c.session.connect(conn.addr, c, conn.host)
 		if err != nil {
 			// TODO: add log handler for things like this
 			return

+ 1 - 0
events.go

@@ -249,6 +249,7 @@ func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) {
 			time.Sleep(t)
 		}
 
+		host.setPort(port)
 		s.pool.hostUp(host)
 		host.setState(NodeUp)
 		return

+ 2 - 2
session.go

@@ -583,8 +583,8 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{})
 	return applied, iter, iter.err
 }
 
-func (s *Session) connect(addr string, errorHandler ConnErrorHandler) (*Conn, error) {
-	return Connect(addr, s.connCfg, errorHandler, s)
+func (s *Session) connect(addr string, errorHandler ConnErrorHandler, host *HostInfo) (*Conn, error) {
+	return Connect(host, addr, s.connCfg, errorHandler, s)
 }
 
 // Query represents a CQL statement that can be executed.