Selaa lähdekoodia

refactor host discovery

Simplify and refactor host discovery, replace methods for discovering
ring info to use the same methods everywhere. Don't store host info on
the connection instead the control connection knows its own host and all
others are discovered. Don't trigger hostUp from the control connection
isntead use only discovered hosts from session init.

Clean things up around, unexport Connect to be a private API.
Chris Bannister 8 vuotta sitten
vanhempi
commit
01f586cc26
13 muutettua tiedostoa jossa 253 lisäystä ja 332 poistoa
  1. 21 39
      cassandra_test.go
  2. 0 1
      common_test.go
  3. 41 27
      conn.go
  4. 32 38
      conn_test.go
  5. 49 44
      control.go
  6. 5 5
      events.go
  7. 23 10
      helpers.go
  8. 53 126
      host_source.go
  9. 3 30
      host_source_test.go
  10. 10 4
      query_executor.go
  11. 2 0
      ring.go
  12. 10 4
      session.go
  13. 4 4
      session_connect_test.go

+ 21 - 39
cassandra_test.go

@@ -2362,31 +2362,26 @@ func TestDiscoverViaProxy(t *testing.T) {
 	if err != nil {
 		t.Fatalf("unable to create proxy listener: %v", err)
 	}
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
 
 	var (
-		wg         sync.WaitGroup
 		mu         sync.Mutex
 		proxyConns []net.Conn
 		closed     bool
 	)
 
-	go func(wg *sync.WaitGroup) {
+	go func() {
 		cassandraAddr := JoinHostPort(clusterHosts[0], 9042)
 
 		cassandra := func() (net.Conn, error) {
 			return net.Dial("tcp", cassandraAddr)
 		}
 
-		proxyFn := func(wg *sync.WaitGroup, from, to net.Conn) {
-			defer wg.Done()
-
+		proxyFn := func(errs chan error, from, to net.Conn) {
 			_, err := io.Copy(to, from)
 			if err != nil {
-				mu.Lock()
-				if !closed {
-					t.Error(err)
-				}
-				mu.Unlock()
+				errs <- err
 			}
 		}
 
@@ -2394,29 +2389,22 @@ func TestDiscoverViaProxy(t *testing.T) {
 		// for both the read and write side of the TCP connection to close before
 		// returning.
 		handle := func(conn net.Conn) error {
-			defer conn.Close()
-
 			cass, err := cassandra()
 			if err != nil {
 				return err
 			}
-
-			mu.Lock()
-			proxyConns = append(proxyConns, cass)
-			mu.Unlock()
-
 			defer cass.Close()
 
-			var wg sync.WaitGroup
-			wg.Add(1)
-			go proxyFn(&wg, conn, cass)
-
-			wg.Add(1)
-			go proxyFn(&wg, cass, conn)
+			errs := make(chan error, 2)
+			go proxyFn(errs, conn, cass)
+			go proxyFn(errs, cass, conn)
 
-			wg.Wait()
-
-			return nil
+			select {
+			case <-ctx.Done():
+				return ctx.Err()
+			case err := <-errs:
+				return err
+			}
 		}
 
 		for {
@@ -2436,19 +2424,19 @@ func TestDiscoverViaProxy(t *testing.T) {
 			proxyConns = append(proxyConns, conn)
 			mu.Unlock()
 
-			wg.Add(1)
 			go func(conn net.Conn) {
-				defer wg.Done()
+				defer conn.Close()
 
 				if err := handle(conn); err != nil {
-					t.Error(err)
-					return
+					mu.Lock()
+					if !closed {
+						t.Error(err)
+					}
+					mu.Unlock()
 				}
 			}(conn)
 		}
-	}(&wg)
-
-	defer wg.Wait()
+	}()
 
 	proxyAddr := proxy.Addr().String()
 
@@ -2460,11 +2448,6 @@ func TestDiscoverViaProxy(t *testing.T) {
 	session := createSessionFromCluster(cluster, t)
 	defer session.Close()
 
-	if session.hostSource.localHost.BroadcastAddress() == nil {
-		t.Skip("Target cluster does not have broadcast_address in system.local.")
-		goto close
-	}
-
 	// we shouldnt need this but to be safe
 	time.Sleep(1 * time.Second)
 
@@ -2476,7 +2459,6 @@ func TestDiscoverViaProxy(t *testing.T) {
 	}
 	session.pool.mu.RUnlock()
 
-close:
 	mu.Lock()
 	closed = true
 	if err := proxy.Close(); err != nil {

+ 0 - 1
common_test.go

@@ -102,7 +102,6 @@ func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
 		panic(err)
 	}
 	defer session.Close()
-	defer tb.Log("closing keyspace session")
 
 	err = createTable(session, `DROP KEYSPACE IF EXISTS `+keyspace)
 	if err != nil {

+ 41 - 27
conn.go

@@ -141,8 +141,6 @@ type Conn struct {
 	version         uint8
 	currentKeyspace string
 
-	host *HostInfo
-
 	session *Session
 
 	closed int32
@@ -152,14 +150,12 @@ type Conn struct {
 }
 
 // Connect establishes a connection to a Cassandra node.
-func Connect(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
+func (s *Session) dial(ip net.IP, port int, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
 	// TODO(zariel): remove these
-	if host == nil {
-		panic("host is nil")
-	} else if len(host.ConnectAddress()) == 0 {
-		panic(fmt.Sprintf("host missing connect ip address: %v", host))
-	} else if host.Port() == 0 {
-		panic(fmt.Sprintf("host missing port: %v", host))
+	if len(ip) == 0 || ip.IsUnspecified() {
+		panic(fmt.Sprintf("host missing connect ip address: %v", ip))
+	} else if port == 0 {
+		panic(fmt.Sprintf("host missing port: %v", port))
 	}
 
 	var (
@@ -172,9 +168,8 @@ func Connect(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler, ses
 	}
 
 	// TODO(zariel): handle ipv6 zone
-	translatedPeer, translatedPort := session.cfg.translateAddressPort(host.ConnectAddress(), host.Port())
+	translatedPeer, translatedPort := s.cfg.translateAddressPort(ip, port)
 	addr := (&net.TCPAddr{IP: translatedPeer, Port: translatedPort}).String()
-	//addr := (&net.TCPAddr{IP: host.Peer(), Port: host.Port()}).String()
 
 	if cfg.tlsConfig != nil {
 		// the TLS config is safe to be reused by connections but it must not
@@ -200,9 +195,8 @@ func Connect(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler, ses
 		compressor:   cfg.Compressor,
 		auth:         cfg.Authenticator,
 		quit:         make(chan struct{}),
-		session:      session,
+		session:      s,
 		streams:      streams.New(cfg.ProtoVersion),
-		host:         host,
 	}
 
 	if cfg.Keepalive > 0 {
@@ -405,13 +399,20 @@ func (c *Conn) closeWithError(err error) {
 
 	// if error was nil then unblock the quit channel
 	close(c.quit)
-	c.conn.Close()
+	cerr := c.close()
 
 	if err != nil {
 		c.errorHandler.HandleError(c, err, true)
+	} else if cerr != nil {
+		// TODO(zariel): is it a good idea to do this?
+		c.errorHandler.HandleError(c, cerr, true)
 	}
 }
 
+func (c *Conn) close() error {
+	return c.conn.Close()
+}
+
 func (c *Conn) Close() {
 	c.closeWithError(nil)
 }
@@ -420,15 +421,9 @@ func (c *Conn) Close() {
 // to execute any queries. This method runs as long as the connection is
 // open and is therefore usually called in a separate goroutine.
 func (c *Conn) serve() {
-	var (
-		err error
-	)
-
-	for {
+	var err error
+	for err == nil {
 		err = c.recv()
-		if err != nil {
-			break
-		}
 	}
 
 	c.closeWithError(err)
@@ -887,8 +882,9 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 
 		if len(x.meta.pagingState) > 0 && !qry.disableAutoPage {
 			iter.next = &nextIter{
-				qry: *qry,
-				pos: int((1 - qry.prefetch) * float64(x.numRows)),
+				qry:  *qry,
+				pos:  int((1 - qry.prefetch) * float64(x.numRows)),
+				conn: c,
 			}
 
 			iter.next.qry.pageState = copyBytes(x.meta.pagingState)
@@ -1100,7 +1096,7 @@ func (c *Conn) query(statement string, values ...interface{}) (iter *Iter) {
 
 func (c *Conn) awaitSchemaAgreement() (err error) {
 	const (
-		peerSchemas  = "SELECT schema_version FROM system.peers"
+		peerSchemas  = "SELECT schema_version, peer FROM system.peers"
 		localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"
 	)
 
@@ -1113,9 +1109,10 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 		versions = make(map[string]struct{})
 
 		var schemaVersion string
-		for iter.Scan(&schemaVersion) {
+		var peer string
+		for iter.Scan(&schemaVersion, &peer) {
 			if schemaVersion == "" {
-				Logger.Println("skipping peer entry with empty schema_version")
+				Logger.Printf("skipping peer entry with empty schema_version: peer=%q", peer)
 				continue
 			}
 
@@ -1158,6 +1155,23 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 	return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas)
 }
 
+const localHostInfo = "SELECT * FROM system.local WHERE key='local'"
+
+func (c *Conn) localHostInfo() (*HostInfo, error) {
+	row, err := c.query(localHostInfo).rowMap()
+	if err != nil {
+		return nil, err
+	}
+
+	// TODO(zariel): avoid doing this here
+	host, err := hostInfoFromMap(row, c.session.cfg.Port)
+	if err != nil {
+		return nil, err
+	}
+
+	return c.session.ring.addOrUpdate(host), nil
+}
+
 var (
 	ErrQueryArgLength    = errors.New("gocql: query argument length mismatch")
 	ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period")

+ 32 - 38
conn_test.go

@@ -574,7 +574,13 @@ func TestStream0(t *testing.T) {
 		}
 	})
 
-	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, createTestSession())
+	s, err := srv.session()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer s.Close()
+
+	conn, err := s.connect(srv.host(), errorHandler)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -609,7 +615,13 @@ func TestConnClosedBlocked(t *testing.T) {
 		t.Log(err)
 	})
 
-	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, createTestSession())
+	s, err := srv.session()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer s.Close()
+
+	conn, err := s.connect(srv.host(), errorHandler)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -737,6 +749,10 @@ type TestServer struct {
 	closed bool
 }
 
+func (srv *TestServer) session() (*Session, error) {
+	return testCluster(srv.Address, protoVersion(srv.protocol)).CreateSession()
+}
+
 func (srv *TestServer) host() *HostInfo {
 	host, err := hostInfo(srv.Address, 9042)
 	if err != nil {
@@ -756,13 +772,7 @@ func (srv *TestServer) closeWatch() {
 
 func (srv *TestServer) serve() {
 	defer srv.listen.Close()
-	for {
-		select {
-		case <-srv.ctx.Done():
-			return
-		default:
-		}
-
+	for !srv.isClosed() {
 		conn, err := srv.listen.Accept()
 		if err != nil {
 			break
@@ -770,26 +780,13 @@ func (srv *TestServer) serve() {
 
 		go func(conn net.Conn) {
 			defer conn.Close()
-			for {
-				select {
-				case <-srv.ctx.Done():
-					return
-				default:
-				}
-
+			for !srv.isClosed() {
 				framer, err := srv.readFrame(conn)
 				if err != nil {
 					if err == io.EOF {
 						return
 					}
-
-					select {
-					case <-srv.ctx.Done():
-						return
-					default:
-					}
-
-					srv.t.Error(err)
+					srv.errorLocked(err)
 					return
 				}
 
@@ -824,16 +821,19 @@ func (srv *TestServer) Stop() {
 	srv.closeLocked()
 }
 
+func (srv *TestServer) errorLocked(err interface{}) {
+	srv.mu.Lock()
+	defer srv.mu.Unlock()
+	if srv.closed {
+		return
+	}
+	srv.t.Error(err)
+}
+
 func (srv *TestServer) process(f *framer) {
 	head := f.header
 	if head == nil {
-		select {
-		case <-srv.ctx.Done():
-			return
-		default:
-		}
-
-		srv.t.Error("process frame with a nil header")
+		srv.errorLocked("process frame with a nil header")
 		return
 	}
 
@@ -901,13 +901,7 @@ func (srv *TestServer) process(f *framer) {
 	f.wbuf[0] = srv.protocol | 0x80
 
 	if err := f.finishWrite(); err != nil {
-		select {
-		case <-srv.ctx.Done():
-			return
-		default:
-		}
-
-		srv.t.Error(err)
+		srv.errorLocked(err)
 	}
 }
 

+ 49 - 44
control.go

@@ -47,7 +47,7 @@ func createControlConn(session *Session) *controlConn {
 		retry:   &SimpleRetryPolicy{NumRetries: 3},
 	}
 
-	control.conn.Store((*Conn)(nil))
+	control.conn.Store((*connHost)(nil))
 
 	return control
 }
@@ -197,14 +197,20 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
 	handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
 		// we should never get here, but if we do it means we connected to a
 		// host successfully which means our attempted protocol version worked
+		if !closed {
+			c.Close()
+		}
 	})
 
 	var err error
 	for _, host := range hosts {
 		var conn *Conn
-		conn, err = Connect(host, &connCfg, handler, c.session)
-		if err == nil {
+		conn, err = c.session.dial(host.ConnectAddress(), host.Port(), &connCfg, handler)
+		if conn != nil {
 			conn.Close()
+		}
+
+		if err == nil {
 			return connCfg.ProtoVersion, nil
 		}
 
@@ -239,35 +245,31 @@ func (c *controlConn) connect(hosts []*HostInfo) error {
 	return nil
 }
 
+type connHost struct {
+	conn *Conn
+	host *HostInfo
+}
+
 func (c *controlConn) setupConn(conn *Conn) error {
 	if err := c.registerEvents(conn); err != nil {
 		conn.Close()
 		return err
 	}
 
-	c.conn.Store(conn)
-
-	if v, ok := conn.conn.RemoteAddr().(*net.TCPAddr); ok {
-		c.session.handleNodeUp(copyBytes(v.IP), v.Port, false)
-		return nil
-	}
-
-	host, portstr, err := net.SplitHostPort(conn.conn.RemoteAddr().String())
-	if err != nil {
-		return err
-	}
-
-	port, err := strconv.Atoi(portstr)
+	// TODO(zariel): do we need to fetch host info everytime
+	// the control conn connects? Surely we have it cached?
+	host, err := conn.localHostInfo()
 	if err != nil {
 		return err
 	}
 
-	ip := net.ParseIP(host)
-	if ip == nil {
-		return fmt.Errorf("invalid remote addr: addr=%v host=%q", conn.conn.RemoteAddr(), host)
+	ch := &connHost{
+		conn: conn,
+		host: host,
 	}
 
-	c.session.handleNodeUp(ip, port, false)
+	c.conn.Store(ch)
+	// c.session.handleNodeUp(host.ConnectAddress(), host.Port(), false)
 
 	return nil
 }
@@ -312,10 +314,10 @@ func (c *controlConn) reconnect(refreshring bool) {
 	// connection pool
 
 	var host *HostInfo
-	oldConn := c.conn.Load().(*Conn)
-	if oldConn != nil {
-		host = oldConn.host
-		oldConn.Close()
+	ch := c.getConn()
+	if ch != nil {
+		host = ch.host
+		ch.conn.Close()
 	}
 
 	var newConn *Conn
@@ -364,21 +366,25 @@ func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
 		return
 	}
 
-	oldConn := c.conn.Load().(*Conn)
-	if oldConn != conn {
+	oldConn := c.getConn()
+	if oldConn.conn != conn {
 		return
 	}
 
 	c.reconnect(true)
 }
 
+func (c *controlConn) getConn() *connHost {
+	return c.conn.Load().(*connHost)
+}
+
 func (c *controlConn) writeFrame(w frameWriter) (frame, error) {
-	conn := c.conn.Load().(*Conn)
-	if conn == nil {
+	ch := c.getConn()
+	if ch == nil {
 		return nil, errNoControl
 	}
 
-	framer, err := conn.exec(context.Background(), w, nil)
+	framer, err := ch.conn.exec(context.Background(), w, nil)
 	if err != nil {
 		return nil, err
 	}
@@ -386,13 +392,13 @@ func (c *controlConn) writeFrame(w frameWriter) (frame, error) {
 	return framer.parseFrame()
 }
 
-func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter {
+func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter {
 	const maxConnectAttempts = 5
 	connectAttempts := 0
 
 	for i := 0; i < maxConnectAttempts; i++ {
-		conn := c.conn.Load().(*Conn)
-		if conn == nil {
+		ch := c.getConn()
+		if ch == nil {
 			if connectAttempts > maxConnectAttempts {
 				break
 			}
@@ -403,12 +409,18 @@ func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter {
 			continue
 		}
 
-		return fn(conn)
+		return fn(ch)
 	}
 
 	return &Iter{err: errNoControl}
 }
 
+func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter {
+	return c.withConnHost(func(ch *connHost) *Iter {
+		return fn(ch.conn)
+	})
+}
+
 // query will return nil if the connection is closed or nil
 func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) {
 	q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil)
@@ -437,21 +449,14 @@ func (c *controlConn) awaitSchemaAgreement() error {
 	}).err
 }
 
-func (c *controlConn) GetHostInfo() *HostInfo {
-	conn := c.conn.Load().(*Conn)
-	if conn == nil {
-		return nil
-	}
-	return conn.host
-}
-
 func (c *controlConn) close() {
 	if atomic.CompareAndSwapInt32(&c.started, 1, -1) {
 		c.quit <- struct{}{}
 	}
-	conn := c.conn.Load().(*Conn)
-	if conn != nil {
-		conn.Close()
+
+	ch := c.getConn()
+	if ch != nil {
+		ch.conn.Close()
 	}
 }
 

+ 5 - 5
events.go

@@ -175,14 +175,12 @@ func (s *Session) handleNodeEvent(frames []frame) {
 
 func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
 	// Get host info and apply any filters to the host
-	hostInfo, err := s.hostSource.GetHostInfo(ip, port)
+	hostInfo, err := s.hostSource.getHostInfo(ip, port)
 	if err != nil {
 		Logger.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err)
 		return
-	}
-
-	// If hostInfo is nil, this host was filtered out by cfg.HostFilter
-	if hostInfo == nil {
+	} else if hostInfo == nil {
+		// If hostInfo is nil, this host was filtered out by cfg.HostFilter
 		return
 	}
 
@@ -199,7 +197,9 @@ func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
 	s.pool.addHost(hostInfo)
 	s.policy.AddHost(hostInfo)
 	hostInfo.setState(NodeUp)
+
 	if s.control != nil && !s.cfg.IgnorePeerAddr {
+		// TODO(zariel): debounce ring refresh
 		s.hostSource.refreshRing()
 	}
 }

+ 23 - 10
helpers.go

@@ -205,30 +205,43 @@ func (iter *Iter) RowData() (RowData, error) {
 		return RowData{}, iter.err
 	}
 
-	columns := make([]string, 0)
-	values := make([]interface{}, 0)
+	columns := make([]string, 0, len(iter.Columns()))
+	values := make([]interface{}, 0, len(iter.Columns()))
 
 	for _, column := range iter.Columns() {
-
-		switch c := column.TypeInfo.(type) {
-		case TupleTypeInfo:
+		if c, ok := column.TypeInfo.(TupleTypeInfo); !ok {
+			val := column.TypeInfo.New()
+			columns = append(columns, column.Name)
+			values = append(values, val)
+		} else {
 			for i, elem := range c.Elems {
 				columns = append(columns, TupleColumnName(column.Name, i))
 				values = append(values, elem.New())
 			}
-		default:
-			val := column.TypeInfo.New()
-			columns = append(columns, column.Name)
-			values = append(values, val)
 		}
 	}
+
 	rowData := RowData{
 		Columns: columns,
 		Values:  values,
 	}
+
 	return rowData, nil
 }
 
+// TODO(zariel): is it worth exporting this?
+func (iter *Iter) rowMap() (map[string]interface{}, error) {
+	if iter.err != nil {
+		return nil, iter.err
+	}
+
+	rowData, _ := iter.RowData()
+	iter.Scan(rowData.Values...)
+	m := make(map[string]interface{}, len(rowData.Columns))
+	rowData.rowMap(m)
+	return m, nil
+}
+
 // SliceMap is a helper function to make the API easier to use
 // returns the data from the query in the form of []map[string]interface{}
 func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
@@ -240,7 +253,7 @@ func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
 	rowData, _ := iter.RowData()
 	dataToReturn := make([]map[string]interface{}, 0)
 	for iter.Scan(rowData.Values...) {
-		m := make(map[string]interface{})
+		m := make(map[string]interface{}, len(rowData.Columns))
 		rowData.rowMap(m)
 		dataToReturn = append(dataToReturn, m)
 	}

+ 53 - 126
host_source.go

@@ -366,7 +366,6 @@ type ringDescriber struct {
 	session         *Session
 	mu              sync.Mutex
 	prevHosts       []*HostInfo
-	localHost       *HostInfo
 	prevPartitioner string
 }
 
@@ -388,13 +387,13 @@ func checkSystemSchema(control *controlConn) (bool, error) {
 
 // Given a map that represents a row from either system.local or system.peers
 // return as much information as we can in *HostInfo
-func (r *ringDescriber) hostInfoFromMap(row map[string]interface{}) (*HostInfo, error) {
+func hostInfoFromMap(row map[string]interface{}, defaultPort int) (*HostInfo, error) {
 	const assertErrorMsg = "Assertion failed for %s"
 	var ok bool
 
 	// Default to our connected port if the cluster doesn't have port information
 	host := HostInfo{
-		port: r.session.cfg.Port,
+		port: defaultPort,
 	}
 
 	for key, value := range row {
@@ -489,83 +488,44 @@ func (r *ringDescriber) hostInfoFromMap(row map[string]interface{}) (*HostInfo,
 	return &host, nil
 }
 
-// Ask the control node for it's local host information
-func (r *ringDescriber) GetLocalHostInfo() (*HostInfo, error) {
-	it := r.session.control.query("SELECT * FROM system.local WHERE key='local'")
-	if it == nil {
-		return nil, errors.New("Attempted to query 'system.local' on a closed control connection")
-	}
-	host, err := r.extractHostInfo(it)
-	if err != nil {
-		return nil, err
-	}
-
-	if host.invalidConnectAddr() {
-		host.SetConnectAddress(r.session.control.GetHostInfo().ConnectAddress())
-	}
-
-	return host, nil
-}
-
-// Given an ip address and port, return a peer that matched the ip address
-func (r *ringDescriber) GetPeerHostInfo(ip net.IP, port int) (*HostInfo, error) {
-	it := r.session.control.query("SELECT * FROM system.peers WHERE peer=?", ip)
-	if it == nil {
-		return nil, errors.New("Attempted to query 'system.peers' on a closed control connection")
-	}
-	return r.extractHostInfo(it)
-}
-
-func (r *ringDescriber) extractHostInfo(it *Iter) (*HostInfo, error) {
-	row := make(map[string]interface{})
-
-	// expect only 1 row
-	it.MapScan(row)
-	if err := it.Close(); err != nil {
-		return nil, err
-	}
-
-	// extract all available info about the host
-	return r.hostInfoFromMap(row)
-}
-
 // Ask the control node for host info on all it's known peers
-func (r *ringDescriber) GetClusterPeerInfo() ([]*HostInfo, error) {
+func (r *ringDescriber) getClusterPeerInfo() ([]*HostInfo, error) {
 	var hosts []*HostInfo
+	iter := r.session.control.withConnHost(func(ch *connHost) *Iter {
+		hosts = append(hosts, ch.host)
+		return ch.conn.query("SELECT * FROM system.peers")
+	})
 
-	// Ask the node for a list of it's peers
-	it := r.session.control.query("SELECT * FROM system.peers")
-	if it == nil {
-		return nil, errors.New("Attempted to query 'system.peers' on a closed connection")
+	if iter == nil {
+		return nil, errNoControl
 	}
 
-	for {
-		row := make(map[string]interface{})
-		if !it.MapScan(row) {
-			break
-		}
+	rows, err := iter.SliceMap()
+	if err != nil {
+		// TODO(zariel): make typed error
+		return nil, fmt.Errorf("unable to fetch peer host info: %s", err)
+	}
+
+	for _, row := range rows {
 		// extract all available info about the peer
-		host, err := r.hostInfoFromMap(row)
+		host, err := hostInfoFromMap(row, r.session.cfg.Port)
 		if err != nil {
 			return nil, err
-		}
-
-		// If it's not a valid peer
-		if !r.IsValidPeer(host) {
-			Logger.Printf("Found invalid peer '%+v' "+
+		} else if !isValidPeer(host) {
+			// If it's not a valid peer
+			Logger.Printf("Found invalid peer '%s' "+
 				"Likely due to a gossip or snitch issue, this host will be ignored", host)
 			continue
 		}
+
 		hosts = append(hosts, host)
 	}
-	if it.err != nil {
-		return nil, fmt.Errorf("while scanning 'system.peers' table: %s", it.err)
-	}
+
 	return hosts, nil
 }
 
 // Return true if the host is a valid peer
-func (r *ringDescriber) IsValidPeer(host *HostInfo) bool {
+func isValidPeer(host *HostInfo) bool {
 	return !(len(host.RPCAddress()) == 0 ||
 		host.hostId == "" ||
 		host.dataCenter == "" ||
@@ -578,84 +538,47 @@ func (r *ringDescriber) GetHosts() ([]*HostInfo, string, error) {
 	r.mu.Lock()
 	defer r.mu.Unlock()
 
-	// Update the localHost info with data from the connected host
-	localHost, err := r.GetLocalHostInfo()
+	hosts, err := r.getClusterPeerInfo()
 	if err != nil {
 		return r.prevHosts, r.prevPartitioner, err
-	} else if localHost.invalidConnectAddr() {
-		panic(fmt.Sprintf("unable to get localhost connect address: %v", localHost))
 	}
 
-	// Update our list of hosts by querying the cluster
-	hosts, err := r.GetClusterPeerInfo()
-	if err != nil {
-		return r.prevHosts, r.prevPartitioner, err
-	}
-
-	hosts = append(hosts, localHost)
-
-	// Filter the hosts if filter is provided
-	filteredHosts := hosts
-	if r.session.cfg.HostFilter != nil {
-		filteredHosts = filteredHosts[:0]
-		for _, host := range hosts {
-			if r.session.cfg.HostFilter.Accept(host) {
-				filteredHosts = append(filteredHosts, host)
-			}
-		}
+	var partitioner string
+	if len(hosts) > 0 {
+		partitioner = hosts[0].Partitioner()
 	}
 
-	r.prevHosts = filteredHosts
-	r.prevPartitioner = localHost.partitioner
-	r.localHost = localHost
-
-	return filteredHosts, localHost.partitioner, nil
+	return hosts, partitioner, nil
 }
 
 // Given an ip/port return HostInfo for the specified ip/port
-func (r *ringDescriber) GetHostInfo(ip net.IP, port int) (*HostInfo, error) {
-	// TODO(thrawn01): Is IgnorePeerAddr still useful now that we have DisableInitialHostLookup?
-	// TODO(thrawn01): should we also check for DisableInitialHostLookup and return if true?
-
-	// Ignore the port and connect address and use the address/port we already have
-	if r.session.control == nil || r.session.cfg.IgnorePeerAddr {
-		return &HostInfo{connectAddress: ip, port: port}, nil
-	}
-
-	// Attempt to get the host info for our control connection
-	controlHost := r.session.control.GetHostInfo()
-	if controlHost == nil {
-		return nil, errors.New("invalid control connection")
-	}
-
-	var (
-		host *HostInfo
-		err  error
-	)
+func (r *ringDescriber) getHostInfo(ip net.IP, port int) (*HostInfo, error) {
+	var host *HostInfo
+	iter := r.session.control.withConnHost(func(ch *connHost) *Iter {
+		if ch.host.ConnectAddress().Equal(ip) {
+			host = ch.host
+			return nil
+		}
 
-	// If we are asking about the same node our control connection has a connection too
-	if controlHost.ConnectAddress().Equal(ip) {
-		host, err = r.GetLocalHostInfo()
-	} else {
-		host, err = r.GetPeerHostInfo(ip, port)
-	}
+		return ch.conn.query("SELECT * FROM system.peers WHERE peer=?", ip)
+	})
 
-	// No host was found matching this ip/port
-	if err != nil {
-		return nil, err
-	}
+	if iter != nil {
+		row, err := iter.rowMap()
+		if err != nil {
+			return nil, err
+		}
 
-	if controlHost.ConnectAddress().Equal(ip) {
-		// Always respect the provided control node address and disregard the ip address
-		// the cassandra node provides. We do this as we are already connected and have a
-		// known valid ip address. This insulates gocql from client connection issues stemming
-		// from node misconfiguration. For instance when a node is run from a container, by
-		// default the node will report its ip address as 127.0.0.1 which is typically invalid.
-		host.SetConnectAddress(ip)
+		host, err = hostInfoFromMap(row, port)
+		if err != nil {
+			return nil, err
+		}
+	} else if host == nil {
+		return nil, errors.New("unable to fetch host info: invalid control connection")
 	}
 
 	if host.invalidConnectAddr() {
-		return nil, fmt.Errorf("host ConnectAddress invalid: %v", host)
+		return nil, fmt.Errorf("host ConnectAddress invalid ip=%v: %v", ip, host)
 	}
 
 	return host, nil
@@ -675,6 +598,10 @@ func (r *ringDescriber) refreshRing() error {
 
 	// TODO: move this to session
 	for _, h := range hosts {
+		if filter := r.session.cfg.HostFilter; filter != nil && !filter.Accept(h) {
+			continue
+		}
+
 		if host, ok := r.session.ring.addHostIfMissing(h); !ok {
 			r.session.pool.addHost(h)
 			r.session.policy.AddHost(h)

+ 3 - 30
host_source_test.go

@@ -3,7 +3,6 @@
 package gocql
 
 import (
-	"fmt"
 	"net"
 	"testing"
 )
@@ -50,7 +49,6 @@ func TestCassVersionBefore(t *testing.T) {
 }
 
 func TestIsValidPeer(t *testing.T) {
-	ring := ringDescriber{}
 	host := &HostInfo{
 		rpcAddress: net.ParseIP("0.0.0.0"),
 		rack:       "myRack",
@@ -59,12 +57,12 @@ func TestIsValidPeer(t *testing.T) {
 		tokens:     []string{"0", "1"},
 	}
 
-	if !ring.IsValidPeer(host) {
+	if !isValidPeer(host) {
 		t.Errorf("expected %+v to be a valid peer", host)
 	}
 
 	host.rack = ""
-	if ring.IsValidPeer(host) {
+	if isValidPeer(host) {
 		t.Errorf("expected %+v to NOT be a valid peer", host)
 	}
 }
@@ -76,33 +74,8 @@ func TestGetHosts(t *testing.T) {
 	hosts, partitioner, err := session.hostSource.GetHosts()
 
 	assertTrue(t, "err == nil", err == nil)
-	assertTrue(t, "len(hosts) == 3", len(hosts) == 3)
+	assertEqual(t, "len(hosts)", len(clusterHosts), len(hosts))
 	assertTrue(t, "len(partitioner) != 0", len(partitioner) != 0)
-
-}
-
-func TestGetHostsWithFilter(t *testing.T) {
-	filterHostIP := net.ParseIP("127.0.0.3")
-	cluster := createCluster()
-
-	// Filter to remove one of the localhost nodes
-	cluster.HostFilter = HostFilterFunc(func(host *HostInfo) bool {
-		if host.ConnectAddress().Equal(filterHostIP) {
-			return false
-		}
-		return true
-	})
-	session := createSessionFromCluster(cluster, t)
-
-	hosts, partitioner, err := session.hostSource.GetHosts()
-	assertTrue(t, "err == nil", err == nil)
-	assertTrue(t, "len(hosts) == 2", len(hosts) == 2)
-	assertTrue(t, "len(partitioner) != 0", len(partitioner) != 0)
-	for _, host := range hosts {
-		if host.ConnectAddress().Equal(filterHostIP) {
-			t.Fatal(fmt.Sprintf("Did not expect to see '%q' in host list", filterHostIP))
-		}
-	}
 }
 
 func TestHostInfo_ConnectAddress(t *testing.T) {

+ 10 - 4
query_executor.go

@@ -17,6 +17,15 @@ type queryExecutor struct {
 	policy HostSelectionPolicy
 }
 
+func (q *queryExecutor) attemptQuery(qry ExecutableQuery, conn *Conn) *Iter {
+	start := time.Now()
+	iter := qry.execute(conn)
+
+	qry.attempt(time.Since(start))
+
+	return iter
+}
+
 func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
 	rt := qry.retryPolicy()
 	hostIter := q.policy.Pick(qry)
@@ -38,10 +47,7 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
 			continue
 		}
 
-		start := time.Now()
-		iter = qry.execute(conn)
-
-		qry.attempt(time.Since(start))
+		iter = q.attemptQuery(qry, conn)
 
 		// Update host
 		hostResponse.Mark(iter.err)

+ 2 - 0
ring.go

@@ -64,6 +64,8 @@ func (r *ring) currentHosts() map[string]*HostInfo {
 }
 
 func (r *ring) addHost(host *HostInfo) bool {
+	// TODO(zariel): key all host info by HostID instead of
+	// ip addresses
 	if host.invalidConnectAddr() {
 		panic(fmt.Sprintf("invalid host: %v", host))
 	}

+ 10 - 4
session.go

@@ -639,7 +639,7 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{})
 }
 
 func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
-	return Connect(host, s.connCfg, errorHandler, s)
+	return s.dial(host.ConnectAddress(), host.Port(), s.connCfg, errorHandler)
 }
 
 // Query represents a CQL statement that can be executed.
@@ -1057,14 +1057,14 @@ type Scanner interface {
 	// scanned into with Scan.
 	// Next must be called before every call to Scan.
 	Next() bool
-	
+
 	// Scan copies the current row's columns into dest. If the length of dest does not equal
 	// the number of columns returned in the row an error is returned. If an error is encountered
 	// when unmarshalling a column into the value in dest an error is returned and the row is invalidated
 	// until the next call to Next.
 	// Next must be called before calling Scan, if it is not an error is returned.
 	Scan(...interface{}) error
-	
+
 	// Err returns the if there was one during iteration that resulted in iteration being unable to complete.
 	// Err will also release resources held by the iterator, the Scanner should not used after being called.
 	Err() error
@@ -1301,11 +1301,17 @@ type nextIter struct {
 	pos  int
 	once sync.Once
 	next *Iter
+	conn *Conn
 }
 
 func (n *nextIter) fetch() *Iter {
 	n.once.Do(func() {
-		n.next = n.qry.session.executeQuery(&n.qry)
+		iter := n.qry.session.executor.attemptQuery(&n.qry, n.conn)
+		if iter != nil && iter.err == nil {
+			n.next = iter
+		} else {
+			n.next = n.qry.session.executeQuery(&n.qry)
+		}
 	})
 	return n.next
 }

+ 4 - 4
session_connect_test.go

@@ -102,10 +102,10 @@ func TestSession_connect_WithNoTranslator(t *testing.T) {
 
 	go srvr.Serve()
 
-	Connect(&HostInfo{
+	session.connect(&HostInfo{
 		connectAddress: srvr.Addr,
 		port:           srvr.Port,
-	}, session.connCfg, testConnErrorHandler(t), session)
+	}, testConnErrorHandler(t))
 
 	assertConnectionEventually(t, 500*time.Millisecond, srvr)
 }
@@ -122,10 +122,10 @@ func TestSession_connect_WithTranslator(t *testing.T) {
 	go srvr.Serve()
 
 	// the provided address will be translated
-	Connect(&HostInfo{
+	session.connect(&HostInfo{
 		connectAddress: net.ParseIP("10.10.10.10"),
 		port:           5432,
-	}, session.connCfg, testConnErrorHandler(t), session)
+	}, testConnErrorHandler(t))
 
 	assertConnectionEventually(t, 500*time.Millisecond, srvr)
 }