ソースを参照

add internal control connection

Add a control connection which is used initial for querying
system tables. It will periodically ping the host to ensure
that the connection is alive and reconnect on failures.

Updates #359
Chris Bannister 10 年 前
コミット
e5248ed2fb
10 ファイル変更316 行追加135 行削除
  1. 7 13
      cassandra_test.go
  2. 3 0
      cluster.go
  3. 6 64
      conn.go
  4. 1 0
      conn_test.go
  5. 4 4
      connectionpool.go
  6. 223 0
      control.go
  7. 11 0
      frame.go
  8. 23 17
      host_source.go
  9. 8 28
      metadata.go
  10. 30 9
      session.go

+ 7 - 13
cassandra_test.go

@@ -67,7 +67,7 @@ func createTable(s *Session, table string) error {
 		return err
 	}
 
-	return c.awaitSchemaAgreement()
+	return s.control.awaitSchemaAgreement()
 }
 
 func createCluster() *ClusterConfig {
@@ -101,28 +101,22 @@ func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
 		tb.Fatal("createSession:", err)
 	}
 
-	// should reuse the same conn apparently
-	conn := session.pool.Pick(nil)
-	if conn == nil {
-		tb.Fatal("no connections available in the pool")
-	}
-
-	err = conn.executeQuery(session.Query(`DROP KEYSPACE IF EXISTS ` + keyspace).Consistency(All)).Close()
+	err = session.Query(`DROP KEYSPACE IF EXISTS ` + keyspace).Exec()
 	if err != nil {
 		tb.Fatal(err)
 	}
 
-	if err = conn.awaitSchemaAgreement(); err != nil {
+	if err = session.control.awaitSchemaAgreement(); err != nil {
 		tb.Fatal(err)
 	}
 
-	query := session.Query(fmt.Sprintf(`CREATE KEYSPACE %s
+	err = session.Query(fmt.Sprintf(`CREATE KEYSPACE %s
 	WITH replication = {
 		'class' : 'SimpleStrategy',
 		'replication_factor' : %d
-	}`, keyspace, *flagRF)).Consistency(All)
+	}`, keyspace, *flagRF)).Exec()
 
-	if err = conn.executeQuery(query).Close(); err != nil {
+	if err != nil {
 		tb.Fatal(err)
 	}
 
@@ -130,7 +124,7 @@ func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
 	// cluster to settle.
 	// TODO(zariel): use events here to know when the cluster has resolved to the
 	// new schema version
-	if err = conn.awaitSchemaAgreement(); err != nil {
+	if err = session.control.awaitSchemaAgreement(); err != nil {
 		tb.Fatal(err)
 	}
 }

+ 3 - 0
cluster.go

@@ -106,6 +106,9 @@ type ClusterConfig struct {
 	// PoolConfig configures the underlying connection pool, allowing the
 	// configuration of host selection and connection selection policies.
 	PoolConfig PoolConfig
+
+	// internal config for testing
+	disableControlConn bool
 }
 
 // NewCluster generates a new config for the default cluster implementation.

+ 6 - 64
conn.go

@@ -99,6 +99,7 @@ type Conn struct {
 	conn    net.Conn
 	r       *bufio.Reader
 	timeout time.Duration
+	cfg     *ConnConfig
 
 	headerBuf []byte
 
@@ -121,7 +122,7 @@ type Conn struct {
 
 // Connect establishes a connection to a Cassandra node.
 // You must also call the Serve method before you can execute any queries.
-func Connect(addr string, cfg ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
+func Connect(addr string, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
 	var (
 		err  error
 		conn net.Conn
@@ -166,6 +167,7 @@ func Connect(addr string, cfg ConnConfig, errorHandler ConnErrorHandler) (*Conn,
 	c := &Conn{
 		conn:         conn,
 		r:            bufio.NewReader(conn),
+		cfg:          cfg,
 		uniq:         make(chan int, cfg.NumStreams),
 		calls:        make([]callReq, cfg.NumStreams),
 		timeout:      cfg.Timeout,
@@ -191,7 +193,7 @@ func Connect(addr string, cfg ConnConfig, errorHandler ConnErrorHandler) (*Conn,
 
 	go c.serve()
 
-	if err := c.startup(&cfg); err != nil {
+	if err := c.startup(); err != nil {
 		conn.Close()
 		return nil, err
 	}
@@ -231,9 +233,9 @@ func (c *Conn) Read(p []byte) (n int, err error) {
 	return
 }
 
-func (c *Conn) startup(cfg *ConnConfig) error {
+func (c *Conn) startup() error {
 	m := map[string]string{
-		"CQL_VERSION": cfg.CQLVersion,
+		"CQL_VERSION": c.cfg.CQLVersion,
 	}
 
 	if c.compressor != nil {
@@ -884,66 +886,6 @@ func (c *Conn) setKeepalive(d time.Duration) error {
 	return nil
 }
 
-func (c *Conn) awaitSchemaAgreement() (err error) {
-
-	const (
-		// TODO(zariel): if we export this make this configurable
-		maxWaitTime = 60 * time.Second
-
-		peerSchemas  = "SELECT schema_version FROM system.peers"
-		localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"
-	)
-
-	endDeadline := time.Now().Add(maxWaitTime)
-
-	for time.Now().Before(endDeadline) {
-		iter := c.executeQuery(&Query{
-			stmt: peerSchemas,
-			cons: One,
-		})
-
-		versions := make(map[string]struct{})
-
-		var schemaVersion string
-		for iter.Scan(&schemaVersion) {
-			versions[schemaVersion] = struct{}{}
-			schemaVersion = ""
-		}
-
-		if err = iter.Close(); err != nil {
-			goto cont
-		}
-
-		iter = c.executeQuery(&Query{
-			stmt: localSchemas,
-			cons: One,
-		})
-
-		for iter.Scan(&schemaVersion) {
-			versions[schemaVersion] = struct{}{}
-			schemaVersion = ""
-		}
-
-		if err = iter.Close(); err != nil {
-			goto cont
-		}
-
-		if len(versions) <= 1 {
-			return nil
-		}
-
-	cont:
-		time.Sleep(200 * time.Millisecond)
-	}
-
-	if err != nil {
-		return
-	}
-
-	// not exported
-	return errors.New("gocql: cluster schema versions not consistent")
-}
-
 type inflightPrepare struct {
 	info QueryInfo
 	err  error

+ 1 - 0
conn_test.go

@@ -341,6 +341,7 @@ func TestRoundRobinConnPoolRoundRobin(t *testing.T) {
 	cluster := NewCluster(addrs...)
 	cluster.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
 	cluster.PoolConfig.ConnSelectionPolicy = RoundRobinConnPolicy()
+	cluster.disableControlConn = true
 
 	db, err := cluster.CreateSession()
 	if err != nil {

+ 4 - 4
connectionpool.go

@@ -60,7 +60,7 @@ func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
 type policyConnPool struct {
 	port     int
 	numConns int
-	connCfg  ConnConfig
+	connCfg  *ConnConfig
 	keyspace string
 
 	mu            sync.RWMutex
@@ -88,7 +88,7 @@ func newPolicyConnPool(cfg *ClusterConfig, hostPolicy HostSelectionPolicy,
 	pool := &policyConnPool{
 		port:     cfg.Port,
 		numConns: cfg.NumConns,
-		connCfg: ConnConfig{
+		connCfg: &ConnConfig{
 			ProtoVersion:  cfg.ProtoVersion,
 			CQLVersion:    cfg.CQLVersion,
 			Timeout:       cfg.Timeout,
@@ -212,7 +212,7 @@ type hostConnPool struct {
 	port     int
 	addr     string
 	size     int
-	connCfg  ConnConfig
+	connCfg  *ConnConfig
 	keyspace string
 	policy   ConnSelectionPolicy
 	// protection for conns, closed, filling
@@ -222,7 +222,7 @@ type hostConnPool struct {
 	filling bool
 }
 
-func newHostConnPool(host string, port int, size int, connCfg ConnConfig,
+func newHostConnPool(host string, port int, size int, connCfg *ConnConfig,
 	keyspace string, policy ConnSelectionPolicy) *hostConnPool {
 
 	pool := &hostConnPool{

+ 223 - 0
control.go

@@ -0,0 +1,223 @@
+package gocql
+
+import (
+	"errors"
+	"fmt"
+	"sync/atomic"
+	"time"
+)
+
+type controlConn struct {
+	session *Session
+
+	conn       atomic.Value
+	connecting uint64
+
+	retry RetryPolicy
+
+	quit chan struct{}
+}
+
+func createControlConn(session *Session) *controlConn {
+	control := &controlConn{
+		session: session,
+		quit:    make(chan struct{}),
+		retry:   &SimpleRetryPolicy{NumRetries: 3},
+	}
+
+	control.conn.Store((*Conn)(nil))
+	control.reconnect()
+	go control.heartBeat()
+
+	return control
+}
+
+func (c *controlConn) heartBeat() {
+	for {
+		select {
+		case <-c.quit:
+			return
+		case <-time.After(5 * time.Second):
+		}
+
+		resp, err := c.writeFrame(&writeOptionsFrame{})
+		if err != nil {
+			goto reconn
+		}
+
+		switch resp.(type) {
+		case *supportedFrame:
+			continue
+		case error:
+			goto reconn
+		default:
+			panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp))
+		}
+
+	reconn:
+		c.reconnect()
+		time.Sleep(5 * time.Second)
+		continue
+
+	}
+}
+
+func (c *controlConn) reconnect() {
+	if !atomic.CompareAndSwapUint64(&c.connecting, 0, 1) {
+		return
+	}
+
+	success := false
+	defer func() {
+		// debounce reconnect a little
+		if success {
+			go func() {
+				time.Sleep(500 * time.Millisecond)
+				atomic.StoreUint64(&c.connecting, 0)
+			}()
+		} else {
+			atomic.StoreUint64(&c.connecting, 0)
+		}
+	}()
+
+	oldConn := c.conn.Load().(*Conn)
+
+	// TODO: should have our own roundrobbin for hosts so that we can try each
+	// in succession and guantee that we get a different host each time.
+	conn := c.session.pool.Pick(nil)
+	if conn == nil {
+		return
+	}
+
+	newConn, err := Connect(conn.addr, conn.cfg, c)
+	if err != nil {
+		// TODO: add log handler for things like this
+		return
+	}
+
+	c.conn.Store(newConn)
+	success = true
+
+	if oldConn != nil {
+		oldConn.Close()
+	}
+}
+
+func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
+	if !closed {
+		return
+	}
+
+	oldConn := c.conn.Load().(*Conn)
+	if oldConn != conn {
+		panic("controlConn: got error for connection which we did not create")
+	}
+
+	c.reconnect()
+}
+
+func (c *controlConn) writeFrame(w frameWriter) (frame, error) {
+	conn := c.conn.Load().(*Conn)
+	if conn == nil {
+		return nil, errNoControl
+	}
+
+	framer, err := conn.exec(w, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	return framer.parseFrame()
+}
+
+// query will return nil if the connection is closed or nil
+func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) {
+	q := c.session.Query(statement, values...).Consistency(One)
+
+	const maxConnectAttempts = 5
+	connectAttempts := 0
+
+	for {
+		conn := c.conn.Load().(*Conn)
+		if conn == nil {
+			if connectAttempts > maxConnectAttempts {
+				return &Iter{err: errNoControl}
+			}
+
+			connectAttempts++
+
+			c.reconnect()
+			continue
+		}
+
+		iter = conn.executeQuery(q)
+		if iter.err == nil {
+			break
+		}
+
+		if !c.retry.Attempt(q) {
+			break
+		}
+	}
+
+	return
+}
+
+func (c *controlConn) awaitSchemaAgreement() (err error) {
+
+	const (
+		// TODO(zariel): if we export this make this configurable
+		maxWaitTime = 60 * time.Second
+
+		peerSchemas  = "SELECT schema_version FROM system.peers"
+		localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"
+	)
+
+	endDeadline := time.Now().Add(maxWaitTime)
+
+	for time.Now().Before(endDeadline) {
+		iter := c.query(peerSchemas)
+
+		versions := make(map[string]struct{})
+
+		var schemaVersion string
+		for iter.Scan(&schemaVersion) {
+			versions[schemaVersion] = struct{}{}
+			schemaVersion = ""
+		}
+
+		if err = iter.Close(); err != nil {
+			goto cont
+		}
+
+		iter = c.query(localSchemas)
+		for iter.Scan(&schemaVersion) {
+			versions[schemaVersion] = struct{}{}
+			schemaVersion = ""
+		}
+
+		if err = iter.Close(); err != nil {
+			goto cont
+		}
+
+		if len(versions) <= 1 {
+			return nil
+		}
+
+	cont:
+		time.Sleep(200 * time.Millisecond)
+	}
+
+	if err != nil {
+		return
+	}
+
+	// not exported
+	return errors.New("gocql: cluster schema versions not consistent")
+}
+func (c *controlConn) close() {
+	// TODO: handle more gracefully
+	close(c.quit)
+}
+
+var errNoControl = errors.New("gocql: no controll connection available")

+ 11 - 0
frame.go

@@ -1397,6 +1397,17 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame) error {
 	return f.finishWrite()
 }
 
+type writeOptionsFrame struct{}
+
+func (w *writeOptionsFrame) writeFrame(framer *framer, streamID int) error {
+	return framer.writeOptionsFrame(streamID, w)
+}
+
+func (f *framer) writeOptionsFrame(stream int, _ *writeOptionsFrame) error {
+	f.writeHeader(f.flags, opOptions, stream)
+	return f.finishWrite()
+}
+
 func (f *framer) readByte() byte {
 	if len(f.rbuf) < 1 {
 		panic(fmt.Errorf("not enough bytes in buffer to read byte require 1 got: %d", len(f.rbuf)))

+ 23 - 17
host_source.go

@@ -27,14 +27,17 @@ type ringDescriber struct {
 func (r *ringDescriber) GetHosts() (hosts []HostInfo, partitioner string, err error) {
 	// we need conn to be the same because we need to query system.peers and system.local
 	// on the same node to get the whole cluster
+
+	iter := r.session.control.query("SELECT data_center, rack, host_id, tokens, partitioner FROM system.local")
+	if iter == nil {
+		return r.prevHosts, r.prevPartitioner, nil
+	}
+
 	conn := r.session.pool.Pick(nil)
 	if conn == nil {
 		return r.prevHosts, r.prevPartitioner, nil
 	}
 
-	query := r.session.Query("SELECT data_center, rack, host_id, tokens, partitioner FROM system.local")
-	iter := conn.executeQuery(query)
-
 	host := HostInfo{}
 	iter.Scan(&host.DataCenter, &host.Rack, &host.HostId, &host.Tokens, &partitioner)
 
@@ -53,8 +56,10 @@ func (r *ringDescriber) GetHosts() (hosts []HostInfo, partitioner string, err er
 
 	hosts = []HostInfo{host}
 
-	query = r.session.Query("SELECT peer, data_center, rack, host_id, tokens FROM system.peers")
-	iter = conn.executeQuery(query)
+	iter = r.session.control.query("SELECT peer, data_center, rack, host_id, tokens FROM system.peers")
+	if iter == nil {
+		return r.prevHosts, r.prevPartitioner, nil
+	}
 
 	host = HostInfo{}
 	for iter.Scan(&host.Peer, &host.DataCenter, &host.Rack, &host.HostId, &host.Tokens) {
@@ -93,20 +98,21 @@ func (h *ringDescriber) run(sleep time.Duration) {
 	}
 
 	for {
+		// if we have 0 hosts this will return the previous list of hosts to
+		// attempt to reconnect to the cluster otherwise we would never find
+		// downed hosts again, could possibly have an optimisation to only
+		// try to add new hosts if GetHosts didnt error and the hosts didnt change.
+		hosts, partitioner, err := h.GetHosts()
+		if err != nil {
+			log.Println("RingDescriber: unable to get ring topology:", err)
+			continue
+		}
+
+		h.session.pool.SetHosts(hosts)
+		h.session.pool.SetPartitioner(partitioner)
+
 		select {
 		case <-time.After(sleep):
-			// if we have 0 hosts this will return the previous list of hosts to
-			// attempt to reconnect to the cluster otherwise we would never find
-			// downed hosts again, could possibly have an optimisation to only
-			// try to add new hosts if GetHosts didnt error and the hosts didnt change.
-			hosts, partitioner, err := h.GetHosts()
-			if err != nil {
-				log.Println("RingDescriber: unable to get ring topology:", err)
-				continue
-			}
-
-			h.session.pool.SetHosts(hosts)
-			h.session.pool.SetPartitioner(partitioner)
 		case <-h.closeChan:
 			return
 		}

+ 8 - 28
metadata.go

@@ -335,30 +335,18 @@ func componentColumnCountOfType(columns map[string]*ColumnMetadata, kind string)
 }
 
 // query only for the keyspace metadata for the specified keyspace from system.schema_keyspace
-func getKeyspaceMetadata(
-	session *Session,
-	keyspaceName string,
-) (*KeyspaceMetadata, error) {
-	query := session.Query(
-		`
+func getKeyspaceMetadata(session *Session, keyspaceName string) (*KeyspaceMetadata, error) {
+	const stmt = `
 		SELECT durable_writes, strategy_class, strategy_options
 		FROM system.schema_keyspaces
-		WHERE keyspace_name = ?
-		`,
-		keyspaceName,
-	)
-	// Set a routing key to avoid GetRoutingKey from computing the routing key
-	// TODO use a separate connection (pool) for system keyspace queries.
-	query.RoutingKey([]byte{})
+		WHERE keyspace_name = ?`
 
 	keyspace := &KeyspaceMetadata{Name: keyspaceName}
 	var strategyOptionsJSON []byte
 
-	err := query.Scan(
-		&keyspace.DurableWrites,
-		&keyspace.StrategyClass,
-		&strategyOptionsJSON,
-	)
+	iter := session.control.query(stmt, keyspaceName)
+	iter.Scan(&keyspace.DurableWrites, &keyspace.StrategyClass, &strategyOptionsJSON)
+	err := iter.Close()
 	if err != nil {
 		return nil, fmt.Errorf("Error querying keyspace schema: %v", err)
 	}
@@ -431,11 +419,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
 		}
 	}
 
-	// Set a routing key to avoid GetRoutingKey from computing the routing key
-	// TODO use a separate connection (pool) for system keyspace queries.
-	query := session.Query(stmt, keyspaceName)
-	query.RoutingKey([]byte{})
-	iter := query.Iter()
+	iter := session.control.query(stmt, keyspaceName)
 
 	tables := []TableMetadata{}
 	table := TableMetadata{Keyspace: keyspaceName}
@@ -560,11 +544,7 @@ func getColumnMetadata(
 
 	var indexOptionsJSON []byte
 
-	query := session.Query(stmt, keyspaceName)
-	// Set a routing key to avoid GetRoutingKey from computing the routing key
-	// TODO use a separate connection (pool) for system keyspace queries.
-	query.RoutingKey([]byte{})
-	iter := query.Iter()
+	iter := session.control.query(stmt, keyspaceName)
 
 	for scan(iter, &column, &indexOptionsJSON) {
 		var err error

+ 30 - 9
session.go

@@ -38,6 +38,8 @@ type Session struct {
 	hostSource          *ringDescriber
 	mu                  sync.RWMutex
 
+	control *controlConn
+
 	cfg ClusterConfig
 
 	closeMu  sync.RWMutex
@@ -86,6 +88,10 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 
 	s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo)
 
+	if !cfg.disableControlConn {
+		s.control = createControlConn(s)
+	}
+
 	if cfg.DiscoverHosts {
 		s.hostSource = &ringDescriber{
 			session:    s,
@@ -188,6 +194,10 @@ func (s *Session) Close() {
 	if s.hostSource != nil {
 		close(s.hostSource.closeChan)
 	}
+
+	if s.control != nil {
+		s.control.close()
+	}
 }
 
 func (s *Session) Closed() bool {
@@ -1055,29 +1065,40 @@ func (t *traceWriter) Trace(traceId []byte) {
 		coordinator string
 		duration    int
 	)
-	t.session.Query(`SELECT coordinator, duration
+	iter := t.session.control.query(`SELECT coordinator, duration
 			FROM system_traces.sessions
-			WHERE session_id = ?`, traceId).
-		Consistency(One).Scan(&coordinator, &duration)
+			WHERE session_id = ?`, traceId)
+
+	iter.Scan(&coordinator, &duration)
+	if err := iter.Close(); err != nil {
+		t.mu.Lock()
+		fmt.Fprintln(t.w, "Error:", err)
+		t.mu.Unlock()
+		return
+	}
 
-	iter := t.session.Query(`SELECT event_id, activity, source, source_elapsed
-			FROM system_traces.events
-			WHERE session_id = ?`, traceId).
-		Consistency(One).Iter()
 	var (
 		timestamp time.Time
 		activity  string
 		source    string
 		elapsed   int
 	)
-	t.mu.Lock()
-	defer t.mu.Unlock()
+
 	fmt.Fprintf(t.w, "Tracing session %016x (coordinator: %s, duration: %v):\n",
 		traceId, coordinator, time.Duration(duration)*time.Microsecond)
+
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
+	iter = t.session.control.query(`SELECT event_id, activity, source, source_elapsed
+			FROM system_traces.events
+			WHERE session_id = ?`, traceId)
+
 	for iter.Scan(&timestamp, &activity, &source, &elapsed) {
 		fmt.Fprintf(t.w, "%s: %s (source: %s, elapsed: %d)\n",
 			timestamp.Format("2006/01/02 15:04:05.999999"), activity, source, elapsed)
 	}
+
 	if err := iter.Close(); err != nil {
 		fmt.Fprintln(t.w, "Error:", err)
 	}