瀏覽代碼

Merge pull request #697 from Zariel/query-executor

add queryExecutor which executes queries.
Chris Bannister 9 年之前
父節點
當前提交
51009831ae
共有 13 個文件被更改,包括 319 次插入342 次删除
  1. 13 15
      cassandra_test.go
  2. 1 15
      cluster.go
  3. 2 3
      conn_test.go
  4. 25 74
      connectionpool.go
  5. 3 3
      control.go
  6. 4 0
      events.go
  7. 1 1
      host_source.go
  8. 37 106
      policies.go
  9. 32 36
      policies_test.go
  10. 65 0
      query_executor.go
  11. 33 0
      ring.go
  12. 96 87
      session.go
  13. 7 2
      session_test.go

+ 13 - 15
cassandra_test.go

@@ -1019,6 +1019,14 @@ func TestBatchQueryInfo(t *testing.T) {
 	}
 }
 
+func getRandomConn(t *testing.T, session *Session) *Conn {
+	conn := session.getConn()
+	if conn == nil {
+		t.Fatal("unable to get a connection")
+	}
+	return conn
+}
+
 func injectInvalidPreparedStatement(t *testing.T, session *Session, table string) (string, *Conn) {
 	if err := createTable(session, `CREATE TABLE gocql_test.`+table+` (
 			foo   varchar,
@@ -1029,7 +1037,8 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 	}
 
 	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
-	_, conn := session.pool.Pick(nil)
+
+	conn := getRandomConn(t, session)
 
 	flight := new(inflightPrepare)
 	key := session.stmtsLRU.keyFor(conn.addr, "", stmt)
@@ -1060,7 +1069,7 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 
 func TestPrepare_MissingSchemaPrepare(t *testing.T) {
 	s := createSession(t)
-	_, conn := s.pool.Pick(nil)
+	conn := getRandomConn(t, s)
 	defer s.Close()
 
 	insertQry := &Query{stmt: "INSERT INTO invalidschemaprep (val) VALUES (?)", values: []interface{}{5}, cons: s.cons,
@@ -1108,7 +1117,7 @@ func TestQueryInfo(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
 
-	_, conn := session.pool.Pick(nil)
+	conn := getRandomConn(t, session)
 	info, err := conn.prepareStatement("SELECT release_version, host_id FROM system.local WHERE key = ?", nil)
 
 	if err != nil {
@@ -1982,18 +1991,7 @@ func TestNegativeStream(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
 
-	var conn *Conn
-	for i := 0; i < 5; i++ {
-		if conn != nil {
-			break
-		}
-
-		_, conn = session.pool.Pick(nil)
-	}
-
-	if conn == nil {
-		t.Fatal("no connections available in the pool")
-	}
+	conn := getRandomConn(t, session)
 
 	const stream = -50
 	writer := frameWriterFunc(func(f *framer, streamID int) error {

+ 1 - 15
cluster.go

@@ -16,24 +16,10 @@ type PoolConfig struct {
 	// HostSelectionPolicy sets the policy for selecting which host to use for a
 	// given query (default: RoundRobinHostPolicy())
 	HostSelectionPolicy HostSelectionPolicy
-
-	// ConnSelectionPolicy sets the policy factory for selecting a connection to use for
-	// each host for a query (default: RoundRobinConnPolicy())
-	ConnSelectionPolicy func() ConnSelectionPolicy
 }
 
 func (p PoolConfig) buildPool(session *Session) *policyConnPool {
-	hostSelection := p.HostSelectionPolicy
-	if hostSelection == nil {
-		hostSelection = RoundRobinHostPolicy()
-	}
-
-	connSelection := p.ConnSelectionPolicy
-	if connSelection == nil {
-		connSelection = RoundRobinConnPolicy()
-	}
-
-	return newPolicyConnPool(session, hostSelection, connSelection)
+	return newPolicyConnPool(session)
 }
 
 type DiscoveryConfig struct {

+ 2 - 3
conn_test.go

@@ -187,8 +187,8 @@ func TestQueryRetry(t *testing.T) {
 		t.Fatalf("expected requests %v to match query attemps %v", requests, attempts)
 	}
 
-	//Minus 1 from the requests variable since there is the initial query attempt
-	if requests-1 != int64(rt.NumRetries) {
+	// the query will only be attempted once, but is being retried
+	if requests != int64(rt.NumRetries) {
 		t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, requests-1)
 	}
 }
@@ -284,7 +284,6 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 
 	cluster := createTestSslCluster(srv.Address, defaultProto, true)
 	cluster.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
-	cluster.PoolConfig.ConnSelectionPolicy = RoundRobinConnPolicy()
 
 	db, err := cluster.CreateSession()
 	if err != nil {

+ 25 - 74
connectionpool.go

@@ -14,6 +14,7 @@ import (
 	"math/rand"
 	"net"
 	"sync"
+	"sync/atomic"
 	"time"
 )
 
@@ -65,8 +66,6 @@ type policyConnPool struct {
 	keyspace string
 
 	mu            sync.RWMutex
-	hostPolicy    HostSelectionPolicy
-	connPolicy    func() ConnSelectionPolicy
 	hostConnPools map[string]*hostConnPool
 
 	endpoints []string
@@ -99,17 +98,13 @@ func connConfig(session *Session) (*ConnConfig, error) {
 	}, nil
 }
 
-func newPolicyConnPool(session *Session, hostPolicy HostSelectionPolicy,
-	connPolicy func() ConnSelectionPolicy) *policyConnPool {
-
+func newPolicyConnPool(session *Session) *policyConnPool {
 	// create the pool
 	pool := &policyConnPool{
 		session:       session,
 		port:          session.cfg.Port,
 		numConns:      session.cfg.NumConns,
 		keyspace:      session.cfg.Keyspace,
-		hostPolicy:    hostPolicy,
-		connPolicy:    connPolicy,
 		hostConnPools: map[string]*hostConnPool{},
 	}
 
@@ -150,7 +145,6 @@ func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
 				p.port,
 				p.numConns,
 				p.keyspace,
-				p.connPolicy(),
 			)
 		}(host)
 	}
@@ -170,13 +164,6 @@ func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
 		delete(p.hostConnPools, addr)
 		go pool.Close()
 	}
-
-	// update the policy
-	p.hostPolicy.SetHosts(hosts)
-}
-
-func (p *policyConnPool) SetPartitioner(partitioner string) {
-	p.hostPolicy.SetPartitioner(partitioner)
 }
 
 func (p *policyConnPool) Size() int {
@@ -197,41 +184,10 @@ func (p *policyConnPool) getPool(addr string) (pool *hostConnPool, ok bool) {
 	return
 }
 
-func (p *policyConnPool) Pick(qry *Query) (SelectedHost, *Conn) {
-	nextHost := p.hostPolicy.Pick(qry)
-
-	var (
-		host SelectedHost
-		conn *Conn
-	)
-
-	p.mu.RLock()
-	defer p.mu.RUnlock()
-	for conn == nil {
-		host = nextHost()
-		if host == nil {
-			break
-		} else if host.Info() == nil {
-			panic(fmt.Sprintf("policy %T returned no host info: %+v", p.hostPolicy, host))
-		}
-
-		pool, ok := p.hostConnPools[host.Info().Peer()]
-		if !ok {
-			continue
-		}
-
-		conn = pool.Pick(qry)
-	}
-	return host, conn
-}
-
 func (p *policyConnPool) Close() {
 	p.mu.Lock()
 	defer p.mu.Unlock()
 
-	// remove the hosts from the policy
-	p.hostPolicy.SetHosts(nil)
-
 	// close the pools
 	for addr, pool := range p.hostConnPools {
 		delete(p.hostConnPools, addr)
@@ -249,7 +205,6 @@ func (p *policyConnPool) addHost(host *HostInfo) {
 			host.Port(), // TODO: if port == 0 use pool.port?
 			p.numConns,
 			p.keyspace,
-			p.connPolicy(),
 		)
 
 		p.hostConnPools[host.Peer()] = pool
@@ -257,17 +212,10 @@ func (p *policyConnPool) addHost(host *HostInfo) {
 	p.mu.Unlock()
 
 	pool.fill()
-
-	// update policy
-	// TODO: policy should not have conns, it should have hosts and return a host
-	// iter which the pool will use to serve conns
-	p.hostPolicy.AddHost(host)
 }
 
 func (p *policyConnPool) removeHost(addr string) {
-	p.hostPolicy.RemoveHost(addr)
 	p.mu.Lock()
-
 	pool, ok := p.hostConnPools[addr]
 	if !ok {
 		p.mu.Unlock()
@@ -301,12 +249,13 @@ type hostConnPool struct {
 	addr     string
 	size     int
 	keyspace string
-	policy   ConnSelectionPolicy
 	// protection for conns, closed, filling
 	mu      sync.RWMutex
 	conns   []*Conn
 	closed  bool
 	filling bool
+
+	pos uint32
 }
 
 func (h *hostConnPool) String() string {
@@ -317,7 +266,7 @@ func (h *hostConnPool) String() string {
 }
 
 func newHostConnPool(session *Session, host *HostInfo, port, size int,
-	keyspace string, policy ConnSelectionPolicy) *hostConnPool {
+	keyspace string) *hostConnPool {
 
 	pool := &hostConnPool{
 		session:  session,
@@ -326,7 +275,6 @@ func newHostConnPool(session *Session, host *HostInfo, port, size int,
 		addr:     JoinHostPort(host.Peer(), port),
 		size:     size,
 		keyspace: keyspace,
-		policy:   policy,
 		conns:    make([]*Conn, 0, size),
 		filling:  false,
 		closed:   false,
@@ -337,16 +285,15 @@ func newHostConnPool(session *Session, host *HostInfo, port, size int,
 }
 
 // Pick a connection from this connection pool for the given query.
-func (pool *hostConnPool) Pick(qry *Query) *Conn {
+func (pool *hostConnPool) Pick() *Conn {
 	pool.mu.RLock()
+	defer pool.mu.RUnlock()
+
 	if pool.closed {
-		pool.mu.RUnlock()
 		return nil
 	}
 
 	size := len(pool.conns)
-	pool.mu.RUnlock()
-
 	if size < pool.size {
 		// try to fill the pool
 		go pool.fill()
@@ -356,7 +303,23 @@ func (pool *hostConnPool) Pick(qry *Query) *Conn {
 		}
 	}
 
-	return pool.policy.Pick(qry)
+	pos := int(atomic.AddUint32(&pool.pos, 1) - 1)
+
+	var (
+		leastBusyConn    *Conn
+		streamsAvailable int
+	)
+
+	// find the conn which has the most available streams, this is racy
+	for i := 0; i < size; i++ {
+		conn := pool.conns[(pos+i)%size]
+		if streams := conn.AvailableStreams(); streams > streamsAvailable {
+			leastBusyConn = conn
+			streamsAvailable = streams
+		}
+	}
+
+	return leastBusyConn
 }
 
 //Size returns the number of connections currently active in the pool
@@ -543,10 +506,6 @@ func (pool *hostConnPool) connect() (err error) {
 
 	pool.conns = append(pool.conns, conn)
 
-	conns := make([]*Conn, len(pool.conns))
-	copy(conns, pool.conns)
-	pool.policy.SetConns(conns)
-
 	return nil
 }
 
@@ -573,11 +532,6 @@ func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) {
 			// remove the connection, not preserving order
 			pool.conns[i], pool.conns = pool.conns[len(pool.conns)-1], pool.conns[:len(pool.conns)-1]
 
-			// update the policy
-			conns := make([]*Conn, len(pool.conns))
-			copy(conns, pool.conns)
-			pool.policy.SetConns(conns)
-
 			// lost a connection, so fill the pool
 			go pool.fill()
 			break
@@ -590,9 +544,6 @@ func (pool *hostConnPool) drainLocked() {
 	conns := pool.conns
 	pool.conns = nil
 
-	// update the policy
-	pool.policy.SetConns(nil)
-
 	// close the connections
 	for _, conn := range conns {
 		conn.Close()

+ 3 - 3
control.go

@@ -238,14 +238,14 @@ func (c *controlConn) reconnect(refreshring bool) {
 	// TODO: should have our own roundrobbin for hosts so that we can try each
 	// in succession and guantee that we get a different host each time.
 	if newConn == nil {
-		_, conn := c.session.pool.Pick(nil)
-		if conn == nil {
+		host := c.session.ring.rrHost()
+		if host == nil {
 			c.connect(c.session.ring.endpoints)
 			return
 		}
 
 		var err error
-		newConn, err = c.session.connect(conn.addr, c, conn.host)
+		newConn, err = c.session.connect(host.Peer(), c, host)
 		if err != nil {
 			// TODO: add log handler for things like this
 			return

+ 4 - 0
events.go

@@ -201,6 +201,7 @@ func (s *Session) handleNewNode(host net.IP, port int, waitForBinary bool) {
 	}
 
 	s.pool.addHost(hostInfo)
+	s.policy.AddHost(hostInfo)
 	hostInfo.setState(NodeUp)
 
 	if s.control != nil {
@@ -222,6 +223,7 @@ func (s *Session) handleRemovedNode(ip net.IP, port int) {
 	}
 
 	host.setState(NodeDown)
+	s.policy.RemoveHost(addr)
 	s.pool.removeHost(addr)
 	s.ring.removeHost(addr)
 
@@ -251,6 +253,7 @@ func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) {
 
 		host.setPort(port)
 		s.pool.hostUp(host)
+		s.policy.HostUp(host)
 		host.setState(NodeUp)
 		return
 	}
@@ -270,5 +273,6 @@ func (s *Session) handleNodeDown(ip net.IP, port int) {
 	}
 
 	host.setState(NodeDown)
+	s.policy.HostDown(addr)
 	s.pool.hostDown(addr)
 }

+ 1 - 1
host_source.go

@@ -390,6 +390,6 @@ func (r *ringDescriber) refreshRing() error {
 		}
 	}
 
-	r.session.pool.SetPartitioner(partitioner)
+	r.session.metadata.setPartitioner(partitioner)
 	return nil
 }

+ 37 - 106
policies.go

@@ -162,17 +162,17 @@ func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool {
 type HostStateNotifier interface {
 	AddHost(host *HostInfo)
 	RemoveHost(addr string)
-	// TODO(zariel): add host up/down
+	HostUp(host *HostInfo)
+	HostDown(addr string)
 }
 
 // HostSelectionPolicy is an interface for selecting
 // the most appropriate host to execute a given query.
 type HostSelectionPolicy interface {
 	HostStateNotifier
-	SetHosts
 	SetPartitioner
 	//Pick returns an iteration function over selected hosts
-	Pick(*Query) NextHost
+	Pick(ExecutableQuery) NextHost
 }
 
 // SelectedHost is an interface returned when picking a host from a host
@@ -182,6 +182,14 @@ type SelectedHost interface {
 	Mark(error)
 }
 
+type selectedHost HostInfo
+
+func (host *selectedHost) Info() *HostInfo {
+	return (*HostInfo)(host)
+}
+
+func (host *selectedHost) Mark(err error) {}
+
 // NextHost is an iteration function over picked hosts
 type NextHost func() SelectedHost
 
@@ -197,15 +205,11 @@ type roundRobinHostPolicy struct {
 	mu    sync.RWMutex
 }
 
-func (r *roundRobinHostPolicy) SetHosts(hosts []*HostInfo) {
-	r.hosts.set(hosts)
-}
-
 func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {
 	// noop
 }
 
-func (r *roundRobinHostPolicy) Pick(qry *Query) NextHost {
+func (r *roundRobinHostPolicy) Pick(qry ExecutableQuery) NextHost {
 	// i is used to limit the number of attempts to find a host
 	// to the number of hosts known to this policy
 	var i int
@@ -223,7 +227,7 @@ func (r *roundRobinHostPolicy) Pick(qry *Query) NextHost {
 		}
 		host := hosts[(pos)%uint32(len(hosts))]
 		i++
-		return selectedRoundRobinHost{host}
+		return (*selectedHost)(host)
 	}
 }
 
@@ -235,18 +239,12 @@ func (r *roundRobinHostPolicy) RemoveHost(addr string) {
 	r.hosts.remove(addr)
 }
 
-// selectedRoundRobinHost is a host returned by the roundRobinHostPolicy and
-// implements the SelectedHost interface
-type selectedRoundRobinHost struct {
-	info *HostInfo
+func (r *roundRobinHostPolicy) HostUp(host *HostInfo) {
+	r.AddHost(host)
 }
 
-func (host selectedRoundRobinHost) Info() *HostInfo {
-	return host.info
-}
-
-func (host selectedRoundRobinHost) Mark(err error) {
-	// noop
+func (r *roundRobinHostPolicy) HostDown(addr string) {
+	r.RemoveHost(addr)
 }
 
 // TokenAwareHostPolicy is a token aware host selection policy, where hosts are
@@ -264,18 +262,6 @@ type tokenAwareHostPolicy struct {
 	fallback    HostSelectionPolicy
 }
 
-func (t *tokenAwareHostPolicy) SetHosts(hosts []*HostInfo) {
-	t.hosts.set(hosts)
-
-	t.mu.Lock()
-	defer t.mu.Unlock()
-
-	// always update the fallback
-	t.fallback.SetHosts(hosts)
-
-	t.resetTokenRing()
-}
-
 func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
 	t.mu.Lock()
 	defer t.mu.Unlock()
@@ -299,12 +285,21 @@ func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) {
 
 func (t *tokenAwareHostPolicy) RemoveHost(addr string) {
 	t.hosts.remove(addr)
+	t.fallback.RemoveHost(addr)
 
 	t.mu.Lock()
 	t.resetTokenRing()
 	t.mu.Unlock()
 }
 
+func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) {
+	t.AddHost(host)
+}
+
+func (t *tokenAwareHostPolicy) HostDown(addr string) {
+	t.RemoveHost(addr)
+}
+
 func (t *tokenAwareHostPolicy) resetTokenRing() {
 	if t.partitioner == "" {
 		// partitioner not yet set
@@ -323,14 +318,9 @@ func (t *tokenAwareHostPolicy) resetTokenRing() {
 	t.tokenRing = tokenRing
 }
 
-func (t *tokenAwareHostPolicy) Pick(qry *Query) NextHost {
+func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
 	if qry == nil {
 		return t.fallback.Pick(qry)
-	} else if qry.binding != nil && len(qry.values) == 0 {
-		// If this query was created using session.Bind we wont have the query
-		// values yet, so we have to pass down to the next policy.
-		// TODO: Remove this and handle this case
-		return t.fallback.Pick(qry)
 	}
 
 	routingKey, err := qry.GetRoutingKey()
@@ -359,7 +349,7 @@ func (t *tokenAwareHostPolicy) Pick(qry *Query) NextHost {
 	return func() SelectedHost {
 		if !hostReturned {
 			hostReturned = true
-			return selectedTokenAwareHost{host}
+			return (*selectedHost)(host)
 		}
 
 		// fallback
@@ -378,20 +368,6 @@ func (t *tokenAwareHostPolicy) Pick(qry *Query) NextHost {
 	}
 }
 
-// selectedTokenAwareHost is a host returned by the tokenAwareHostPolicy and
-// implements the SelectedHost interface
-type selectedTokenAwareHost struct {
-	info *HostInfo
-}
-
-func (host selectedTokenAwareHost) Info() *HostInfo {
-	return host.info
-}
-
-func (host selectedTokenAwareHost) Mark(err error) {
-	// noop
-}
-
 // HostPoolHostPolicy is a host policy which uses the bitly/go-hostpool library
 // to distribute queries between hosts and prevent sending queries to
 // unresponsive hosts. When creating the host pool that is passed to the policy
@@ -466,11 +442,19 @@ func (r *hostPoolHostPolicy) RemoveHost(addr string) {
 	r.hp.SetHosts(hosts)
 }
 
+func (r *hostPoolHostPolicy) HostUp(host *HostInfo) {
+	r.AddHost(host)
+}
+
+func (r *hostPoolHostPolicy) HostDown(addr string) {
+	r.RemoveHost(addr)
+}
+
 func (r *hostPoolHostPolicy) SetPartitioner(partitioner string) {
 	// noop
 }
 
-func (r *hostPoolHostPolicy) Pick(qry *Query) NextHost {
+func (r *hostPoolHostPolicy) Pick(qry ExecutableQuery) NextHost {
 	return func() SelectedHost {
 		r.mu.RLock()
 		defer r.mu.RUnlock()
@@ -516,56 +500,3 @@ func (host selectedHostPoolHost) Mark(err error) {
 
 	host.hostR.Mark(err)
 }
-
-//ConnSelectionPolicy is an interface for selecting an
-//appropriate connection for executing a query
-type ConnSelectionPolicy interface {
-	SetConns(conns []*Conn)
-	Pick(*Query) *Conn
-}
-
-type roundRobinConnPolicy struct {
-	// pos is still used to evenly distribute queries amongst connections.
-	pos   uint32
-	conns atomic.Value // *[]*Conn
-}
-
-func RoundRobinConnPolicy() func() ConnSelectionPolicy {
-	return func() ConnSelectionPolicy {
-		p := &roundRobinConnPolicy{}
-		var conns []*Conn
-		p.conns.Store(&conns)
-		return p
-	}
-}
-
-func (r *roundRobinConnPolicy) SetConns(conns []*Conn) {
-	// NOTE: we do not need to lock here due to the conneciton pool is already
-	// holding its own mutex over the conn seleciton policy
-	r.conns.Store(&conns)
-}
-
-func (r *roundRobinConnPolicy) Pick(qry *Query) *Conn {
-	conns := *(r.conns.Load().(*[]*Conn))
-	if len(conns) == 0 {
-		return nil
-	}
-
-	pos := int(atomic.AddUint32(&r.pos, 1) - 1)
-
-	var (
-		leastBusyConn    *Conn
-		streamsAvailable int
-	)
-
-	// find the conn which has the most available streams, this is racy
-	for i := 0; i < len(conns); i++ {
-		conn := conns[(pos+i)%len(conns)]
-		if streams := conn.AvailableStreams(); streams > streamsAvailable {
-			leastBusyConn = conn
-			streamsAvailable = streams
-		}
-	}
-
-	return leastBusyConn
-}

+ 32 - 36
policies_test.go

@@ -6,7 +6,6 @@ package gocql
 
 import (
 	"fmt"
-	"github.com/gocql/gocql/internal/streams"
 	"testing"
 
 	"github.com/hailocab/go-hostpool"
@@ -16,12 +15,14 @@ import (
 func TestRoundRobinHostPolicy(t *testing.T) {
 	policy := RoundRobinHostPolicy()
 
-	hosts := []*HostInfo{
+	hosts := [...]*HostInfo{
 		{hostId: "0"},
 		{hostId: "1"},
 	}
 
-	policy.SetHosts(hosts)
+	for _, host := range hosts {
+		policy.AddHost(host)
+	}
 
 	// interleaved iteration should always increment the host
 	iterA := policy.Pick(nil)
@@ -65,13 +66,15 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 	}
 
 	// set the hosts
-	hosts := []*HostInfo{
+	hosts := [...]*HostInfo{
 		{peer: "0", tokens: []string{"00"}},
 		{peer: "1", tokens: []string{"25"}},
 		{peer: "2", tokens: []string{"50"}},
 		{peer: "3", tokens: []string{"75"}},
 	}
-	policy.SetHosts(hosts)
+	for _, host := range hosts {
+		policy.AddHost(host)
+	}
 
 	// the token ring is not setup without the partitioner, but the fallback
 	// should work
@@ -108,12 +111,14 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 func TestHostPoolHostPolicy(t *testing.T) {
 	policy := HostPoolHostPolicy(hostpool.New(nil))
 
-	hosts := []*HostInfo{
+	hosts := [...]*HostInfo{
 		{hostId: "0", peer: "0"},
 		{hostId: "1", peer: "1"},
 	}
 
-	policy.SetHosts(hosts)
+	for _, host := range hosts {
+		policy.AddHost(host)
+	}
 
 	// the first host selected is actually at [1], but this is ok for RR
 	// interleaved iteration should always increment the host
@@ -143,35 +148,11 @@ func TestHostPoolHostPolicy(t *testing.T) {
 	actualD.Mark(nil)
 }
 
-// Tests of the round-robin connection selection policy implementation
-func TestRoundRobinConnPolicy(t *testing.T) {
-	policy := RoundRobinConnPolicy()()
-
-	conn0 := &Conn{streams: streams.New(1)}
-	conn1 := &Conn{streams: streams.New(1)}
-	conn := []*Conn{
-		conn0,
-		conn1,
-	}
-
-	policy.SetConns(conn)
-
-	if actual := policy.Pick(nil); actual != conn0 {
-		t.Error("Expected conn1")
-	}
-	if actual := policy.Pick(nil); actual != conn1 {
-		t.Error("Expected conn0")
-	}
-	if actual := policy.Pick(nil); actual != conn0 {
-		t.Error("Expected conn1")
-	}
-}
-
 func TestRoundRobinNilHostInfo(t *testing.T) {
 	policy := RoundRobinHostPolicy()
 
 	host := &HostInfo{hostId: "host-1"}
-	policy.SetHosts([]*HostInfo{host})
+	policy.AddHost(host)
 
 	iter := policy.Pick(nil)
 	next := iter()
@@ -195,13 +176,15 @@ func TestRoundRobinNilHostInfo(t *testing.T) {
 func TestTokenAwareNilHostInfo(t *testing.T) {
 	policy := TokenAwareHostPolicy(RoundRobinHostPolicy())
 
-	hosts := []*HostInfo{
+	hosts := [...]*HostInfo{
 		{peer: "0", tokens: []string{"00"}},
 		{peer: "1", tokens: []string{"25"}},
 		{peer: "2", tokens: []string{"50"}},
 		{peer: "3", tokens: []string{"75"}},
 	}
-	policy.SetHosts(hosts)
+	for _, host := range hosts {
+		policy.AddHost(host)
+	}
 	policy.SetPartitioner("OrderedPartitioner")
 
 	query := &Query{}
@@ -218,8 +201,9 @@ func TestTokenAwareNilHostInfo(t *testing.T) {
 	}
 
 	// Empty the hosts to trigger the panic when using the fallback.
-	hosts = []*HostInfo{}
-	policy.SetHosts(hosts)
+	for _, host := range hosts {
+		policy.RemoveHost(host.Peer())
+	}
 
 	next = iter()
 	if next != nil {
@@ -257,3 +241,15 @@ func TestCOWList_Add(t *testing.T) {
 		}
 	}
 }
+
+func TestSimpleRetryPolicy(t *testing.T) {
+	q := &Query{}
+	rt := &SimpleRetryPolicy{NumRetries: 2}
+	if !rt.Attempt(q) {
+		t.Fatal("should allow retry after 0 attempts")
+	}
+	q.attempts = 5
+	if rt.Attempt(q) {
+		t.Fatal("should not allow retry after passing threshold")
+	}
+}

+ 65 - 0
query_executor.go

@@ -0,0 +1,65 @@
+package gocql
+
+import (
+	"time"
+)
+
+type ExecutableQuery interface {
+	execute(conn *Conn) *Iter
+	attempt(time.Duration)
+	retryPolicy() RetryPolicy
+	GetRoutingKey() ([]byte, error)
+	RetryableQuery
+}
+
+type queryExecutor struct {
+	pool   *policyConnPool
+	policy HostSelectionPolicy
+}
+
+func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
+	rt := qry.retryPolicy()
+	hostIter := q.policy.Pick(qry)
+
+	var iter *Iter
+	for hostResponse := hostIter(); hostResponse != nil; hostResponse = hostIter() {
+		host := hostResponse.Info()
+		if !host.IsUp() {
+			continue
+		}
+
+		pool, ok := q.pool.getPool(host.Peer())
+		if !ok {
+			continue
+		}
+
+		conn := pool.Pick()
+		if conn == nil {
+			continue
+		}
+
+		start := time.Now()
+		iter = qry.execute(conn)
+
+		qry.attempt(time.Since(start))
+
+		// Update host
+		hostResponse.Mark(iter.err)
+
+		// Exit for loop if the query was successful
+		if iter.err == nil {
+			return iter, nil
+		}
+
+		if rt == nil || !rt.Attempt(qry) {
+			// What do here? Should we just return an error here?
+			break
+		}
+	}
+
+	if iter == nil {
+		return nil, ErrNoConnections
+	}
+
+	return iter, nil
+}

+ 33 - 0
ring.go

@@ -2,6 +2,7 @@ package gocql
 
 import (
 	"sync"
+	"sync/atomic"
 )
 
 type ring struct {
@@ -9,13 +10,30 @@ type ring struct {
 	// to in the case it can not reach any of its hosts. They are also used to boot
 	// strap the initial connection.
 	endpoints []string
+
 	// hosts are the set of all hosts in the cassandra ring that we know of
 	mu    sync.RWMutex
 	hosts map[string]*HostInfo
 
+	hostList []*HostInfo
+	pos      uint32
+
 	// TODO: we should store the ring metadata here also.
 }
 
+func (r *ring) rrHost() *HostInfo {
+	// TODO: should we filter hosts that get used here? These hosts will be used
+	// for the control connection, should we also provide an iterator?
+	r.mu.RLock()
+	defer r.mu.RUnlock()
+	if len(r.hostList) == 0 {
+		return nil
+	}
+
+	pos := int(atomic.AddUint32(&r.pos, 1) - 1)
+	return r.hostList[pos%len(r.hostList)]
+}
+
 func (r *ring) getHost(addr string) *HostInfo {
 	r.mu.RLock()
 	host := r.hosts[addr]
@@ -73,3 +91,18 @@ func (r *ring) removeHost(addr string) bool {
 	r.mu.Unlock()
 	return ok
 }
+
+type clusterMetadata struct {
+	mu          sync.RWMutex
+	partitioner string
+}
+
+func (c *clusterMetadata) setPartitioner(partitioner string) {
+	c.mu.RLock()
+	defer c.mu.RUnlock()
+
+	if c.partitioner != partitioner {
+		// TODO: update other things now
+		c.partitioner = partitioner
+	}
+}

+ 96 - 87
session.go

@@ -31,7 +31,6 @@ import (
 // and automatically sets a default consinstency level on all operations
 // that do not have a consistency level set.
 type Session struct {
-	pool                *policyConnPool
 	cons                Consistency
 	pageSize            int
 	prefetch            float64
@@ -39,11 +38,17 @@ type Session struct {
 	schemaDescriber     *schemaDescriber
 	trace               Tracer
 	hostSource          *ringDescriber
-	ring                ring
 	stmtsLRU            *preparedLRU
 
 	connCfg *ConnConfig
 
+	executor *queryExecutor
+	pool     *policyConnPool
+	policy   HostSelectionPolicy
+
+	ring     ring
+	metadata clusterMetadata
+
 	mu sync.RWMutex
 
 	control *controlConn
@@ -116,7 +121,17 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 		closeChan: make(chan bool),
 	}
 
-	s.pool = cfg.PoolConfig.buildPool(s)
+	pool := cfg.PoolConfig.buildPool(s)
+	if cfg.PoolConfig.HostSelectionPolicy == nil {
+		cfg.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
+	}
+
+	s.pool = pool
+	s.policy = cfg.PoolConfig.HostSelectionPolicy
+	s.executor = &queryExecutor{
+		pool:   pool,
+		policy: cfg.PoolConfig.HostSelectionPolicy,
+	}
 
 	var hosts []*HostInfo
 
@@ -284,44 +299,17 @@ func (s *Session) Closed() bool {
 }
 
 func (s *Session) executeQuery(qry *Query) *Iter {
-
 	// fail fast
 	if s.Closed() {
 		return &Iter{err: ErrSessionClosed}
 	}
 
-	var iter *Iter
-	qry.attempts = 0
-	qry.totalLatency = 0
-	for {
-		host, conn := s.pool.Pick(qry)
-
-		qry.attempts++
-		//Assign the error unavailable to the iterator
-		if conn == nil {
-			if qry.rt == nil || !qry.rt.Attempt(qry) {
-				iter = &Iter{err: ErrNoConnections}
-				break
-			}
-
-			continue
-		}
-
-		t := time.Now()
-		iter = conn.executeQuery(qry)
-		qry.totalLatency += time.Now().Sub(t).Nanoseconds()
-
-		// Update host
-		host.Mark(iter.err)
-
-		// Exit for loop if the query was successful
-		if iter.err == nil {
-			break
-		}
-
-		if qry.rt == nil || !qry.rt.Attempt(qry) {
-			break
-		}
+	iter, err := s.executor.executeQuery(qry)
+	if err != nil {
+		return &Iter{err: err}
+	}
+	if iter == nil {
+		panic("nil iter")
 	}
 
 	return iter
@@ -348,6 +336,28 @@ func (s *Session) KeyspaceMetadata(keyspace string) (*KeyspaceMetadata, error) {
 	return s.schemaDescriber.getSchema(keyspace)
 }
 
+func (s *Session) getConn() *Conn {
+	hosts := s.ring.allHosts()
+	var conn *Conn
+	for _, host := range hosts {
+		if !host.IsUp() {
+			continue
+		}
+
+		pool, ok := s.pool.getPool(host.Peer())
+		if !ok {
+			continue
+		}
+
+		conn = pool.Pick()
+		if conn != nil {
+			return conn
+		}
+	}
+
+	return nil
+}
+
 // returns routing key indexes and type info
 func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 	s.routingKeyInfoCache.mu.Lock()
@@ -384,26 +394,23 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 		partitionKey []*ColumnMetadata
 	)
 
-	// get the query info for the statement
-	host, conn := s.pool.Pick(nil)
+	conn := s.getConn()
 	if conn == nil {
-		// no connections
-		inflight.err = ErrNoConnections
-		// don't cache this error
-		s.routingKeyInfoCache.Remove(stmt)
+		// TODO: better error?
+		inflight.err = errors.New("gocql: unable to fetch preapred info: no connection avilable")
 		return nil, inflight.err
 	}
 
+	// get the query info for the statement
 	info, inflight.err = conn.prepareStatement(stmt, nil)
 	if inflight.err != nil {
 		// don't cache this error
 		s.routingKeyInfoCache.Remove(stmt)
-		host.Mark(inflight.err)
 		return nil, inflight.err
 	}
 
-	// Mark host as OK
-	host.Mark(nil)
+	// TODO: it would be nice to mark hosts here but as we are not using the policies
+	// to fetch hosts we cant
 
 	if info.request.colCount == 0 {
 		// no arguments, no routing key, and no error
@@ -455,6 +462,7 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 		indexes: make([]int, size),
 		types:   make([]TypeInfo, size),
 	}
+
 	for keyIndex, keyColumn := range partitionKey {
 		// set an indicator for checking if the mapping is missing
 		routingKeyInfo.indexes[keyIndex] = -1
@@ -482,6 +490,10 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 	return routingKeyInfo, nil
 }
 
+func (b *Batch) execute(conn *Conn) *Iter {
+	return conn.executeBatch(b)
+}
+
 func (s *Session) executeBatch(batch *Batch) *Iter {
 	// fail fast
 	if s.Closed() {
@@ -495,45 +507,9 @@ func (s *Session) executeBatch(batch *Batch) *Iter {
 		return &Iter{err: ErrTooManyStmts}
 	}
 
-	var iter *Iter
-	batch.attempts = 0
-	batch.totalLatency = 0
-	for {
-		host, conn := s.pool.Pick(nil)
-
-		batch.attempts++
-		if conn == nil {
-			if batch.rt == nil || !batch.rt.Attempt(batch) {
-				// Assign the error unavailable and break loop
-				iter = &Iter{err: ErrNoConnections}
-				break
-			}
-
-			continue
-		}
-
-		if conn == nil {
-			iter = &Iter{err: ErrNoConnections}
-			break
-		}
-
-		t := time.Now()
-
-		iter = conn.executeBatch(batch)
-
-		batch.totalLatency += time.Since(t).Nanoseconds()
-		// Exit loop if operation executed correctly
-		if iter.err == nil {
-			host.Mark(nil)
-			break
-		}
-
-		// Mark host with error if returned from Close
-		host.Mark(iter.Close())
-
-		if batch.rt == nil || !batch.rt.Attempt(batch) {
-			break
-		}
+	iter, err := s.executor.executeQuery(batch)
+	if err != nil {
+		return &Iter{err: err}
 	}
 
 	return iter
@@ -680,6 +656,20 @@ func (q *Query) RoutingKey(routingKey []byte) *Query {
 	return q
 }
 
+func (q *Query) execute(conn *Conn) *Iter {
+	return conn.executeQuery(q)
+}
+
+func (q *Query) attempt(d time.Duration) {
+	q.attempts++
+	q.totalLatency += d.Nanoseconds()
+	// TODO: track latencies per host and things as well instead of just total
+}
+
+func (q *Query) retryPolicy() RetryPolicy {
+	return q.rt
+}
+
 // GetRoutingKey gets the routing key to use for routing this query. If
 // a routing key has not been explicitly set, then the routing key will
 // be constructed if possible using the keyspace's schema and the query
@@ -689,6 +679,11 @@ func (q *Query) RoutingKey(routingKey []byte) *Query {
 func (q *Query) GetRoutingKey() ([]byte, error) {
 	if q.routingKey != nil {
 		return q.routingKey, nil
+	} else if q.binding != nil && len(q.values) == 0 {
+		// If this query was created using session.Bind we wont have the query
+		// values yet, so we have to pass down to the next policy.
+		// TODO: Remove this and handle this case
+		return nil, nil
 	}
 
 	// try to determine the routing key
@@ -816,8 +811,7 @@ func (q *Query) NoSkipMetadata() *Query {
 
 // Exec executes the query without returning any rows.
 func (q *Query) Exec() error {
-	iter := q.Iter()
-	return iter.Close()
+	return q.Iter().Close()
 }
 
 func isUseStatement(stmt string) bool {
@@ -1107,6 +1101,10 @@ func (b *Batch) Bind(stmt string, bind func(q *QueryInfo) ([]interface{}, error)
 	b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, binding: bind})
 }
 
+func (b *Batch) retryPolicy() RetryPolicy {
+	return b.rt
+}
+
 // RetryPolicy sets the retry policy to use when executing the batch operation
 func (b *Batch) RetryPolicy(r RetryPolicy) *Batch {
 	b.rt = r
@@ -1141,6 +1139,17 @@ func (b *Batch) DefaultTimestamp(enable bool) *Batch {
 	return b
 }
 
+func (b *Batch) attempt(d time.Duration) {
+	b.attempts++
+	b.totalLatency += d.Nanoseconds()
+	// TODO: track latencies per host and things as well instead of just total
+}
+
+func (b *Batch) GetRoutingKey() ([]byte, error) {
+	// TODO: use the first statement in the batch as the routing key?
+	return nil, nil
+}
+
 type BatchType byte
 
 const (
@@ -1285,7 +1294,7 @@ var (
 	ErrTooManyStmts  = errors.New("too many statements")
 	ErrUseStmt       = errors.New("use statements aren't supported. Please see https://github.com/gocql/gocql for explaination.")
 	ErrSessionClosed = errors.New("session has been closed")
-	ErrNoConnections = errors.New("no connections available")
+	ErrNoConnections = errors.New("qocql: no hosts available in the pool")
 	ErrNoKeyspace    = errors.New("no keyspace provided")
 	ErrNoMetadata    = errors.New("no metadata available")
 )

+ 7 - 2
session_test.go

@@ -12,11 +12,16 @@ func TestSessionAPI(t *testing.T) {
 	cfg := &ClusterConfig{}
 
 	s := &Session{
-		cfg:  *cfg,
-		cons: Quorum,
+		cfg:    *cfg,
+		cons:   Quorum,
+		policy: RoundRobinHostPolicy(),
 	}
 
 	s.pool = cfg.PoolConfig.buildPool(s)
+	s.executor = &queryExecutor{
+		pool:   s.pool,
+		policy: s.policy,
+	}
 	defer s.Close()
 
 	s.SetConsistency(All)