فهرست منبع

start replacing host discovery with events

Chris Bannister 10 سال پیش
والد
کامیت
5c1118aa01
10فایلهای تغییر یافته به همراه239 افزوده شده و 148 حذف شده
  1. 0 8
      cassandra_test.go
  2. 2 13
      cluster.go
  3. 0 1
      conn.go
  4. 22 21
      conn_test.go
  5. 44 19
      connectionpool.go
  6. 67 8
      control.go
  7. 10 14
      events.go
  8. 44 18
      policies.go
  9. 16 17
      policies_test.go
  10. 34 29
      session.go

+ 0 - 8
cassandra_test.go

@@ -165,7 +165,6 @@ func TestAuthentication(t *testing.T) {
 func TestRingDiscovery(t *testing.T) {
 	cluster := createCluster()
 	cluster.Hosts = clusterHosts[:1]
-	cluster.DiscoverHosts = true
 
 	session := createSessionFromCluster(cluster, t)
 	defer session.Close()
@@ -649,10 +648,6 @@ func TestCreateSessionTimeout(t *testing.T) {
 		session.Close()
 		t.Fatal("expected ErrNoConnectionsStarted, but no error was returned.")
 	}
-
-	if err != ErrNoConnectionsStarted {
-		t.Fatalf("expected ErrNoConnectionsStarted, but received %v", err)
-	}
 }
 
 type FullName struct {
@@ -2001,7 +1996,6 @@ func TestRoutingKey(t *testing.T) {
 func TestTokenAwareConnPool(t *testing.T) {
 	cluster := createCluster()
 	cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy())
-	cluster.DiscoverHosts = true
 
 	session := createSessionFromCluster(cluster, t)
 	defer session.Close()
@@ -2379,9 +2373,7 @@ func TestDiscoverViaProxy(t *testing.T) {
 	proxyAddr := proxy.Addr().String()
 
 	cluster := createCluster()
-	cluster.DiscoverHosts = true
 	cluster.NumConns = 1
-	cluster.Discovery.Sleep = 100 * time.Millisecond
 	// initial host is the proxy address
 	cluster.Hosts = []string{proxyAddr}
 

+ 2 - 13
cluster.go

@@ -40,16 +40,6 @@ func initStmtsLRU(max int) {
 	}
 }
 
-// To enable periodic node discovery enable DiscoverHosts in ClusterConfig
-type DiscoveryConfig struct {
-	// If not empty will filter all discoverred hosts to a single Data Centre (default: "")
-	DcFilter string
-	// If not empty will filter all discoverred hosts to a single Rack (default: "")
-	RackFilter string
-	// The interval to check for new hosts (default: 30s)
-	Sleep time.Duration
-}
-
 // PoolConfig configures the connection pool used by the driver, it defaults to
 // using a round robbin host selection policy and a round robbin connection selection
 // policy for each host.
@@ -94,12 +84,10 @@ type ClusterConfig struct {
 	Authenticator     Authenticator     // authenticator (default: nil)
 	RetryPolicy       RetryPolicy       // Default retry policy to use for queries (default: 0)
 	SocketKeepalive   time.Duration     // The keepalive period to use, enabled if > 0 (default: 0)
-	DiscoverHosts     bool              // If set, gocql will attempt to automatically discover other members of the Cassandra cluster (default: false)
 	MaxPreparedStmts  int               // Sets the maximum cache size for prepared statements globally for gocql (default: 1000)
 	MaxRoutingKeyInfo int               // Sets the maximum cache size for query info about statements for each session (default: 1000)
 	PageSize          int               // Default page size to use for created sessions (default: 5000)
 	SerialConsistency SerialConsistency // Sets the consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL (default: unset)
-	Discovery         DiscoveryConfig
 	SslOpts           *SslOptions
 	DefaultTimestamp  bool // Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server. (default: true, only enabled for protocol 3 and above)
 	// PoolConfig configures the underlying connection pool, allowing the
@@ -110,6 +98,8 @@ type ClusterConfig struct {
 	// receiving a schema change frame. (deault: 60s)
 	MaxWaitSchemaAgreement time.Duration
 
+	HostFilter
+
 	// internal config for testing
 	disableControlConn bool
 }
@@ -124,7 +114,6 @@ func NewCluster(hosts ...string) *ClusterConfig {
 		Port:                   9042,
 		NumConns:               2,
 		Consistency:            Quorum,
-		DiscoverHosts:          false,
 		MaxPreparedStmts:       defaultMaxPreparedStmts,
 		MaxRoutingKeyInfo:      1000,
 		PageSize:               5000,

+ 0 - 1
conn.go

@@ -141,7 +141,6 @@ type Conn struct {
 }
 
 // Connect establishes a connection to a Cassandra node.
-// You must also call the Serve method before you can execute any queries.
 func Connect(addr string, cfg *ConnConfig, errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
 	var (
 		err  error

+ 22 - 21
conn_test.go

@@ -50,12 +50,18 @@ func TestJoinHostPort(t *testing.T) {
 	}
 }
 
+func testCluster(addr string, proto protoVersion) *ClusterConfig {
+	cluster := NewCluster(addr)
+	cluster.ProtoVersion = int(proto)
+	cluster.disableControlConn = true
+	return cluster
+}
+
 func TestSimple(t *testing.T) {
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	cluster := NewCluster(srv.Address)
-	cluster.ProtoVersion = int(defaultProto)
+	cluster := testCluster(srv.Address, defaultProto)
 	db, err := cluster.CreateSession()
 	if err != nil {
 		t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
@@ -94,18 +100,19 @@ func TestSSLSimpleNoClientCert(t *testing.T) {
 	}
 }
 
-func createTestSslCluster(hosts string, proto uint8, useClientCert bool) *ClusterConfig {
-	cluster := NewCluster(hosts)
+func createTestSslCluster(addr string, proto protoVersion, useClientCert bool) *ClusterConfig {
+	cluster := testCluster(addr, proto)
 	sslOpts := &SslOptions{
 		CaPath:                 "testdata/pki/ca.crt",
 		EnableHostVerification: false,
 	}
+
 	if useClientCert {
 		sslOpts.CertPath = "testdata/pki/gocql.crt"
 		sslOpts.KeyPath = "testdata/pki/gocql.key"
 	}
+
 	cluster.SslOpts = sslOpts
-	cluster.ProtoVersion = int(proto)
 	return cluster
 }
 
@@ -115,28 +122,23 @@ func TestClosed(t *testing.T) {
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	cluster := NewCluster(srv.Address)
-	cluster.ProtoVersion = int(defaultProto)
-
-	session, err := cluster.CreateSession()
-	defer session.Close()
+	session, err := newTestSession(srv.Address, defaultProto)
 	if err != nil {
 		t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
 	}
 
+	session.Close()
+
 	if err := session.Query("void").Exec(); err != ErrSessionClosed {
 		t.Fatalf("0x%x: expected %#v, got %#v", defaultProto, ErrSessionClosed, err)
 	}
 }
 
-func newTestSession(addr string, proto uint8) (*Session, error) {
-	cluster := NewCluster(addr)
-	cluster.ProtoVersion = int(proto)
-	return cluster.CreateSession()
+func newTestSession(addr string, proto protoVersion) (*Session, error) {
+	return testCluster(addr, proto).CreateSession()
 }
 
 func TestTimeout(t *testing.T) {
-
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
@@ -197,7 +199,7 @@ func TestStreams_Protocol1(t *testing.T) {
 
 	// TODO: these are more like session tests and should instead operate
 	// on a single Conn
-	cluster := NewCluster(srv.Address)
+	cluster := testCluster(srv.Address, protoVersion1)
 	cluster.NumConns = 1
 	cluster.ProtoVersion = 1
 
@@ -229,7 +231,7 @@ func TestStreams_Protocol3(t *testing.T) {
 
 	// TODO: these are more like session tests and should instead operate
 	// on a single Conn
-	cluster := NewCluster(srv.Address)
+	cluster := testCluster(srv.Address, protoVersion3)
 	cluster.NumConns = 1
 	cluster.ProtoVersion = 3
 
@@ -356,7 +358,6 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 
 	db, err := cluster.CreateSession()
 	if err != nil {
-		db.Close()
 		t.Fatalf("failed to create new session: %v", err)
 	}
 
@@ -377,7 +378,7 @@ func TestQueryTimeout(t *testing.T) {
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	cluster := NewCluster(srv.Address)
+	cluster := testCluster(srv.Address, defaultProto)
 	// Set the timeout arbitrarily low so that the query hits the timeout in a
 	// timely manner.
 	cluster.Timeout = 1 * time.Millisecond
@@ -418,7 +419,7 @@ func TestQueryTimeoutReuseStream(t *testing.T) {
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	cluster := NewCluster(srv.Address)
+	cluster := testCluster(srv.Address, defaultProto)
 	// Set the timeout arbitrarily low so that the query hits the timeout in a
 	// timely manner.
 	cluster.Timeout = 1 * time.Millisecond
@@ -442,7 +443,7 @@ func TestQueryTimeoutClose(t *testing.T) {
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	cluster := NewCluster(srv.Address)
+	cluster := testCluster(srv.Address, defaultProto)
 	// Set the timeout arbitrarily low so that the query hits the timeout in a
 	// timely manner.
 	cluster.Timeout = 1000 * time.Millisecond

+ 44 - 19
connectionpool.go

@@ -71,16 +71,15 @@ type policyConnPool struct {
 	hostConnPools map[string]*hostConnPool
 }
 
-func newPolicyConnPool(session *Session, hostPolicy HostSelectionPolicy,
-	connPolicy func() ConnSelectionPolicy) (*policyConnPool, error) {
+func connConfig(session *Session) (*ConnConfig, error) {
+	cfg := session.cfg
 
 	var (
 		err       error
 		tlsConfig *tls.Config
 	)
 
-	cfg := session.cfg
-
+	// TODO(zariel): move tls config setup into session init.
 	if cfg.SslOpts != nil {
 		tlsConfig, err = setupTLSConfig(cfg.SslOpts)
 		if err != nil {
@@ -88,28 +87,41 @@ func newPolicyConnPool(session *Session, hostPolicy HostSelectionPolicy,
 		}
 	}
 
+	return &ConnConfig{
+		ProtoVersion:  cfg.ProtoVersion,
+		CQLVersion:    cfg.CQLVersion,
+		Timeout:       cfg.Timeout,
+		NumStreams:    cfg.NumStreams,
+		Compressor:    cfg.Compressor,
+		Authenticator: cfg.Authenticator,
+		Keepalive:     cfg.SocketKeepalive,
+		tlsConfig:     tlsConfig,
+	}, nil
+}
+
+func newPolicyConnPool(session *Session, hostPolicy HostSelectionPolicy,
+	connPolicy func() ConnSelectionPolicy) (*policyConnPool, error) {
+
+	connCfg, err := connConfig(session)
+	if err != nil {
+		return nil, err
+	}
+
 	// create the pool
 	pool := &policyConnPool{
-		session:  session,
-		port:     cfg.Port,
-		numConns: cfg.NumConns,
-		connCfg: &ConnConfig{
-			ProtoVersion:  cfg.ProtoVersion,
-			CQLVersion:    cfg.CQLVersion,
-			Timeout:       cfg.Timeout,
-			Compressor:    cfg.Compressor,
-			Authenticator: cfg.Authenticator,
-			Keepalive:     cfg.SocketKeepalive,
-			tlsConfig:     tlsConfig,
-		},
-		keyspace:      cfg.Keyspace,
+		session:       session,
+		port:          session.cfg.Port,
+		numConns:      session.cfg.NumConns,
+		connCfg:       connCfg,
+		keyspace:      session.cfg.Keyspace,
 		hostPolicy:    hostPolicy,
 		connPolicy:    connPolicy,
 		hostConnPools: map[string]*hostConnPool{},
 	}
 
-	hosts := make([]HostInfo, len(cfg.Hosts))
-	for i, hostAddr := range cfg.Hosts {
+	// TODO(zariel): fetch this from session metadata.
+	hosts := make([]HostInfo, len(session.cfg.Hosts))
+	for i, hostAddr := range session.cfg.Hosts {
 		hosts[i].Peer = hostAddr
 	}
 
@@ -241,6 +253,7 @@ func (p *policyConnPool) addHost(host *HostInfo) {
 }
 
 func (p *policyConnPool) removeHost(addr string) {
+	p.hostPolicy.RemoveHost(addr)
 	p.mu.Lock()
 
 	pool, ok := p.hostConnPools[addr]
@@ -255,6 +268,18 @@ func (p *policyConnPool) removeHost(addr string) {
 	pool.Close()
 }
 
+func (p *policyConnPool) hostUp(host *HostInfo) {
+	// TODO(zariel): have a set of up hosts and down hosts, we can internally
+	// detect down hosts, then try to reconnect to them.
+	p.addHost(host)
+}
+
+func (p *policyConnPool) hostDown(addr string) {
+	// TODO(zariel): mark host as down so we can try to connect to it later, for
+	// now just treat it has removed.
+	p.removeHost(addr)
+}
+
 // hostConnPool is a connection pool for a single host.
 // Connection selection is based on a provided ConnSelectionPolicy
 type hostConnPool struct {

+ 67 - 8
control.go

@@ -4,23 +4,26 @@ import (
 	"errors"
 	"fmt"
 	"log"
+	"math/rand"
 	"net"
+	"sync"
 	"sync/atomic"
 	"time"
 )
 
-// Ensure that the atomic variable is aligned to a 64bit boundary 
+// Ensure that the atomic variable is aligned to a 64bit boundary
 // so that atomic operations can be applied on 32bit architectures.
 type controlConn struct {
 	connecting uint64
 
 	session *Session
 
-	conn       atomic.Value
+	conn atomic.Value
 
 	retry RetryPolicy
 
-	quit chan struct{}
+	closeWg sync.WaitGroup
+	quit    chan struct{}
 }
 
 func createControlConn(session *Session) *controlConn {
@@ -31,12 +34,14 @@ func createControlConn(session *Session) *controlConn {
 	}
 
 	control.conn.Store((*Conn)(nil))
-	go control.heartBeat()
 
 	return control
 }
 
 func (c *controlConn) heartBeat() {
+	c.closeWg.Add(1)
+	defer c.closeWg.Done()
+
 	for {
 		select {
 		case <-c.quit:
@@ -62,12 +67,61 @@ func (c *controlConn) heartBeat() {
 		c.reconnect(true)
 		// time.Sleep(5 * time.Second)
 		continue
+	}
+}
+
+func (c *controlConn) connect(endpoints []string) error {
+	// intial connection attmept, try to connect to each endpoint to get an initial
+	// list of nodes.
+
+	// shuffle endpoints so not all drivers will connect to the same initial
+	// node.
+	r := rand.New(rand.NewSource(time.Now().UnixNano()))
+	perm := r.Perm(len(endpoints))
+	shuffled := make([]string, len(endpoints))
+
+	for i, endpoint := range endpoints {
+		shuffled[perm[i]] = endpoint
+	}
 
+	connCfg, err := connConfig(c.session)
+	if err != nil {
+		return err
 	}
+
+	// store that we are not connected so that reconnect wont happen if we error
+	atomic.StoreInt64(&c.connecting, -1)
+
+	var (
+		conn *Conn
+	)
+
+	for _, addr := range shuffled {
+		conn, err = Connect(JoinHostPort(addr, c.session.cfg.Port), connCfg, c, c.session)
+		if err != nil {
+			log.Printf("gocql: unable to dial %v: %v\n", addr, err)
+			continue
+		}
+
+		// we should fetch the initial ring here and update initial host data. So that
+		// when we return from here we have a ring topology ready to go.
+		break
+	}
+
+	if conn == nil {
+		// this is fatal, not going to connect a session
+		return err
+	}
+
+	c.conn.Store(conn)
+	atomic.StoreInt64(&c.connecting, 0)
+	go c.heartBeat()
+
+	return nil
 }
 
 func (c *controlConn) reconnect(refreshring bool) {
-	if !atomic.CompareAndSwapUint64(&c.connecting, 0, 1) {
+	if !atomic.CompareAndSwapInt64(&c.connecting, 0, 1) {
 		return
 	}
 
@@ -77,10 +131,10 @@ func (c *controlConn) reconnect(refreshring bool) {
 		if success {
 			go func() {
 				time.Sleep(500 * time.Millisecond)
-				atomic.StoreUint64(&c.connecting, 0)
+				atomic.StoreInt64(&c.connecting, 0)
 			}()
 		} else {
-			atomic.StoreUint64(&c.connecting, 0)
+			atomic.StoreInt64(&c.connecting, 0)
 		}
 	}()
 
@@ -120,7 +174,7 @@ func (c *controlConn) reconnect(refreshring bool) {
 		oldConn.Close()
 	}
 
-	if refreshring && c.session.cfg.DiscoverHosts {
+	if refreshring {
 		c.session.hostSource.refreshRing()
 	}
 }
@@ -242,6 +296,11 @@ func (c *controlConn) addr() string {
 func (c *controlConn) close() {
 	// TODO: handle more gracefully
 	close(c.quit)
+	c.closeWg.Wait()
+	conn := c.conn.Load().(*Conn)
+	if conn != nil {
+		conn.Close()
+	}
 }
 
 var errNoControl = errors.New("gocql: no control connection available")

+ 10 - 14
events.go

@@ -43,36 +43,32 @@ func (s *Session) handleEvent(framer *framer) {
 }
 
 func (s *Session) handleNewNode(host net.IP, port int) {
-	if !s.cfg.DiscoverHosts || s.control == nil {
+	// TODO(zariel): need to be able to filter discovered nodes
+	if s.control == nil {
 		return
 	}
 
-	if s.control.addr() == host.String() {
-		go s.control.reconnect(false)
-	}
-
 	hostInfo, err := s.control.fetchHostInfo(host, port)
 	if err != nil {
 		log.Printf("gocql: unable to fetch host info for %v: %v\n", host, err)
 		return
 	}
 
-	s.pool.addHost(hostInfo)
+	if s.hostFilter.Accept(*hostInfo) {
+		s.pool.addHost(hostInfo)
+	}
 }
 
 func (s *Session) handleRemovedNode(host net.IP, port int) {
-	if !s.cfg.DiscoverHosts {
-		return
-	}
-
+	// we remove all nodes but only add ones which pass the filter
 	s.pool.removeHost(host.String())
 }
 
 func (s *Session) handleNodeUp(host net.IP, port int) {
-	// even if were not disconvering new nodes we should still handle nodes going
-	// up.
-
-	s.pool.hostUp(host.String())
+	// TODO(zariel): handle this case even when not discovering, just mark the
+	// host up.
+	// TODO: implement this properly not as newNode
+	s.handleNewNode(host, port)
 }
 
 func (s *Session) handleNodeDown(host net.IP, port int) {

+ 44 - 18
policies.go

@@ -55,7 +55,7 @@ func (c *cowHostList) add(host HostInfo) {
 	c.mu.Unlock()
 }
 
-func (c *cowHostList) remove(host HostInfo) {
+func (c *cowHostList) remove(addr string) {
 	c.mu.Lock()
 	l := c.get()
 	size := len(l)
@@ -67,7 +67,7 @@ func (c *cowHostList) remove(host HostInfo) {
 	found := false
 	newL := make([]HostInfo, 0, size)
 	for i := 0; i < len(l); i++ {
-		if host.Peer != l[i].Peer && host.HostId != l[i].HostId {
+		if l[i].Peer != addr {
 			newL = append(newL, l[i])
 		} else {
 			found = true
@@ -124,6 +124,8 @@ func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool {
 
 type HostStateNotifier interface {
 	AddHost(host *HostInfo)
+	RemoveHost(addr string)
+	// TODO(zariel): add host up/down
 }
 
 // HostSelectionPolicy is an interface for selecting
@@ -149,7 +151,7 @@ type NextHost func() SelectedHost
 // RoundRobinHostPolicy is a round-robin load balancing policy, where each host
 // is tried sequentially for each query.
 func RoundRobinHostPolicy() HostSelectionPolicy {
-	return &roundRobinHostPolicy{hosts: []HostInfo{}}
+	return &roundRobinHostPolicy{}
 }
 
 type roundRobinHostPolicy struct {
@@ -179,7 +181,7 @@ func (r *roundRobinHostPolicy) Pick(qry *Query) NextHost {
 		// always increment pos to evenly distribute traffic in case of
 		// failures
 		pos := atomic.AddUint32(&r.pos, 1)
-		if i >= len(r.hosts) {
+		if i >= len(hosts) {
 			return nil
 		}
 		host := &r.hosts[(pos)%uint32(len(r.hosts))]
@@ -192,6 +194,10 @@ func (r *roundRobinHostPolicy) AddHost(host *HostInfo) {
 	r.hosts.add(*host)
 }
 
+func (r *roundRobinHostPolicy) RemoveHost(addr string) {
+	r.hosts.remove(addr)
+}
+
 // selectedRoundRobinHost is a host returned by the roundRobinHostPolicy and
 // implements the SelectedHost interface
 type selectedRoundRobinHost struct {
@@ -210,24 +216,25 @@ func (host selectedRoundRobinHost) Mark(err error) {
 // selected based on the partition key, so queries are sent to the host which
 // owns the partition. Fallback is used when routing information is not available.
 func TokenAwareHostPolicy(fallback HostSelectionPolicy) HostSelectionPolicy {
-	return &tokenAwareHostPolicy{fallback: fallback, hosts: []HostInfo{}}
+	return &tokenAwareHostPolicy{fallback: fallback}
 }
 
 type tokenAwareHostPolicy struct {
+	hosts       cowHostList
 	mu          sync.RWMutex
-	hosts       []HostInfo
 	partitioner string
 	tokenRing   *tokenRing
 	fallback    HostSelectionPolicy
 }
 
 func (t *tokenAwareHostPolicy) SetHosts(hosts []HostInfo) {
+	t.hosts.set(hosts)
+
 	t.mu.Lock()
 	defer t.mu.Unlock()
 
 	// always update the fallback
 	t.fallback.SetHosts(hosts)
-	t.hosts = hosts
 
 	t.resetTokenRing()
 }
@@ -245,19 +252,19 @@ func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
 }
 
 func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) {
+	t.hosts.add(*host)
+
 	t.mu.Lock()
-	defer t.mu.Unlock()
+	t.resetTokenRing()
+	t.mu.Unlock()
+}
 
-	t.fallback.AddHost(host)
-	for i := range t.hosts {
-		h := &t.hosts[i]
-		if h.HostId == host.HostId && h.Peer == host.Peer {
-			return
-		}
-	}
+func (t *tokenAwareHostPolicy) RemoveHost(addr string) {
+	t.hosts.remove(addr)
 
-	t.hosts = append(t.hosts, *host)
+	t.mu.Lock()
 	t.resetTokenRing()
+	t.mu.Unlock()
 }
 
 func (t *tokenAwareHostPolicy) resetTokenRing() {
@@ -267,7 +274,8 @@ func (t *tokenAwareHostPolicy) resetTokenRing() {
 	}
 
 	// create a new token ring
-	tokenRing, err := newTokenRing(t.partitioner, t.hosts)
+	hosts := t.hosts.get()
+	tokenRing, err := newTokenRing(t.partitioner, hosts)
 	if err != nil {
 		log.Printf("Unable to update the token ring due to error: %s", err)
 		return
@@ -309,6 +317,7 @@ func (t *tokenAwareHostPolicy) Pick(qry *Query) NextHost {
 		hostReturned bool
 		fallbackIter NextHost
 	)
+
 	return func() SelectedHost {
 		if !hostReturned {
 			hostReturned = true
@@ -365,8 +374,8 @@ func HostPoolHostPolicy(hp hostpool.HostPool) HostSelectionPolicy {
 
 type hostPoolHostPolicy struct {
 	hp      hostpool.HostPool
-	hostMap map[string]HostInfo
 	mu      sync.RWMutex
+	hostMap map[string]HostInfo
 }
 
 func (r *hostPoolHostPolicy) SetHosts(hosts []HostInfo) {
@@ -402,6 +411,23 @@ func (r *hostPoolHostPolicy) AddHost(host *HostInfo) {
 	r.hostMap[host.Peer] = *host
 }
 
+func (r *hostPoolHostPolicy) RemoveHost(addr string) {
+	r.mu.Unlock()
+	defer r.mu.Unlock()
+
+	if _, ok := r.hostMap[addr]; !ok {
+		return
+	}
+
+	delete(r.hostMap, addr)
+	hosts := make([]string, 0, len(r.hostMap))
+	for addr := range r.hostMap {
+		hosts = append(hosts, addr)
+	}
+
+	r.hp.SetHosts(hosts)
+}
+
 func (r *hostPoolHostPolicy) SetPartitioner(partitioner string) {
 	// noop
 }

+ 16 - 17
policies_test.go

@@ -23,30 +23,29 @@ func TestRoundRobinHostPolicy(t *testing.T) {
 
 	policy.SetHosts(hosts)
 
-	// the first host selected is actually at [1], but this is ok for RR
 	// interleaved iteration should always increment the host
 	iterA := policy.Pick(nil)
-	if actual := iterA(); actual.Info() != &hosts[1] {
-		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostId)
-	}
-	iterB := policy.Pick(nil)
-	if actual := iterB(); actual.Info() != &hosts[0] {
+	if actual := iterA(); actual.Info() != &hosts[0] {
 		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostId)
 	}
+	iterB := policy.Pick(nil)
 	if actual := iterB(); actual.Info() != &hosts[1] {
 		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostId)
 	}
-	if actual := iterA(); actual.Info() != &hosts[0] {
+	if actual := iterB(); actual.Info() != &hosts[0] {
 		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostId)
 	}
-
-	iterC := policy.Pick(nil)
-	if actual := iterC(); actual.Info() != &hosts[1] {
+	if actual := iterA(); actual.Info() != &hosts[1] {
 		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostId)
 	}
+
+	iterC := policy.Pick(nil)
 	if actual := iterC(); actual.Info() != &hosts[0] {
 		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostId)
 	}
+	if actual := iterC(); actual.Info() != &hosts[1] {
+		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostId)
+	}
 }
 
 // Tests of the token-aware host selection policy implementation with a
@@ -76,13 +75,13 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 
 	// the token ring is not setup without the partitioner, but the fallback
 	// should work
-	if actual := policy.Pick(nil)(); actual.Info().Peer != "1" {
-		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer)
+	if actual := policy.Pick(nil)(); actual.Info().Peer != "0" {
+		t.Errorf("Expected peer 0 but was %s", actual.Info().Peer)
 	}
 
 	query.RoutingKey([]byte("30"))
-	if actual := policy.Pick(query)(); actual.Info().Peer != "2" {
-		t.Errorf("Expected peer 2 but was %s", actual.Info().Peer)
+	if actual := policy.Pick(query)(); actual.Info().Peer != "1" {
+		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer)
 	}
 
 	policy.SetPartitioner("OrderedPartitioner")
@@ -94,15 +93,15 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer)
 	}
 	// rest are round robin
+	if actual := iter(); actual.Info().Peer != "2" {
+		t.Errorf("Expected peer 2 but was %s", actual.Info().Peer)
+	}
 	if actual := iter(); actual.Info().Peer != "3" {
 		t.Errorf("Expected peer 3 but was %s", actual.Info().Peer)
 	}
 	if actual := iter(); actual.Info().Peer != "0" {
 		t.Errorf("Expected peer 0 but was %s", actual.Info().Peer)
 	}
-	if actual := iter(); actual.Info().Peer != "2" {
-		t.Errorf("Expected peer 2 but was %s", actual.Info().Peer)
-	}
 }
 
 // Tests of the host pool host selection policy implementation

+ 34 - 29
session.go

@@ -10,7 +10,6 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"log"
 	"strings"
 	"sync"
 	"time"
@@ -39,6 +38,8 @@ type Session struct {
 	hostSource          *ringDescriber
 	mu                  sync.RWMutex
 
+	hostFilter HostFilter
+
 	control *controlConn
 
 	cfg ClusterConfig
@@ -66,49 +67,53 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 		pageSize: cfg.PageSize,
 	}
 
-	pool, err := cfg.PoolConfig.buildPool(s)
-	if err != nil {
-		return nil, err
-	}
-	s.pool = pool
-
-	// See if there are any connections in the pool
-	if pool.Size() == 0 {
-		s.Close()
-		return nil, ErrNoConnectionsStarted
-	}
-
 	s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo)
 
 	// I think it might be a good idea to simplify this and make it always discover
 	// hosts, maybe with more filters.
-	if cfg.DiscoverHosts {
-		s.hostSource = &ringDescriber{
-			session:    s,
-			dcFilter:   cfg.Discovery.DcFilter,
-			rackFilter: cfg.Discovery.RackFilter,
-			closeChan:  make(chan bool),
-		}
+	s.hostSource = &ringDescriber{
+		session:   s,
+		closeChan: make(chan bool),
 	}
 
 	if !cfg.disableControlConn {
 		s.control = createControlConn(s)
-		s.control.reconnect(false)
+		if err := s.control.connect(cfg.Hosts); err != nil {
+			s.control.close()
+			return nil, err
+		}
 
-		// need to setup host source to check for rpc_address in system.local
-		localHasRPCAddr, err := checkSystemLocal(s.control)
+		// need to setup host source to check for broadcast_address in system.local
+		localHasRPCAddr, _ := checkSystemLocal(s.control)
+		s.hostSource.localHasRpcAddr = localHasRPCAddr
+		hosts, _, err := s.hostSource.GetHosts()
 		if err != nil {
-			log.Printf("gocql: unable to verify if system.local table contains rpc_address, falling back to connection address: %v", err)
+			s.control.close()
+			return nil, err
 		}
 
-		if cfg.DiscoverHosts {
-			s.hostSource.localHasRpcAddr = localHasRPCAddr
+		pool, err := cfg.PoolConfig.buildPool(s)
+		if err != nil {
+			return nil, err
 		}
+		s.pool = pool
+		// TODO(zariel): this should be used to create initial metadata
+		s.pool.SetHosts(hosts)
+	} else {
+		// TODO(zariel): remove branch for creating pools
+		pool, err := cfg.PoolConfig.buildPool(s)
+		if err != nil {
+			return nil, err
+		}
+		s.pool = pool
 	}
 
-	if cfg.DiscoverHosts {
-		s.hostSource.refreshRing()
-		go s.hostSource.run(cfg.Discovery.Sleep)
+	// TODO(zariel): we probably dont need this any more as we verify that we
+	// can connect to one of the endpoints supplied by using the control conn.
+	// See if there are any connections in the pool
+	if s.pool.Size() == 0 {
+		s.Close()
+		return nil, ErrNoConnectionsStarted
 	}
 
 	return s, nil