Browse Source

changed the cluster interface a bit. keyspace changes are now propagated correctly.

Christoph Hack 12 years ago
parent
commit
f22ff84aa7
6 changed files with 242 additions and 85 deletions
  1. 150 55
      cluster.go
  2. 74 23
      conn.go
  3. 3 3
      gocql_test.go
  4. 0 3
      gocql_test/main.go
  5. 0 1
      session.go
  6. 15 0
      topology.go

+ 150 - 55
cluster.go

@@ -17,110 +17,205 @@ import (
 // It has a varity of attributes that can be used to modify the behavior
 // It has a varity of attributes that can be used to modify the behavior
 // to fit the most common use cases. Applications that requre a different
 // to fit the most common use cases. Applications that requre a different
 // a setup should compose the nodes on their own.
 // a setup should compose the nodes on their own.
-type Cluster struct {
-	Hosts       []string
-	CQLVersion  string
-	Timeout     time.Duration
-	DefaultPort int
-	Keyspace    string
-	ConnPerHost int
-	DelayMin    time.Duration
-	DelayMax    time.Duration
+type ClusterConfig struct {
+	Hosts        []string
+	CQLVersion   string
+	ProtoVersion int
+	Timeout      time.Duration
+	DefaultPort  int
+	Keyspace     string
+	NumConn      int
+	NumStreams   int
+	DelayMin     time.Duration
+	DelayMax     time.Duration
+	StartupMin   int
 }
 }
 
 
-func NewCluster(hosts ...string) *Cluster {
-	c := &Cluster{
-		Hosts:       hosts,
-		CQLVersion:  "3.0.0",
-		Timeout:     200 * time.Millisecond,
-		DefaultPort: 9042,
-		ConnPerHost: 2,
+func NewCluster(hosts ...string) *ClusterConfig {
+	cfg := &ClusterConfig{
+		Hosts:        hosts,
+		CQLVersion:   "3.0.0",
+		ProtoVersion: 2,
+		Timeout:      200 * time.Millisecond,
+		DefaultPort:  9042,
+		NumConn:      2,
+		DelayMin:     1 * time.Second,
+		DelayMax:     10 * time.Minute,
+		StartupMin:   len(hosts)/2 + 1,
 	}
 	}
-	return c
+	return cfg
 }
 }
 
 
-func (c *Cluster) CreateSession() *Session {
-	return NewSession(newClusterNode(c))
+func (cfg *ClusterConfig) CreateSession() *Session {
+	impl := &clusterImpl{
+		cfg:      *cfg,
+		hostPool: NewRoundRobin(),
+		connPool: make(map[string]*RoundRobin),
+	}
+	impl.wgStart.Add(1)
+	impl.startup()
+	impl.wgStart.Wait()
+	return NewSession(impl)
 }
 }
 
 
-type clusterNode struct {
-	cfg      Cluster
+type clusterImpl struct {
+	cfg      ClusterConfig
 	hostPool *RoundRobin
 	hostPool *RoundRobin
 	connPool map[string]*RoundRobin
 	connPool map[string]*RoundRobin
-	closed   bool
-	mu       sync.Mutex
+	mu       sync.RWMutex
+
+	conns []*Conn
+
+	started bool
+	wgStart sync.WaitGroup
+
+	quit     bool
+	quitWait chan bool
+	quitOnce sync.Once
+
+	keyspace string
 }
 }
 
 
-func newClusterNode(cfg *Cluster) *clusterNode {
-	c := &clusterNode{
-		cfg:      *cfg,
-		hostPool: NewRoundRobin(),
-		connPool: make(map[string]*RoundRobin),
-	}
+func (c *clusterImpl) startup() {
 	for i := 0; i < len(c.cfg.Hosts); i++ {
 	for i := 0; i < len(c.cfg.Hosts); i++ {
 		addr := strings.TrimSpace(c.cfg.Hosts[i])
 		addr := strings.TrimSpace(c.cfg.Hosts[i])
 		if strings.IndexByte(addr, ':') < 0 {
 		if strings.IndexByte(addr, ':') < 0 {
 			addr = fmt.Sprintf("%s:%d", addr, c.cfg.DefaultPort)
 			addr = fmt.Sprintf("%s:%d", addr, c.cfg.DefaultPort)
 		}
 		}
-		for j := 0; j < c.cfg.ConnPerHost; j++ {
+		for j := 0; j < c.cfg.NumConn; j++ {
 			go c.connect(addr)
 			go c.connect(addr)
 		}
 		}
 	}
 	}
-	<-time.After(c.cfg.Timeout)
-	return c
 }
 }
 
 
-func (c *clusterNode) connect(addr string) {
+func (c *clusterImpl) connect(addr string) {
+	cfg := ConnConfig{
+		ProtoVersion: 2,
+		CQLVersion:   c.cfg.CQLVersion,
+		Timeout:      c.cfg.Timeout,
+		NumStreams:   c.cfg.NumStreams,
+	}
 	delay := c.cfg.DelayMin
 	delay := c.cfg.DelayMin
 	for {
 	for {
-		conn, err := Connect(addr, c.cfg.CQLVersion, c.cfg.Timeout)
+		conn, err := Connect(addr, cfg, c)
 		if err != nil {
 		if err != nil {
-			fmt.Println(err)
-			<-time.After(delay)
-			if delay *= 2; delay > c.cfg.DelayMax {
-				delay = c.cfg.DelayMax
+			select {
+			case <-time.After(delay):
+				if delay *= 2; delay > c.cfg.DelayMax {
+					delay = c.cfg.DelayMax
+				}
+				continue
+			case <-c.quitWait:
+				return
 			}
 			}
-			continue
 		}
 		}
-		c.addConn(addr, conn)
+		c.addConn(conn, "")
 		return
 		return
 	}
 	}
 }
 }
 
 
-func (c *clusterNode) addConn(addr string, conn *Conn) {
+func (c *clusterImpl) changeKeyspace(conn *Conn, keyspace string, connected bool) {
+	if err := conn.UseKeyspace(keyspace); err != nil {
+		conn.Close()
+		if connected {
+			c.removeConn(conn)
+		}
+		go c.connect(conn.Address())
+	}
+	if !connected {
+		c.addConn(conn, keyspace)
+	}
+}
+
+func (c *clusterImpl) addConn(conn *Conn, keyspace string) {
 	c.mu.Lock()
 	c.mu.Lock()
 	defer c.mu.Unlock()
 	defer c.mu.Unlock()
-	connPool := c.connPool[addr]
+	if c.quit {
+		conn.Close()
+		return
+	}
+	if keyspace != c.keyspace && c.keyspace != "" {
+		go c.changeKeyspace(conn, c.keyspace, false)
+		return
+	}
+	connPool := c.connPool[conn.Address()]
 	if connPool == nil {
 	if connPool == nil {
 		connPool = NewRoundRobin()
 		connPool = NewRoundRobin()
-		c.connPool[addr] = connPool
+		c.connPool[conn.Address()] = connPool
 		c.hostPool.AddNode(connPool)
 		c.hostPool.AddNode(connPool)
+		if !c.started && c.hostPool.Size() >= c.cfg.StartupMin {
+			c.started = true
+			c.wgStart.Done()
+		}
 	}
 	}
 	connPool.AddNode(conn)
 	connPool.AddNode(conn)
-	go func() {
-		conn.Serve()
-		c.removeConn(addr, conn)
-	}()
+	c.conns = append(c.conns, conn)
 }
 }
 
 
-func (c *clusterNode) removeConn(addr string, conn *Conn) {
+func (c *clusterImpl) removeConn(conn *Conn) {
 	c.mu.Lock()
 	c.mu.Lock()
 	defer c.mu.Unlock()
 	defer c.mu.Unlock()
-	pool := c.connPool[addr]
-	if pool == nil {
+	conn.Close()
+	connPool := c.connPool[conn.addr]
+	if connPool == nil {
+		return
+	}
+	connPool.RemoveNode(conn)
+	if connPool.Size() == 0 {
+		c.hostPool.RemoveNode(connPool)
+	}
+	for i := 0; i < len(c.conns); i++ {
+		if c.conns[i] == conn {
+			last := len(c.conns) - 1
+			c.conns[i], c.conns[last] = c.conns[last], c.conns[i]
+			c.conns = c.conns[:last]
+		}
+	}
+}
+
+func (c *clusterImpl) HandleError(conn *Conn, err error, closed bool) {
+	if !closed {
 		return
 		return
 	}
 	}
-	pool.RemoveNode(conn)
+	c.removeConn(conn)
+	go c.connect(conn.Address())
 }
 }
 
 
-func (c *clusterNode) ExecuteQuery(qry *Query) (*Iter, error) {
+func (c *clusterImpl) HandleKeyspace(conn *Conn, keyspace string) {
+	c.mu.Lock()
+	if c.keyspace == keyspace {
+		c.mu.Unlock()
+		return
+	}
+	c.keyspace = keyspace
+	conns := make([]*Conn, len(c.conns))
+	copy(conns, c.conns)
+	c.mu.Unlock()
+
+	for i := 0; i < len(conns); i++ {
+		if conns[i] == conn {
+			continue
+		}
+		c.changeKeyspace(conns[i], keyspace, true)
+	}
+}
+
+func (c *clusterImpl) ExecuteQuery(qry *Query) (*Iter, error) {
 	return c.hostPool.ExecuteQuery(qry)
 	return c.hostPool.ExecuteQuery(qry)
 }
 }
 
 
-func (c *clusterNode) ExecuteBatch(batch *Batch) error {
+func (c *clusterImpl) ExecuteBatch(batch *Batch) error {
 	return c.hostPool.ExecuteBatch(batch)
 	return c.hostPool.ExecuteBatch(batch)
 }
 }
 
 
-func (c *clusterNode) Close() {
-	c.hostPool.Close()
+func (c *clusterImpl) Close() {
+	c.quitOnce.Do(func() {
+		c.mu.Lock()
+		defer c.mu.Unlock()
+		c.quit = true
+		close(c.quitWait)
+		for i := 0; i < len(c.conns); i++ {
+			c.conns[i].Close()
+		}
+	})
 }
 }

+ 74 - 23
conn.go

@@ -13,6 +13,24 @@ import (
 
 
 const defaultFrameSize = 4096
 const defaultFrameSize = 4096
 
 
+type Cluster interface {
+	//HandleAuth(addr, method string) ([]byte, Challenger, error)
+	HandleError(conn *Conn, err error, closed bool)
+	HandleKeyspace(conn *Conn, keyspace string)
+}
+
+/* type Challenger interface {
+	Challenge(data []byte) ([]byte, error)
+} */
+
+type ConnConfig struct {
+	ProtoVersion int
+	CQLVersion   string
+	Keyspace     string
+	Timeout      time.Duration
+	NumStreams   int
+}
+
 // Conn is a single connection to a Cassandra node. It can be used to execute
 // Conn is a single connection to a Cassandra node. It can be used to execute
 // queries, but users are usually advised to use a more reliable, higher
 // queries, but users are usually advised to use a more reliable, higher
 // level API.
 // level API.
@@ -24,41 +42,51 @@ type Conn struct {
 	calls []callReq
 	calls []callReq
 	nwait int32
 	nwait int32
 
 
-	prepMu   sync.Mutex
-	prep     map[string]*queryInfo
+	prepMu sync.Mutex
+	prep   map[string]*queryInfo
+
+	cluster  Cluster
+	addr     string
 	keyspace string
 	keyspace string
 }
 }
 
 
 // Connect establishes a connection to a Cassandra node.
 // Connect establishes a connection to a Cassandra node.
 // You must also call the Serve method before you can execute any queries.
 // You must also call the Serve method before you can execute any queries.
-func Connect(addr, version string, timeout time.Duration) (*Conn, error) {
-	conn, err := net.DialTimeout("tcp", addr, timeout)
+func Connect(addr string, cfg ConnConfig, cluster Cluster) (*Conn, error) {
+	conn, err := net.DialTimeout("tcp", addr, cfg.Timeout)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
+	if cfg.NumStreams <= 0 || cfg.NumStreams > 128 {
+		cfg.NumStreams = 128
+	}
 	c := &Conn{
 	c := &Conn{
 		conn:    conn,
 		conn:    conn,
-		uniq:    make(chan uint8, 128),
-		calls:   make([]callReq, 128),
+		uniq:    make(chan uint8, cfg.NumStreams),
+		calls:   make([]callReq, cfg.NumStreams),
 		prep:    make(map[string]*queryInfo),
 		prep:    make(map[string]*queryInfo),
-		timeout: timeout,
+		timeout: cfg.Timeout,
+		addr:    conn.RemoteAddr().String(),
+		cluster: cluster,
 	}
 	}
 	for i := 0; i < cap(c.uniq); i++ {
 	for i := 0; i < cap(c.uniq); i++ {
 		c.uniq <- uint8(i)
 		c.uniq <- uint8(i)
 	}
 	}
 
 
-	if err := c.init(version); err != nil {
+	if err := c.startup(&cfg); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
+	go c.serve()
+
 	return c, nil
 	return c, nil
 }
 }
 
 
-func (c *Conn) init(version string) error {
+func (c *Conn) startup(cfg *ConnConfig) error {
 	req := make(frame, headerSize, defaultFrameSize)
 	req := make(frame, headerSize, defaultFrameSize)
 	req.setHeader(protoRequest, 0, 0, opStartup)
 	req.setHeader(protoRequest, 0, 0, opStartup)
 	req.writeStringMap(map[string]string{
 	req.writeStringMap(map[string]string{
-		"CQL_VERSION": version,
+		"CQL_VERSION": cfg.CQLVersion,
 	})
 	})
 	resp, err := c.callSimple(req)
 	resp, err := c.callSimple(req)
 	if err != nil {
 	if err != nil {
@@ -69,21 +97,13 @@ func (c *Conn) init(version string) error {
 		return ErrProtocol
 		return ErrProtocol
 	}
 	}
 
 
-	/*	if cfg.Keyspace != "" {
-		qry := &Query{stmt: "USE " + cfg.Keyspace}
-		frame, err = c.executeQuery(qry)
-		if err != nil {
-			return err
-		}
-	} */
-
 	return nil
 	return nil
 }
 }
 
 
 // Serve starts the stream multiplexer for this connection, which is required
 // Serve starts the stream multiplexer for this connection, which is required
 // to execute any queries. This method runs as long as the connection is
 // to execute any queries. This method runs as long as the connection is
 // open and is therefore usually called in a separate goroutine.
 // open and is therefore usually called in a separate goroutine.
-func (c *Conn) Serve() error {
+func (c *Conn) serve() {
 	var err error
 	var err error
 	for {
 	for {
 		var frame frame
 		var frame frame
@@ -101,7 +121,7 @@ func (c *Conn) Serve() error {
 			req.resp <- callResp{nil, err}
 			req.resp <- callResp{nil, err}
 		}
 		}
 	}
 	}
-	return err
+	c.cluster.HandleError(c, err, true)
 }
 }
 
 
 func (c *Conn) recv() (frame, error) {
 func (c *Conn) recv() (frame, error) {
@@ -237,9 +257,6 @@ func (c *Conn) switchKeyspace(keyspace string) error {
 }
 }
 
 
 func (c *Conn) ExecuteQuery(qry *Query) (*Iter, error) {
 func (c *Conn) ExecuteQuery(qry *Query) (*Iter, error) {
-	if err := c.switchKeyspace(qry.Keyspace); err != nil {
-		return nil, err
-	}
 	frame, err := c.executeQuery(qry)
 	frame, err := c.executeQuery(qry)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -300,6 +317,10 @@ func (c *Conn) Close() {
 	c.conn.Close()
 	c.conn.Close()
 }
 }
 
 
+func (c *Conn) Address() string {
+	return c.addr
+}
+
 func (c *Conn) executeQuery(query *Query) (frame, error) {
 func (c *Conn) executeQuery(query *Query) (frame, error) {
 	var info *queryInfo
 	var info *queryInfo
 	if len(query.Args) > 0 {
 	if len(query.Args) > 0 {
@@ -340,6 +361,15 @@ func (c *Conn) executeQuery(query *Query) (frame, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
+	if frame[3] == opResult {
+		f := frame
+		f.skipHeader()
+		if f.readInt() == resultKindKeyspace {
+			keyspace := f.readString()
+			c.cluster.HandleKeyspace(c, keyspace)
+		}
+	}
+
 	if frame[3] == opError {
 	if frame[3] == opError {
 		frame.skipHeader()
 		frame.skipHeader()
 		code := frame.readInt()
 		code := frame.readInt()
@@ -349,6 +379,27 @@ func (c *Conn) executeQuery(query *Query) (frame, error) {
 	return frame, nil
 	return frame, nil
 }
 }
 
 
+func (c *Conn) UseKeyspace(keyspace string) error {
+	frame := make(frame, headerSize, defaultFrameSize)
+	frame.setHeader(protoRequest, 0, 0, opQuery)
+	frame.writeLongString("USE " + keyspace)
+	frame.writeConsistency(1)
+	frame.writeByte(0)
+
+	frame, err := c.call(frame)
+	if err != nil {
+		return err
+	}
+
+	if frame[3] == opError {
+		frame.skipHeader()
+		code := frame.readInt()
+		desc := frame.readString()
+		return Error{code, desc}
+	}
+	return nil
+}
+
 type queryInfo struct {
 type queryInfo struct {
 	id   []byte
 	id   []byte
 	args []ColumnInfo
 	args []ColumnInfo

+ 3 - 3
gocql_test.go

@@ -165,9 +165,9 @@ func TestRoundRobin(t *testing.T) {
 		addrs[i] = servers[i].Address
 		addrs[i] = servers[i].Address
 		defer servers[i].Stop()
 		defer servers[i].Stop()
 	}
 	}
-	db := NewCluster(addrs...).CreateSession()
-
-	time.Sleep(1 * time.Second)
+	cluster := NewCluster(addrs...)
+	cluster.StartupMin = len(addrs)
+	db := cluster.CreateSession()
 
 
 	var wg sync.WaitGroup
 	var wg sync.WaitGroup
 	wg.Add(5)
 	wg.Add(5)

+ 0 - 3
gocql_test/main.go

@@ -1,7 +1,6 @@
 package main
 package main
 
 
 import (
 import (
-	"fmt"
 	"log"
 	"log"
 	"reflect"
 	"reflect"
 	"sort"
 	"sort"
@@ -15,7 +14,6 @@ var session *gocql.Session
 
 
 func init() {
 func init() {
 	cluster := gocql.NewCluster("127.0.0.1")
 	cluster := gocql.NewCluster("127.0.0.1")
-	cluster.ConnPerHost = 1
 	session = cluster.CreateSession()
 	session = cluster.CreateSession()
 }
 }
 
 
@@ -58,7 +56,6 @@ func initSchema() error {
 			attachments map<varchar, text>,
 			attachments map<varchar, text>,
 			PRIMARY KEY (title, revid)
 			PRIMARY KEY (title, revid)
 		)`).Exec(); err != nil {
 		)`).Exec(); err != nil {
-		fmt.Println("create err")
 		return err
 		return err
 	}
 	}
 
 

+ 0 - 1
session.go

@@ -68,7 +68,6 @@ type Query struct {
 	Token    string
 	Token    string
 	PageSize int
 	PageSize int
 	Trace    bool
 	Trace    bool
-	Keyspace string
 }
 }
 
 
 func NewQuery(stmt string, args ...interface{}) *Query {
 func NewQuery(stmt string, args ...interface{}) *Query {

+ 15 - 0
topology.go

@@ -40,6 +40,21 @@ func (r *RoundRobin) RemoveNode(node Node) {
 	r.mu.Unlock()
 	r.mu.Unlock()
 }
 }
 
 
+func (r *RoundRobin) Size() int {
+	r.mu.RLock()
+	n := len(r.pool)
+	r.mu.RUnlock()
+	return n
+}
+
+func (r *RoundRobin) GetPool() []Node {
+	r.mu.RLock()
+	pool := make([]Node, len(r.pool))
+	copy(pool, r.pool)
+	r.mu.RUnlock()
+	return pool
+}
+
 func (r *RoundRobin) ExecuteQuery(qry *Query) (*Iter, error) {
 func (r *RoundRobin) ExecuteQuery(qry *Query) (*Iter, error) {
 	node := r.pick()
 	node := r.pick()
 	if node == nil {
 	if node == nil {