Browse Source

added gocql_test app

Christoph Hack 12 years ago
parent
commit
f8bac0e311
7 changed files with 260 additions and 63 deletions
  1. 70 26
      cluster.go
  2. 16 4
      conn.go
  3. 0 31
      gocql.go
  4. 0 2
      gocql_test.go
  5. 142 0
      gocql_test/main.go
  6. 8 0
      session.go
  7. 24 0
      uuid/uuid.go

+ 70 - 26
cluster.go

@@ -26,11 +26,6 @@ type Cluster struct {
 	ConnPerHost int
 	DelayMin    time.Duration
 	DelayMax    time.Duration
-
-	pool     *RoundRobin
-	initOnce sync.Once
-	boot     sync.WaitGroup
-	bootOnce sync.Once
 }
 
 func NewCluster(hosts ...string) *Cluster {
@@ -39,44 +34,93 @@ func NewCluster(hosts ...string) *Cluster {
 		CQLVersion:  "3.0.0",
 		Timeout:     200 * time.Millisecond,
 		DefaultPort: 9042,
+		ConnPerHost: 2,
 	}
 	return c
 }
 
-func (c *Cluster) init() {
-	for i := 0; i < len(c.Hosts); i++ {
-		addr := strings.TrimSpace(c.Hosts[i])
+func (c *Cluster) CreateSession() *Session {
+	return NewSession(newClusterNode(c))
+}
+
+type clusterNode struct {
+	cfg      Cluster
+	hostPool *RoundRobin
+	connPool map[string]*RoundRobin
+	closed   bool
+	mu       sync.Mutex
+}
+
+func newClusterNode(cfg *Cluster) *clusterNode {
+	c := &clusterNode{
+		cfg:      *cfg,
+		hostPool: NewRoundRobin(),
+		connPool: make(map[string]*RoundRobin),
+	}
+	for i := 0; i < len(c.cfg.Hosts); i++ {
+		addr := strings.TrimSpace(c.cfg.Hosts[i])
 		if strings.IndexByte(addr, ':') < 0 {
-			addr = fmt.Sprintf("%s:%d", addr, c.DefaultPort)
+			addr = fmt.Sprintf("%s:%d", addr, c.cfg.DefaultPort)
+		}
+		for j := 0; j < c.cfg.ConnPerHost; j++ {
+			go c.connect(addr)
 		}
-		go c.connect(addr)
 	}
-	c.pool = NewRoundRobin()
-	<-time.After(c.Timeout)
+	<-time.After(c.cfg.Timeout)
+	return c
 }
 
-func (c *Cluster) connect(addr string) {
-	delay := c.DelayMin
+func (c *clusterNode) connect(addr string) {
+	delay := c.cfg.DelayMin
 	for {
-		conn, err := Connect(addr, c.CQLVersion, c.Timeout)
+		conn, err := Connect(addr, c.cfg.CQLVersion, c.cfg.Timeout)
 		if err != nil {
+			fmt.Println(err)
 			<-time.After(delay)
-			if delay *= 2; delay > c.DelayMax {
-				delay = c.DelayMax
+			if delay *= 2; delay > c.cfg.DelayMax {
+				delay = c.cfg.DelayMax
 			}
 			continue
 		}
-		c.pool.AddNode(conn)
-		go func() {
-			conn.Serve()
-			c.pool.RemoveNode(conn)
-			c.connect(addr)
-		}()
+		c.addConn(addr, conn)
 		return
 	}
 }
 
-func (c *Cluster) CreateSession() *Session {
-	c.initOnce.Do(c.init)
-	return NewSession(c.pool)
+func (c *clusterNode) addConn(addr string, conn *Conn) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	connPool := c.connPool[addr]
+	if connPool == nil {
+		connPool = NewRoundRobin()
+		c.connPool[addr] = connPool
+		c.hostPool.AddNode(connPool)
+	}
+	connPool.AddNode(conn)
+	go func() {
+		conn.Serve()
+		c.removeConn(addr, conn)
+	}()
+}
+
+func (c *clusterNode) removeConn(addr string, conn *Conn) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	pool := c.connPool[addr]
+	if pool == nil {
+		return
+	}
+	pool.RemoveNode(conn)
+}
+
+func (c *clusterNode) ExecuteQuery(qry *Query) (*Iter, error) {
+	return c.hostPool.ExecuteQuery(qry)
+}
+
+func (c *clusterNode) ExecuteBatch(batch *Batch) error {
+	return c.hostPool.ExecuteBatch(batch)
+}
+
+func (c *clusterNode) Close() {
+	c.hostPool.Close()
 }

+ 16 - 4
conn.go

@@ -5,7 +5,6 @@
 package gocql
 
 import (
-	"fmt"
 	"net"
 	"sync"
 	"sync/atomic"
@@ -25,8 +24,9 @@ type Conn struct {
 	calls []callReq
 	nwait int32
 
-	prepMu sync.Mutex
-	prep   map[string]*queryInfo
+	prepMu   sync.Mutex
+	prep     map[string]*queryInfo
+	keyspace string
 }
 
 // Connect establishes a connection to a Cassandra node.
@@ -226,7 +226,20 @@ func (c *Conn) prepareStatement(stmt string) (*queryInfo, error) {
 	return info, nil
 }
 
+func (c *Conn) switchKeyspace(keyspace string) error {
+	if keyspace == "" || c.keyspace == keyspace {
+		return nil
+	}
+	if _, err := c.ExecuteQuery(&Query{Stmt: "USE " + keyspace}); err != nil {
+		return err
+	}
+	return nil
+}
+
 func (c *Conn) ExecuteQuery(qry *Query) (*Iter, error) {
+	if err := c.switchKeyspace(qry.Keyspace); err != nil {
+		return nil, err
+	}
 	frame, err := c.executeQuery(qry)
 	if err != nil {
 		return nil, err
@@ -290,7 +303,6 @@ func (c *Conn) Close() {
 func (c *Conn) executeQuery(query *Query) (frame, error) {
 	var info *queryInfo
 	if len(query.Args) > 0 {
-		fmt.Println("ARGS:", query.Args)
 		var err error
 		info, err = c.prepareStatement(query.Stmt)
 		if err != nil {

+ 0 - 31
gocql.go

@@ -8,37 +8,6 @@ import (
 	"errors"
 )
 
-type queryContext interface {
-	executeQuery(query *Query) (frame, error)
-}
-
-type ColumnInfo struct {
-	Keyspace string
-	Table    string
-	Name     string
-	TypeInfo *TypeInfo
-}
-
-/*
-type Batch struct {
-	queries []*Query
-	ctx     queryContext
-	cons    Consistency
-}
-
-func (b *Batch) Query(stmt string, args ...interface{}) *Query {
-	return &Query{
-		stmt: stmt,
-		args: args,
-		cons: b.cons,
-		//ctx:  b,
-	}
-}
-
-func (b *Batch) Apply() error {
-	return nil
-} */
-
 type Error struct {
 	Code    int
 	Message string

+ 0 - 2
gocql_test.go

@@ -5,7 +5,6 @@
 package gocql
 
 import (
-	"fmt"
 	"io"
 	"net"
 	"strings"
@@ -95,7 +94,6 @@ func (srv *TestServer) process(frame frame, conn net.Conn) {
 			frame.writeInt(0)
 		}
 	default:
-		fmt.Println("unsupproted:", frame)
 		frame = frame[:headerSize]
 		frame.setHeader(protoResponse, 0, frame[2], opError)
 		frame.writeInt(0)

+ 142 - 0
gocql_test/main.go

@@ -0,0 +1,142 @@
+package main
+
+import (
+	"fmt"
+	"log"
+	"reflect"
+	"sort"
+	"time"
+
+	"github.com/tux21b/gocql"
+	"github.com/tux21b/gocql/uuid"
+)
+
+var session *gocql.Session
+
+func init() {
+	cluster := gocql.NewCluster("127.0.0.1")
+	cluster.ConnPerHost = 1
+	session = cluster.CreateSession()
+}
+
+type Page struct {
+	Title       string
+	RevId       uuid.UUID
+	Body        string
+	Views       int64
+	Protected   bool
+	Modified    time.Time
+	Tags        []string
+	Attachments map[string]Attachment
+}
+
+type Attachment []byte
+
+func initSchema() error {
+	session.Query("DROP KEYSPACE gocql_test").Exec()
+
+	if err := session.Query(`CREATE KEYSPACE gocql_test
+		WITH replication = {
+			'class' : 'SimpleStrategy',
+			'replication_factor' : 1
+		}`).Exec(); err != nil {
+		return err
+	}
+
+	if err := session.Query("USE gocql_test").Exec(); err != nil {
+		return err
+	}
+
+	if err := session.Query(`CREATE TABLE page (
+			title       varchar,
+			revid       timeuuid,
+			body        varchar,
+			views       bigint,
+			protected   boolean,
+			modified    timestamp,
+			tags        set<varchar>,
+			attachments map<varchar, text>,
+			PRIMARY KEY (title, revid)
+		)`).Exec(); err != nil {
+		fmt.Println("create err")
+		return err
+	}
+
+	if err := session.Query(`CREATE TABLE page_stats (
+			title varchar,
+			views counter,
+			PRIMARY KEY (title)
+		)`).Exec(); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+var pageTestData = []*Page{
+	&Page{
+		Title:    "Frontpage",
+		RevId:    uuid.TimeUUID(),
+		Body:     "Welcome to this wiki page!",
+		Modified: time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC),
+		Tags:     []string{"start", "important", "test"},
+		Attachments: map[string]Attachment{
+			"logo":    Attachment("\x00company logo\x00"),
+			"favicon": Attachment("favicon.ico"),
+		},
+	},
+}
+
+func insertTestData() error {
+	for _, page := range pageTestData {
+		if err := session.Query(`INSERT INTO page
+			(title, revid, body, views, protected, modified, tags, attachments)
+			VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
+			page.Title, page.RevId, page.Body, page.Views, page.Protected,
+			page.Modified, page.Tags, page.Attachments).Exec(); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func getPage(title string, revid uuid.UUID) (*Page, error) {
+	p := new(Page)
+	err := session.Query(`SELECT title, revid, body, views, protected, modified,
+		tags, attachments
+		FROM page WHERE title = ? AND revid = ?`, title, revid).Scan(
+		&p.Title, &p.RevId, &p.Body, &p.Views, &p.Protected, &p.Modified,
+		&p.Tags, &p.Attachments)
+	return p, err
+}
+
+func main() {
+	if err := initSchema(); err != nil {
+		log.Fatal("initSchema: ", err)
+	}
+
+	if err := insertTestData(); err != nil {
+		log.Fatal("insertTestData: ", err)
+	}
+
+	var count int
+	if err := session.Query("SELECT COUNT(*) FROM page").Scan(&count); err != nil {
+		log.Fatal("getCount: ", err)
+	}
+	if count != len(pageTestData) {
+		log.Println("count: expected %d, got %d", len(pageTestData), count)
+	}
+
+	for _, original := range pageTestData {
+		page, err := getPage(original.Title, original.RevId)
+		if err != nil {
+			log.Print("getPage: ", err)
+			continue
+		}
+		sort.Sort(sort.StringSlice(page.Tags))
+		sort.Sort(sort.StringSlice(original.Tags))
+		if !reflect.DeepEqual(page, original) {
+			log.Printf("page: expected %#v, got %#v\n", original, page)
+		}
+	}
+}

+ 8 - 0
session.go

@@ -68,6 +68,7 @@ type Query struct {
 	Token    string
 	PageSize int
 	Trace    bool
+	Keyspace string
 }
 
 func NewQuery(stmt string, args ...interface{}) *Query {
@@ -240,3 +241,10 @@ var consinstencyNames = []string{
 func (c Consistency) String() string {
 	return consinstencyNames[c]
 }
+
+type ColumnInfo struct {
+	Keyspace string
+	Table    string
+	Name     string
+	TypeInfo *TypeInfo
+}

+ 24 - 0
uuid/uuid.go

@@ -14,6 +14,8 @@ import (
 	"io"
 	"net"
 	"time"
+
+	"github.com/tux21b/gocql"
 )
 
 type UUID [16]byte
@@ -191,3 +193,25 @@ func (u UUID) Time() time.Time {
 	nsec := t - sec
 	return time.Unix(int64(sec)+timeBase, int64(nsec))
 }
+
+func (u UUID) MarshalCQL(info *gocql.TypeInfo) ([]byte, error) {
+	switch info.Type {
+	case gocql.TypeUUID, gocql.TypeTimeUUID:
+		return u[:], nil
+	}
+	return gocql.Marshal(info, u[:])
+}
+
+func (u *UUID) UnmarshalCQL(info *gocql.TypeInfo, data []byte) error {
+	switch info.Type {
+	case gocql.TypeUUID, gocql.TypeTimeUUID:
+		*u = FromBytes(data)
+		return nil
+	}
+	var val []byte
+	if err := gocql.Unmarshal(info, data, &val); err != nil {
+		return err
+	}
+	*u = FromBytes(val)
+	return nil
+}