Browse Source

start replacing host discovery with events

Chris Bannister 10 năm trước cách đây
mục cha
commit
5c1118aa01
10 tập tin đã thay đổi với 239 bổ sung148 xóa
  1. 0 8
      cassandra_test.go
  2. 2 13
      cluster.go
  3. 0 1
      conn.go
  4. 22 21
      conn_test.go
  5. 44 19
      connectionpool.go
  6. 67 8
      control.go
  7. 10 14
      events.go
  8. 44 18
      policies.go
  9. 16 17
      policies_test.go
  10. 34 29
      session.go

+ 0 - 8
cassandra_test.go

@@ -165,7 +165,6 @@ func TestAuthentication(t *testing.T) {
 func TestRingDiscovery(t *testing.T) {
 	cluster := createCluster()
 	cluster.Hosts = clusterHosts[:1]
-	cluster.DiscoverHosts = true
 
 	session := createSessionFromCluster(cluster, t)
 	defer session.Close()
@@ -649,10 +648,6 @@ func TestCreateSessionTimeout(t *testing.T) {
 		session.Close()
 		t.Fatal("expected ErrNoConnectionsStarted, but no error was returned.")
 	}
-
-	if err != ErrNoConnectionsStarted {
-		t.Fatalf("expected ErrNoConnectionsStarted, but received %v", err)
-	}
 }
 
 type FullName struct {
@@ -2001,7 +1996,6 @@ func TestRoutingKey(t *testing.T) {
 func TestTokenAwareConnPool(t *testing.T) {
 	cluster := createCluster()
 	cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy())
-	cluster.DiscoverHosts = true
 
 	session := createSessionFromCluster(cluster, t)
 	defer session.Close()
@@ -2379,9 +2373,7 @@ func TestDiscoverViaProxy(t *testing.T) {
 	proxyAddr := proxy.Addr().String()
 
 	cluster := createCluster()
-	cluster.DiscoverHosts = true
 	cluster.NumConns = 1
-	cluster.Discovery.Sleep = 100 * time.Millisecond
 	// initial host is the proxy address
 	cluster.Hosts = []string{proxyAddr}
 

+ 2 - 13
cluster.go

@@ -40,16 +40,6 @@ func initStmtsLRU(max int) {
 	}
 }
 
-// To enable periodic node discovery enable DiscoverHosts in ClusterConfig
-type DiscoveryConfig struct {
-	// If not empty will filter all discoverred hosts to a single Data Centre (default: "")
-	DcFilter string
-	// If not empty will filter all discoverred hosts to a single Rack (default: "")
-	RackFilter string
-	// The interval to check for new hosts (default: 30s)
-	Sleep time.Duration
-}
-
 // PoolConfig configures the connection pool used by the driver, it defaults to
 // using a round robbin host selection policy and a round robbin connection selection
 // policy for each host.
@@ -94,12 +84,10 @@ type ClusterConfig struct {
 	Authenticator     Authenticator     // authenticator (default: nil)
 	RetryPolicy       RetryPolicy       // Default retry policy to use for queries (default: 0)
 	SocketKeepalive   time.Duration     // The keepalive period to use, enabled if > 0 (default: 0)
-	DiscoverHosts     bool              // If set, gocql will attempt to automatically discover other members of the Cassandra cluster (default: false)
 	MaxPreparedStmts  int               // Sets the maximum cache size for prepared statements globally for gocql (default: 1000)
 	MaxRoutingKeyInfo int               // Sets the maximum cache size for query info about statements for each session (default: 1000)
 	PageSize          int               // Default page size to use for created sessions (default: 5000)
 	SerialConsistency SerialConsistency // Sets the consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL (default: unset)
-	Discovery         DiscoveryConfig
 	SslOpts           *SslOptions
 	DefaultTimestamp  bool // Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server. (default: true, only enabled for protocol 3 and above)
 	// PoolConfig configures the underlying connection pool, allowing the
@@ -110,6 +98,8 @@ type ClusterConfig struct {
 	// receiving a schema change frame. (deault: 60s)
 	MaxWaitSchemaAgreement time.Duration
 
+	HostFilter
+
 	// internal config for testing
 	disableControlConn bool
 }
@@ -124,7 +114,6 @@ func NewCluster(hosts ...string) *ClusterConfig {
 		Port:                   9042,
 		NumConns:               2,
 		Consistency:            Quorum,
-		DiscoverHosts:          false,
 		MaxPreparedStmts:       defaultMaxPreparedStmts,
 		MaxRoutingKeyInfo:      1000,
 		PageSize:               5000,

+ 0 - 1
conn.go

@@ -141,7 +141,6 @@ type Conn struct {
 }
 
 // Connect establishes a connection to a Cassandra node.
-// You must also call the Serve method before you can execute any queries.
 func Connect(addr string, cfg *ConnConfig, errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
 	var (
 		err  error

+ 22 - 21
conn_test.go

@@ -50,12 +50,18 @@ func TestJoinHostPort(t *testing.T) {
 	}
 }
 
+func testCluster(addr string, proto protoVersion) *ClusterConfig {
+	cluster := NewCluster(addr)
+	cluster.ProtoVersion = int(proto)
+	cluster.disableControlConn = true
+	return cluster
+}
+
 func TestSimple(t *testing.T) {
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	cluster := NewCluster(srv.Address)
-	cluster.ProtoVersion = int(defaultProto)
+	cluster := testCluster(srv.Address, defaultProto)
 	db, err := cluster.CreateSession()
 	if err != nil {
 		t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
@@ -94,18 +100,19 @@ func TestSSLSimpleNoClientCert(t *testing.T) {
 	}
 }
 
-func createTestSslCluster(hosts string, proto uint8, useClientCert bool) *ClusterConfig {
-	cluster := NewCluster(hosts)
+func createTestSslCluster(addr string, proto protoVersion, useClientCert bool) *ClusterConfig {
+	cluster := testCluster(addr, proto)
 	sslOpts := &SslOptions{
 		CaPath:                 "testdata/pki/ca.crt",
 		EnableHostVerification: false,
 	}
+
 	if useClientCert {
 		sslOpts.CertPath = "testdata/pki/gocql.crt"
 		sslOpts.KeyPath = "testdata/pki/gocql.key"
 	}
+
 	cluster.SslOpts = sslOpts
-	cluster.ProtoVersion = int(proto)
 	return cluster
 }
 
@@ -115,28 +122,23 @@ func TestClosed(t *testing.T) {
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	cluster := NewCluster(srv.Address)
-	cluster.ProtoVersion = int(defaultProto)
-
-	session, err := cluster.CreateSession()
-	defer session.Close()
+	session, err := newTestSession(srv.Address, defaultProto)
 	if err != nil {
 		t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
 	}
 
+	session.Close()
+
 	if err := session.Query("void").Exec(); err != ErrSessionClosed {
 		t.Fatalf("0x%x: expected %#v, got %#v", defaultProto, ErrSessionClosed, err)
 	}
 }
 
-func newTestSession(addr string, proto uint8) (*Session, error) {
-	cluster := NewCluster(addr)
-	cluster.ProtoVersion = int(proto)
-	return cluster.CreateSession()
+func newTestSession(addr string, proto protoVersion) (*Session, error) {
+	return testCluster(addr, proto).CreateSession()
 }
 
 func TestTimeout(t *testing.T) {
-
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
@@ -197,7 +199,7 @@ func TestStreams_Protocol1(t *testing.T) {
 
 	// TODO: these are more like session tests and should instead operate
 	// on a single Conn
-	cluster := NewCluster(srv.Address)
+	cluster := testCluster(srv.Address, protoVersion1)
 	cluster.NumConns = 1
 	cluster.ProtoVersion = 1
 
@@ -229,7 +231,7 @@ func TestStreams_Protocol3(t *testing.T) {
 
 	// TODO: these are more like session tests and should instead operate
 	// on a single Conn
-	cluster := NewCluster(srv.Address)
+	cluster := testCluster(srv.Address, protoVersion3)
 	cluster.NumConns = 1
 	cluster.ProtoVersion = 3
 
@@ -356,7 +358,6 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 
 	db, err := cluster.CreateSession()
 	if err != nil {
-		db.Close()
 		t.Fatalf("failed to create new session: %v", err)
 	}
 
@@ -377,7 +378,7 @@ func TestQueryTimeout(t *testing.T) {
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	cluster := NewCluster(srv.Address)
+	cluster := testCluster(srv.Address, defaultProto)
 	// Set the timeout arbitrarily low so that the query hits the timeout in a
 	// timely manner.
 	cluster.Timeout = 1 * time.Millisecond
@@ -418,7 +419,7 @@ func TestQueryTimeoutReuseStream(t *testing.T) {
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	cluster := NewCluster(srv.Address)
+	cluster := testCluster(srv.Address, defaultProto)
 	// Set the timeout arbitrarily low so that the query hits the timeout in a
 	// timely manner.
 	cluster.Timeout = 1 * time.Millisecond
@@ -442,7 +443,7 @@ func TestQueryTimeoutClose(t *testing.T) {
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	cluster := NewCluster(srv.Address)
+	cluster := testCluster(srv.Address, defaultProto)
 	// Set the timeout arbitrarily low so that the query hits the timeout in a
 	// timely manner.
 	cluster.Timeout = 1000 * time.Millisecond

+ 44 - 19
connectionpool.go

@@ -71,16 +71,15 @@ type policyConnPool struct {
 	hostConnPools map[string]*hostConnPool
 }
 
-func newPolicyConnPool(session *Session, hostPolicy HostSelectionPolicy,
-	connPolicy func() ConnSelectionPolicy) (*policyConnPool, error) {
+func connConfig(session *Session) (*ConnConfig, error) {
+	cfg := session.cfg
 
 	var (
 		err       error
 		tlsConfig *tls.Config
 	)
 
-	cfg := session.cfg
-
+	// TODO(zariel): move tls config setup into session init.
 	if cfg.SslOpts != nil {
 		tlsConfig, err = setupTLSConfig(cfg.SslOpts)
 		if err != nil {
@@ -88,28 +87,41 @@ func newPolicyConnPool(session *Session, hostPolicy HostSelectionPolicy,
 		}
 	}
 
+	return &ConnConfig{
+		ProtoVersion:  cfg.ProtoVersion,
+		CQLVersion:    cfg.CQLVersion,
+		Timeout:       cfg.Timeout,
+		NumStreams:    cfg.NumStreams,
+		Compressor:    cfg.Compressor,
+		Authenticator: cfg.Authenticator,
+		Keepalive:     cfg.SocketKeepalive,
+		tlsConfig:     tlsConfig,
+	}, nil
+}
+
+func newPolicyConnPool(session *Session, hostPolicy HostSelectionPolicy,
+	connPolicy func() ConnSelectionPolicy) (*policyConnPool, error) {
+
+	connCfg, err := connConfig(session)
+	if err != nil {
+		return nil, err
+	}
+
 	// create the pool
 	pool := &policyConnPool{
-		session:  session,
-		port:     cfg.Port,
-		numConns: cfg.NumConns,
-		connCfg: &ConnConfig{
-			ProtoVersion:  cfg.ProtoVersion,
-			CQLVersion:    cfg.CQLVersion,
-			Timeout:       cfg.Timeout,
-			Compressor:    cfg.Compressor,
-			Authenticator: cfg.Authenticator,
-			Keepalive:     cfg.SocketKeepalive,
-			tlsConfig:     tlsConfig,
-		},
-		keyspace:      cfg.Keyspace,
+		session:       session,
+		port:          session.cfg.Port,
+		numConns:      session.cfg.NumConns,
+		connCfg:       connCfg,
+		keyspace:      session.cfg.Keyspace,
 		hostPolicy:    hostPolicy,
 		connPolicy:    connPolicy,
 		hostConnPools: map[string]*hostConnPool{},
 	}
 
-	hosts := make([]HostInfo, len(cfg.Hosts))
-	for i, hostAddr := range cfg.Hosts {
+	// TODO(zariel): fetch this from session metadata.
+	hosts := make([]HostInfo, len(session.cfg.Hosts))
+	for i, hostAddr := range session.cfg.Hosts {
 		hosts[i].Peer = hostAddr
 	}
 
@@ -241,6 +253,7 @@ func (p *policyConnPool) addHost(host *HostInfo) {
 }
 
 func (p *policyConnPool) removeHost(addr string) {
+	p.hostPolicy.RemoveHost(addr)
 	p.mu.Lock()
 
 	pool, ok := p.hostConnPools[addr]
@@ -255,6 +268,18 @@ func (p *policyConnPool) removeHost(addr string) {
 	pool.Close()
 }
 
+func (p *policyConnPool) hostUp(host *HostInfo) {
+	// TODO(zariel): have a set of up hosts and down hosts, we can internally
+	// detect down hosts, then try to reconnect to them.
+	p.addHost(host)
+}
+
+func (p *policyConnPool) hostDown(addr string) {
+	// TODO(zariel): mark host as down so we can try to connect to it later, for
+	// now just treat it has removed.
+	p.removeHost(addr)
+}
+
 // hostConnPool is a connection pool for a single host.
 // Connection selection is based on a provided ConnSelectionPolicy
 type hostConnPool struct {

+ 67 - 8
control.go

@@ -4,23 +4,26 @@ import (
 	"errors"
 	"fmt"
 	"log"
+	"math/rand"
 	"net"
+	"sync"
 	"sync/atomic"
 	"time"
 )
 
-// Ensure that the atomic variable is aligned to a 64bit boundary 
+// Ensure that the atomic variable is aligned to a 64bit boundary
 // so that atomic operations can be applied on 32bit architectures.
 type controlConn struct {
 	connecting uint64
 
 	session *Session
 
-	conn       atomic.Value
+	conn atomic.Value
 
 	retry RetryPolicy
 
-	quit chan struct{}
+	closeWg sync.WaitGroup
+	quit    chan struct{}
 }
 
 func createControlConn(session *Session) *controlConn {
@@ -31,12 +34,14 @@ func createControlConn(session *Session) *controlConn {
 	}
 
 	control.conn.Store((*Conn)(nil))
-	go control.heartBeat()
 
 	return control
 }
 
 func (c *controlConn) heartBeat() {
+	c.closeWg.Add(1)
+	defer c.closeWg.Done()
+
 	for {
 		select {
 		case <-c.quit:
@@ -62,12 +67,61 @@ func (c *controlConn) heartBeat() {
 		c.reconnect(true)
 		// time.Sleep(5 * time.Second)
 		continue
+	}
+}
+
+func (c *controlConn) connect(endpoints []string) error {
+	// intial connection attmept, try to connect to each endpoint to get an initial
+	// list of nodes.
+
+	// shuffle endpoints so not all drivers will connect to the same initial
+	// node.
+	r := rand.New(rand.NewSource(time.Now().UnixNano()))
+	perm := r.Perm(len(endpoints))
+	shuffled := make([]string, len(endpoints))
+
+	for i, endpoint := range endpoints {
+		shuffled[perm[i]] = endpoint
+	}
 
+	connCfg, err := connConfig(c.session)
+	if err != nil {
+		return err
 	}
+
+	// store that we are not connected so that reconnect wont happen if we error
+	atomic.StoreInt64(&c.connecting, -1)
+
+	var (
+		conn *Conn
+	)
+
+	for _, addr := range shuffled {
+		conn, err = Connect(JoinHostPort(addr, c.session.cfg.Port), connCfg, c, c.session)
+		if err != nil {
+			log.Printf("gocql: unable to dial %v: %v\n", addr, err)
+			continue
+		}
+
+		// we should fetch the initial ring here and update initial host data. So that
+		// when we return from here we have a ring topology ready to go.
+		break
+	}
+
+	if conn == nil {
+		// this is fatal, not going to connect a session
+		return err
+	}
+
+	c.conn.Store(conn)
+	atomic.StoreInt64(&c.connecting, 0)
+	go c.heartBeat()
+
+	return nil
 }
 
 func (c *controlConn) reconnect(refreshring bool) {
-	if !atomic.CompareAndSwapUint64(&c.connecting, 0, 1) {
+	if !atomic.CompareAndSwapInt64(&c.connecting, 0, 1) {
 		return
 	}
 
@@ -77,10 +131,10 @@ func (c *controlConn) reconnect(refreshring bool) {
 		if success {
 			go func() {
 				time.Sleep(500 * time.Millisecond)
-				atomic.StoreUint64(&c.connecting, 0)
+				atomic.StoreInt64(&c.connecting, 0)
 			}()
 		} else {
-			atomic.StoreUint64(&c.connecting, 0)
+			atomic.StoreInt64(&c.connecting, 0)
 		}
 	}()
 
@@ -120,7 +174,7 @@ func (c *controlConn) reconnect(refreshring bool) {
 		oldConn.Close()
 	}
 
-	if refreshring && c.session.cfg.DiscoverHosts {
+	if refreshring {
 		c.session.hostSource.refreshRing()
 	}
 }
@@ -242,6 +296,11 @@ func (c *controlConn) addr() string {
 func (c *controlConn) close() {
 	// TODO: handle more gracefully
 	close(c.quit)
+	c.closeWg.Wait()
+	conn := c.conn.Load().(*Conn)
+	if conn != nil {
+		conn.Close()
+	}
 }
 
 var errNoControl = errors.New("gocql: no control connection available")

+ 10 - 14
events.go

@@ -43,36 +43,32 @@ func (s *Session) handleEvent(framer *framer) {
 }
 
 func (s *Session) handleNewNode(host net.IP, port int) {
-	if !s.cfg.DiscoverHosts || s.control == nil {
+	// TODO(zariel): need to be able to filter discovered nodes
+	if s.control == nil {
 		return
 	}
 
-	if s.control.addr() == host.String() {
-		go s.control.reconnect(false)
-	}
-
 	hostInfo, err := s.control.fetchHostInfo(host, port)
 	if err != nil {
 		log.Printf("gocql: unable to fetch host info for %v: %v\n", host, err)
 		return
 	}
 
-	s.pool.addHost(hostInfo)
+	if s.hostFilter.Accept(*hostInfo) {
+		s.pool.addHost(hostInfo)
+	}
 }
 
 func (s *Session) handleRemovedNode(host net.IP, port int) {
-	if !s.cfg.DiscoverHosts {
-		return
-	}
-
+	// we remove all nodes but only add ones which pass the filter
 	s.pool.removeHost(host.String())
 }
 
 func (s *Session) handleNodeUp(host net.IP, port int) {
-	// even if were not disconvering new nodes we should still handle nodes going
-	// up.
-
-	s.pool.hostUp(host.String())
+	// TODO(zariel): handle this case even when not discovering, just mark the
+	// host up.
+	// TODO: implement this properly not as newNode
+	s.handleNewNode(host, port)
 }
 
 func (s *Session) handleNodeDown(host net.IP, port int) {

+ 44 - 18
policies.go

@@ -55,7 +55,7 @@ func (c *cowHostList) add(host HostInfo) {
 	c.mu.Unlock()
 }
 
-func (c *cowHostList) remove(host HostInfo) {
+func (c *cowHostList) remove(addr string) {
 	c.mu.Lock()
 	l := c.get()
 	size := len(l)
@@ -67,7 +67,7 @@ func (c *cowHostList) remove(host HostInfo) {
 	found := false
 	newL := make([]HostInfo, 0, size)
 	for i := 0; i < len(l); i++ {
-		if host.Peer != l[i].Peer && host.HostId != l[i].HostId {
+		if l[i].Peer != addr {
 			newL = append(newL, l[i])
 		} else {
 			found = true
@@ -124,6 +124,8 @@ func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool {
 
 type HostStateNotifier interface {
 	AddHost(host *HostInfo)
+	RemoveHost(addr string)
+	// TODO(zariel): add host up/down
 }
 
 // HostSelectionPolicy is an interface for selecting
@@ -149,7 +151,7 @@ type NextHost func() SelectedHost
 // RoundRobinHostPolicy is a round-robin load balancing policy, where each host
 // is tried sequentially for each query.
 func RoundRobinHostPolicy() HostSelectionPolicy {
-	return &roundRobinHostPolicy{hosts: []HostInfo{}}
+	return &roundRobinHostPolicy{}
 }
 
 type roundRobinHostPolicy struct {
@@ -179,7 +181,7 @@ func (r *roundRobinHostPolicy) Pick(qry *Query) NextHost {
 		// always increment pos to evenly distribute traffic in case of
 		// failures
 		pos := atomic.AddUint32(&r.pos, 1)
-		if i >= len(r.hosts) {
+		if i >= len(hosts) {
 			return nil
 		}
 		host := &r.hosts[(pos)%uint32(len(r.hosts))]
@@ -192,6 +194,10 @@ func (r *roundRobinHostPolicy) AddHost(host *HostInfo) {
 	r.hosts.add(*host)
 }
 
+func (r *roundRobinHostPolicy) RemoveHost(addr string) {
+	r.hosts.remove(addr)
+}
+
 // selectedRoundRobinHost is a host returned by the roundRobinHostPolicy and
 // implements the SelectedHost interface
 type selectedRoundRobinHost struct {
@@ -210,24 +216,25 @@ func (host selectedRoundRobinHost) Mark(err error) {
 // selected based on the partition key, so queries are sent to the host which
 // owns the partition. Fallback is used when routing information is not available.
 func TokenAwareHostPolicy(fallback HostSelectionPolicy) HostSelectionPolicy {
-	return &tokenAwareHostPolicy{fallback: fallback, hosts: []HostInfo{}}
+	return &tokenAwareHostPolicy{fallback: fallback}
 }
 
 type tokenAwareHostPolicy struct {
+	hosts       cowHostList
 	mu          sync.RWMutex
-	hosts       []HostInfo
 	partitioner string
 	tokenRing   *tokenRing
 	fallback    HostSelectionPolicy
 }
 
 func (t *tokenAwareHostPolicy) SetHosts(hosts []HostInfo) {
+	t.hosts.set(hosts)
+
 	t.mu.Lock()
 	defer t.mu.Unlock()
 
 	// always update the fallback
 	t.fallback.SetHosts(hosts)
-	t.hosts = hosts
 
 	t.resetTokenRing()
 }
@@ -245,19 +252,19 @@ func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
 }
 
 func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) {
+	t.hosts.add(*host)
+
 	t.mu.Lock()
-	defer t.mu.Unlock()
+	t.resetTokenRing()
+	t.mu.Unlock()
+}
 
-	t.fallback.AddHost(host)
-	for i := range t.hosts {
-		h := &t.hosts[i]
-		if h.HostId == host.HostId && h.Peer == host.Peer {
-			return
-		}
-	}
+func (t *tokenAwareHostPolicy) RemoveHost(addr string) {
+	t.hosts.remove(addr)
 
-	t.hosts = append(t.hosts, *host)
+	t.mu.Lock()
 	t.resetTokenRing()
+	t.mu.Unlock()
 }
 
 func (t *tokenAwareHostPolicy) resetTokenRing() {
@@ -267,7 +274,8 @@ func (t *tokenAwareHostPolicy) resetTokenRing() {
 	}
 
 	// create a new token ring
-	tokenRing, err := newTokenRing(t.partitioner, t.hosts)
+	hosts := t.hosts.get()
+	tokenRing, err := newTokenRing(t.partitioner, hosts)
 	if err != nil {
 		log.Printf("Unable to update the token ring due to error: %s", err)
 		return
@@ -309,6 +317,7 @@ func (t *tokenAwareHostPolicy) Pick(qry *Query) NextHost {
 		hostReturned bool
 		fallbackIter NextHost
 	)
+
 	return func() SelectedHost {
 		if !hostReturned {
 			hostReturned = true
@@ -365,8 +374,8 @@ func HostPoolHostPolicy(hp hostpool.HostPool) HostSelectionPolicy {
 
 type hostPoolHostPolicy struct {
 	hp      hostpool.HostPool
-	hostMap map[string]HostInfo
 	mu      sync.RWMutex
+	hostMap map[string]HostInfo
 }
 
 func (r *hostPoolHostPolicy) SetHosts(hosts []HostInfo) {
@@ -402,6 +411,23 @@ func (r *hostPoolHostPolicy) AddHost(host *HostInfo) {
 	r.hostMap[host.Peer] = *host
 }
 
+func (r *hostPoolHostPolicy) RemoveHost(addr string) {
+	r.mu.Unlock()
+	defer r.mu.Unlock()
+
+	if _, ok := r.hostMap[addr]; !ok {
+		return
+	}
+
+	delete(r.hostMap, addr)
+	hosts := make([]string, 0, len(r.hostMap))
+	for addr := range r.hostMap {
+		hosts = append(hosts, addr)
+	}
+
+	r.hp.SetHosts(hosts)
+}
+
 func (r *hostPoolHostPolicy) SetPartitioner(partitioner string) {
 	// noop
 }

+ 16 - 17
policies_test.go

@@ -23,30 +23,29 @@ func TestRoundRobinHostPolicy(t *testing.T) {
 
 	policy.SetHosts(hosts)
 
-	// the first host selected is actually at [1], but this is ok for RR
 	// interleaved iteration should always increment the host
 	iterA := policy.Pick(nil)
-	if actual := iterA(); actual.Info() != &hosts[1] {
-		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostId)
-	}
-	iterB := policy.Pick(nil)
-	if actual := iterB(); actual.Info() != &hosts[0] {
+	if actual := iterA(); actual.Info() != &hosts[0] {
 		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostId)
 	}
+	iterB := policy.Pick(nil)
 	if actual := iterB(); actual.Info() != &hosts[1] {
 		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostId)
 	}
-	if actual := iterA(); actual.Info() != &hosts[0] {
+	if actual := iterB(); actual.Info() != &hosts[0] {
 		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostId)
 	}
-
-	iterC := policy.Pick(nil)
-	if actual := iterC(); actual.Info() != &hosts[1] {
+	if actual := iterA(); actual.Info() != &hosts[1] {
 		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostId)
 	}
+
+	iterC := policy.Pick(nil)
 	if actual := iterC(); actual.Info() != &hosts[0] {
 		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostId)
 	}
+	if actual := iterC(); actual.Info() != &hosts[1] {
+		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostId)
+	}
 }
 
 // Tests of the token-aware host selection policy implementation with a
@@ -76,13 +75,13 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 
 	// the token ring is not setup without the partitioner, but the fallback
 	// should work
-	if actual := policy.Pick(nil)(); actual.Info().Peer != "1" {
-		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer)
+	if actual := policy.Pick(nil)(); actual.Info().Peer != "0" {
+		t.Errorf("Expected peer 0 but was %s", actual.Info().Peer)
 	}
 
 	query.RoutingKey([]byte("30"))
-	if actual := policy.Pick(query)(); actual.Info().Peer != "2" {
-		t.Errorf("Expected peer 2 but was %s", actual.Info().Peer)
+	if actual := policy.Pick(query)(); actual.Info().Peer != "1" {
+		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer)
 	}
 
 	policy.SetPartitioner("OrderedPartitioner")
@@ -94,15 +93,15 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer)
 	}
 	// rest are round robin
+	if actual := iter(); actual.Info().Peer != "2" {
+		t.Errorf("Expected peer 2 but was %s", actual.Info().Peer)
+	}
 	if actual := iter(); actual.Info().Peer != "3" {
 		t.Errorf("Expected peer 3 but was %s", actual.Info().Peer)
 	}
 	if actual := iter(); actual.Info().Peer != "0" {
 		t.Errorf("Expected peer 0 but was %s", actual.Info().Peer)
 	}
-	if actual := iter(); actual.Info().Peer != "2" {
-		t.Errorf("Expected peer 2 but was %s", actual.Info().Peer)
-	}
 }
 
 // Tests of the host pool host selection policy implementation

+ 34 - 29
session.go

@@ -10,7 +10,6 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"log"
 	"strings"
 	"sync"
 	"time"
@@ -39,6 +38,8 @@ type Session struct {
 	hostSource          *ringDescriber
 	mu                  sync.RWMutex
 
+	hostFilter HostFilter
+
 	control *controlConn
 
 	cfg ClusterConfig
@@ -66,49 +67,53 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 		pageSize: cfg.PageSize,
 	}
 
-	pool, err := cfg.PoolConfig.buildPool(s)
-	if err != nil {
-		return nil, err
-	}
-	s.pool = pool
-
-	// See if there are any connections in the pool
-	if pool.Size() == 0 {
-		s.Close()
-		return nil, ErrNoConnectionsStarted
-	}
-
 	s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo)
 
 	// I think it might be a good idea to simplify this and make it always discover
 	// hosts, maybe with more filters.
-	if cfg.DiscoverHosts {
-		s.hostSource = &ringDescriber{
-			session:    s,
-			dcFilter:   cfg.Discovery.DcFilter,
-			rackFilter: cfg.Discovery.RackFilter,
-			closeChan:  make(chan bool),
-		}
+	s.hostSource = &ringDescriber{
+		session:   s,
+		closeChan: make(chan bool),
 	}
 
 	if !cfg.disableControlConn {
 		s.control = createControlConn(s)
-		s.control.reconnect(false)
+		if err := s.control.connect(cfg.Hosts); err != nil {
+			s.control.close()
+			return nil, err
+		}
 
-		// need to setup host source to check for rpc_address in system.local
-		localHasRPCAddr, err := checkSystemLocal(s.control)
+		// need to setup host source to check for broadcast_address in system.local
+		localHasRPCAddr, _ := checkSystemLocal(s.control)
+		s.hostSource.localHasRpcAddr = localHasRPCAddr
+		hosts, _, err := s.hostSource.GetHosts()
 		if err != nil {
-			log.Printf("gocql: unable to verify if system.local table contains rpc_address, falling back to connection address: %v", err)
+			s.control.close()
+			return nil, err
 		}
 
-		if cfg.DiscoverHosts {
-			s.hostSource.localHasRpcAddr = localHasRPCAddr
+		pool, err := cfg.PoolConfig.buildPool(s)
+		if err != nil {
+			return nil, err
 		}
+		s.pool = pool
+		// TODO(zariel): this should be used to create initial metadata
+		s.pool.SetHosts(hosts)
+	} else {
+		// TODO(zariel): remove branch for creating pools
+		pool, err := cfg.PoolConfig.buildPool(s)
+		if err != nil {
+			return nil, err
+		}
+		s.pool = pool
 	}
 
-	if cfg.DiscoverHosts {
-		s.hostSource.refreshRing()
-		go s.hostSource.run(cfg.Discovery.Sleep)
+	// TODO(zariel): we probably dont need this any more as we verify that we
+	// can connect to one of the endpoints supplied by using the control conn.
+	// See if there are any connections in the pool
+	if s.pool.Size() == 0 {
+		s.Close()
+		return nil, ErrNoConnectionsStarted
 	}
 
 	return s, nil