浏览代码

further API improvements

Christoph Hack 12 年之前
父节点
当前提交
ac8fdad0db
共有 7 个文件被更改,包括 230 次插入186 次删除
  1. 82 0
      cluster.go
  2. 12 11
      conn.go
  3. 17 0
      frame.go
  4. 0 23
      gocql.go
  5. 8 16
      gocql_test.go
  6. 83 75
      session.go
  7. 28 61
      topology.go

+ 82 - 0
cluster.go

@@ -0,0 +1,82 @@
+// Copyright (c) 2012 The gocql Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gocql
+
+import (
+	"fmt"
+	"strings"
+	"sync"
+	"time"
+)
+
+// Cluster sets up and maintains the node configuration of a Cassandra
+// cluster.
+//
+// It has a varity of attributes that can be used to modify the behavior
+// to fit the most common use cases. Applications that requre a different
+// a setup should compose the nodes on their own.
+type Cluster struct {
+	Hosts       []string
+	CQLVersion  string
+	Timeout     time.Duration
+	DefaultPort int
+	Keyspace    string
+	ConnPerHost int
+	DelayMin    time.Duration
+	DelayMax    time.Duration
+
+	pool     *RoundRobin
+	initOnce sync.Once
+	boot     sync.WaitGroup
+	bootOnce sync.Once
+}
+
+func NewCluster(hosts ...string) *Cluster {
+	c := &Cluster{
+		Hosts:       hosts,
+		CQLVersion:  "3.0.0",
+		Timeout:     200 * time.Millisecond,
+		DefaultPort: 9042,
+	}
+	return c
+}
+
+func (c *Cluster) init() {
+	for i := 0; i < len(c.Hosts); i++ {
+		addr := strings.TrimSpace(c.Hosts[i])
+		if strings.IndexByte(addr, ':') < 0 {
+			addr = fmt.Sprintf("%s:%d", addr, c.DefaultPort)
+		}
+		go c.connect(addr)
+	}
+	c.pool = NewRoundRobin()
+	<-time.After(c.Timeout)
+}
+
+func (c *Cluster) connect(addr string) {
+	delay := c.DelayMin
+	for {
+		conn, err := Connect(addr, c.CQLVersion, c.Timeout)
+		if err != nil {
+			<-time.After(delay)
+			if delay *= 2; delay > c.DelayMax {
+				delay = c.DelayMax
+			}
+			continue
+		}
+		c.pool.AddNode(conn)
+		go func() {
+			conn.Serve()
+			c.pool.RemoveNode(conn)
+			c.connect(addr)
+		}()
+		return
+	}
+}
+
+func (c *Cluster) CreateSession() *Session {
+	c.initOnce.Do(c.init)
+	return NewSession(c.pool)
+}

+ 12 - 11
conn.go

@@ -5,6 +5,7 @@
 package gocql
 
 import (
+	"fmt"
 	"net"
 	"sync"
 	"sync/atomic"
@@ -30,8 +31,8 @@ type Conn struct {
 
 // Connect establishes a connection to a Cassandra node.
 // You must also call the Serve method before you can execute any queries.
-func Connect(addr string, cfg *Config) (*Conn, error) {
-	conn, err := net.DialTimeout("tcp", addr, cfg.Timeout)
+func Connect(addr, version string, timeout time.Duration) (*Conn, error) {
+	conn, err := net.DialTimeout("tcp", addr, timeout)
 	if err != nil {
 		return nil, err
 	}
@@ -40,24 +41,24 @@ func Connect(addr string, cfg *Config) (*Conn, error) {
 		uniq:    make(chan uint8, 128),
 		calls:   make([]callReq, 128),
 		prep:    make(map[string]*queryInfo),
-		timeout: cfg.Timeout,
+		timeout: timeout,
 	}
 	for i := 0; i < cap(c.uniq); i++ {
 		c.uniq <- uint8(i)
 	}
 
-	if err := c.init(cfg); err != nil {
+	if err := c.init(version); err != nil {
 		return nil, err
 	}
 
 	return c, nil
 }
 
-func (c *Conn) init(cfg *Config) error {
+func (c *Conn) init(version string) error {
 	req := make(frame, headerSize, defaultFrameSize)
 	req.setHeader(protoRequest, 0, 0, opStartup)
 	req.writeStringMap(map[string]string{
-		"CQL_VERSION": cfg.CQLVersion,
+		"CQL_VERSION": version,
 	})
 	resp, err := c.callSimple(req)
 	if err != nil {
@@ -204,7 +205,7 @@ func (c *Conn) prepareStatement(stmt string) (*queryInfo, error) {
 	c.prep[stmt] = info
 	c.prepMu.Unlock()
 
-	frame := make(frame, headerSize, headerSize+512)
+	frame := make(frame, headerSize, defaultFrameSize)
 	frame.setHeader(protoRequest, 0, 0, opPrepare)
 	frame.writeLongString(stmt)
 	frame.setLength(len(frame) - headerSize)
@@ -268,7 +269,7 @@ func (c *Conn) ExecuteBatch(batch *Batch) error {
 			frame.writeBytes(val)
 		}
 	}
-	frame.writeShort(uint16(batch.Cons))
+	frame.writeConsistency(batch.Cons)
 
 	frame, err := c.call(frame)
 	if err != nil {
@@ -289,6 +290,7 @@ func (c *Conn) Close() {
 func (c *Conn) executeQuery(query *Query) (frame, error) {
 	var info *queryInfo
 	if len(query.Args) > 0 {
+		fmt.Println("ARGS:", query.Args)
 		var err error
 		info, err = c.prepareStatement(query.Stmt)
 		if err != nil {
@@ -296,7 +298,7 @@ func (c *Conn) executeQuery(query *Query) (frame, error) {
 		}
 	}
 
-	frame := make(frame, headerSize, headerSize+512)
+	frame := make(frame, headerSize, defaultFrameSize)
 	if info == nil {
 		frame.setHeader(protoRequest, 0, 0, opQuery)
 		frame.writeLongString(query.Stmt)
@@ -304,7 +306,7 @@ func (c *Conn) executeQuery(query *Query) (frame, error) {
 		frame.setHeader(protoRequest, 0, 0, opExecute)
 		frame.writeShortBytes(info.id)
 	}
-	frame.writeShort(uint16(query.Cons))
+	frame.writeConsistency(query.Cons)
 	flags := uint8(0)
 	if len(query.Args) > 0 {
 		flags |= flagQueryValues
@@ -320,7 +322,6 @@ func (c *Conn) executeQuery(query *Query) (frame, error) {
 			frame.writeBytes(val)
 		}
 	}
-	frame.setLength(len(frame) - headerSize)
 
 	frame, err := c.call(frame)
 	if err != nil {

+ 17 - 0
frame.go

@@ -269,3 +269,20 @@ func (f *frame) readErrorFrame() (err error) {
 	desc := f.readString()
 	return Error{code, desc}
 }
+
+func (f *frame) writeConsistency(c Consistency) {
+	f.writeShort(consistencyCodes[c])
+}
+
+var consistencyCodes = []uint16{
+	Any:         0x0000,
+	One:         0x0001,
+	Two:         0x0002,
+	Three:       0x0003,
+	Quorum:      0x0004,
+	All:         0x0005,
+	LocalQuorum: 0x0006,
+	EachQuorum:  0x0007,
+	Serial:      0x0008,
+	LocalSerial: 0x0009,
+}

+ 0 - 23
gocql.go

@@ -19,14 +19,6 @@ type ColumnInfo struct {
 	TypeInfo *TypeInfo
 }
 
-type BatchType int
-
-const (
-	LoggedBatch   BatchType = 0
-	UnloggedBatch BatchType = 1
-	CounterBatch  BatchType = 2
-)
-
 /*
 type Batch struct {
 	queries []*Query
@@ -47,21 +39,6 @@ func (b *Batch) Apply() error {
 	return nil
 } */
 
-type Consistency uint16
-
-const (
-	Any Consistency = iota
-	One
-	Two
-	Three
-	Quorum
-	All
-	LocalQuorum
-	EachQuorum
-	Serial
-	LocalSerial
-)
-
 type Error struct {
 	Code    int
 	Message string

+ 8 - 16
gocql_test.go

@@ -5,6 +5,7 @@
 package gocql
 
 import (
+	"fmt"
 	"io"
 	"net"
 	"strings"
@@ -94,6 +95,7 @@ func (srv *TestServer) process(frame frame, conn net.Conn) {
 			frame.writeInt(0)
 		}
 	default:
+		fmt.Println("unsupproted:", frame)
 		frame = frame[:headerSize]
 		frame.setHeader(protoResponse, 0, frame[2], opError)
 		frame.writeInt(0)
@@ -123,10 +125,8 @@ func TestSimple(t *testing.T) {
 	srv := NewTestServer(t)
 	defer srv.Stop()
 
-	db := NewSession(Config{
-		Nodes:       []string{srv.Address},
-		Consistency: Quorum,
-	})
+	db := NewCluster(srv.Address).CreateSession()
+
 	if err := db.Query("void").Exec(); err != nil {
 		t.Error(err)
 	}
@@ -136,10 +136,7 @@ func TestTimeout(t *testing.T) {
 	srv := NewTestServer(t)
 	defer srv.Stop()
 
-	db := NewSession(Config{
-		Nodes:       []string{srv.Address},
-		Consistency: Quorum,
-	})
+	db := NewCluster(srv.Address).CreateSession()
 
 	go func() {
 		<-time.After(1 * time.Second)
@@ -155,10 +152,7 @@ func TestSlowQuery(t *testing.T) {
 	srv := NewTestServer(t)
 	defer srv.Stop()
 
-	db := NewSession(Config{
-		Nodes:       []string{srv.Address},
-		Consistency: Quorum,
-	})
+	db := NewCluster(srv.Address).CreateSession()
 
 	if err := db.Query("slow").Exec(); err != nil {
 		t.Fatal(err)
@@ -173,10 +167,8 @@ func TestRoundRobin(t *testing.T) {
 		addrs[i] = servers[i].Address
 		defer servers[i].Stop()
 	}
-	db := NewSession(Config{
-		Nodes:       addrs,
-		Consistency: Quorum,
-	})
+	db := NewCluster(addrs...).CreateSession()
+
 	time.Sleep(1 * time.Second)
 
 	var wg sync.WaitGroup

+ 83 - 75
session.go

@@ -6,112 +6,74 @@ package gocql
 
 import (
 	"errors"
-	"fmt"
-	"strings"
-	"time"
 )
 
-type Config struct {
-	Nodes       []string
-	CQLVersion  string
-	Keyspace    string
-	Consistency Consistency
-	DefaultPort int
-	Timeout     time.Duration
-	NodePicker  NodePicker
-	Reconnector Reconnector
-}
-
-func (c *Config) normalize() {
-	if c.CQLVersion == "" {
-		c.CQLVersion = "3.0.0"
-	}
-	if c.DefaultPort == 0 {
-		c.DefaultPort = 9042
-	}
-	if c.Timeout <= 0 {
-		c.Timeout = 200 * time.Millisecond
-	}
-	if c.NodePicker == nil {
-		c.NodePicker = NewRoundRobinPicker()
-	}
-	if c.Reconnector == nil {
-		c.Reconnector = NewExponentialReconnector(1*time.Second, 10*time.Minute)
-	}
-	for i := 0; i < len(c.Nodes); i++ {
-		c.Nodes[i] = strings.TrimSpace(c.Nodes[i])
-		if strings.IndexByte(c.Nodes[i], ':') < 0 {
-			c.Nodes[i] = fmt.Sprintf("%s:%d", c.Nodes[i], c.DefaultPort)
-		}
-	}
+// Session is the interface used by users to interact with the database.
+//
+// It extends the Node interface by adding a convinient query builder and
+// automatically sets a default consinstency level on all operations
+// that do not have a consistency level set.
+type Session struct {
+	Node Node
+	Cons Consistency
 }
 
-type Session struct {
-	cfg         *Config
-	pool        NodePicker
-	reconnector Reconnector
-	keyspace    string
-	nohosts     chan bool
-}
-
-func NewSession(cfg Config) *Session {
-	cfg.normalize()
-	s := &Session{
-		cfg:         &cfg,
-		nohosts:     make(chan bool),
-		reconnector: cfg.Reconnector,
-		pool:        cfg.NodePicker,
+// NewSession wraps an existing Node.
+func NewSession(node Node) *Session {
+	if s, ok := node.(*Session); ok {
+		return &Session{Node: s.Node}
 	}
-	for _, address := range cfg.Nodes {
-		go s.reconnector.Reconnect(s, address)
-	}
-	return s
+	return &Session{Node: node, Cons: Quorum}
 }
 
+// Query can be used to build new queries that should be executed on this
+// session.
 func (s *Session) Query(stmt string, args ...interface{}) QueryBuilder {
-	return QueryBuilder{
-		&Query{
-			Stmt: stmt,
-			Args: args,
-			Cons: s.cfg.Consistency,
-		},
-		s,
-	}
+	return QueryBuilder{NewQuery(stmt, args...), s}
 }
 
+// Do can be used to modify a copy of an existing query before it is
+// executed on this session.
 func (s *Session) Do(qry *Query) QueryBuilder {
 	q := *qry
 	return QueryBuilder{&q, s}
 }
 
+// Close closes all connections. The session is unuseable after this
+// operation.
 func (s *Session) Close() {
-	return
+	s.Node.Close()
 }
 
+// ExecuteBatch executes a Batch on the underlying Node.
 func (s *Session) ExecuteBatch(batch *Batch) error {
-	return nil
+	if batch.Cons == 0 {
+		batch.Cons = s.Cons
+	}
+	return s.Node.ExecuteBatch(batch)
 }
 
+// ExecuteQuery executes a Query on the underlying Node.
 func (s *Session) ExecuteQuery(qry *Query) (*Iter, error) {
-	node := s.pool.Pick(qry)
-	if node == nil {
-		<-time.After(s.cfg.Timeout)
-		node = s.pool.Pick(qry)
+	if qry.Cons == 0 {
+		qry.Cons = s.Cons
 	}
-	if node == nil {
-		return nil, ErrNoHostAvailable
-	}
-	return node.ExecuteQuery(qry)
+	return s.Node.ExecuteQuery(qry)
 }
 
 type Query struct {
 	Stmt     string
 	Args     []interface{}
 	Cons     Consistency
+	Token    string
 	PageSize int
 	Trace    bool
 }
 
+func NewQuery(stmt string, args ...interface{}) *Query {
+	return &Query{Stmt: stmt, Args: args}
+}
+
 type QueryBuilder struct {
 	qry *Query
 	ctx Node
@@ -126,6 +88,11 @@ func (b QueryBuilder) Consistency(cons Consistency) QueryBuilder {
 	return b
 }
 
+func (b QueryBuilder) Token(token string) QueryBuilder {
+	b.qry.Token = token
+	return b
+}
+
 func (b QueryBuilder) Trace(trace bool) QueryBuilder {
 	b.qry.Trace = trace
 	return b
@@ -220,15 +187,56 @@ type Batch struct {
 	Cons    Consistency
 }
 
-func NewBatch(typ BatchType, cons Consistency) *Batch {
-	return &Batch{Type: typ, Cons: cons}
+func NewBatch(typ BatchType) *Batch {
+	return &Batch{Type: typ}
 }
 
 func (b *Batch) Query(stmt string, args ...interface{}) {
 	b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args})
 }
 
+type BatchType int
+
+const (
+	LoggedBatch   BatchType = 0
+	UnloggedBatch BatchType = 1
+	CounterBatch  BatchType = 2
+)
+
 type BatchEntry struct {
 	Stmt string
 	Args []interface{}
 }
+
+type Consistency int
+
+const (
+	Any Consistency = 1 + iota
+	One
+	Two
+	Three
+	Quorum
+	All
+	LocalQuorum
+	EachQuorum
+	Serial
+	LocalSerial
+)
+
+var consinstencyNames = []string{
+	0:           "default",
+	Any:         "any",
+	One:         "one",
+	Two:         "two",
+	Three:       "three",
+	Quorum:      "quorum",
+	All:         "all",
+	LocalQuorum: "localquorum",
+	EachQuorum:  "eachquorum",
+	Serial:      "serial",
+	LocalSerial: "localserial",
+}
+
+func (c Consistency) String() string {
+	return consinstencyNames[c]
+}

+ 28 - 61
topology.go

@@ -3,7 +3,6 @@ package gocql
 import (
 	"sync"
 	"sync/atomic"
-	"time"
 )
 
 type Node interface {
@@ -12,29 +11,23 @@ type Node interface {
 	Close()
 }
 
-type NodePicker interface {
-	AddNode(node Node)
-	RemoveNode(node Node)
-	Pick(qry *Query) Node
-}
-
-type RoundRobinPicker struct {
+type RoundRobin struct {
 	pool []Node
 	pos  uint32
 	mu   sync.RWMutex
 }
 
-func NewRoundRobinPicker() *RoundRobinPicker {
-	return &RoundRobinPicker{}
+func NewRoundRobin() *RoundRobin {
+	return &RoundRobin{}
 }
 
-func (r *RoundRobinPicker) AddNode(node Node) {
+func (r *RoundRobin) AddNode(node Node) {
 	r.mu.Lock()
 	r.pool = append(r.pool, node)
 	r.mu.Unlock()
 }
 
-func (r *RoundRobinPicker) RemoveNode(node Node) {
+func (r *RoundRobin) RemoveNode(node Node) {
 	r.mu.Lock()
 	n := len(r.pool)
 	for i := 0; i < n; i++ {
@@ -47,7 +40,23 @@ func (r *RoundRobinPicker) RemoveNode(node Node) {
 	r.mu.Unlock()
 }
 
-func (r *RoundRobinPicker) Pick(query *Query) Node {
+func (r *RoundRobin) ExecuteQuery(qry *Query) (*Iter, error) {
+	node := r.pick()
+	if node == nil {
+		return nil, ErrNoHostAvailable
+	}
+	return node.ExecuteQuery(qry)
+}
+
+func (r *RoundRobin) ExecuteBatch(batch *Batch) error {
+	node := r.pick()
+	if node == nil {
+		return ErrNoHostAvailable
+	}
+	return node.ExecuteBatch(batch)
+}
+
+func (r *RoundRobin) pick() Node {
 	pos := atomic.AddUint32(&r.pos, 1)
 	var node Node
 	r.mu.RLock()
@@ -58,53 +67,11 @@ func (r *RoundRobinPicker) Pick(query *Query) Node {
 	return node
 }
 
-type Reconnector interface {
-	Reconnect(session *Session, address string)
-}
-
-type ExponentialReconnector struct {
-	baseDelay time.Duration
-	maxDelay  time.Duration
-}
-
-func NewExponentialReconnector(baseDelay, maxDelay time.Duration) *ExponentialReconnector {
-	return &ExponentialReconnector{baseDelay, maxDelay}
-}
-
-func (e *ExponentialReconnector) Reconnect(session *Session, address string) {
-	delay := e.baseDelay
-	for {
-		conn, err := Connect(address, session.cfg)
-		if err != nil {
-			<-time.After(delay)
-			if delay *= 2; delay > e.maxDelay {
-				delay = e.maxDelay
-			}
-			continue
-		}
-		node := &Host{conn}
-		go func() {
-			conn.Serve()
-			session.pool.RemoveNode(node)
-			e.Reconnect(session, address)
-		}()
-		session.pool.AddNode(node)
-		return
+func (r *RoundRobin) Close() {
+	r.mu.Lock()
+	for i := 0; i < len(r.pool); i++ {
+		r.pool[i].Close()
 	}
-}
-
-type Host struct {
-	conn *Conn
-}
-
-func (h *Host) ExecuteQuery(qry *Query) (*Iter, error) {
-	return h.conn.ExecuteQuery(qry)
-}
-
-func (h *Host) ExecuteBatch(batch *Batch) error {
-	return nil
-}
-
-func (h *Host) Close() {
-	h.conn.conn.Close()
+	r.pool = nil
+	r.mu.Unlock()
 }