소스 검색

published first draft of the v2 rewrite

Christoph Hack 12 년 전
부모
커밋
c9fd6df0c6
8개의 변경된 파일891개의 추가작업 그리고 991개의 파일을 삭제
  1. 48 40
      README.md
  2. 247 0
      binary.go
  3. 176 0
      conn.go
  4. 0 240
      convert.go
  5. 9 0
      doc.go
  6. 139 542
      gocql.go
  7. 68 169
      gocql_test.go
  8. 204 0
      marshal.go

+ 48 - 40
README.md

@@ -1,24 +1,22 @@
 gocql
 =====
 
-The gocql package provides a database/sql driver for CQL, the Cassandra
-query language.
-
-This package requires a recent version of Cassandra (≥ 1.2) that supports
-CQL 3.0 and the new native protocol. The native protocol is still considered
-beta and must be enabled manually in Cassandra 1.2 by setting
-"start_native_transport" to true in conf/cassandra.yaml.
-
-**Note:** gocql requires the tip version of Go, as some changes in the 
-`database/sql` have not made it into 1.0.x yet. There is 
-[a fork](https://github.com/titanous/gocql) that backports these changes 
-to Go 1.0.3.
+Package gocql implements a fast and robust Cassandra driver for the
+Go programming language.
 
 Installation
 ------------
 
     go get github.com/tux21b/gocql
 
+
+Features
+--------
+
+* Modern Cassandra client for Cassandra 2.0
+* Built-In support for UUIDs (version 1 and 4)
+
+
 Example
 -------
 
@@ -26,48 +24,58 @@ Example
 package main
 
 import (
-	"database/sql"
 	"fmt"
-	_ "github.com/tux21b/gocql"
+	"github.com/tux21b/gocql"
+	"log"
 )
 
 func main() {
-	db, err := sql.Open("gocql", "localhost:9042 keyspace=system")
-	if err != nil {
-		fmt.Println("Open error:", err)
+	// connect to your cluster
+	db := gocql.NewSession(gocql.Config{
+		Nodes: []string{
+			"192.168.1.1",
+			"192.168.1.2",
+			"192.168.1.3",
+		},
+		Keyspace:    "example",       // (optional)
+		Consistency: gocql.ConQuorum, // (optional)
+	})
+	defer db.Close()
+
+	// simple query
+	var title, text string
+	if err := db.Query("SELECT title, text FROM posts WHERE title = ?",
+		"Lorem Ipsum").Scan(&title, &text); err != nil {
+		log.Fatal(err)
 	}
+	fmt.Println(title, text)
 
-	rows, err := db.Query("SELECT keyspace_name FROM schema_keyspaces")
-	if err != nil {
-		fmt.Println("Query error:", err)
+	// iterator example
+	var titles []string
+	iter := db.Query("SELECT title FROM posts").Iter()
+	for iter.Scan(&title) {
+		titles = append(titles, title)
+	}
+	if err := iter.Close(); err != nil {
+		log.Fatal(err)
 	}
+	fmt.Println(titles)
 
-	for rows.Next() {
-		var keyspace string
-		err = rows.Scan(&keyspace)
-		if err != nil {
-			fmt.Println("Scan error:", err)
-		}
-		fmt.Println(keyspace)
+	// insertion example (with custom consistency level)
+	if err := db.Query("INSERT INTO posts (title, text) VALUES (?, ?)",
+		"New Title", "foobar").Consistency(gocql.ConAny).Exec(); err != nil {
+		log.Fatal(err)
 	}
 
-	if err = rows.Err(); err != nil {
-		fmt.Println("Iteration error:", err)
-		return
+	// prepared queries
+	query := gocql.NewQuery("SELECT text FROM posts WHERE title = ?")
+	if err := db.Do(query, "New Title").Scan(&text); err != nil {
+		log.Fatal(err)
 	}
+	fmt.Println(text)
 }
 ```
 
-Please see `gocql_test.go` for some more advanced examples.
-
-Features
---------
-
-* Modern Cassandra client that is based on Cassandra's new native protocol
-* Compatible with Go's `database/sql` package
-* Built-In support for UUIDs (version 1 and 4)
-* Optional frame compression (using snappy)
-
 License
 -------
 

+ 247 - 0
binary.go

@@ -0,0 +1,247 @@
+package gocql
+
+import (
+	"errors"
+	"net"
+)
+
+const (
+	protoRequest  byte = 0x02
+	protoResponse byte = 0x82
+
+	opError         byte = 0x00
+	opStartup       byte = 0x01
+	opReady         byte = 0x02
+	opAuthenticate  byte = 0x03
+	opOptions       byte = 0x05
+	opSupported     byte = 0x06
+	opQuery         byte = 0x07
+	opResult        byte = 0x08
+	opPrepare       byte = 0x09
+	opExecute       byte = 0x0A
+	opRegister      byte = 0x0B
+	opEvent         byte = 0x0C
+	opBatch         byte = 0x0D
+	opAuthChallenge byte = 0x0E
+	opAuthResponse  byte = 0x0F
+	opAuthSuccess   byte = 0x10
+
+	resultKindVoid          = 1
+	resultKindRows          = 2
+	resultKindKeyspace      = 3
+	resultKindPrepared      = 4
+	resultKindSchemaChanged = 5
+
+	flagQueryValues uint8 = 1
+
+	headerSize = 8
+)
+
+var ErrInvalid = errors.New("invalid response")
+
+type buffer []byte
+
+func (b *buffer) writeInt(v int32) {
+	p := b.grow(4)
+	(*b)[p] = byte(v >> 24)
+	(*b)[p+1] = byte(v >> 16)
+	(*b)[p+2] = byte(v >> 8)
+	(*b)[p+3] = byte(v)
+}
+
+func (b *buffer) writeShort(v uint16) {
+	p := b.grow(2)
+	(*b)[p] = byte(v >> 8)
+	(*b)[p+1] = byte(v)
+}
+
+func (b *buffer) writeString(v string) {
+	b.writeShort(uint16(len(v)))
+	p := b.grow(len(v))
+	copy((*b)[p:], v)
+}
+
+func (b *buffer) writeLongString(v string) {
+	b.writeInt(int32(len(v)))
+	p := b.grow(len(v))
+	copy((*b)[p:], v)
+}
+
+func (b *buffer) writeUUID() {
+}
+
+func (b *buffer) writeStringList(v []string) {
+	b.writeShort(uint16(len(v)))
+	for i := range v {
+		b.writeString(v[i])
+	}
+}
+
+func (b *buffer) writeByte(v byte) {
+	p := b.grow(1)
+	(*b)[p] = v
+}
+
+func (b *buffer) writeBytes(v []byte) {
+	if v == nil {
+		b.writeInt(-1)
+		return
+	}
+	b.writeInt(int32(len(v)))
+	p := b.grow(len(v))
+	copy((*b)[p:], v)
+}
+
+func (b *buffer) writeShortBytes(v []byte) {
+	b.writeShort(uint16(len(v)))
+	p := b.grow(len(v))
+	copy((*b)[p:], v)
+}
+
+func (b *buffer) writeInet(ip net.IP, port int) {
+	p := b.grow(1 + len(ip))
+	(*b)[p] = byte(len(ip))
+	copy((*b)[p+1:], ip)
+	b.writeInt(int32(port))
+}
+
+func (b *buffer) writeConsistency() {
+}
+
+func (b *buffer) writeStringMap(v map[string]string) {
+	b.writeShort(uint16(len(v)))
+	for key, value := range v {
+		b.writeString(key)
+		b.writeString(value)
+	}
+}
+
+func (b *buffer) writeStringMultimap(v map[string][]string) {
+	b.writeShort(uint16(len(v)))
+	for key, values := range v {
+		b.writeString(key)
+		b.writeStringList(values)
+	}
+}
+
+func (b *buffer) setHeader(version, flags, stream, opcode uint8) {
+	(*b)[0] = version
+	(*b)[1] = flags
+	(*b)[2] = stream
+	(*b)[3] = opcode
+}
+
+func (b *buffer) setLength(length int) {
+	(*b)[4] = byte(length >> 24)
+	(*b)[5] = byte(length >> 16)
+	(*b)[6] = byte(length >> 8)
+	(*b)[7] = byte(length)
+}
+
+func (b *buffer) Length() int {
+	return int((*b)[4])<<24 | int((*b)[5])<<16 | int((*b)[6])<<8 | int((*b)[7])
+}
+
+func (b *buffer) grow(n int) int {
+	if len(*b)+n >= cap(*b) {
+		buf := make(buffer, len(*b), len(*b)*2+n)
+		copy(buf, *b)
+		*b = buf
+	}
+	p := len(*b)
+	*b = (*b)[:p+n]
+	return p
+}
+
+func (b *buffer) skipHeader() {
+	*b = (*b)[headerSize:]
+}
+
+func (b *buffer) readInt() int {
+	if len(*b) < 4 {
+		panic(ErrInvalid)
+	}
+	v := int((*b)[0])<<24 | int((*b)[1])<<16 | int((*b)[2])<<8 | int((*b)[3])
+	*b = (*b)[4:]
+	return v
+}
+
+func (b *buffer) readShort() uint16 {
+	if len(*b) < 2 {
+		panic(ErrInvalid)
+	}
+	v := uint16((*b)[0])<<8 | uint16((*b)[1])
+	*b = (*b)[2:]
+	return v
+}
+
+func (b *buffer) readString() string {
+	n := int(b.readShort())
+	if len(*b) < n {
+		panic(ErrInvalid)
+	}
+	v := string((*b)[:n])
+	*b = (*b)[n:]
+	return v
+}
+
+func (b *buffer) readBytes() []byte {
+	n := b.readInt()
+	if n < 0 {
+		return nil
+	}
+	if len(*b) < n {
+		panic(ErrInvalid)
+	}
+	v := (*b)[:n]
+	*b = (*b)[n:]
+	return v
+}
+
+func (b *buffer) readShortBytes() []byte {
+	n := int(b.readShort())
+	if len(*b) < n {
+		panic(ErrInvalid)
+	}
+	v := (*b)[:n]
+	*b = (*b)[n:]
+	return v
+}
+
+func (b *buffer) readTypeInfo() *TypeInfo {
+	x := b.readShort()
+	typ := &TypeInfo{Type: Type(x)}
+	switch typ.Type {
+	case TypeCustom:
+		typ.Custom = b.readString()
+	case TypeMap:
+		typ.Key = b.readTypeInfo()
+		fallthrough
+	case TypeList, TypeSet:
+		typ.Value = b.readTypeInfo()
+	}
+	return typ
+}
+
+func (b *buffer) readMetaData() []columnInfo {
+	flags := b.readInt()
+	numColumns := b.readInt()
+	globalKeyspace := ""
+	globalTable := ""
+	if flags&1 != 0 {
+		globalKeyspace = b.readString()
+		globalTable = b.readString()
+	}
+	info := make([]columnInfo, numColumns)
+	for i := 0; i < numColumns; i++ {
+		info[i].Keyspace = globalKeyspace
+		info[i].Table = globalTable
+		if flags&1 == 0 {
+			info[i].Keyspace = b.readString()
+			info[i].Table = b.readString()
+		}
+		info[i].Name = b.readString()
+		info[i].TypeInfo = b.readTypeInfo()
+	}
+	return info
+}

+ 176 - 0
conn.go

@@ -0,0 +1,176 @@
+package gocql
+
+import (
+	"io"
+	"net"
+	"sync"
+	"sync/atomic"
+)
+
+type queryInfo struct {
+	id    []byte
+	args  []columnInfo
+	rval  []columnInfo
+	avail chan bool
+}
+
+type connection struct {
+	conn    net.Conn
+	uniq    chan uint8
+	reply   []chan buffer
+	waiting uint64
+
+	prepMu sync.Mutex
+	prep   map[string]*queryInfo
+}
+
+func connect(addr string, cfg *Config) (*connection, error) {
+	conn, err := net.Dial("tcp", addr)
+	if err != nil {
+		return nil, err
+	}
+	c := &connection{
+		conn:  conn,
+		uniq:  make(chan uint8, 64),
+		reply: make([]chan buffer, 64),
+		prep:  make(map[string]*queryInfo),
+	}
+	for i := 0; i < cap(c.uniq); i++ {
+		c.uniq <- uint8(i)
+	}
+
+	go c.recv()
+
+	frame := make(buffer, headerSize)
+	frame.setHeader(protoRequest, 0, 0, opStartup)
+	frame.writeStringMap(map[string]string{
+		"CQL_VERSION": cfg.CQLVersion,
+	})
+	frame.setLength(len(frame) - headerSize)
+
+	frame = c.request(frame)
+
+	if cfg.Keyspace != "" {
+		qry := &Query{stmt: "USE " + cfg.Keyspace}
+		frame, err = c.executeQuery(qry)
+	}
+
+	return c, nil
+}
+
+func (c *connection) recv() {
+	for {
+		frame := make(buffer, headerSize, headerSize+512)
+		if _, err := io.ReadFull(c.conn, frame); err != nil {
+			return
+		}
+		if frame[0] != protoResponse {
+			continue
+		}
+		if length := frame.Length(); length > 0 {
+			frame.grow(frame.Length())
+			io.ReadFull(c.conn, frame[headerSize:])
+		}
+		c.dispatch(frame)
+	}
+	panic("not possible")
+}
+
+func (c *connection) request(frame buffer) buffer {
+	id := <-c.uniq
+	frame[2] = id
+	c.reply[id] = make(chan buffer, 1)
+
+	for {
+		w := atomic.LoadUint64(&c.waiting)
+		if atomic.CompareAndSwapUint64(&c.waiting, w, w|(1<<id)) {
+			break
+		}
+	}
+	c.conn.Write(frame)
+	resp := <-c.reply[id]
+	c.uniq <- id
+	return resp
+}
+
+func (c *connection) dispatch(frame buffer) {
+	id := frame[2]
+	if id >= 128 {
+		return
+	}
+	for {
+		w := atomic.LoadUint64(&c.waiting)
+		if w&(1<<id) == 0 {
+			return
+		}
+		if atomic.CompareAndSwapUint64(&c.waiting, w, w&^(1<<id)) {
+			break
+		}
+	}
+	c.reply[id] <- frame
+}
+
+func (c *connection) prepareQuery(stmt string) *queryInfo {
+	c.prepMu.Lock()
+	info := c.prep[stmt]
+	if info != nil {
+		c.prepMu.Unlock()
+		<-info.avail
+		return info
+	}
+	info = &queryInfo{avail: make(chan bool)}
+	c.prep[stmt] = info
+	c.prepMu.Unlock()
+
+	frame := make(buffer, headerSize, headerSize+512)
+	frame.setHeader(protoRequest, 0, 0, opPrepare)
+	frame.writeLongString(stmt)
+	frame.setLength(len(frame) - headerSize)
+
+	frame = c.request(frame)
+	frame.skipHeader()
+	frame.readInt() // kind
+	info.id = frame.readShortBytes()
+	info.args = frame.readMetaData()
+	info.rval = frame.readMetaData()
+	close(info.avail)
+	return info
+}
+
+func (c *connection) executeQuery(query *Query) (buffer, error) {
+	var info *queryInfo
+	if len(query.args) > 0 {
+		info = c.prepareQuery(query.stmt)
+	}
+
+	frame := make(buffer, headerSize, headerSize+512)
+	frame.setHeader(protoRequest, 0, 0, opQuery)
+	frame.writeLongString(query.stmt)
+	frame.writeShort(uint16(query.cons))
+	flags := uint8(0)
+	if len(query.args) > 0 {
+		flags |= flagQueryValues
+	}
+	frame.writeByte(flags)
+	if len(query.args) > 0 {
+		frame.writeShort(uint16(len(query.args)))
+		for i := 0; i < len(query.args); i++ {
+			val, err := Marshal(info.args[i].TypeInfo, query.args[i])
+			if err != nil {
+				return nil, err
+			}
+			frame.writeBytes(val)
+		}
+	}
+	frame.setLength(len(frame) - headerSize)
+
+	frame = c.request(frame)
+
+	if frame[3] == opError {
+		frame.skipHeader()
+		code := frame.readInt()
+		desc := frame.readString()
+		return nil, Error{code, desc}
+	}
+	return frame, nil
+}

+ 0 - 240
convert.go

@@ -1,240 +0,0 @@
-// Copyright (c) 2012 The gocql Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package gocql
-
-import (
-	"database/sql/driver"
-	"encoding/binary"
-	"fmt"
-	"github.com/tux21b/gocql/uuid"
-	"math"
-	"reflect"
-	"strconv"
-	"time"
-)
-
-const (
-	typeCustom    uint16 = 0x0000
-	typeAscii     uint16 = 0x0001
-	typeBigInt    uint16 = 0x0002
-	typeBlob      uint16 = 0x0003
-	typeBool      uint16 = 0x0004
-	typeCounter   uint16 = 0x0005
-	typeDecimal   uint16 = 0x0006
-	typeDouble    uint16 = 0x0007
-	typeFloat     uint16 = 0x0008
-	typeInt       uint16 = 0x0009
-	typeText      uint16 = 0x000A
-	typeTimestamp uint16 = 0x000B
-	typeUUID      uint16 = 0x000C
-	typeVarchar   uint16 = 0x000D
-	typeVarint    uint16 = 0x000E
-	typeTimeUUID  uint16 = 0x000F
-	typeList      uint16 = 0x0020
-	typeMap       uint16 = 0x0021
-	typeSet       uint16 = 0x0022
-)
-
-func decode(b []byte, t uint16) driver.Value {
-	switch t {
-	case typeBool:
-		if len(b) >= 1 && b[0] != 0 {
-			return true
-		}
-		return false
-	case typeBlob:
-		return b
-	case typeVarchar, typeText, typeAscii:
-		return b
-	case typeInt:
-		return int64(int32(binary.BigEndian.Uint32(b)))
-	case typeBigInt:
-		return int64(binary.BigEndian.Uint64(b))
-	case typeFloat:
-		return float64(math.Float32frombits(binary.BigEndian.Uint32(b)))
-	case typeDouble:
-		return math.Float64frombits(binary.BigEndian.Uint64(b))
-	case typeTimestamp:
-		t := int64(binary.BigEndian.Uint64(b))
-		sec := t / 1000
-		nsec := (t - sec*1000) * 1000000
-		return time.Unix(sec, nsec)
-	case typeUUID, typeTimeUUID:
-		return uuid.FromBytes(b)
-	default:
-		panic("unsupported type")
-	}
-	return b
-}
-
-type columnEncoder struct {
-	columnTypes []uint16
-}
-
-func (e *columnEncoder) ColumnConverter(idx int) ValueConverter {
-	switch e.columnTypes[idx] {
-	case typeInt:
-		return ValueConverter(encInt)
-	case typeBigInt:
-		return ValueConverter(encBigInt)
-	case typeFloat:
-		return ValueConverter(encFloat)
-	case typeDouble:
-		return ValueConverter(encDouble)
-	case typeBool:
-		return ValueConverter(encBool)
-	case typeVarchar, typeText, typeAscii:
-		return ValueConverter(encVarchar)
-	case typeBlob:
-		return ValueConverter(encBlob)
-	case typeTimestamp:
-		return ValueConverter(encTimestamp)
-	case typeUUID, typeTimeUUID:
-		return ValueConverter(encUUID)
-	}
-	panic("not implemented")
-}
-
-type ValueConverter func(v interface{}) (driver.Value, error)
-
-func (vc ValueConverter) ConvertValue(v interface{}) (driver.Value, error) {
-	return vc(v)
-}
-
-func encBool(v interface{}) (driver.Value, error) {
-	b, err := driver.Bool.ConvertValue(v)
-	if err != nil {
-		return nil, err
-	}
-	if b.(bool) {
-		return []byte{1}, nil
-	}
-	return []byte{0}, nil
-}
-
-func encInt(v interface{}) (driver.Value, error) {
-	x, err := driver.Int32.ConvertValue(v)
-	if err != nil {
-		return nil, err
-	}
-	b := make([]byte, 4)
-	binary.BigEndian.PutUint32(b, uint32(x.(int64)))
-	return b, nil
-}
-
-func encBigInt(v interface{}) (driver.Value, error) {
-	x := reflect.Indirect(reflect.ValueOf(v)).Interface()
-	b := make([]byte, 8)
-	binary.BigEndian.PutUint64(b, uint64(x.(int64)))
-	return b, nil
-}
-
-func encVarchar(v interface{}) (driver.Value, error) {
-	x, err := driver.String.ConvertValue(v)
-	if err != nil {
-		return nil, err
-	}
-	return []byte(x.(string)), nil
-}
-
-func encFloat(v interface{}) (driver.Value, error) {
-	x, err := driver.DefaultParameterConverter.ConvertValue(v)
-	if err != nil {
-		return nil, err
-	}
-	var f float64
-	switch x := x.(type) {
-	case float64:
-		f = x
-	case int64:
-		f = float64(x)
-	case []byte:
-		if f, err = strconv.ParseFloat(string(x), 64); err != nil {
-			return nil, err
-		}
-	default:
-		return nil, fmt.Errorf("can not convert %T to float64", x)
-	}
-	b := make([]byte, 4)
-	binary.BigEndian.PutUint32(b, math.Float32bits(float32(f)))
-	return b, nil
-}
-
-func encDouble(v interface{}) (driver.Value, error) {
-	x, err := driver.DefaultParameterConverter.ConvertValue(v)
-	if err != nil {
-		return nil, err
-	}
-	var f float64
-	switch x := x.(type) {
-	case float64:
-		f = x
-	case int64:
-		f = float64(x)
-	case []byte:
-		if f, err = strconv.ParseFloat(string(x), 64); err != nil {
-			return nil, err
-		}
-	default:
-		return nil, fmt.Errorf("can not convert %T to float64", x)
-	}
-	b := make([]byte, 8)
-	binary.BigEndian.PutUint64(b, math.Float64bits(f))
-	return b, nil
-}
-
-func encTimestamp(v interface{}) (driver.Value, error) {
-	x, err := driver.DefaultParameterConverter.ConvertValue(v)
-	if err != nil {
-		return nil, err
-	}
-	var millis int64
-	switch x := x.(type) {
-	case time.Time:
-		x = x.In(time.UTC)
-		millis = x.UnixNano() / 1000000
-	default:
-		return nil, fmt.Errorf("can not convert %T to a timestamp", x)
-	}
-	b := make([]byte, 8)
-	binary.BigEndian.PutUint64(b, uint64(millis))
-	return b, nil
-}
-
-func encBlob(v interface{}) (driver.Value, error) {
-	x, err := driver.DefaultParameterConverter.ConvertValue(v)
-	if err != nil {
-		return nil, err
-	}
-	var b []byte
-	switch x := x.(type) {
-	case string:
-		b = []byte(x)
-	case []byte:
-		b = x
-	default:
-		return nil, fmt.Errorf("can not convert %T to a []byte", x)
-	}
-	return b, nil
-}
-
-func encUUID(v interface{}) (driver.Value, error) {
-	var u uuid.UUID
-	switch v := v.(type) {
-	case string:
-		var err error
-		u, err = uuid.ParseUUID(v)
-		if err != nil {
-			return nil, err
-		}
-	case []byte:
-		u = uuid.FromBytes(v)
-	case uuid.UUID:
-		u = v
-	default:
-		return nil, fmt.Errorf("can not convert %T to a UUID", v)
-	}
-	return u.Bytes(), nil
-}

+ 9 - 0
doc.go

@@ -0,0 +1,9 @@
+// Copyright (c) 2012 The gocql Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package gocql implements a fast and robust Cassandra driver for the
+// Go programming language.
+package gocql
+
+// TODO(tux21b): write more docs.

+ 139 - 542
gocql.go

@@ -2,616 +2,213 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// The gocql package provides a database/sql driver for CQL, the Cassandra
-// query language.
-//
-// This package requires a recent version of Cassandra (≥ 1.2) that supports
-// CQL 3.0 and the new native protocol. The native protocol is still considered
-// beta and must be enabled manually in Cassandra 1.2 by setting
-// "start_native_transport" to true in conf/cassandra.yaml.
-//
-// Example Usage:
-//
-//     db, err := sql.Open("gocql", "localhost:9042 keyspace=system")
-//     // ...
-//     rows, err := db.Query("SELECT keyspace_name FROM schema_keyspaces")
-//     // ...
-//     for rows.Next() {
-//          var keyspace string
-//          err = rows.Scan(&keyspace)
-//          // ...
-//          fmt.Println(keyspace)
-//     }
-//     if err := rows.Err(); err != nil {
-//         // ...
-//     }
-//
 package gocql
 
 import (
-	"bytes"
-	"code.google.com/p/snappy-go/snappy"
-	"database/sql"
-	"database/sql/driver"
-	"encoding/binary"
+	"errors"
 	"fmt"
-	"io"
-	"net"
 	"strings"
-	"time"
 )
 
-const (
-	protoRequest  byte = 0x01
-	protoResponse byte = 0x81
-
-	opError        byte = 0x00
-	opStartup      byte = 0x01
-	opReady        byte = 0x02
-	opAuthenticate byte = 0x03
-	opCredentials  byte = 0x04
-	opOptions      byte = 0x05
-	opSupported    byte = 0x06
-	opQuery        byte = 0x07
-	opResult       byte = 0x08
-	opPrepare      byte = 0x09
-	opExecute      byte = 0x0A
-	opLAST         byte = 0x0A // not a real opcode -- used to check for valid opcodes
-
-	flagCompressed byte = 0x01
-
-	keyVersion     string = "CQL_VERSION"
-	keyCompression string = "COMPRESSION"
-	keyspaceQuery  string = "USE "
-)
-
-var consistencyLevels = map[string]byte{"any": 0x00, "one": 0x01, "two": 0x02,
-	"three": 0x03, "quorum": 0x04, "all": 0x05, "local_quorum": 0x06, "each_quorum": 0x07}
-
-type drv struct{}
-
-func (d drv) Open(name string) (driver.Conn, error) {
-	return Open(name)
+type Config struct {
+	Nodes       []string
+	CQLVersion  string
+	Keyspace    string
+	Consistency Consistency
+	DefaultPort int
 }
 
-type connection struct {
-	c       net.Conn
-	address string
-	alive   bool
-	pool    *pool
-}
-
-type pool struct {
-	connections []*connection
-	i           int
-	keyspace    string
-	version     string
-	compression string
-	consistency byte
-	dead        bool
-	stop        chan struct{}
-}
-
-func Open(name string) (*pool, error) {
-	parts := strings.Split(name, " ")
-	var addresses []string
-	if len(parts) >= 1 {
-		addresses = strings.Split(parts[0], ",")
-	}
-
-	version := "3.0.0"
-	var (
-		keyspace    string
-		compression string
-		consistency byte = 0x01
-		ok          bool
-	)
-	for i := 1; i < len(parts); i++ {
-		switch {
-		case parts[i] == "":
-			continue
-		case strings.HasPrefix(parts[i], "keyspace="):
-			keyspace = strings.TrimSpace(parts[i][9:])
-		case strings.HasPrefix(parts[i], "compression="):
-			compression = strings.TrimSpace(parts[i][12:])
-			if compression != "snappy" {
-				return nil, fmt.Errorf("unknown compression algorithm %q",
-					compression)
-			}
-		case strings.HasPrefix(parts[i], "version="):
-			version = strings.TrimSpace(parts[i][8:])
-		case strings.HasPrefix(parts[i], "consistency="):
-			cs := strings.TrimSpace(parts[i][12:])
-			if consistency, ok = consistencyLevels[cs]; !ok {
-				return nil, fmt.Errorf("unknown consistency level %q", cs)
-			}
-		default:
-			return nil, fmt.Errorf("unsupported option %q", parts[i])
-		}
-	}
-
-	pool := &pool{
-		keyspace:    keyspace,
-		version:     version,
-		compression: compression,
-		consistency: consistency,
-		stop:        make(chan struct{}),
-	}
-
-	for _, address := range addresses {
-		pool.connections = append(pool.connections, &connection{address: address, pool: pool})
-	}
-
-	pool.join()
-
-	return pool, nil
-}
-
-func (cn *connection) open() {
-	cn.alive = false
-
-	var err error
-	cn.c, err = net.Dial("tcp", cn.address)
-	if err != nil {
-		return
-	}
-
-	var (
-		version     = cn.pool.version
-		compression = cn.pool.compression
-		keyspace    = cn.pool.keyspace
-	)
-
-	b := &bytes.Buffer{}
-
-	if compression != "" {
-		binary.Write(b, binary.BigEndian, uint16(2))
-	} else {
-		binary.Write(b, binary.BigEndian, uint16(1))
-	}
-
-	binary.Write(b, binary.BigEndian, uint16(len(keyVersion)))
-	b.WriteString(keyVersion)
-	binary.Write(b, binary.BigEndian, uint16(len(version)))
-	b.WriteString(version)
-
-	if compression != "" {
-		binary.Write(b, binary.BigEndian, uint16(len(keyCompression)))
-		b.WriteString(keyCompression)
-		binary.Write(b, binary.BigEndian, uint16(len(compression)))
-		b.WriteString(compression)
+func (c *Config) normalize() {
+	if c.CQLVersion == "" {
+		c.CQLVersion = "3.0.0"
 	}
-
-	if err := cn.sendUncompressed(opStartup, b.Bytes()); err != nil {
-		return
-	}
-
-	opcode, _, err := cn.recv()
-	if err != nil {
-		return
-	}
-	if opcode != opReady {
-		return
+	if c.DefaultPort == 0 {
+		c.DefaultPort = 9042
 	}
-
-	if keyspace != "" {
-		cn.UseKeyspace(keyspace)
-	}
-
-	cn.alive = true
-}
-
-// close a connection actively, typically used when there's an error and we want to ensure
-// we don't repeatedly try to use the broken connection
-func (cn *connection) close() {
-	cn.c.Close()
-	cn.c = nil // ensure we generate ErrBadConn when cn gets reused
-	cn.alive = false
-
-	// Check if the entire pool is dead
-	for _, cn := range cn.pool.connections {
-		if cn.alive {
-			return
+	for i := 0; i < len(c.Nodes); i++ {
+		c.Nodes[i] = strings.TrimSpace(c.Nodes[i])
+		if strings.IndexByte(c.Nodes[i], ':') < 0 {
+			c.Nodes[i] = fmt.Sprintf("%s:%d", c.Nodes[i], c.DefaultPort)
 		}
 	}
-	cn.pool.dead = false
-}
-
-// explicitly send a request as uncompressed
-// This is only really needed for the "startup" handshake
-func (cn *connection) sendUncompressed(opcode byte, body []byte) error {
-	return cn._send(opcode, body, false)
 }
 
-func (cn *connection) send(opcode byte, body []byte) error {
-	return cn._send(opcode, body, cn.pool.compression == "snappy" && len(body) > 0)
+type Session struct {
+	cfg  *Config
+	pool []*connection
 }
 
-func (cn *connection) _send(opcode byte, body []byte, compression bool) error {
-	if cn.c == nil {
-		return driver.ErrBadConn
-	}
-	var flags byte = 0x00
-	if compression {
-		var err error
-		body, err = snappy.Encode(nil, body)
-		if err != nil {
-			return err
+func NewSession(cfg Config) *Session {
+	cfg.normalize()
+	pool := make([]*connection, 0, len(cfg.Nodes))
+	for _, address := range cfg.Nodes {
+		con, err := connect(address, &cfg)
+		if err == nil {
+			pool = append(pool, con)
 		}
-		flags = flagCompressed
-	}
-	frame := make([]byte, len(body)+8)
-	frame[0] = protoRequest
-	frame[1] = flags
-	frame[2] = 0
-	frame[3] = opcode
-	binary.BigEndian.PutUint32(frame[4:8], uint32(len(body)))
-	copy(frame[8:], body)
-	if _, err := cn.c.Write(frame); err != nil {
-		return err
 	}
-	return nil
+	return &Session{cfg: &cfg, pool: pool}
 }
 
-func (cn *connection) recv() (byte, []byte, error) {
-	if cn.c == nil {
-		return 0, nil, driver.ErrBadConn
-	}
-	header := make([]byte, 8)
-	if _, err := io.ReadFull(cn.c, header); err != nil {
-		cn.close() // better assume that the connection is broken (may have read some bytes)
-		return 0, nil, err
-	}
-	// verify that the frame starts with version==1 and req/resp flag==response
-	// this may be overly conservative in that future versions may be backwards compatible
-	// in that case simply amend the check...
-	if header[0] != protoResponse {
-		cn.close()
-		return 0, nil, fmt.Errorf("unsupported frame version or not a response: 0x%x (header=%v)", header[0], header)
-	}
-	// verify that the flags field has only a single flag set, again, this may
-	// be overly conservative if additional flags are backwards-compatible
-	if header[1] > 1 {
-		cn.close()
-		return 0, nil, fmt.Errorf("unsupported frame flags: 0x%x (header=%v)", header[1], header)
-	}
-	opcode := header[3]
-	if opcode > opLAST {
-		cn.close()
-		return 0, nil, fmt.Errorf("unknown opcode: 0x%x (header=%v)", opcode, header)
-	}
-	length := binary.BigEndian.Uint32(header[4:8])
-	var body []byte
-	if length > 0 {
-		if length > 256*1024*1024 { // spec says 256MB is max
-			cn.close()
-			return 0, nil, fmt.Errorf("frame too large: %d (header=%v)", length, header)
-		}
-		body = make([]byte, length)
-		if _, err := io.ReadFull(cn.c, body); err != nil {
-			cn.close() // better assume that the connection is broken
-			return 0, nil, err
-		}
-	}
-	if header[1]&flagCompressed != 0 && cn.pool.compression == "snappy" {
-		var err error
-		body, err = snappy.Decode(nil, body)
-		if err != nil {
-			cn.close()
-			return 0, nil, err
-		}
-	}
-	if opcode == opError {
-		code := binary.BigEndian.Uint32(body[0:4])
-		msglen := binary.BigEndian.Uint16(body[4:6])
-		msg := string(body[6 : 6+msglen])
-		return opcode, body, Error{Code: int(code), Msg: msg}
-	}
-	return opcode, body, nil
-}
-
-func (p *pool) conn() (*connection, error) {
-	if p.dead {
-		return nil, driver.ErrBadConn
-	}
-
-	totalConnections := len(p.connections)
-	start := p.i + 1 // make sure that we start from the next position in the ring
-
-	for i := 0; i < totalConnections; i++ {
-		idx := (i + start) % totalConnections
-		cn := p.connections[idx]
-		if cn.alive {
-			p.i = idx // set the new 'i' so the ring will start again in the right place
-			return cn, nil
-		}
+func (s *Session) Query(stmt string, args ...interface{}) *Query {
+	return &Query{
+		stmt: stmt,
+		args: args,
+		cons: s.cfg.Consistency,
+		ctx:  s,
 	}
-
-	// we've exhausted the pool, gonna have a bad time
-	p.dead = true
-	return nil, driver.ErrBadConn
 }
 
-func (p *pool) join() {
-	p.reconnect()
-
-	// Every 1 second, we want to try reconnecting to disconnected nodes
-	go func() {
-		for {
-			select {
-			case <-p.stop:
-				return
-			default:
-				p.reconnect()
-				time.Sleep(time.Second)
-			}
-		}
-	}()
+func (s *Session) executeQuery(query *Query) (buffer, error) {
+	// TODO(tux21b): do something clever here
+	return s.pool[0].executeQuery(query)
 }
 
-func (p *pool) reconnect() {
-	for _, cn := range p.connections {
-		if !cn.alive {
-			cn.open()
-		}
-	}
+func (s *Session) Close() {
+	return
 }
 
-func (p *pool) Begin() (driver.Tx, error) {
-	if p.dead {
-		return nil, driver.ErrBadConn
-	}
-	return p, nil
-}
+type Consistency uint16
 
-func (p *pool) Commit() error {
-	if p.dead {
-		return driver.ErrBadConn
-	}
-	return nil
-}
+const (
+	ConAny         Consistency = 0x0000
+	ConOne         Consistency = 0x0001
+	ConTwo         Consistency = 0x0002
+	ConThree       Consistency = 0x0003
+	ConQuorum      Consistency = 0x0004
+	ConAll         Consistency = 0x0005
+	ConLocalQuorum Consistency = 0x0006
+	ConEachQuorum  Consistency = 0x0007
+	ConSerial      Consistency = 0x0008
+	ConLocalSerial Consistency = 0x0009
+)
 
-func (p *pool) Close() error {
-	if p.dead {
-		return driver.ErrBadConn
-	}
-	for _, cn := range p.connections {
-		cn.close()
-	}
-	p.stop <- struct{}{}
-	p.dead = true
-	return nil
-}
+var ErrNotFound = errors.New("not found")
 
-func (p *pool) Rollback() error {
-	if p.dead {
-		return driver.ErrBadConn
+type Query struct {
+	stmt string
+	args []interface{}
+	cons Consistency
+	ctx  interface {
+		executeQuery(query *Query) (buffer, error)
 	}
-	return nil
 }
 
-func (p *pool) Prepare(query string) (driver.Stmt, error) {
-	// Explicitly check if the query is a "USE <keyspace>"
-	// Since it needs to be special cased and run on each server
-	if strings.HasPrefix(query, keyspaceQuery) {
-		keyspace := query[len(keyspaceQuery):]
-		p.UseKeyspace(keyspace)
-		return &statement{}, nil
-	}
-
-	for {
-		cn, err := p.conn()
-		if err != nil {
-			return nil, err
-		}
-		st, err := cn.Prepare(query)
-		if err != nil {
-			// the cn has gotten marked as dead already
-			if p.dead {
-				// The entire pool is dead, so we bubble up the ErrBadConn
-				return nil, driver.ErrBadConn
-			} else {
-				continue // Retry request on another cn
-			}
-		}
-		return st, nil
-	}
-}
+var ErrQueryUnbound = errors.New("can not execute unbound query")
 
-func (p *pool) UseKeyspace(keyspace string) {
-	p.keyspace = keyspace
-	for _, cn := range p.connections {
-		cn.UseKeyspace(keyspace)
-	}
+func NewQuery(stmt string) *Query {
+	return &Query{stmt: stmt, cons: ConQuorum}
 }
 
-func (cn *connection) UseKeyspace(keyspace string) error {
-	st, err := cn.Prepare(keyspaceQuery + keyspace)
+func (q *Query) Exec() error {
+	frame, err := q.request()
 	if err != nil {
 		return err
 	}
-	if _, err = st.Exec([]driver.Value{}); err != nil {
-		return err
-	}
-	return nil
-}
-
-func (cn *connection) Prepare(query string) (driver.Stmt, error) {
-	body := make([]byte, len(query)+4)
-	binary.BigEndian.PutUint32(body[0:4], uint32(len(query)))
-	copy(body[4:], []byte(query))
-	if err := cn.send(opPrepare, body); err != nil {
-		return nil, err
-	}
-	opcode, body, err := cn.recv()
-	if err != nil {
-		return nil, err
-	}
-	if opcode != opResult || binary.BigEndian.Uint32(body) != 4 {
-		return nil, fmt.Errorf("expected prepared result")
+	if frame[3] == opResult {
+		frame.skipHeader()
+		kind := frame.readInt()
+		if kind == 3 {
+			keyspace := frame.readString()
+			fmt.Println("set keyspace:", keyspace)
+		} else {
+		}
 	}
-	n := int(binary.BigEndian.Uint16(body[4:]))
-	prepared := body[6 : 6+n]
-	columns, meta, _ := parseMeta(body[6+n:])
-	return &statement{cn: cn, query: query,
-		prepared: prepared, columns: columns, meta: meta}, nil
-}
-
-type statement struct {
-	cn       *connection
-	query    string
-	prepared []byte
-	columns  []string
-	meta     []uint16
-}
-
-func (s *statement) Close() error {
 	return nil
 }
 
-func (st *statement) ColumnConverter(idx int) driver.ValueConverter {
-	return (&columnEncoder{st.meta}).ColumnConverter(idx)
+func (q *Query) request() (buffer, error) {
+	return q.ctx.executeQuery(q)
 }
 
-func (st *statement) NumInput() int {
-	return len(st.columns)
+func (q *Query) Consistency(cons Consistency) *Query {
+	q.cons = cons
+	return q
 }
 
-func parseMeta(body []byte) ([]string, []uint16, int) {
-	flags := binary.BigEndian.Uint32(body)
-	globalTableSpec := flags&1 == 1
-	columnCount := int(binary.BigEndian.Uint32(body[4:]))
-	i := 8
-	if globalTableSpec {
-		l := int(binary.BigEndian.Uint16(body[i:]))
-		keyspace := string(body[i+2 : i+2+l])
-		i += 2 + l
-		l = int(binary.BigEndian.Uint16(body[i:]))
-		tablename := string(body[i+2 : i+2+l])
-		i += 2 + l
-		_, _ = keyspace, tablename
-	}
-	columns := make([]string, columnCount)
-	meta := make([]uint16, columnCount)
-	for c := 0; c < columnCount; c++ {
-		l := int(binary.BigEndian.Uint16(body[i:]))
-		columns[c] = string(body[i+2 : i+2+l])
-		i += 2 + l
-		meta[c] = binary.BigEndian.Uint16(body[i:])
-		i += 2
+func (q *Query) Scan(values ...interface{}) error {
+	found := false
+	iter := q.Iter()
+	if iter.Scan(values...) {
+		found = true
 	}
-	return columns, meta, i
-}
-
-func (st *statement) exec(v []driver.Value) error {
-	sz := 6 + len(st.prepared)
-	for i := range v {
-		if b, ok := v[i].([]byte); ok {
-			sz += len(b) + 4
-		}
-	}
-	body, p := make([]byte, sz), 4+len(st.prepared)
-	binary.BigEndian.PutUint16(body, uint16(len(st.prepared)))
-	copy(body[2:], st.prepared)
-	binary.BigEndian.PutUint16(body[p-2:], uint16(len(v)))
-	for i := range v {
-		b, ok := v[i].([]byte)
-		if !ok {
-			return fmt.Errorf("unsupported type %T at column %d", v[i], i)
-		}
-		binary.BigEndian.PutUint32(body[p:], uint32(len(b)))
-		copy(body[p+4:], b)
-		p += 4 + len(b)
-	}
-	binary.BigEndian.PutUint16(body[p:], uint16(st.cn.pool.consistency))
-	if err := st.cn.send(opExecute, body); err != nil {
+	if err := iter.Close(); err != nil {
 		return err
+	} else if !found {
+		return ErrNotFound
 	}
 	return nil
 }
 
-func (st *statement) Exec(v []driver.Value) (driver.Result, error) {
-	if st.cn == nil {
-		return nil, nil
-	}
-	if err := st.exec(v); err != nil {
-		return nil, err
-	}
-	opcode, body, err := st.cn.recv()
+func (q *Query) Iter() *Iter {
+	iter := new(Iter)
+	frame, err := q.request()
 	if err != nil {
-		return nil, err
+		iter.err = err
+		return iter
 	}
-	_, _ = opcode, body
-	return nil, nil
-}
-
-func (st *statement) Query(v []driver.Value) (driver.Rows, error) {
-	if err := st.exec(v); err != nil {
-		return nil, err
+	frame.skipHeader()
+	kind := frame.readInt()
+	if kind == resultKindRows {
+		iter.setFrame(frame)
 	}
-	opcode, body, err := st.cn.recv()
-	if err != nil {
-		return nil, err
-	}
-	kind := binary.BigEndian.Uint32(body[0:4])
-	if opcode != opResult || kind != 2 {
-		return nil, fmt.Errorf("expected rows as result")
-	}
-	columns, meta, n := parseMeta(body[4:])
-	i := n + 4
-	rows := &rows{
-		columns: columns,
-		meta:    meta,
-		numRows: int(binary.BigEndian.Uint32(body[i:])),
-	}
-	i += 4
-	rows.body = body[i:]
-	return rows, nil
+	return iter
 }
 
-type rows struct {
-	columns []string
-	meta    []uint16
-	body    []byte
-	row     int
+type Iter struct {
+	err     error
+	pos     int
 	numRows int
+	info    []columnInfo
+	flags   int
+	frame   buffer
+}
+
+func (iter *Iter) setFrame(frame buffer) {
+	info := frame.readMetaData()
+	iter.flags = 0
+	iter.info = info
+	iter.numRows = frame.readInt()
+	iter.pos = 0
+	iter.err = nil
+	iter.frame = frame
+}
+
+func (iter *Iter) Scan(values ...interface{}) bool {
+	if iter.err != nil || iter.pos >= iter.numRows {
+		return false
+	}
+	iter.pos++
+	if len(values) != len(iter.info) {
+		iter.err = errors.New("count mismatch")
+		return false
+	}
+	for i := 0; i < len(values); i++ {
+		data := iter.frame.readBytes()
+		if err := Unmarshal(iter.info[i].TypeInfo, data, values[i]); err != nil {
+			iter.err = err
+			return false
+		}
+	}
+	return true
 }
 
-func (r *rows) Close() error {
-	return nil
-}
-
-func (r *rows) Columns() []string {
-	return r.columns
+func (iter *Iter) Close() error {
+	return iter.err
 }
 
-func (r *rows) Next(values []driver.Value) error {
-	if r.row >= r.numRows {
-		return io.EOF
-	}
-	for column := 0; column < len(r.columns); column++ {
-		n := int32(binary.BigEndian.Uint32(r.body))
-		r.body = r.body[4:]
-		if n >= 0 {
-			values[column] = decode(r.body[:n], r.meta[column])
-			r.body = r.body[n:]
-		} else {
-			values[column] = nil
-		}
-	}
-	r.row++
-	return nil
+type columnInfo struct {
+	Keyspace string
+	Table    string
+	Name     string
+	TypeInfo *TypeInfo
 }
 
 type Error struct {
-	Code int
-	Msg  string
+	Code    int
+	Message string
 }
 
 func (e Error) Error() string {
-	return e.Msg
-}
-
-func init() {
-	sql.Register("gocql", &drv{})
+	return e.Message
 }

+ 68 - 169
gocql_test.go

@@ -1,42 +1,41 @@
-// Copyright (c) 2012 The gocql Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
 package gocql
 
 import (
 	"bytes"
-	"database/sql"
-	"github.com/tux21b/gocql/uuid"
+	"fmt"
 	"testing"
 	"time"
 )
 
-func TestSimple(t *testing.T) {
-	db, err := sql.Open("gocql", "localhost:9042 keyspace=system")
-	if err != nil {
-		t.Fatal(err)
-	}
+func TestConnect(t *testing.T) {
+	db := NewSession(Config{
+		Nodes: []string{
+			"127.0.0.1",
+		},
+		Keyspace:    "system",
+		Consistency: ConQuorum,
+	})
+	defer db.Close()
 
-	rows, err := db.Query("SELECT keyspace_name FROM schema_keyspaces")
-	if err != nil {
-		t.Fatal(err)
+	for i := 0; i < 5; i++ {
+		db.Query("SELECT keyspace_name FROM schema_keyspaces WHERE keyspace_name = ?",
+			"system_auth").Exec()
 	}
 
-	for rows.Next() {
-		var keyspace string
-		if err := rows.Scan(&keyspace); err != nil {
-			t.Fatal(err)
-		}
+	var keyspace string
+	var durable bool
+	iter := db.Query("SELECT keyspace_name, durable_writes FROM schema_keyspaces").Iter()
+	for iter.Scan(&keyspace, &durable) {
+		fmt.Println("Keyspace:", keyspace, durable)
 	}
-	if err != nil {
-		t.Fatal(err)
+	if err := iter.Close(); err != nil {
+		fmt.Println(err)
 	}
 }
 
 type Page struct {
 	Title      string
-	RevID      uuid.UUID
+	RevID      int
 	Body       string
 	Hits       int
 	Protected  bool
@@ -45,67 +44,74 @@ type Page struct {
 }
 
 var pages = []*Page{
-	&Page{"Frontpage", uuid.TimeUUID(), "Hello world!", 0, false,
-		time.Date(2012, 8, 20, 10, 0, 0, 0, time.UTC), nil},
-	&Page{"Frontpage", uuid.TimeUUID(), "Hello modified world!", 0, false,
+	&Page{"Frontpage", 1, "Hello world!", 0, false,
+		time.Date(2012, 8, 20, 10, 0, 0, 0, time.UTC), []byte{}},
+	&Page{"Frontpage", 2, "Hello modified world!", 0, false,
 		time.Date(2012, 8, 22, 10, 0, 0, 0, time.UTC), []byte("img data\x00")},
-	&Page{"LoremIpsum", uuid.TimeUUID(), "Lorem ipsum dolor sit amet", 12,
-		true, time.Date(2012, 8, 22, 10, 0, 8, 0, time.UTC), nil},
+	&Page{"LoremIpsum", 3, "Lorem ipsum dolor sit amet", 12,
+		true, time.Date(2012, 8, 22, 10, 0, 8, 0, time.UTC), []byte{}},
 }
 
 func TestWiki(t *testing.T) {
-	db, err := sql.Open("gocql", "localhost:9042 compression=snappy")
-	if err != nil {
-		t.Fatal(err)
+	db := NewSession(Config{
+		Nodes:       []string{"localhost"},
+		Consistency: ConQuorum,
+	})
+
+	if err := db.Query("DROP KEYSPACE gocql_wiki").Exec(); err != nil {
+		t.Log("DROP KEYSPACE:", err)
 	}
-	db.Exec("DROP KEYSPACE gocql_wiki")
-	if _, err := db.Exec(`CREATE KEYSPACE gocql_wiki
-	                      WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }`); err != nil {
-		t.Fatal(err)
+
+	if err := db.Query(`CREATE KEYSPACE gocql_wiki
+		WITH replication = {
+			'class' : 'SimpleStrategy',
+			'replication_factor' : 1
+		}`).Exec(); err != nil {
+		t.Fatal("CREATE KEYSPACE:", err)
 	}
-	if _, err := db.Exec("USE gocql_wiki"); err != nil {
-		t.Fatal(err)
+
+	if err := db.Query("USE gocql_wiki").Exec(); err != nil {
+		t.Fatal("USE:", err)
 	}
 
-	if _, err := db.Exec(`CREATE TABLE page (
-        title varchar,
-        revid timeuuid,
-        body varchar,
-        hits int,
-        protected boolean,
-        modified timestamp,
-        attachment blob,
-        PRIMARY KEY (title, revid)
-        )`); err != nil {
-		t.Fatal(err)
+	if err := db.Query(`CREATE TABLE page (
+		title varchar,
+		revid int,
+		body varchar,
+		hits int,
+		protected boolean,
+		modified timestamp,
+		attachment blob,
+		PRIMARY KEY (title, revid)
+		)`).Exec(); err != nil {
+		t.Fatal("CREATE TABLE:", err)
 	}
+
 	for _, p := range pages {
-		if _, err := db.Exec(`INSERT INTO page (title, revid, body, hits,
-            protected, modified, attachment) VALUES (?, ?, ?, ?, ?, ?, ?);`,
+		if err := db.Query(`INSERT INTO page (title, revid, body, hits,
+			protected, modified, attachment) VALUES (?, ?, ?, ?, ?, ?, ?)`,
 			p.Title, p.RevID, p.Body, p.Hits, p.Protected, p.Modified,
-			p.Attachment); err != nil {
-			t.Fatal(err)
+			p.Attachment).Exec(); err != nil {
+			t.Fatal("INSERT:", err)
 		}
 	}
 
-	row := db.QueryRow(`SELECT count(*) FROM page`)
 	var count int
-	if err := row.Scan(&count); err != nil {
-		t.Error(err)
+	if err := db.Query("SELECT count(*) FROM page").Scan(&count); err != nil {
+		t.Fatal("COUNT:", err)
 	}
 	if count != len(pages) {
-		t.Fatalf("expected %d rows, got %d", len(pages), count)
+		t.Fatalf("COUNT: expected %d got %d", len(pages), count)
 	}
 
 	for _, page := range pages {
-		row := db.QueryRow(`SELECT title, revid, body, hits, protected,
-            modified, attachment
-            FROM page WHERE title = ? AND revid = ?`, page.Title, page.RevID)
+		qry := db.Query(`SELECT title, revid, body, hits, protected,
+			modified, attachment
+		    FROM page WHERE title = ? AND revid = ?`, page.Title, page.RevID)
 		var p Page
-		err := row.Scan(&p.Title, &p.RevID, &p.Body, &p.Hits, &p.Protected,
-			&p.Modified, &p.Attachment)
-		if err != nil {
-			t.Fatal(err)
+		if err := qry.Scan(&p.Title, &p.RevID, &p.Body, &p.Hits, &p.Protected,
+			&p.Modified, &p.Attachment); err != nil {
+			t.Fatal("SELECT PAGE:", err)
 		}
 		p.Modified = p.Modified.In(time.UTC)
 		if page.Title != p.Title || page.RevID != p.RevID ||
@@ -115,111 +121,4 @@ func TestWiki(t *testing.T) {
 			t.Errorf("expected %#v got %#v", *page, p)
 		}
 	}
-
-	row = db.QueryRow(`SELECT title, revid, body, hits, protected,
-        modified, attachment
-        FROM page WHERE title = ? ORDER BY revid DESC`, "Frontpage")
-	var p Page
-	if err := row.Scan(&p.Title, &p.RevID, &p.Body, &p.Hits, &p.Protected,
-		&p.Modified, &p.Attachment); err != nil {
-		t.Error(err)
-	}
-	p.Modified = p.Modified.In(time.UTC)
-	page := pages[1]
-	if page.Title != p.Title || page.RevID != p.RevID ||
-		page.Body != p.Body || page.Modified != p.Modified ||
-		page.Hits != p.Hits || page.Protected != p.Protected ||
-		!bytes.Equal(page.Attachment, p.Attachment) {
-		t.Errorf("expected %#v got %#v", *page, p)
-	}
-
-}
-
-func TestTypes(t *testing.T) {
-	db, err := sql.Open("gocql", "localhost:9042 compression=snappy")
-	if err != nil {
-		t.Fatal(err)
-	}
-	db.Exec("DROP KEYSPACE gocql_types")
-	if _, err := db.Exec(`CREATE KEYSPACE gocql_types
-	                      WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }`); err != nil {
-		t.Fatal(err)
-	}
-	if _, err := db.Exec("USE gocql_types"); err != nil {
-		t.Fatal(err)
-	}
-
-	if _, err := db.Exec(`CREATE TABLE stuff (
-        id bigint,
-		foo text,
-        PRIMARY KEY (id)
-        )`); err != nil {
-		t.Fatal(err)
-	}
-
-	id := int64(-1 << 63)
-
-	if _, err := db.Exec(`INSERT INTO stuff (id, foo) VALUES (?, ?);`, &id, "test"); err != nil {
-		t.Fatal(err)
-	}
-
-	var rid int64
-
-	row := db.QueryRow(`SELECT id FROM stuff WHERE id = ?`, id)
-
-	if err := row.Scan(&rid); err != nil {
-		t.Error(err)
-	}
-
-	if id != rid {
-		t.Errorf("expected %v got %v", id, rid)
-	}
-}
-
-
-func TestNullColumnValues(t *testing.T) {
-	db, err := sql.Open("gocql", "localhost:9042 compression=snappy")
-	if err != nil {
-		t.Fatal(err)
-	}
-	db.Exec("DROP KEYSPACE gocql_nullvalues")
-	if _, err := db.Exec(`CREATE KEYSPACE gocql_nullvalues
-	                      WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };`); err != nil {
-		t.Fatal(err)
-	}
-	if _, err := db.Exec("USE gocql_nullvalues"); err != nil {
-		t.Fatal(err)
-	}
-	if _, err := db.Exec(`CREATE TABLE stuff (
-        id bigint,
-        subid bigint,
-				foo text,
-				bar text,
-        PRIMARY KEY (id, subid)
-        )`); err != nil {
-		t.Fatal(err)
-	}
-	id := int64(-1 << 63)
-
-	if _, err := db.Exec(`INSERT INTO stuff (id, subid, foo) VALUES (?, ?, ?);`, id, int64(4), "test"); err != nil {
-		t.Fatal(err)
-	}
-
-	if _, err := db.Exec(`INSERT INTO stuff (id, subid, bar) VALUES (?, ?, ?);`, id, int64(6), "test2"); err != nil {
-		t.Fatal(err)
-	}
-
-	var rid int64
-	var sid int64
-	var data1 []byte
-	var data2 []byte
-	if rows, err := db.Query(`SELECT id, subid, foo, bar FROM stuff`); err == nil {
-			for rows.Next() {
-					if err := rows.Scan(&rid, &sid, &data1, &data2); err != nil {
-				t.Error(err)
-			}
-		}
-	} else {
-		t.Fatal(err)
-	}
 }

+ 204 - 0
marshal.go

@@ -0,0 +1,204 @@
+// Copyright (c) 2012 The gocql Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gocql
+
+import (
+	"fmt"
+	"time"
+)
+
+// Marshaler is the interface implemented by objects that can marshal
+// themselves into values understood by Cassandra.
+type Marshaler interface {
+	MarshalCQL(info *TypeInfo, value interface{}) ([]byte, error)
+}
+
+// Unmarshaler is the interface implemented by objects that can unmarshal
+// a Cassandra specific description of themselves.
+type Unmarshaler interface {
+	UnmarshalCQL(info *TypeInfo, data []byte, value interface{}) error
+}
+
+// Marshal returns the CQL encoding of the value for the Cassandra
+// internal type described by the info parameter.
+func Marshal(info *TypeInfo, value interface{}) ([]byte, error) {
+	if v, ok := value.(Marshaler); ok {
+		return v.MarshalCQL(info, value)
+	}
+	switch info.Type {
+	case TypeVarchar, TypeAscii, TypeBlob:
+		switch v := value.(type) {
+		case string:
+			return []byte(v), nil
+		case []byte:
+			return v, nil
+		}
+	case TypeBoolean:
+		if v, ok := value.(bool); ok {
+			if v {
+				return []byte{1}, nil
+			} else {
+				return []byte{0}, nil
+			}
+		}
+	case TypeInt:
+		switch v := value.(type) {
+		case int:
+			x := int32(v)
+			return []byte{byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)}, nil
+		}
+	case TypeTimestamp:
+		if v, ok := value.(time.Time); ok {
+			x := v.In(time.UTC).UnixNano() / int64(time.Millisecond)
+			return []byte{byte(x >> 56), byte(x >> 48), byte(x >> 40),
+				byte(x >> 32), byte(x >> 24), byte(x >> 16),
+				byte(x >> 8), byte(x)}, nil
+		}
+	}
+	// TODO(tux21b): add reflection and a lot of other types
+	return nil, fmt.Errorf("can not marshal %T into %s", value, info)
+}
+
+// Unmarshal parses the CQL encoded data based on the info parameter that
+// describes the Cassandra internal data type and stores the result in the
+// value pointed by value.
+func Unmarshal(info *TypeInfo, data []byte, value interface{}) error {
+	if v, ok := value.(Unmarshaler); ok {
+		return v.UnmarshalCQL(info, data, value)
+	}
+	switch info.Type {
+	case TypeVarchar, TypeAscii, TypeBlob:
+		switch v := value.(type) {
+		case *string:
+			*v = string(data)
+			return nil
+		case *[]byte:
+			val := make([]byte, len(data))
+			copy(val, data)
+			*v = val
+			return nil
+		}
+	case TypeBoolean:
+		if v, ok := value.(*bool); ok && len(data) == 1 {
+			*v = data[0] != 0
+			return nil
+		}
+	case TypeBigInt:
+		if v, ok := value.(*int); ok && len(data) == 8 {
+			*v = int(data[0])<<56 | int(data[1])<<48 | int(data[2])<<40 |
+				int(data[3])<<32 | int(data[4])<<24 | int(data[5])<<16 |
+				int(data[6])<<8 | int(data[7])
+			return nil
+		}
+	case TypeInt:
+		if v, ok := value.(*int); ok && len(data) == 4 {
+			*v = int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 |
+				int(data[3])
+			return nil
+		}
+	case TypeTimestamp:
+		if v, ok := value.(*time.Time); ok && len(data) == 8 {
+			x := int64(data[0])<<56 | int64(data[1])<<48 |
+				int64(data[2])<<40 | int64(data[3])<<32 |
+				int64(data[4])<<24 | int64(data[5])<<16 |
+				int64(data[6])<<8 | int64(data[7])
+			sec := x / 1000
+			nsec := (x - sec*1000) * 1000000
+			*v = time.Unix(sec, nsec)
+			return nil
+		}
+	}
+	// TODO(tux21b): add reflection and a lot of other basic types
+	return fmt.Errorf("can not unmarshal %s into %T", info, value)
+}
+
+// TypeInfo describes a Cassandra specific data type.
+type TypeInfo struct {
+	Type   Type
+	Key    *TypeInfo // only used for TypeMap
+	Value  *TypeInfo // only used for TypeMap, TypeList and TypeSet
+	Custom string    // only used for TypeCostum
+}
+
+// String returns a human readable name for the Cassandra datatype
+// described by t.
+func (t TypeInfo) String() string {
+	switch t.Type {
+	case TypeMap:
+		return fmt.Sprintf("%s(%s, %s)", t.Type, t.Key, t.Value)
+	case TypeList, TypeSet:
+		return fmt.Sprintf("%s(%s)", t.Type, t.Value)
+	case TypeCustom:
+		return fmt.Sprintf("%s(%s)", t.Type, t.Custom)
+	}
+	return t.Type.String()
+}
+
+// Type is the identifier of a Cassandra internal datatype.
+type Type int
+
+const (
+	TypeCustom    Type = 0x0000
+	TypeAscii     Type = 0x0001
+	TypeBigInt    Type = 0x0002
+	TypeBlob      Type = 0x0003
+	TypeBoolean   Type = 0x0004
+	TypeCounter   Type = 0x0005
+	TypeDecimal   Type = 0x0006
+	TypeDouble    Type = 0x0007
+	TypeFloat     Type = 0x0008
+	TypeInt       Type = 0x0009
+	TypeTimestamp Type = 0x000B
+	TypeUUID      Type = 0x000C
+	TypeVarchar   Type = 0x000D
+	TypeVarint    Type = 0x000E
+	TypeTimeUUID  Type = 0x000F
+	TypeInet      Type = 0x0010
+	TypeList      Type = 0x0020
+	TypeMap       Type = 0x0021
+	TypeSet       Type = 0x0022
+)
+
+// String returns the name of the identifier.
+func (t Type) String() string {
+	switch t {
+	case TypeCustom:
+		return "custom"
+	case TypeAscii:
+		return "ascii"
+	case TypeBigInt:
+		return "bigint"
+	case TypeBlob:
+		return "blob"
+	case TypeBoolean:
+		return "boolean"
+	case TypeCounter:
+		return "counter"
+	case TypeDecimal:
+		return "decimal"
+	case TypeFloat:
+		return "float"
+	case TypeInt:
+		return "int"
+	case TypeTimestamp:
+		return "timestamp"
+	case TypeUUID:
+		return "uuid"
+	case TypeVarchar:
+		return "varchar"
+	case TypeTimeUUID:
+		return "timeuuid"
+	case TypeInet:
+		return "inet"
+	case TypeList:
+		return "list"
+	case TypeMap:
+		return "map"
+	case TypeSet:
+		return "set"
+	default:
+		return "unknown"
+	}
+}