소스 검색

add queryExecutor which executes queries.

Simplify how queries and batches are executed by making them both
implement the same interface which allows the queryExecutor to handle
both cases.

Remove conn selection policy and let the connection pool make decisions
about which conn to use.
Chris Bannister 9 년 전
부모
커밋
a04d2a5e3f
13개의 변경된 파일302개의 추가작업 그리고 340개의 파일을 삭제
  1. 13 15
      cassandra_test.go
  2. 1 15
      cluster.go
  3. 0 1
      conn_test.go
  4. 25 74
      connectionpool.go
  5. 3 3
      control.go
  6. 4 0
      events.go
  7. 1 1
      host_source.go
  8. 37 106
      policies.go
  9. 20 36
      policies_test.go
  10. 65 0
      query_executor.go
  11. 30 0
      ring.go
  12. 96 87
      session.go
  13. 7 2
      session_test.go

+ 13 - 15
cassandra_test.go

@@ -1019,6 +1019,14 @@ func TestBatchQueryInfo(t *testing.T) {
 	}
 }
 
+func getRandomConn(t *testing.T, session *Session) *Conn {
+	conn := session.getConn()
+	if conn == nil {
+		t.Fatal("unable to get a connection")
+	}
+	return conn
+}
+
 func injectInvalidPreparedStatement(t *testing.T, session *Session, table string) (string, *Conn) {
 	if err := createTable(session, `CREATE TABLE gocql_test.`+table+` (
 			foo   varchar,
@@ -1029,7 +1037,8 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 	}
 
 	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
-	_, conn := session.pool.Pick(nil)
+
+	conn := getRandomConn(t, session)
 
 	flight := new(inflightPrepare)
 	key := session.stmtsLRU.keyFor(conn.addr, "", stmt)
@@ -1060,7 +1069,7 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 
 func TestPrepare_MissingSchemaPrepare(t *testing.T) {
 	s := createSession(t)
-	_, conn := s.pool.Pick(nil)
+	conn := getRandomConn(t, s)
 	defer s.Close()
 
 	insertQry := &Query{stmt: "INSERT INTO invalidschemaprep (val) VALUES (?)", values: []interface{}{5}, cons: s.cons,
@@ -1108,7 +1117,7 @@ func TestQueryInfo(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
 
-	_, conn := session.pool.Pick(nil)
+	conn := getRandomConn(t, session)
 	info, err := conn.prepareStatement("SELECT release_version, host_id FROM system.local WHERE key = ?", nil)
 
 	if err != nil {
@@ -1982,18 +1991,7 @@ func TestNegativeStream(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
 
-	var conn *Conn
-	for i := 0; i < 5; i++ {
-		if conn != nil {
-			break
-		}
-
-		_, conn = session.pool.Pick(nil)
-	}
-
-	if conn == nil {
-		t.Fatal("no connections available in the pool")
-	}
+	conn := getRandomConn(t, session)
 
 	const stream = -50
 	writer := frameWriterFunc(func(f *framer, streamID int) error {

+ 1 - 15
cluster.go

@@ -16,24 +16,10 @@ type PoolConfig struct {
 	// HostSelectionPolicy sets the policy for selecting which host to use for a
 	// given query (default: RoundRobinHostPolicy())
 	HostSelectionPolicy HostSelectionPolicy
-
-	// ConnSelectionPolicy sets the policy factory for selecting a connection to use for
-	// each host for a query (default: RoundRobinConnPolicy())
-	ConnSelectionPolicy func() ConnSelectionPolicy
 }
 
 func (p PoolConfig) buildPool(session *Session) *policyConnPool {
-	hostSelection := p.HostSelectionPolicy
-	if hostSelection == nil {
-		hostSelection = RoundRobinHostPolicy()
-	}
-
-	connSelection := p.ConnSelectionPolicy
-	if connSelection == nil {
-		connSelection = RoundRobinConnPolicy()
-	}
-
-	return newPolicyConnPool(session, hostSelection, connSelection)
+	return newPolicyConnPool(session)
 }
 
 type DiscoveryConfig struct {

+ 0 - 1
conn_test.go

@@ -284,7 +284,6 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 
 	cluster := createTestSslCluster(srv.Address, defaultProto, true)
 	cluster.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
-	cluster.PoolConfig.ConnSelectionPolicy = RoundRobinConnPolicy()
 
 	db, err := cluster.CreateSession()
 	if err != nil {

+ 25 - 74
connectionpool.go

@@ -14,6 +14,7 @@ import (
 	"math/rand"
 	"net"
 	"sync"
+	"sync/atomic"
 	"time"
 )
 
@@ -65,8 +66,6 @@ type policyConnPool struct {
 	keyspace string
 
 	mu            sync.RWMutex
-	hostPolicy    HostSelectionPolicy
-	connPolicy    func() ConnSelectionPolicy
 	hostConnPools map[string]*hostConnPool
 
 	endpoints []string
@@ -99,17 +98,13 @@ func connConfig(session *Session) (*ConnConfig, error) {
 	}, nil
 }
 
-func newPolicyConnPool(session *Session, hostPolicy HostSelectionPolicy,
-	connPolicy func() ConnSelectionPolicy) *policyConnPool {
-
+func newPolicyConnPool(session *Session) *policyConnPool {
 	// create the pool
 	pool := &policyConnPool{
 		session:       session,
 		port:          session.cfg.Port,
 		numConns:      session.cfg.NumConns,
 		keyspace:      session.cfg.Keyspace,
-		hostPolicy:    hostPolicy,
-		connPolicy:    connPolicy,
 		hostConnPools: map[string]*hostConnPool{},
 	}
 
@@ -150,7 +145,6 @@ func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
 				p.port,
 				p.numConns,
 				p.keyspace,
-				p.connPolicy(),
 			)
 		}(host)
 	}
@@ -170,13 +164,6 @@ func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
 		delete(p.hostConnPools, addr)
 		go pool.Close()
 	}
-
-	// update the policy
-	p.hostPolicy.SetHosts(hosts)
-}
-
-func (p *policyConnPool) SetPartitioner(partitioner string) {
-	p.hostPolicy.SetPartitioner(partitioner)
 }
 
 func (p *policyConnPool) Size() int {
@@ -197,41 +184,10 @@ func (p *policyConnPool) getPool(addr string) (pool *hostConnPool, ok bool) {
 	return
 }
 
-func (p *policyConnPool) Pick(qry *Query) (SelectedHost, *Conn) {
-	nextHost := p.hostPolicy.Pick(qry)
-
-	var (
-		host SelectedHost
-		conn *Conn
-	)
-
-	p.mu.RLock()
-	defer p.mu.RUnlock()
-	for conn == nil {
-		host = nextHost()
-		if host == nil {
-			break
-		} else if host.Info() == nil {
-			panic(fmt.Sprintf("policy %T returned no host info: %+v", p.hostPolicy, host))
-		}
-
-		pool, ok := p.hostConnPools[host.Info().Peer()]
-		if !ok {
-			continue
-		}
-
-		conn = pool.Pick(qry)
-	}
-	return host, conn
-}
-
 func (p *policyConnPool) Close() {
 	p.mu.Lock()
 	defer p.mu.Unlock()
 
-	// remove the hosts from the policy
-	p.hostPolicy.SetHosts(nil)
-
 	// close the pools
 	for addr, pool := range p.hostConnPools {
 		delete(p.hostConnPools, addr)
@@ -249,7 +205,6 @@ func (p *policyConnPool) addHost(host *HostInfo) {
 			host.Port(), // TODO: if port == 0 use pool.port?
 			p.numConns,
 			p.keyspace,
-			p.connPolicy(),
 		)
 
 		p.hostConnPools[host.Peer()] = pool
@@ -257,17 +212,10 @@ func (p *policyConnPool) addHost(host *HostInfo) {
 	p.mu.Unlock()
 
 	pool.fill()
-
-	// update policy
-	// TODO: policy should not have conns, it should have hosts and return a host
-	// iter which the pool will use to serve conns
-	p.hostPolicy.AddHost(host)
 }
 
 func (p *policyConnPool) removeHost(addr string) {
-	p.hostPolicy.RemoveHost(addr)
 	p.mu.Lock()
-
 	pool, ok := p.hostConnPools[addr]
 	if !ok {
 		p.mu.Unlock()
@@ -301,12 +249,13 @@ type hostConnPool struct {
 	addr     string
 	size     int
 	keyspace string
-	policy   ConnSelectionPolicy
 	// protection for conns, closed, filling
 	mu      sync.RWMutex
 	conns   []*Conn
 	closed  bool
 	filling bool
+
+	pos uint32
 }
 
 func (h *hostConnPool) String() string {
@@ -317,7 +266,7 @@ func (h *hostConnPool) String() string {
 }
 
 func newHostConnPool(session *Session, host *HostInfo, port, size int,
-	keyspace string, policy ConnSelectionPolicy) *hostConnPool {
+	keyspace string) *hostConnPool {
 
 	pool := &hostConnPool{
 		session:  session,
@@ -326,7 +275,6 @@ func newHostConnPool(session *Session, host *HostInfo, port, size int,
 		addr:     JoinHostPort(host.Peer(), port),
 		size:     size,
 		keyspace: keyspace,
-		policy:   policy,
 		conns:    make([]*Conn, 0, size),
 		filling:  false,
 		closed:   false,
@@ -337,16 +285,15 @@ func newHostConnPool(session *Session, host *HostInfo, port, size int,
 }
 
 // Pick a connection from this connection pool for the given query.
-func (pool *hostConnPool) Pick(qry *Query) *Conn {
+func (pool *hostConnPool) Pick() *Conn {
 	pool.mu.RLock()
+	defer pool.mu.RUnlock()
+
 	if pool.closed {
-		pool.mu.RUnlock()
 		return nil
 	}
 
 	size := len(pool.conns)
-	pool.mu.RUnlock()
-
 	if size < pool.size {
 		// try to fill the pool
 		go pool.fill()
@@ -356,7 +303,23 @@ func (pool *hostConnPool) Pick(qry *Query) *Conn {
 		}
 	}
 
-	return pool.policy.Pick(qry)
+	pos := int(atomic.AddUint32(&pool.pos, 1) - 1)
+
+	var (
+		leastBusyConn    *Conn
+		streamsAvailable int
+	)
+
+	// find the conn which has the most available streams, this is racy
+	for i := 0; i < size; i++ {
+		conn := pool.conns[(pos+i)%size]
+		if streams := conn.AvailableStreams(); streams > streamsAvailable {
+			leastBusyConn = conn
+			streamsAvailable = streams
+		}
+	}
+
+	return leastBusyConn
 }
 
 //Size returns the number of connections currently active in the pool
@@ -543,10 +506,6 @@ func (pool *hostConnPool) connect() (err error) {
 
 	pool.conns = append(pool.conns, conn)
 
-	conns := make([]*Conn, len(pool.conns))
-	copy(conns, pool.conns)
-	pool.policy.SetConns(conns)
-
 	return nil
 }
 
@@ -573,11 +532,6 @@ func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) {
 			// remove the connection, not preserving order
 			pool.conns[i], pool.conns = pool.conns[len(pool.conns)-1], pool.conns[:len(pool.conns)-1]
 
-			// update the policy
-			conns := make([]*Conn, len(pool.conns))
-			copy(conns, pool.conns)
-			pool.policy.SetConns(conns)
-
 			// lost a connection, so fill the pool
 			go pool.fill()
 			break
@@ -590,9 +544,6 @@ func (pool *hostConnPool) drainLocked() {
 	conns := pool.conns
 	pool.conns = nil
 
-	// update the policy
-	pool.policy.SetConns(nil)
-
 	// close the connections
 	for _, conn := range conns {
 		conn.Close()

+ 3 - 3
control.go

@@ -238,14 +238,14 @@ func (c *controlConn) reconnect(refreshring bool) {
 	// TODO: should have our own roundrobbin for hosts so that we can try each
 	// in succession and guantee that we get a different host each time.
 	if newConn == nil {
-		_, conn := c.session.pool.Pick(nil)
-		if conn == nil {
+		host := c.session.ring.rrHost()
+		if host == nil {
 			c.connect(c.session.ring.endpoints)
 			return
 		}
 
 		var err error
-		newConn, err = c.session.connect(conn.addr, c, conn.host)
+		newConn, err = c.session.connect(host.Peer(), c, host)
 		if err != nil {
 			// TODO: add log handler for things like this
 			return

+ 4 - 0
events.go

@@ -201,6 +201,7 @@ func (s *Session) handleNewNode(host net.IP, port int, waitForBinary bool) {
 	}
 
 	s.pool.addHost(hostInfo)
+	s.policy.AddHost(hostInfo)
 	hostInfo.setState(NodeUp)
 
 	if s.control != nil {
@@ -222,6 +223,7 @@ func (s *Session) handleRemovedNode(ip net.IP, port int) {
 	}
 
 	host.setState(NodeDown)
+	s.policy.RemoveHost(addr)
 	s.pool.removeHost(addr)
 	s.ring.removeHost(addr)
 
@@ -251,6 +253,7 @@ func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) {
 
 		host.setPort(port)
 		s.pool.hostUp(host)
+		s.policy.HostUp(host)
 		host.setState(NodeUp)
 		return
 	}
@@ -270,5 +273,6 @@ func (s *Session) handleNodeDown(ip net.IP, port int) {
 	}
 
 	host.setState(NodeDown)
+	s.policy.HostDown(addr)
 	s.pool.hostDown(addr)
 }

+ 1 - 1
host_source.go

@@ -390,6 +390,6 @@ func (r *ringDescriber) refreshRing() error {
 		}
 	}
 
-	r.session.pool.SetPartitioner(partitioner)
+	r.session.metadata.setPartitioner(partitioner)
 	return nil
 }

+ 37 - 106
policies.go

@@ -162,17 +162,17 @@ func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool {
 type HostStateNotifier interface {
 	AddHost(host *HostInfo)
 	RemoveHost(addr string)
-	// TODO(zariel): add host up/down
+	HostUp(host *HostInfo)
+	HostDown(addr string)
 }
 
 // HostSelectionPolicy is an interface for selecting
 // the most appropriate host to execute a given query.
 type HostSelectionPolicy interface {
 	HostStateNotifier
-	SetHosts
 	SetPartitioner
 	//Pick returns an iteration function over selected hosts
-	Pick(*Query) NextHost
+	Pick(ExecutableQuery) NextHost
 }
 
 // SelectedHost is an interface returned when picking a host from a host
@@ -182,6 +182,14 @@ type SelectedHost interface {
 	Mark(error)
 }
 
+type selectedHost HostInfo
+
+func (host *selectedHost) Info() *HostInfo {
+	return (*HostInfo)(host)
+}
+
+func (host *selectedHost) Mark(err error) {}
+
 // NextHost is an iteration function over picked hosts
 type NextHost func() SelectedHost
 
@@ -197,15 +205,11 @@ type roundRobinHostPolicy struct {
 	mu    sync.RWMutex
 }
 
-func (r *roundRobinHostPolicy) SetHosts(hosts []*HostInfo) {
-	r.hosts.set(hosts)
-}
-
 func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {
 	// noop
 }
 
-func (r *roundRobinHostPolicy) Pick(qry *Query) NextHost {
+func (r *roundRobinHostPolicy) Pick(qry ExecutableQuery) NextHost {
 	// i is used to limit the number of attempts to find a host
 	// to the number of hosts known to this policy
 	var i int
@@ -223,7 +227,7 @@ func (r *roundRobinHostPolicy) Pick(qry *Query) NextHost {
 		}
 		host := hosts[(pos)%uint32(len(hosts))]
 		i++
-		return selectedRoundRobinHost{host}
+		return (*selectedHost)(host)
 	}
 }
 
@@ -235,18 +239,12 @@ func (r *roundRobinHostPolicy) RemoveHost(addr string) {
 	r.hosts.remove(addr)
 }
 
-// selectedRoundRobinHost is a host returned by the roundRobinHostPolicy and
-// implements the SelectedHost interface
-type selectedRoundRobinHost struct {
-	info *HostInfo
+func (r *roundRobinHostPolicy) HostUp(host *HostInfo) {
+	r.AddHost(host)
 }
 
-func (host selectedRoundRobinHost) Info() *HostInfo {
-	return host.info
-}
-
-func (host selectedRoundRobinHost) Mark(err error) {
-	// noop
+func (r *roundRobinHostPolicy) HostDown(addr string) {
+	r.RemoveHost(addr)
 }
 
 // TokenAwareHostPolicy is a token aware host selection policy, where hosts are
@@ -264,18 +262,6 @@ type tokenAwareHostPolicy struct {
 	fallback    HostSelectionPolicy
 }
 
-func (t *tokenAwareHostPolicy) SetHosts(hosts []*HostInfo) {
-	t.hosts.set(hosts)
-
-	t.mu.Lock()
-	defer t.mu.Unlock()
-
-	// always update the fallback
-	t.fallback.SetHosts(hosts)
-
-	t.resetTokenRing()
-}
-
 func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
 	t.mu.Lock()
 	defer t.mu.Unlock()
@@ -299,12 +285,21 @@ func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) {
 
 func (t *tokenAwareHostPolicy) RemoveHost(addr string) {
 	t.hosts.remove(addr)
+	t.fallback.RemoveHost(addr)
 
 	t.mu.Lock()
 	t.resetTokenRing()
 	t.mu.Unlock()
 }
 
+func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) {
+	t.AddHost(host)
+}
+
+func (t *tokenAwareHostPolicy) HostDown(addr string) {
+	t.RemoveHost(addr)
+}
+
 func (t *tokenAwareHostPolicy) resetTokenRing() {
 	if t.partitioner == "" {
 		// partitioner not yet set
@@ -323,14 +318,9 @@ func (t *tokenAwareHostPolicy) resetTokenRing() {
 	t.tokenRing = tokenRing
 }
 
-func (t *tokenAwareHostPolicy) Pick(qry *Query) NextHost {
+func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
 	if qry == nil {
 		return t.fallback.Pick(qry)
-	} else if qry.binding != nil && len(qry.values) == 0 {
-		// If this query was created using session.Bind we wont have the query
-		// values yet, so we have to pass down to the next policy.
-		// TODO: Remove this and handle this case
-		return t.fallback.Pick(qry)
 	}
 
 	routingKey, err := qry.GetRoutingKey()
@@ -359,7 +349,7 @@ func (t *tokenAwareHostPolicy) Pick(qry *Query) NextHost {
 	return func() SelectedHost {
 		if !hostReturned {
 			hostReturned = true
-			return selectedTokenAwareHost{host}
+			return (*selectedHost)(host)
 		}
 
 		// fallback
@@ -378,20 +368,6 @@ func (t *tokenAwareHostPolicy) Pick(qry *Query) NextHost {
 	}
 }
 
-// selectedTokenAwareHost is a host returned by the tokenAwareHostPolicy and
-// implements the SelectedHost interface
-type selectedTokenAwareHost struct {
-	info *HostInfo
-}
-
-func (host selectedTokenAwareHost) Info() *HostInfo {
-	return host.info
-}
-
-func (host selectedTokenAwareHost) Mark(err error) {
-	// noop
-}
-
 // HostPoolHostPolicy is a host policy which uses the bitly/go-hostpool library
 // to distribute queries between hosts and prevent sending queries to
 // unresponsive hosts. When creating the host pool that is passed to the policy
@@ -466,11 +442,19 @@ func (r *hostPoolHostPolicy) RemoveHost(addr string) {
 	r.hp.SetHosts(hosts)
 }
 
+func (r *hostPoolHostPolicy) HostUp(host *HostInfo) {
+	r.AddHost(host)
+}
+
+func (r *hostPoolHostPolicy) HostDown(addr string) {
+	r.RemoveHost(addr)
+}
+
 func (r *hostPoolHostPolicy) SetPartitioner(partitioner string) {
 	// noop
 }
 
-func (r *hostPoolHostPolicy) Pick(qry *Query) NextHost {
+func (r *hostPoolHostPolicy) Pick(qry ExecutableQuery) NextHost {
 	return func() SelectedHost {
 		r.mu.RLock()
 		defer r.mu.RUnlock()
@@ -516,56 +500,3 @@ func (host selectedHostPoolHost) Mark(err error) {
 
 	host.hostR.Mark(err)
 }
-
-//ConnSelectionPolicy is an interface for selecting an
-//appropriate connection for executing a query
-type ConnSelectionPolicy interface {
-	SetConns(conns []*Conn)
-	Pick(*Query) *Conn
-}
-
-type roundRobinConnPolicy struct {
-	// pos is still used to evenly distribute queries amongst connections.
-	pos   uint32
-	conns atomic.Value // *[]*Conn
-}
-
-func RoundRobinConnPolicy() func() ConnSelectionPolicy {
-	return func() ConnSelectionPolicy {
-		p := &roundRobinConnPolicy{}
-		var conns []*Conn
-		p.conns.Store(&conns)
-		return p
-	}
-}
-
-func (r *roundRobinConnPolicy) SetConns(conns []*Conn) {
-	// NOTE: we do not need to lock here due to the conneciton pool is already
-	// holding its own mutex over the conn seleciton policy
-	r.conns.Store(&conns)
-}
-
-func (r *roundRobinConnPolicy) Pick(qry *Query) *Conn {
-	conns := *(r.conns.Load().(*[]*Conn))
-	if len(conns) == 0 {
-		return nil
-	}
-
-	pos := int(atomic.AddUint32(&r.pos, 1) - 1)
-
-	var (
-		leastBusyConn    *Conn
-		streamsAvailable int
-	)
-
-	// find the conn which has the most available streams, this is racy
-	for i := 0; i < len(conns); i++ {
-		conn := conns[(pos+i)%len(conns)]
-		if streams := conn.AvailableStreams(); streams > streamsAvailable {
-			leastBusyConn = conn
-			streamsAvailable = streams
-		}
-	}
-
-	return leastBusyConn
-}

+ 20 - 36
policies_test.go

@@ -6,7 +6,6 @@ package gocql
 
 import (
 	"fmt"
-	"github.com/gocql/gocql/internal/streams"
 	"testing"
 
 	"github.com/hailocab/go-hostpool"
@@ -16,12 +15,14 @@ import (
 func TestRoundRobinHostPolicy(t *testing.T) {
 	policy := RoundRobinHostPolicy()
 
-	hosts := []*HostInfo{
+	hosts := [...]*HostInfo{
 		{hostId: "0"},
 		{hostId: "1"},
 	}
 
-	policy.SetHosts(hosts)
+	for _, host := range hosts {
+		policy.AddHost(host)
+	}
 
 	// interleaved iteration should always increment the host
 	iterA := policy.Pick(nil)
@@ -65,13 +66,15 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 	}
 
 	// set the hosts
-	hosts := []*HostInfo{
+	hosts := [...]*HostInfo{
 		{peer: "0", tokens: []string{"00"}},
 		{peer: "1", tokens: []string{"25"}},
 		{peer: "2", tokens: []string{"50"}},
 		{peer: "3", tokens: []string{"75"}},
 	}
-	policy.SetHosts(hosts)
+	for _, host := range hosts {
+		policy.AddHost(host)
+	}
 
 	// the token ring is not setup without the partitioner, but the fallback
 	// should work
@@ -108,12 +111,14 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 func TestHostPoolHostPolicy(t *testing.T) {
 	policy := HostPoolHostPolicy(hostpool.New(nil))
 
-	hosts := []*HostInfo{
+	hosts := [...]*HostInfo{
 		{hostId: "0", peer: "0"},
 		{hostId: "1", peer: "1"},
 	}
 
-	policy.SetHosts(hosts)
+	for _, host := range hosts {
+		policy.AddHost(host)
+	}
 
 	// the first host selected is actually at [1], but this is ok for RR
 	// interleaved iteration should always increment the host
@@ -143,35 +148,11 @@ func TestHostPoolHostPolicy(t *testing.T) {
 	actualD.Mark(nil)
 }
 
-// Tests of the round-robin connection selection policy implementation
-func TestRoundRobinConnPolicy(t *testing.T) {
-	policy := RoundRobinConnPolicy()()
-
-	conn0 := &Conn{streams: streams.New(1)}
-	conn1 := &Conn{streams: streams.New(1)}
-	conn := []*Conn{
-		conn0,
-		conn1,
-	}
-
-	policy.SetConns(conn)
-
-	if actual := policy.Pick(nil); actual != conn0 {
-		t.Error("Expected conn1")
-	}
-	if actual := policy.Pick(nil); actual != conn1 {
-		t.Error("Expected conn0")
-	}
-	if actual := policy.Pick(nil); actual != conn0 {
-		t.Error("Expected conn1")
-	}
-}
-
 func TestRoundRobinNilHostInfo(t *testing.T) {
 	policy := RoundRobinHostPolicy()
 
 	host := &HostInfo{hostId: "host-1"}
-	policy.SetHosts([]*HostInfo{host})
+	policy.AddHost(host)
 
 	iter := policy.Pick(nil)
 	next := iter()
@@ -195,13 +176,15 @@ func TestRoundRobinNilHostInfo(t *testing.T) {
 func TestTokenAwareNilHostInfo(t *testing.T) {
 	policy := TokenAwareHostPolicy(RoundRobinHostPolicy())
 
-	hosts := []*HostInfo{
+	hosts := [...]*HostInfo{
 		{peer: "0", tokens: []string{"00"}},
 		{peer: "1", tokens: []string{"25"}},
 		{peer: "2", tokens: []string{"50"}},
 		{peer: "3", tokens: []string{"75"}},
 	}
-	policy.SetHosts(hosts)
+	for _, host := range hosts {
+		policy.AddHost(host)
+	}
 	policy.SetPartitioner("OrderedPartitioner")
 
 	query := &Query{}
@@ -218,8 +201,9 @@ func TestTokenAwareNilHostInfo(t *testing.T) {
 	}
 
 	// Empty the hosts to trigger the panic when using the fallback.
-	hosts = []*HostInfo{}
-	policy.SetHosts(hosts)
+	for _, host := range hosts {
+		policy.RemoveHost(host.Peer())
+	}
 
 	next = iter()
 	if next != nil {

+ 65 - 0
query_executor.go

@@ -0,0 +1,65 @@
+package gocql
+
+import (
+	"time"
+)
+
+type ExecutableQuery interface {
+	execute(conn *Conn) *Iter
+	attempt(time.Duration)
+	retryPolicy() RetryPolicy
+	GetRoutingKey() ([]byte, error)
+	RetryableQuery
+}
+
+type queryExecutor struct {
+	pool   *policyConnPool
+	policy HostSelectionPolicy
+}
+
+func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
+	rt := qry.retryPolicy()
+	hostIter := q.policy.Pick(qry)
+
+	var iter *Iter
+	for hostResponse := hostIter(); hostResponse != nil; hostResponse = hostIter() {
+		host := hostResponse.Info()
+		if !host.IsUp() {
+			continue
+		}
+
+		pool, ok := q.pool.getPool(host.Peer())
+		if !ok {
+			continue
+		}
+
+		conn := pool.Pick()
+		if conn == nil {
+			continue
+		}
+
+		start := time.Now()
+		iter = qry.execute(conn)
+
+		qry.attempt(time.Since(start))
+
+		// Update host
+		hostResponse.Mark(iter.err)
+
+		// Exit for loop if the query was successful
+		if iter.err == nil {
+			return iter, nil
+		}
+
+		if rt == nil || !rt.Attempt(qry) {
+			// What do here? Should we just return an error here?
+			break
+		}
+	}
+
+	if iter == nil {
+		return nil, ErrNoConnections
+	}
+
+	return iter, nil
+}

+ 30 - 0
ring.go

@@ -2,6 +2,7 @@ package gocql
 
 import (
 	"sync"
+	"sync/atomic"
 )
 
 type ring struct {
@@ -9,13 +10,27 @@ type ring struct {
 	// to in the case it can not reach any of its hosts. They are also used to boot
 	// strap the initial connection.
 	endpoints []string
+
 	// hosts are the set of all hosts in the cassandra ring that we know of
 	mu    sync.RWMutex
 	hosts map[string]*HostInfo
 
+	hostList []*HostInfo
+	pos      uint32
+
 	// TODO: we should store the ring metadata here also.
 }
 
+func (r *ring) rrHost() *HostInfo {
+	// TODO: should we filter hosts that get used here? These hosts will be used
+	// for the control connection, should we also provide an iterator?
+	r.mu.RLock()
+	defer r.mu.RUnlock()
+
+	pos := int(atomic.AddUint32(&r.pos, 1) - 1)
+	return r.hostList[pos%len(r.hostList)]
+}
+
 func (r *ring) getHost(addr string) *HostInfo {
 	r.mu.RLock()
 	host := r.hosts[addr]
@@ -73,3 +88,18 @@ func (r *ring) removeHost(addr string) bool {
 	r.mu.Unlock()
 	return ok
 }
+
+type clusterMetadata struct {
+	mu          sync.RWMutex
+	partitioner string
+}
+
+func (c *clusterMetadata) setPartitioner(partitioner string) {
+	c.mu.RLock()
+	defer c.mu.RUnlock()
+
+	if c.partitioner != partitioner {
+		// TODO: update other things now
+		c.partitioner = partitioner
+	}
+}

+ 96 - 87
session.go

@@ -31,7 +31,6 @@ import (
 // and automatically sets a default consinstency level on all operations
 // that do not have a consistency level set.
 type Session struct {
-	pool                *policyConnPool
 	cons                Consistency
 	pageSize            int
 	prefetch            float64
@@ -39,11 +38,17 @@ type Session struct {
 	schemaDescriber     *schemaDescriber
 	trace               Tracer
 	hostSource          *ringDescriber
-	ring                ring
 	stmtsLRU            *preparedLRU
 
 	connCfg *ConnConfig
 
+	executor *queryExecutor
+	pool     *policyConnPool
+	policy   HostSelectionPolicy
+
+	ring     ring
+	metadata clusterMetadata
+
 	mu sync.RWMutex
 
 	control *controlConn
@@ -116,7 +121,17 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 		closeChan: make(chan bool),
 	}
 
-	s.pool = cfg.PoolConfig.buildPool(s)
+	pool := cfg.PoolConfig.buildPool(s)
+	if cfg.PoolConfig.HostSelectionPolicy == nil {
+		cfg.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
+	}
+
+	s.pool = pool
+	s.policy = cfg.PoolConfig.HostSelectionPolicy
+	s.executor = &queryExecutor{
+		pool:   pool,
+		policy: cfg.PoolConfig.HostSelectionPolicy,
+	}
 
 	var hosts []*HostInfo
 
@@ -284,44 +299,17 @@ func (s *Session) Closed() bool {
 }
 
 func (s *Session) executeQuery(qry *Query) *Iter {
-
 	// fail fast
 	if s.Closed() {
 		return &Iter{err: ErrSessionClosed}
 	}
 
-	var iter *Iter
-	qry.attempts = 0
-	qry.totalLatency = 0
-	for {
-		host, conn := s.pool.Pick(qry)
-
-		qry.attempts++
-		//Assign the error unavailable to the iterator
-		if conn == nil {
-			if qry.rt == nil || !qry.rt.Attempt(qry) {
-				iter = &Iter{err: ErrNoConnections}
-				break
-			}
-
-			continue
-		}
-
-		t := time.Now()
-		iter = conn.executeQuery(qry)
-		qry.totalLatency += time.Now().Sub(t).Nanoseconds()
-
-		// Update host
-		host.Mark(iter.err)
-
-		// Exit for loop if the query was successful
-		if iter.err == nil {
-			break
-		}
-
-		if qry.rt == nil || !qry.rt.Attempt(qry) {
-			break
-		}
+	iter, err := s.executor.executeQuery(qry)
+	if err != nil {
+		return &Iter{err: err}
+	}
+	if iter == nil {
+		panic("nil iter")
 	}
 
 	return iter
@@ -348,6 +336,28 @@ func (s *Session) KeyspaceMetadata(keyspace string) (*KeyspaceMetadata, error) {
 	return s.schemaDescriber.getSchema(keyspace)
 }
 
+func (s *Session) getConn() *Conn {
+	hosts := s.ring.allHosts()
+	var conn *Conn
+	for _, host := range hosts {
+		if !host.IsUp() {
+			continue
+		}
+
+		pool, ok := s.pool.getPool(host.Peer())
+		if !ok {
+			continue
+		}
+
+		conn = pool.Pick()
+		if conn != nil {
+			return conn
+		}
+	}
+
+	return nil
+}
+
 // returns routing key indexes and type info
 func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 	s.routingKeyInfoCache.mu.Lock()
@@ -384,26 +394,23 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 		partitionKey []*ColumnMetadata
 	)
 
-	// get the query info for the statement
-	host, conn := s.pool.Pick(nil)
+	conn := s.getConn()
 	if conn == nil {
-		// no connections
-		inflight.err = ErrNoConnections
-		// don't cache this error
-		s.routingKeyInfoCache.Remove(stmt)
+		// TODO: better error?
+		inflight.err = errors.New("gocql: unable to fetch preapred info: no connection avilable")
 		return nil, inflight.err
 	}
 
+	// get the query info for the statement
 	info, inflight.err = conn.prepareStatement(stmt, nil)
 	if inflight.err != nil {
 		// don't cache this error
 		s.routingKeyInfoCache.Remove(stmt)
-		host.Mark(inflight.err)
 		return nil, inflight.err
 	}
 
-	// Mark host as OK
-	host.Mark(nil)
+	// TODO: it would be nice to mark hosts here but as we are not using the policies
+	// to fetch hosts we cant
 
 	if info.request.colCount == 0 {
 		// no arguments, no routing key, and no error
@@ -455,6 +462,7 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 		indexes: make([]int, size),
 		types:   make([]TypeInfo, size),
 	}
+
 	for keyIndex, keyColumn := range partitionKey {
 		// set an indicator for checking if the mapping is missing
 		routingKeyInfo.indexes[keyIndex] = -1
@@ -482,6 +490,10 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 	return routingKeyInfo, nil
 }
 
+func (b *Batch) execute(conn *Conn) *Iter {
+	return conn.executeBatch(b)
+}
+
 func (s *Session) executeBatch(batch *Batch) *Iter {
 	// fail fast
 	if s.Closed() {
@@ -495,45 +507,9 @@ func (s *Session) executeBatch(batch *Batch) *Iter {
 		return &Iter{err: ErrTooManyStmts}
 	}
 
-	var iter *Iter
-	batch.attempts = 0
-	batch.totalLatency = 0
-	for {
-		host, conn := s.pool.Pick(nil)
-
-		batch.attempts++
-		if conn == nil {
-			if batch.rt == nil || !batch.rt.Attempt(batch) {
-				// Assign the error unavailable and break loop
-				iter = &Iter{err: ErrNoConnections}
-				break
-			}
-
-			continue
-		}
-
-		if conn == nil {
-			iter = &Iter{err: ErrNoConnections}
-			break
-		}
-
-		t := time.Now()
-
-		iter = conn.executeBatch(batch)
-
-		batch.totalLatency += time.Since(t).Nanoseconds()
-		// Exit loop if operation executed correctly
-		if iter.err == nil {
-			host.Mark(nil)
-			break
-		}
-
-		// Mark host with error if returned from Close
-		host.Mark(iter.Close())
-
-		if batch.rt == nil || !batch.rt.Attempt(batch) {
-			break
-		}
+	iter, err := s.executor.executeQuery(batch)
+	if err != nil {
+		return &Iter{err: err}
 	}
 
 	return iter
@@ -680,6 +656,20 @@ func (q *Query) RoutingKey(routingKey []byte) *Query {
 	return q
 }
 
+func (q *Query) execute(conn *Conn) *Iter {
+	return conn.executeQuery(q)
+}
+
+func (q *Query) attempt(d time.Duration) {
+	q.attempts++
+	q.totalLatency += d.Nanoseconds()
+	// TODO: track latencies per host and things as well instead of just total
+}
+
+func (q *Query) retryPolicy() RetryPolicy {
+	return q.rt
+}
+
 // GetRoutingKey gets the routing key to use for routing this query. If
 // a routing key has not been explicitly set, then the routing key will
 // be constructed if possible using the keyspace's schema and the query
@@ -689,6 +679,11 @@ func (q *Query) RoutingKey(routingKey []byte) *Query {
 func (q *Query) GetRoutingKey() ([]byte, error) {
 	if q.routingKey != nil {
 		return q.routingKey, nil
+	} else if q.binding != nil && len(q.values) == 0 {
+		// If this query was created using session.Bind we wont have the query
+		// values yet, so we have to pass down to the next policy.
+		// TODO: Remove this and handle this case
+		return nil, nil
 	}
 
 	// try to determine the routing key
@@ -816,8 +811,7 @@ func (q *Query) NoSkipMetadata() *Query {
 
 // Exec executes the query without returning any rows.
 func (q *Query) Exec() error {
-	iter := q.Iter()
-	return iter.Close()
+	return q.Iter().Close()
 }
 
 func isUseStatement(stmt string) bool {
@@ -1107,6 +1101,10 @@ func (b *Batch) Bind(stmt string, bind func(q *QueryInfo) ([]interface{}, error)
 	b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, binding: bind})
 }
 
+func (b *Batch) retryPolicy() RetryPolicy {
+	return b.rt
+}
+
 // RetryPolicy sets the retry policy to use when executing the batch operation
 func (b *Batch) RetryPolicy(r RetryPolicy) *Batch {
 	b.rt = r
@@ -1141,6 +1139,17 @@ func (b *Batch) DefaultTimestamp(enable bool) *Batch {
 	return b
 }
 
+func (b *Batch) attempt(d time.Duration) {
+	b.attempts++
+	b.totalLatency += d.Nanoseconds()
+	// TODO: track latencies per host and things as well instead of just total
+}
+
+func (b *Batch) GetRoutingKey() ([]byte, error) {
+	// TODO: use the first statement in the batch as the routing key?
+	return nil, nil
+}
+
 type BatchType byte
 
 const (
@@ -1285,7 +1294,7 @@ var (
 	ErrTooManyStmts  = errors.New("too many statements")
 	ErrUseStmt       = errors.New("use statements aren't supported. Please see https://github.com/gocql/gocql for explaination.")
 	ErrSessionClosed = errors.New("session has been closed")
-	ErrNoConnections = errors.New("no connections available")
+	ErrNoConnections = errors.New("qocql: no hosts available in the pool")
 	ErrNoKeyspace    = errors.New("no keyspace provided")
 	ErrNoMetadata    = errors.New("no metadata available")
 )

+ 7 - 2
session_test.go

@@ -12,11 +12,16 @@ func TestSessionAPI(t *testing.T) {
 	cfg := &ClusterConfig{}
 
 	s := &Session{
-		cfg:  *cfg,
-		cons: Quorum,
+		cfg:    *cfg,
+		cons:   Quorum,
+		policy: RoundRobinHostPolicy(),
 	}
 
 	s.pool = cfg.PoolConfig.buildPool(s)
+	s.executor = &queryExecutor{
+		pool:   s.pool,
+		policy: s.policy,
+	}
 	defer s.Close()
 
 	s.SetConsistency(All)