Ver código fonte

Merge pull request #989 from Zariel/refactor-host-discovery

refactor host discovery
Chris Bannister 8 anos atrás
pai
commit
843f6b1d22
13 arquivos alterados com 253 adições e 332 exclusões
  1. 21 39
      cassandra_test.go
  2. 0 1
      common_test.go
  3. 41 27
      conn.go
  4. 32 38
      conn_test.go
  5. 49 44
      control.go
  6. 5 5
      events.go
  7. 23 10
      helpers.go
  8. 53 126
      host_source.go
  9. 3 30
      host_source_test.go
  10. 10 4
      query_executor.go
  11. 2 0
      ring.go
  12. 10 4
      session.go
  13. 4 4
      session_connect_test.go

+ 21 - 39
cassandra_test.go

@@ -2362,31 +2362,26 @@ func TestDiscoverViaProxy(t *testing.T) {
 	if err != nil {
 		t.Fatalf("unable to create proxy listener: %v", err)
 	}
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
 
 	var (
-		wg         sync.WaitGroup
 		mu         sync.Mutex
 		proxyConns []net.Conn
 		closed     bool
 	)
 
-	go func(wg *sync.WaitGroup) {
+	go func() {
 		cassandraAddr := JoinHostPort(clusterHosts[0], 9042)
 
 		cassandra := func() (net.Conn, error) {
 			return net.Dial("tcp", cassandraAddr)
 		}
 
-		proxyFn := func(wg *sync.WaitGroup, from, to net.Conn) {
-			defer wg.Done()
-
+		proxyFn := func(errs chan error, from, to net.Conn) {
 			_, err := io.Copy(to, from)
 			if err != nil {
-				mu.Lock()
-				if !closed {
-					t.Error(err)
-				}
-				mu.Unlock()
+				errs <- err
 			}
 		}
 
@@ -2394,29 +2389,22 @@ func TestDiscoverViaProxy(t *testing.T) {
 		// for both the read and write side of the TCP connection to close before
 		// returning.
 		handle := func(conn net.Conn) error {
-			defer conn.Close()
-
 			cass, err := cassandra()
 			if err != nil {
 				return err
 			}
-
-			mu.Lock()
-			proxyConns = append(proxyConns, cass)
-			mu.Unlock()
-
 			defer cass.Close()
 
-			var wg sync.WaitGroup
-			wg.Add(1)
-			go proxyFn(&wg, conn, cass)
-
-			wg.Add(1)
-			go proxyFn(&wg, cass, conn)
+			errs := make(chan error, 2)
+			go proxyFn(errs, conn, cass)
+			go proxyFn(errs, cass, conn)
 
-			wg.Wait()
-
-			return nil
+			select {
+			case <-ctx.Done():
+				return ctx.Err()
+			case err := <-errs:
+				return err
+			}
 		}
 
 		for {
@@ -2436,19 +2424,19 @@ func TestDiscoverViaProxy(t *testing.T) {
 			proxyConns = append(proxyConns, conn)
 			mu.Unlock()
 
-			wg.Add(1)
 			go func(conn net.Conn) {
-				defer wg.Done()
+				defer conn.Close()
 
 				if err := handle(conn); err != nil {
-					t.Error(err)
-					return
+					mu.Lock()
+					if !closed {
+						t.Error(err)
+					}
+					mu.Unlock()
 				}
 			}(conn)
 		}
-	}(&wg)
-
-	defer wg.Wait()
+	}()
 
 	proxyAddr := proxy.Addr().String()
 
@@ -2460,11 +2448,6 @@ func TestDiscoverViaProxy(t *testing.T) {
 	session := createSessionFromCluster(cluster, t)
 	defer session.Close()
 
-	if session.hostSource.localHost.BroadcastAddress() == nil {
-		t.Skip("Target cluster does not have broadcast_address in system.local.")
-		goto close
-	}
-
 	// we shouldnt need this but to be safe
 	time.Sleep(1 * time.Second)
 
@@ -2476,7 +2459,6 @@ func TestDiscoverViaProxy(t *testing.T) {
 	}
 	session.pool.mu.RUnlock()
 
-close:
 	mu.Lock()
 	closed = true
 	if err := proxy.Close(); err != nil {

+ 0 - 1
common_test.go

@@ -102,7 +102,6 @@ func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
 		panic(err)
 	}
 	defer session.Close()
-	defer tb.Log("closing keyspace session")
 
 	err = createTable(session, `DROP KEYSPACE IF EXISTS `+keyspace)
 	if err != nil {

+ 41 - 27
conn.go

@@ -141,8 +141,6 @@ type Conn struct {
 	version         uint8
 	currentKeyspace string
 
-	host *HostInfo
-
 	session *Session
 
 	closed int32
@@ -152,14 +150,12 @@ type Conn struct {
 }
 
 // Connect establishes a connection to a Cassandra node.
-func Connect(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
+func (s *Session) dial(ip net.IP, port int, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
 	// TODO(zariel): remove these
-	if host == nil {
-		panic("host is nil")
-	} else if len(host.ConnectAddress()) == 0 {
-		panic(fmt.Sprintf("host missing connect ip address: %v", host))
-	} else if host.Port() == 0 {
-		panic(fmt.Sprintf("host missing port: %v", host))
+	if len(ip) == 0 || ip.IsUnspecified() {
+		panic(fmt.Sprintf("host missing connect ip address: %v", ip))
+	} else if port == 0 {
+		panic(fmt.Sprintf("host missing port: %v", port))
 	}
 
 	var (
@@ -172,9 +168,8 @@ func Connect(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler, ses
 	}
 
 	// TODO(zariel): handle ipv6 zone
-	translatedPeer, translatedPort := session.cfg.translateAddressPort(host.ConnectAddress(), host.Port())
+	translatedPeer, translatedPort := s.cfg.translateAddressPort(ip, port)
 	addr := (&net.TCPAddr{IP: translatedPeer, Port: translatedPort}).String()
-	//addr := (&net.TCPAddr{IP: host.Peer(), Port: host.Port()}).String()
 
 	if cfg.tlsConfig != nil {
 		// the TLS config is safe to be reused by connections but it must not
@@ -200,9 +195,8 @@ func Connect(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler, ses
 		compressor:   cfg.Compressor,
 		auth:         cfg.Authenticator,
 		quit:         make(chan struct{}),
-		session:      session,
+		session:      s,
 		streams:      streams.New(cfg.ProtoVersion),
-		host:         host,
 	}
 
 	if cfg.Keepalive > 0 {
@@ -405,13 +399,20 @@ func (c *Conn) closeWithError(err error) {
 
 	// if error was nil then unblock the quit channel
 	close(c.quit)
-	c.conn.Close()
+	cerr := c.close()
 
 	if err != nil {
 		c.errorHandler.HandleError(c, err, true)
+	} else if cerr != nil {
+		// TODO(zariel): is it a good idea to do this?
+		c.errorHandler.HandleError(c, cerr, true)
 	}
 }
 
+func (c *Conn) close() error {
+	return c.conn.Close()
+}
+
 func (c *Conn) Close() {
 	c.closeWithError(nil)
 }
@@ -420,15 +421,9 @@ func (c *Conn) Close() {
 // to execute any queries. This method runs as long as the connection is
 // open and is therefore usually called in a separate goroutine.
 func (c *Conn) serve() {
-	var (
-		err error
-	)
-
-	for {
+	var err error
+	for err == nil {
 		err = c.recv()
-		if err != nil {
-			break
-		}
 	}
 
 	c.closeWithError(err)
@@ -887,8 +882,9 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 
 		if len(x.meta.pagingState) > 0 && !qry.disableAutoPage {
 			iter.next = &nextIter{
-				qry: *qry,
-				pos: int((1 - qry.prefetch) * float64(x.numRows)),
+				qry:  *qry,
+				pos:  int((1 - qry.prefetch) * float64(x.numRows)),
+				conn: c,
 			}
 
 			iter.next.qry.pageState = copyBytes(x.meta.pagingState)
@@ -1100,7 +1096,7 @@ func (c *Conn) query(statement string, values ...interface{}) (iter *Iter) {
 
 func (c *Conn) awaitSchemaAgreement() (err error) {
 	const (
-		peerSchemas  = "SELECT schema_version FROM system.peers"
+		peerSchemas  = "SELECT schema_version, peer FROM system.peers"
 		localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"
 	)
 
@@ -1113,9 +1109,10 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 		versions = make(map[string]struct{})
 
 		var schemaVersion string
-		for iter.Scan(&schemaVersion) {
+		var peer string
+		for iter.Scan(&schemaVersion, &peer) {
 			if schemaVersion == "" {
-				Logger.Println("skipping peer entry with empty schema_version")
+				Logger.Printf("skipping peer entry with empty schema_version: peer=%q", peer)
 				continue
 			}
 
@@ -1158,6 +1155,23 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 	return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas)
 }
 
+const localHostInfo = "SELECT * FROM system.local WHERE key='local'"
+
+func (c *Conn) localHostInfo() (*HostInfo, error) {
+	row, err := c.query(localHostInfo).rowMap()
+	if err != nil {
+		return nil, err
+	}
+
+	// TODO(zariel): avoid doing this here
+	host, err := hostInfoFromMap(row, c.session.cfg.Port)
+	if err != nil {
+		return nil, err
+	}
+
+	return c.session.ring.addOrUpdate(host), nil
+}
+
 var (
 	ErrQueryArgLength    = errors.New("gocql: query argument length mismatch")
 	ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period")

+ 32 - 38
conn_test.go

@@ -574,7 +574,13 @@ func TestStream0(t *testing.T) {
 		}
 	})
 
-	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, createTestSession())
+	s, err := srv.session()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer s.Close()
+
+	conn, err := s.connect(srv.host(), errorHandler)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -609,7 +615,13 @@ func TestConnClosedBlocked(t *testing.T) {
 		t.Log(err)
 	})
 
-	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, createTestSession())
+	s, err := srv.session()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer s.Close()
+
+	conn, err := s.connect(srv.host(), errorHandler)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -737,6 +749,10 @@ type TestServer struct {
 	closed bool
 }
 
+func (srv *TestServer) session() (*Session, error) {
+	return testCluster(srv.Address, protoVersion(srv.protocol)).CreateSession()
+}
+
 func (srv *TestServer) host() *HostInfo {
 	host, err := hostInfo(srv.Address, 9042)
 	if err != nil {
@@ -756,13 +772,7 @@ func (srv *TestServer) closeWatch() {
 
 func (srv *TestServer) serve() {
 	defer srv.listen.Close()
-	for {
-		select {
-		case <-srv.ctx.Done():
-			return
-		default:
-		}
-
+	for !srv.isClosed() {
 		conn, err := srv.listen.Accept()
 		if err != nil {
 			break
@@ -770,26 +780,13 @@ func (srv *TestServer) serve() {
 
 		go func(conn net.Conn) {
 			defer conn.Close()
-			for {
-				select {
-				case <-srv.ctx.Done():
-					return
-				default:
-				}
-
+			for !srv.isClosed() {
 				framer, err := srv.readFrame(conn)
 				if err != nil {
 					if err == io.EOF {
 						return
 					}
-
-					select {
-					case <-srv.ctx.Done():
-						return
-					default:
-					}
-
-					srv.t.Error(err)
+					srv.errorLocked(err)
 					return
 				}
 
@@ -824,16 +821,19 @@ func (srv *TestServer) Stop() {
 	srv.closeLocked()
 }
 
+func (srv *TestServer) errorLocked(err interface{}) {
+	srv.mu.Lock()
+	defer srv.mu.Unlock()
+	if srv.closed {
+		return
+	}
+	srv.t.Error(err)
+}
+
 func (srv *TestServer) process(f *framer) {
 	head := f.header
 	if head == nil {
-		select {
-		case <-srv.ctx.Done():
-			return
-		default:
-		}
-
-		srv.t.Error("process frame with a nil header")
+		srv.errorLocked("process frame with a nil header")
 		return
 	}
 
@@ -901,13 +901,7 @@ func (srv *TestServer) process(f *framer) {
 	f.wbuf[0] = srv.protocol | 0x80
 
 	if err := f.finishWrite(); err != nil {
-		select {
-		case <-srv.ctx.Done():
-			return
-		default:
-		}
-
-		srv.t.Error(err)
+		srv.errorLocked(err)
 	}
 }
 

+ 49 - 44
control.go

@@ -47,7 +47,7 @@ func createControlConn(session *Session) *controlConn {
 		retry:   &SimpleRetryPolicy{NumRetries: 3},
 	}
 
-	control.conn.Store((*Conn)(nil))
+	control.conn.Store((*connHost)(nil))
 
 	return control
 }
@@ -197,14 +197,20 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
 	handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
 		// we should never get here, but if we do it means we connected to a
 		// host successfully which means our attempted protocol version worked
+		if !closed {
+			c.Close()
+		}
 	})
 
 	var err error
 	for _, host := range hosts {
 		var conn *Conn
-		conn, err = Connect(host, &connCfg, handler, c.session)
-		if err == nil {
+		conn, err = c.session.dial(host.ConnectAddress(), host.Port(), &connCfg, handler)
+		if conn != nil {
 			conn.Close()
+		}
+
+		if err == nil {
 			return connCfg.ProtoVersion, nil
 		}
 
@@ -239,35 +245,31 @@ func (c *controlConn) connect(hosts []*HostInfo) error {
 	return nil
 }
 
+type connHost struct {
+	conn *Conn
+	host *HostInfo
+}
+
 func (c *controlConn) setupConn(conn *Conn) error {
 	if err := c.registerEvents(conn); err != nil {
 		conn.Close()
 		return err
 	}
 
-	c.conn.Store(conn)
-
-	if v, ok := conn.conn.RemoteAddr().(*net.TCPAddr); ok {
-		c.session.handleNodeUp(copyBytes(v.IP), v.Port, false)
-		return nil
-	}
-
-	host, portstr, err := net.SplitHostPort(conn.conn.RemoteAddr().String())
-	if err != nil {
-		return err
-	}
-
-	port, err := strconv.Atoi(portstr)
+	// TODO(zariel): do we need to fetch host info everytime
+	// the control conn connects? Surely we have it cached?
+	host, err := conn.localHostInfo()
 	if err != nil {
 		return err
 	}
 
-	ip := net.ParseIP(host)
-	if ip == nil {
-		return fmt.Errorf("invalid remote addr: addr=%v host=%q", conn.conn.RemoteAddr(), host)
+	ch := &connHost{
+		conn: conn,
+		host: host,
 	}
 
-	c.session.handleNodeUp(ip, port, false)
+	c.conn.Store(ch)
+	// c.session.handleNodeUp(host.ConnectAddress(), host.Port(), false)
 
 	return nil
 }
@@ -312,10 +314,10 @@ func (c *controlConn) reconnect(refreshring bool) {
 	// connection pool
 
 	var host *HostInfo
-	oldConn := c.conn.Load().(*Conn)
-	if oldConn != nil {
-		host = oldConn.host
-		oldConn.Close()
+	ch := c.getConn()
+	if ch != nil {
+		host = ch.host
+		ch.conn.Close()
 	}
 
 	var newConn *Conn
@@ -364,21 +366,25 @@ func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
 		return
 	}
 
-	oldConn := c.conn.Load().(*Conn)
-	if oldConn != conn {
+	oldConn := c.getConn()
+	if oldConn.conn != conn {
 		return
 	}
 
 	c.reconnect(true)
 }
 
+func (c *controlConn) getConn() *connHost {
+	return c.conn.Load().(*connHost)
+}
+
 func (c *controlConn) writeFrame(w frameWriter) (frame, error) {
-	conn := c.conn.Load().(*Conn)
-	if conn == nil {
+	ch := c.getConn()
+	if ch == nil {
 		return nil, errNoControl
 	}
 
-	framer, err := conn.exec(context.Background(), w, nil)
+	framer, err := ch.conn.exec(context.Background(), w, nil)
 	if err != nil {
 		return nil, err
 	}
@@ -386,13 +392,13 @@ func (c *controlConn) writeFrame(w frameWriter) (frame, error) {
 	return framer.parseFrame()
 }
 
-func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter {
+func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter {
 	const maxConnectAttempts = 5
 	connectAttempts := 0
 
 	for i := 0; i < maxConnectAttempts; i++ {
-		conn := c.conn.Load().(*Conn)
-		if conn == nil {
+		ch := c.getConn()
+		if ch == nil {
 			if connectAttempts > maxConnectAttempts {
 				break
 			}
@@ -403,12 +409,18 @@ func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter {
 			continue
 		}
 
-		return fn(conn)
+		return fn(ch)
 	}
 
 	return &Iter{err: errNoControl}
 }
 
+func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter {
+	return c.withConnHost(func(ch *connHost) *Iter {
+		return fn(ch.conn)
+	})
+}
+
 // query will return nil if the connection is closed or nil
 func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) {
 	q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil)
@@ -437,21 +449,14 @@ func (c *controlConn) awaitSchemaAgreement() error {
 	}).err
 }
 
-func (c *controlConn) GetHostInfo() *HostInfo {
-	conn := c.conn.Load().(*Conn)
-	if conn == nil {
-		return nil
-	}
-	return conn.host
-}
-
 func (c *controlConn) close() {
 	if atomic.CompareAndSwapInt32(&c.started, 1, -1) {
 		c.quit <- struct{}{}
 	}
-	conn := c.conn.Load().(*Conn)
-	if conn != nil {
-		conn.Close()
+
+	ch := c.getConn()
+	if ch != nil {
+		ch.conn.Close()
 	}
 }
 

+ 5 - 5
events.go

@@ -175,14 +175,12 @@ func (s *Session) handleNodeEvent(frames []frame) {
 
 func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
 	// Get host info and apply any filters to the host
-	hostInfo, err := s.hostSource.GetHostInfo(ip, port)
+	hostInfo, err := s.hostSource.getHostInfo(ip, port)
 	if err != nil {
 		Logger.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err)
 		return
-	}
-
-	// If hostInfo is nil, this host was filtered out by cfg.HostFilter
-	if hostInfo == nil {
+	} else if hostInfo == nil {
+		// If hostInfo is nil, this host was filtered out by cfg.HostFilter
 		return
 	}
 
@@ -199,7 +197,9 @@ func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
 	s.pool.addHost(hostInfo)
 	s.policy.AddHost(hostInfo)
 	hostInfo.setState(NodeUp)
+
 	if s.control != nil && !s.cfg.IgnorePeerAddr {
+		// TODO(zariel): debounce ring refresh
 		s.hostSource.refreshRing()
 	}
 }

+ 23 - 10
helpers.go

@@ -205,30 +205,43 @@ func (iter *Iter) RowData() (RowData, error) {
 		return RowData{}, iter.err
 	}
 
-	columns := make([]string, 0)
-	values := make([]interface{}, 0)
+	columns := make([]string, 0, len(iter.Columns()))
+	values := make([]interface{}, 0, len(iter.Columns()))
 
 	for _, column := range iter.Columns() {
-
-		switch c := column.TypeInfo.(type) {
-		case TupleTypeInfo:
+		if c, ok := column.TypeInfo.(TupleTypeInfo); !ok {
+			val := column.TypeInfo.New()
+			columns = append(columns, column.Name)
+			values = append(values, val)
+		} else {
 			for i, elem := range c.Elems {
 				columns = append(columns, TupleColumnName(column.Name, i))
 				values = append(values, elem.New())
 			}
-		default:
-			val := column.TypeInfo.New()
-			columns = append(columns, column.Name)
-			values = append(values, val)
 		}
 	}
+
 	rowData := RowData{
 		Columns: columns,
 		Values:  values,
 	}
+
 	return rowData, nil
 }
 
+// TODO(zariel): is it worth exporting this?
+func (iter *Iter) rowMap() (map[string]interface{}, error) {
+	if iter.err != nil {
+		return nil, iter.err
+	}
+
+	rowData, _ := iter.RowData()
+	iter.Scan(rowData.Values...)
+	m := make(map[string]interface{}, len(rowData.Columns))
+	rowData.rowMap(m)
+	return m, nil
+}
+
 // SliceMap is a helper function to make the API easier to use
 // returns the data from the query in the form of []map[string]interface{}
 func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
@@ -240,7 +253,7 @@ func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
 	rowData, _ := iter.RowData()
 	dataToReturn := make([]map[string]interface{}, 0)
 	for iter.Scan(rowData.Values...) {
-		m := make(map[string]interface{})
+		m := make(map[string]interface{}, len(rowData.Columns))
 		rowData.rowMap(m)
 		dataToReturn = append(dataToReturn, m)
 	}

+ 53 - 126
host_source.go

@@ -366,7 +366,6 @@ type ringDescriber struct {
 	session         *Session
 	mu              sync.Mutex
 	prevHosts       []*HostInfo
-	localHost       *HostInfo
 	prevPartitioner string
 }
 
@@ -388,13 +387,13 @@ func checkSystemSchema(control *controlConn) (bool, error) {
 
 // Given a map that represents a row from either system.local or system.peers
 // return as much information as we can in *HostInfo
-func (r *ringDescriber) hostInfoFromMap(row map[string]interface{}) (*HostInfo, error) {
+func hostInfoFromMap(row map[string]interface{}, defaultPort int) (*HostInfo, error) {
 	const assertErrorMsg = "Assertion failed for %s"
 	var ok bool
 
 	// Default to our connected port if the cluster doesn't have port information
 	host := HostInfo{
-		port: r.session.cfg.Port,
+		port: defaultPort,
 	}
 
 	for key, value := range row {
@@ -489,83 +488,44 @@ func (r *ringDescriber) hostInfoFromMap(row map[string]interface{}) (*HostInfo,
 	return &host, nil
 }
 
-// Ask the control node for it's local host information
-func (r *ringDescriber) GetLocalHostInfo() (*HostInfo, error) {
-	it := r.session.control.query("SELECT * FROM system.local WHERE key='local'")
-	if it == nil {
-		return nil, errors.New("Attempted to query 'system.local' on a closed control connection")
-	}
-	host, err := r.extractHostInfo(it)
-	if err != nil {
-		return nil, err
-	}
-
-	if host.invalidConnectAddr() {
-		host.SetConnectAddress(r.session.control.GetHostInfo().ConnectAddress())
-	}
-
-	return host, nil
-}
-
-// Given an ip address and port, return a peer that matched the ip address
-func (r *ringDescriber) GetPeerHostInfo(ip net.IP, port int) (*HostInfo, error) {
-	it := r.session.control.query("SELECT * FROM system.peers WHERE peer=?", ip)
-	if it == nil {
-		return nil, errors.New("Attempted to query 'system.peers' on a closed control connection")
-	}
-	return r.extractHostInfo(it)
-}
-
-func (r *ringDescriber) extractHostInfo(it *Iter) (*HostInfo, error) {
-	row := make(map[string]interface{})
-
-	// expect only 1 row
-	it.MapScan(row)
-	if err := it.Close(); err != nil {
-		return nil, err
-	}
-
-	// extract all available info about the host
-	return r.hostInfoFromMap(row)
-}
-
 // Ask the control node for host info on all it's known peers
-func (r *ringDescriber) GetClusterPeerInfo() ([]*HostInfo, error) {
+func (r *ringDescriber) getClusterPeerInfo() ([]*HostInfo, error) {
 	var hosts []*HostInfo
+	iter := r.session.control.withConnHost(func(ch *connHost) *Iter {
+		hosts = append(hosts, ch.host)
+		return ch.conn.query("SELECT * FROM system.peers")
+	})
 
-	// Ask the node for a list of it's peers
-	it := r.session.control.query("SELECT * FROM system.peers")
-	if it == nil {
-		return nil, errors.New("Attempted to query 'system.peers' on a closed connection")
+	if iter == nil {
+		return nil, errNoControl
 	}
 
-	for {
-		row := make(map[string]interface{})
-		if !it.MapScan(row) {
-			break
-		}
+	rows, err := iter.SliceMap()
+	if err != nil {
+		// TODO(zariel): make typed error
+		return nil, fmt.Errorf("unable to fetch peer host info: %s", err)
+	}
+
+	for _, row := range rows {
 		// extract all available info about the peer
-		host, err := r.hostInfoFromMap(row)
+		host, err := hostInfoFromMap(row, r.session.cfg.Port)
 		if err != nil {
 			return nil, err
-		}
-
-		// If it's not a valid peer
-		if !r.IsValidPeer(host) {
-			Logger.Printf("Found invalid peer '%+v' "+
+		} else if !isValidPeer(host) {
+			// If it's not a valid peer
+			Logger.Printf("Found invalid peer '%s' "+
 				"Likely due to a gossip or snitch issue, this host will be ignored", host)
 			continue
 		}
+
 		hosts = append(hosts, host)
 	}
-	if it.err != nil {
-		return nil, fmt.Errorf("while scanning 'system.peers' table: %s", it.err)
-	}
+
 	return hosts, nil
 }
 
 // Return true if the host is a valid peer
-func (r *ringDescriber) IsValidPeer(host *HostInfo) bool {
+func isValidPeer(host *HostInfo) bool {
 	return !(len(host.RPCAddress()) == 0 ||
 		host.hostId == "" ||
 		host.dataCenter == "" ||
@@ -578,84 +538,47 @@ func (r *ringDescriber) GetHosts() ([]*HostInfo, string, error) {
 	r.mu.Lock()
 	defer r.mu.Unlock()
 
-	// Update the localHost info with data from the connected host
-	localHost, err := r.GetLocalHostInfo()
+	hosts, err := r.getClusterPeerInfo()
 	if err != nil {
 		return r.prevHosts, r.prevPartitioner, err
-	} else if localHost.invalidConnectAddr() {
-		panic(fmt.Sprintf("unable to get localhost connect address: %v", localHost))
 	}
 
-	// Update our list of hosts by querying the cluster
-	hosts, err := r.GetClusterPeerInfo()
-	if err != nil {
-		return r.prevHosts, r.prevPartitioner, err
-	}
-
-	hosts = append(hosts, localHost)
-
-	// Filter the hosts if filter is provided
-	filteredHosts := hosts
-	if r.session.cfg.HostFilter != nil {
-		filteredHosts = filteredHosts[:0]
-		for _, host := range hosts {
-			if r.session.cfg.HostFilter.Accept(host) {
-				filteredHosts = append(filteredHosts, host)
-			}
-		}
+	var partitioner string
+	if len(hosts) > 0 {
+		partitioner = hosts[0].Partitioner()
 	}
 
-	r.prevHosts = filteredHosts
-	r.prevPartitioner = localHost.partitioner
-	r.localHost = localHost
-
-	return filteredHosts, localHost.partitioner, nil
+	return hosts, partitioner, nil
 }
 
 // Given an ip/port return HostInfo for the specified ip/port
-func (r *ringDescriber) GetHostInfo(ip net.IP, port int) (*HostInfo, error) {
-	// TODO(thrawn01): Is IgnorePeerAddr still useful now that we have DisableInitialHostLookup?
-	// TODO(thrawn01): should we also check for DisableInitialHostLookup and return if true?
-
-	// Ignore the port and connect address and use the address/port we already have
-	if r.session.control == nil || r.session.cfg.IgnorePeerAddr {
-		return &HostInfo{connectAddress: ip, port: port}, nil
-	}
-
-	// Attempt to get the host info for our control connection
-	controlHost := r.session.control.GetHostInfo()
-	if controlHost == nil {
-		return nil, errors.New("invalid control connection")
-	}
-
-	var (
-		host *HostInfo
-		err  error
-	)
+func (r *ringDescriber) getHostInfo(ip net.IP, port int) (*HostInfo, error) {
+	var host *HostInfo
+	iter := r.session.control.withConnHost(func(ch *connHost) *Iter {
+		if ch.host.ConnectAddress().Equal(ip) {
+			host = ch.host
+			return nil
+		}
 
-	// If we are asking about the same node our control connection has a connection too
-	if controlHost.ConnectAddress().Equal(ip) {
-		host, err = r.GetLocalHostInfo()
-	} else {
-		host, err = r.GetPeerHostInfo(ip, port)
-	}
+		return ch.conn.query("SELECT * FROM system.peers WHERE peer=?", ip)
+	})
 
-	// No host was found matching this ip/port
-	if err != nil {
-		return nil, err
-	}
+	if iter != nil {
+		row, err := iter.rowMap()
+		if err != nil {
+			return nil, err
+		}
 
-	if controlHost.ConnectAddress().Equal(ip) {
-		// Always respect the provided control node address and disregard the ip address
-		// the cassandra node provides. We do this as we are already connected and have a
-		// known valid ip address. This insulates gocql from client connection issues stemming
-		// from node misconfiguration. For instance when a node is run from a container, by
-		// default the node will report its ip address as 127.0.0.1 which is typically invalid.
-		host.SetConnectAddress(ip)
+		host, err = hostInfoFromMap(row, port)
+		if err != nil {
+			return nil, err
+		}
+	} else if host == nil {
+		return nil, errors.New("unable to fetch host info: invalid control connection")
 	}
 
 	if host.invalidConnectAddr() {
-		return nil, fmt.Errorf("host ConnectAddress invalid: %v", host)
+		return nil, fmt.Errorf("host ConnectAddress invalid ip=%v: %v", ip, host)
 	}
 
 	return host, nil
@@ -675,6 +598,10 @@ func (r *ringDescriber) refreshRing() error {
 
 	// TODO: move this to session
 	for _, h := range hosts {
+		if filter := r.session.cfg.HostFilter; filter != nil && !filter.Accept(h) {
+			continue
+		}
+
 		if host, ok := r.session.ring.addHostIfMissing(h); !ok {
 			r.session.pool.addHost(h)
 			r.session.policy.AddHost(h)

+ 3 - 30
host_source_test.go

@@ -3,7 +3,6 @@
 package gocql
 
 import (
-	"fmt"
 	"net"
 	"testing"
 )
@@ -50,7 +49,6 @@ func TestCassVersionBefore(t *testing.T) {
 }
 
 func TestIsValidPeer(t *testing.T) {
-	ring := ringDescriber{}
 	host := &HostInfo{
 		rpcAddress: net.ParseIP("0.0.0.0"),
 		rack:       "myRack",
@@ -59,12 +57,12 @@ func TestIsValidPeer(t *testing.T) {
 		tokens:     []string{"0", "1"},
 	}
 
-	if !ring.IsValidPeer(host) {
+	if !isValidPeer(host) {
 		t.Errorf("expected %+v to be a valid peer", host)
 	}
 
 	host.rack = ""
-	if ring.IsValidPeer(host) {
+	if isValidPeer(host) {
 		t.Errorf("expected %+v to NOT be a valid peer", host)
 	}
 }
@@ -76,33 +74,8 @@ func TestGetHosts(t *testing.T) {
 	hosts, partitioner, err := session.hostSource.GetHosts()
 
 	assertTrue(t, "err == nil", err == nil)
-	assertTrue(t, "len(hosts) == 3", len(hosts) == 3)
+	assertEqual(t, "len(hosts)", len(clusterHosts), len(hosts))
 	assertTrue(t, "len(partitioner) != 0", len(partitioner) != 0)
-
-}
-
-func TestGetHostsWithFilter(t *testing.T) {
-	filterHostIP := net.ParseIP("127.0.0.3")
-	cluster := createCluster()
-
-	// Filter to remove one of the localhost nodes
-	cluster.HostFilter = HostFilterFunc(func(host *HostInfo) bool {
-		if host.ConnectAddress().Equal(filterHostIP) {
-			return false
-		}
-		return true
-	})
-	session := createSessionFromCluster(cluster, t)
-
-	hosts, partitioner, err := session.hostSource.GetHosts()
-	assertTrue(t, "err == nil", err == nil)
-	assertTrue(t, "len(hosts) == 2", len(hosts) == 2)
-	assertTrue(t, "len(partitioner) != 0", len(partitioner) != 0)
-	for _, host := range hosts {
-		if host.ConnectAddress().Equal(filterHostIP) {
-			t.Fatal(fmt.Sprintf("Did not expect to see '%q' in host list", filterHostIP))
-		}
-	}
 }
 
 func TestHostInfo_ConnectAddress(t *testing.T) {

+ 10 - 4
query_executor.go

@@ -17,6 +17,15 @@ type queryExecutor struct {
 	policy HostSelectionPolicy
 }
 
+func (q *queryExecutor) attemptQuery(qry ExecutableQuery, conn *Conn) *Iter {
+	start := time.Now()
+	iter := qry.execute(conn)
+
+	qry.attempt(time.Since(start))
+
+	return iter
+}
+
 func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
 	rt := qry.retryPolicy()
 	hostIter := q.policy.Pick(qry)
@@ -38,10 +47,7 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
 			continue
 		}
 
-		start := time.Now()
-		iter = qry.execute(conn)
-
-		qry.attempt(time.Since(start))
+		iter = q.attemptQuery(qry, conn)
 
 		// Update host
 		hostResponse.Mark(iter.err)

+ 2 - 0
ring.go

@@ -64,6 +64,8 @@ func (r *ring) currentHosts() map[string]*HostInfo {
 }
 
 func (r *ring) addHost(host *HostInfo) bool {
+	// TODO(zariel): key all host info by HostID instead of
+	// ip addresses
 	if host.invalidConnectAddr() {
 		panic(fmt.Sprintf("invalid host: %v", host))
 	}

+ 10 - 4
session.go

@@ -639,7 +639,7 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{})
 }
 
 func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
-	return Connect(host, s.connCfg, errorHandler, s)
+	return s.dial(host.ConnectAddress(), host.Port(), s.connCfg, errorHandler)
 }
 
 // Query represents a CQL statement that can be executed.
@@ -1057,14 +1057,14 @@ type Scanner interface {
 	// scanned into with Scan.
 	// Next must be called before every call to Scan.
 	Next() bool
-	
+
 	// Scan copies the current row's columns into dest. If the length of dest does not equal
 	// the number of columns returned in the row an error is returned. If an error is encountered
 	// when unmarshalling a column into the value in dest an error is returned and the row is invalidated
 	// until the next call to Next.
 	// Next must be called before calling Scan, if it is not an error is returned.
 	Scan(...interface{}) error
-	
+
 	// Err returns the if there was one during iteration that resulted in iteration being unable to complete.
 	// Err will also release resources held by the iterator, the Scanner should not used after being called.
 	Err() error
@@ -1301,11 +1301,17 @@ type nextIter struct {
 	pos  int
 	once sync.Once
 	next *Iter
+	conn *Conn
 }
 
 func (n *nextIter) fetch() *Iter {
 	n.once.Do(func() {
-		n.next = n.qry.session.executeQuery(&n.qry)
+		iter := n.qry.session.executor.attemptQuery(&n.qry, n.conn)
+		if iter != nil && iter.err == nil {
+			n.next = iter
+		} else {
+			n.next = n.qry.session.executeQuery(&n.qry)
+		}
 	})
 	return n.next
 }

+ 4 - 4
session_connect_test.go

@@ -102,10 +102,10 @@ func TestSession_connect_WithNoTranslator(t *testing.T) {
 
 	go srvr.Serve()
 
-	Connect(&HostInfo{
+	session.connect(&HostInfo{
 		connectAddress: srvr.Addr,
 		port:           srvr.Port,
-	}, session.connCfg, testConnErrorHandler(t), session)
+	}, testConnErrorHandler(t))
 
 	assertConnectionEventually(t, 500*time.Millisecond, srvr)
 }
@@ -122,10 +122,10 @@ func TestSession_connect_WithTranslator(t *testing.T) {
 	go srvr.Serve()
 
 	// the provided address will be translated
-	Connect(&HostInfo{
+	session.connect(&HostInfo{
 		connectAddress: net.ParseIP("10.10.10.10"),
 		port:           5432,
-	}, session.connCfg, testConnErrorHandler(t), session)
+	}, testConnErrorHandler(t))
 
 	assertConnectionEventually(t, 500*time.Millisecond, srvr)
 }