浏览代码

added nicer connection API

Christoph Hack 12 年之前
父节点
当前提交
30bc700a7b
共有 6 个文件被更改,包括 522 次插入379 次删除
  1. 0 258
      binary.go
  2. 115 70
      conn.go
  3. 265 0
      frame.go
  4. 45 43
      gocql.go
  5. 9 8
      gocql_test.go
  6. 88 0
      session.go

+ 0 - 258
binary.go

@@ -1,258 +0,0 @@
-// Copyright (c) 2012 The gocql Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package gocql
-
-import (
-	"errors"
-	"net"
-)
-
-const (
-	protoRequest  byte = 0x02
-	protoResponse byte = 0x82
-
-	opError         byte = 0x00
-	opStartup       byte = 0x01
-	opReady         byte = 0x02
-	opAuthenticate  byte = 0x03
-	opOptions       byte = 0x05
-	opSupported     byte = 0x06
-	opQuery         byte = 0x07
-	opResult        byte = 0x08
-	opPrepare       byte = 0x09
-	opExecute       byte = 0x0A
-	opRegister      byte = 0x0B
-	opEvent         byte = 0x0C
-	opBatch         byte = 0x0D
-	opAuthChallenge byte = 0x0E
-	opAuthResponse  byte = 0x0F
-	opAuthSuccess   byte = 0x10
-
-	resultKindVoid          = 1
-	resultKindRows          = 2
-	resultKindKeyspace      = 3
-	resultKindPrepared      = 4
-	resultKindSchemaChanged = 5
-
-	flagQueryValues uint8 = 1
-
-	headerSize = 8
-)
-
-var ErrInvalid = errors.New("invalid response")
-
-type buffer []byte
-
-func (b *buffer) writeInt(v int32) {
-	p := b.grow(4)
-	(*b)[p] = byte(v >> 24)
-	(*b)[p+1] = byte(v >> 16)
-	(*b)[p+2] = byte(v >> 8)
-	(*b)[p+3] = byte(v)
-}
-
-func (b *buffer) writeShort(v uint16) {
-	p := b.grow(2)
-	(*b)[p] = byte(v >> 8)
-	(*b)[p+1] = byte(v)
-}
-
-func (b *buffer) writeString(v string) {
-	b.writeShort(uint16(len(v)))
-	p := b.grow(len(v))
-	copy((*b)[p:], v)
-}
-
-func (b *buffer) writeLongString(v string) {
-	b.writeInt(int32(len(v)))
-	p := b.grow(len(v))
-	copy((*b)[p:], v)
-}
-
-func (b *buffer) writeUUID() {
-}
-
-func (b *buffer) writeStringList(v []string) {
-	b.writeShort(uint16(len(v)))
-	for i := range v {
-		b.writeString(v[i])
-	}
-}
-
-func (b *buffer) writeByte(v byte) {
-	p := b.grow(1)
-	(*b)[p] = v
-}
-
-func (b *buffer) writeBytes(v []byte) {
-	if v == nil {
-		b.writeInt(-1)
-		return
-	}
-	b.writeInt(int32(len(v)))
-	p := b.grow(len(v))
-	copy((*b)[p:], v)
-}
-
-func (b *buffer) writeShortBytes(v []byte) {
-	b.writeShort(uint16(len(v)))
-	p := b.grow(len(v))
-	copy((*b)[p:], v)
-}
-
-func (b *buffer) writeInet(ip net.IP, port int) {
-	p := b.grow(1 + len(ip))
-	(*b)[p] = byte(len(ip))
-	copy((*b)[p+1:], ip)
-	b.writeInt(int32(port))
-}
-
-func (b *buffer) writeStringMap(v map[string]string) {
-	b.writeShort(uint16(len(v)))
-	for key, value := range v {
-		b.writeString(key)
-		b.writeString(value)
-	}
-}
-
-func (b *buffer) writeStringMultimap(v map[string][]string) {
-	b.writeShort(uint16(len(v)))
-	for key, values := range v {
-		b.writeString(key)
-		b.writeStringList(values)
-	}
-}
-
-func (b *buffer) setHeader(version, flags, stream, opcode uint8) {
-	(*b)[0] = version
-	(*b)[1] = flags
-	(*b)[2] = stream
-	(*b)[3] = opcode
-}
-
-func (b *buffer) setLength(length int) {
-	(*b)[4] = byte(length >> 24)
-	(*b)[5] = byte(length >> 16)
-	(*b)[6] = byte(length >> 8)
-	(*b)[7] = byte(length)
-}
-
-func (b *buffer) Length() int {
-	return int((*b)[4])<<24 | int((*b)[5])<<16 | int((*b)[6])<<8 | int((*b)[7])
-}
-
-func (b *buffer) grow(n int) int {
-	if len(*b)+n >= cap(*b) {
-		buf := make(buffer, len(*b), len(*b)*2+n)
-		copy(buf, *b)
-		*b = buf
-	}
-	p := len(*b)
-	*b = (*b)[:p+n]
-	return p
-}
-
-func (b *buffer) skipHeader() {
-	*b = (*b)[headerSize:]
-}
-
-func (b *buffer) readInt() int {
-	if len(*b) < 4 {
-		panic(ErrInvalid)
-	}
-	v := int((*b)[0])<<24 | int((*b)[1])<<16 | int((*b)[2])<<8 | int((*b)[3])
-	*b = (*b)[4:]
-	return v
-}
-
-func (b *buffer) readShort() uint16 {
-	if len(*b) < 2 {
-		panic(ErrInvalid)
-	}
-	v := uint16((*b)[0])<<8 | uint16((*b)[1])
-	*b = (*b)[2:]
-	return v
-}
-
-func (b *buffer) readString() string {
-	n := int(b.readShort())
-	if len(*b) < n {
-		panic(ErrInvalid)
-	}
-	v := string((*b)[:n])
-	*b = (*b)[n:]
-	return v
-}
-
-func (b *buffer) readLongString() string {
-	n := b.readInt()
-	if len(*b) < n {
-		panic(ErrInvalid)
-	}
-	v := string((*b)[:n])
-	*b = (*b)[n:]
-	return v
-}
-
-func (b *buffer) readBytes() []byte {
-	n := b.readInt()
-	if n < 0 {
-		return nil
-	}
-	if len(*b) < n {
-		panic(ErrInvalid)
-	}
-	v := (*b)[:n]
-	*b = (*b)[n:]
-	return v
-}
-
-func (b *buffer) readShortBytes() []byte {
-	n := int(b.readShort())
-	if len(*b) < n {
-		panic(ErrInvalid)
-	}
-	v := (*b)[:n]
-	*b = (*b)[n:]
-	return v
-}
-
-func (b *buffer) readTypeInfo() *TypeInfo {
-	x := b.readShort()
-	typ := &TypeInfo{Type: Type(x)}
-	switch typ.Type {
-	case TypeCustom:
-		typ.Custom = b.readString()
-	case TypeMap:
-		typ.Key = b.readTypeInfo()
-		fallthrough
-	case TypeList, TypeSet:
-		typ.Value = b.readTypeInfo()
-	}
-	return typ
-}
-
-func (b *buffer) readMetaData() []columnInfo {
-	flags := b.readInt()
-	numColumns := b.readInt()
-	globalKeyspace := ""
-	globalTable := ""
-	if flags&1 != 0 {
-		globalKeyspace = b.readString()
-		globalTable = b.readString()
-	}
-	info := make([]columnInfo, numColumns)
-	for i := 0; i < numColumns; i++ {
-		info[i].Keyspace = globalKeyspace
-		info[i].Table = globalTable
-		if flags&1 == 0 {
-			info[i].Keyspace = b.readString()
-			info[i].Table = b.readString()
-		}
-		info[i].Name = b.readString()
-		info[i].TypeInfo = b.readTypeInfo()
-	}
-	return info
-}

+ 115 - 70
conn.go

@@ -11,60 +11,81 @@ import (
 	"time"
 )
 
-type connection struct {
-	conn     net.Conn
-	uniq     chan uint8
-	requests []frameRequest
-	nwait    int32
+const defaultFrameSize = 4096
+
+// Conn is a single connection to a Cassandra node. It can be used to execute
+// queries, but users are usually advised to use a more reliable, higher
+// level API.
+type Conn struct {
+	conn    net.Conn
+	timeout time.Duration
+
+	uniq  chan uint8
+	calls []callReq
+	nwait int32
 
 	prepMu sync.Mutex
 	prep   map[string]*queryInfo
-
-	timeout time.Duration
 }
 
-func connect(addr string, cfg *Config) (*connection, error) {
-	conn, err := net.Dial("tcp", addr)
+// Connect establishes a connection to a Cassandra node.
+// You must also call the Serve method before you can execute any queries.
+func Connect(addr string, cfg *Config) (*Conn, error) {
+	conn, err := net.DialTimeout("tcp", addr, cfg.Timeout)
 	if err != nil {
 		return nil, err
 	}
-	c := &connection{
-		conn:     conn,
-		uniq:     make(chan uint8, 64),
-		requests: make([]frameRequest, 64),
-		prep:     make(map[string]*queryInfo),
-		timeout:  cfg.Timeout,
+	c := &Conn{
+		conn:    conn,
+		uniq:    make(chan uint8, 128),
+		calls:   make([]callReq, 128),
+		prep:    make(map[string]*queryInfo),
+		timeout: cfg.Timeout,
 	}
 	for i := 0; i < cap(c.uniq); i++ {
 		c.uniq <- uint8(i)
 	}
 
-	go c.run()
+	if err := c.init(cfg); err != nil {
+		return nil, err
+	}
+
+	return c, nil
+}
 
-	frame := make(buffer, headerSize)
-	frame.setHeader(protoRequest, 0, 0, opStartup)
-	frame.writeStringMap(map[string]string{
+func (c *Conn) init(cfg *Config) error {
+	req := make(frame, headerSize, defaultFrameSize)
+	req.setHeader(protoRequest, 0, 0, opStartup)
+	req.writeStringMap(map[string]string{
 		"CQL_VERSION": cfg.CQLVersion,
 	})
-	frame.setLength(len(frame) - headerSize)
-
-	frame, err = c.request(frame)
+	resp, err := c.callSimple(req)
 	if err != nil {
-		return nil, err
+		return err
+	} else if resp[3] == opError {
+		return resp.readErrorFrame()
+	} else if resp[3] != opReady {
+		return ErrProtocol
 	}
 
-	if cfg.Keyspace != "" {
+	/*	if cfg.Keyspace != "" {
 		qry := &Query{stmt: "USE " + cfg.Keyspace}
 		frame, err = c.executeQuery(qry)
-	}
+		if err != nil {
+			return err
+		}
+	} */
 
-	return c, nil
+	return nil
 }
 
-func (c *connection) run() {
+// Serve starts the stream multiplexer for this connection, which is required
+// to execute any queries. This method runs as long as the connection is
+// open and is therefore usually called in a separate goroutine.
+func (c *Conn) Serve() error {
 	var err error
 	for {
-		var frame buffer
+		var frame frame
 		frame, err = c.recv()
 		if err != nil {
 			break
@@ -73,20 +94,21 @@ func (c *connection) run() {
 	}
 
 	c.conn.Close()
-	for id := 0; id < len(c.requests); id++ {
-		req := &c.requests[id]
+	for id := 0; id < len(c.calls); id++ {
+		req := &c.calls[id]
 		if atomic.LoadInt32(&req.active) == 1 {
-			req.reply <- frameReply{nil, err}
+			req.resp <- callResp{nil, err}
 		}
 	}
+	return err
 }
 
-func (c *connection) recv() (buffer, error) {
-	frame := make(buffer, headerSize, headerSize+512)
+func (c *Conn) recv() (frame, error) {
+	resp := make(frame, headerSize, headerSize+512)
 	c.conn.SetReadDeadline(time.Now().Add(c.timeout))
 	n, last, pinged := 0, 0, false
-	for n < len(frame) {
-		nn, err := c.conn.Read(frame[n:])
+	for n < len(resp) {
+		nn, err := c.conn.Read(resp[n:])
 		n += nn
 		if err != nil {
 			if err, ok := err.(net.Error); ok && err.Timeout() {
@@ -108,59 +130,68 @@ func (c *connection) recv() (buffer, error) {
 				return nil, err
 			}
 		}
-		if n == headerSize && len(frame) == headerSize {
-			if frame[0] != protoResponse {
+		if n == headerSize && len(resp) == headerSize {
+			if resp[0] != protoResponse {
 				return nil, ErrInvalid
 			}
-			frame.grow(frame.Length())
+			resp.grow(resp.Length())
 		}
 	}
-	return frame, nil
+	return resp, nil
 }
 
-func (c *connection) ping() error {
-	frame := make(buffer, headerSize, headerSize)
-	frame.setHeader(protoRequest, 0, 0, opOptions)
-	frame.setLength(0)
-
-	_, err := c.request(frame)
-	return err
+func (c *Conn) callSimple(req frame) (frame, error) {
+	req.setLength(len(req) - headerSize)
+	if _, err := c.conn.Write(req); err != nil {
+		c.conn.Close()
+		return nil, err
+	}
+	return c.recv()
 }
 
-func (c *connection) request(frame buffer) (buffer, error) {
+func (c *Conn) call(req frame) (frame, error) {
 	id := <-c.uniq
-	frame[2] = id
+	req[2] = id
 
-	req := &c.requests[id]
-	req.reply = make(chan frameReply, 1)
+	call := &c.calls[id]
+	call.resp = make(chan callResp, 1)
 	atomic.AddInt32(&c.nwait, 1)
-	atomic.StoreInt32(&req.active, 1)
+	atomic.StoreInt32(&call.active, 1)
 
-	if _, err := c.conn.Write(frame); err != nil {
+	req.setLength(len(req) - headerSize)
+	if _, err := c.conn.Write(req); err != nil {
+		c.conn.Close()
 		return nil, err
 	}
 
-	reply := <-req.reply
-	req.reply = nil
+	reply := <-call.resp
+	call.resp = nil
 
 	c.uniq <- id
 	return reply.buf, reply.err
 }
 
-func (c *connection) dispatch(frame buffer) {
-	id := int(frame[2])
-	if id >= len(c.requests) {
+func (c *Conn) dispatch(resp frame) {
+	id := int(resp[2])
+	if id >= len(c.calls) {
 		return
 	}
-	req := &c.requests[id]
-	if !atomic.CompareAndSwapInt32(&req.active, 1, 0) {
+	call := &c.calls[id]
+	if !atomic.CompareAndSwapInt32(&call.active, 1, 0) {
 		return
 	}
 	atomic.AddInt32(&c.nwait, -1)
-	req.reply <- frameReply{frame, nil}
+	call.resp <- callResp{resp, nil}
 }
 
-func (c *connection) prepareQuery(stmt string) *queryInfo {
+func (c *Conn) ping() error {
+	req := make(frame, headerSize)
+	req.setHeader(protoRequest, 0, 0, opOptions)
+	_, err := c.call(req)
+	return err
+}
+
+func (c *Conn) prepareQuery(stmt string) *queryInfo {
 	c.prepMu.Lock()
 	info := c.prep[stmt]
 	if info != nil {
@@ -173,12 +204,12 @@ func (c *connection) prepareQuery(stmt string) *queryInfo {
 	c.prep[stmt] = info
 	c.prepMu.Unlock()
 
-	frame := make(buffer, headerSize, headerSize+512)
+	frame := make(frame, headerSize, headerSize+512)
 	frame.setHeader(protoRequest, 0, 0, opPrepare)
 	frame.writeLongString(stmt)
 	frame.setLength(len(frame) - headerSize)
 
-	frame, err := c.request(frame)
+	frame, err := c.call(frame)
 	if err != nil {
 		return nil
 	}
@@ -191,13 +222,13 @@ func (c *connection) prepareQuery(stmt string) *queryInfo {
 	return info
 }
 
-func (c *connection) executeQuery(query *Query) (buffer, error) {
+func (c *Conn) executeQuery(query *Query) (frame, error) {
 	var info *queryInfo
 	if len(query.args) > 0 {
 		info = c.prepareQuery(query.stmt)
 	}
 
-	frame := make(buffer, headerSize, headerSize+512)
+	frame := make(frame, headerSize, headerSize+512)
 	if info == nil {
 		frame.setHeader(protoRequest, 0, 0, opQuery)
 		frame.writeLongString(query.stmt)
@@ -223,7 +254,7 @@ func (c *connection) executeQuery(query *Query) (buffer, error) {
 	}
 	frame.setLength(len(frame) - headerSize)
 
-	frame, err := c.request(frame)
+	frame, err := c.call(frame)
 	if err != nil {
 		return nil, err
 	}
@@ -244,12 +275,26 @@ type queryInfo struct {
 	wg   sync.WaitGroup
 }
 
-type frameRequest struct {
+type callReq struct {
 	active int32
-	reply  chan frameReply
+	resp   chan callResp
 }
 
-type frameReply struct {
-	buf buffer
+type callResp struct {
+	buf frame
 	err error
 }
+
+/*
+  conn := NewConn(addr, cfg)
+
+  querier := conn.Querier()
+
+
+  conn.Init(addr, cfg)
+  go func() {
+  	err := conn.Serve()
+
+  }
+*/
+var foo = 0

+ 265 - 0
frame.go

@@ -0,0 +1,265 @@
+// Copyright (c) 2012 The gocql Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gocql
+
+import (
+	"errors"
+	"net"
+)
+
+const (
+	protoRequest  byte = 0x02
+	protoResponse byte = 0x82
+
+	opError         byte = 0x00
+	opStartup       byte = 0x01
+	opReady         byte = 0x02
+	opAuthenticate  byte = 0x03
+	opOptions       byte = 0x05
+	opSupported     byte = 0x06
+	opQuery         byte = 0x07
+	opResult        byte = 0x08
+	opPrepare       byte = 0x09
+	opExecute       byte = 0x0A
+	opRegister      byte = 0x0B
+	opEvent         byte = 0x0C
+	opBatch         byte = 0x0D
+	opAuthChallenge byte = 0x0E
+	opAuthResponse  byte = 0x0F
+	opAuthSuccess   byte = 0x10
+
+	resultKindVoid          = 1
+	resultKindRows          = 2
+	resultKindKeyspace      = 3
+	resultKindPrepared      = 4
+	resultKindSchemaChanged = 5
+
+	flagQueryValues uint8 = 1
+
+	headerSize = 8
+)
+
+var ErrInvalid = errors.New("invalid response")
+
+type frame []byte
+
+func (f *frame) writeInt(v int32) {
+	p := f.grow(4)
+	(*f)[p] = byte(v >> 24)
+	(*f)[p+1] = byte(v >> 16)
+	(*f)[p+2] = byte(v >> 8)
+	(*f)[p+3] = byte(v)
+}
+
+func (f *frame) writeShort(v uint16) {
+	p := f.grow(2)
+	(*f)[p] = byte(v >> 8)
+	(*f)[p+1] = byte(v)
+}
+
+func (f *frame) writeString(v string) {
+	f.writeShort(uint16(len(v)))
+	p := f.grow(len(v))
+	copy((*f)[p:], v)
+}
+
+func (f *frame) writeLongString(v string) {
+	f.writeInt(int32(len(v)))
+	p := f.grow(len(v))
+	copy((*f)[p:], v)
+}
+
+func (f *frame) writeUUID() {
+}
+
+func (f *frame) writeStringList(v []string) {
+	f.writeShort(uint16(len(v)))
+	for i := range v {
+		f.writeString(v[i])
+	}
+}
+
+func (f *frame) writeByte(v byte) {
+	p := f.grow(1)
+	(*f)[p] = v
+}
+
+func (f *frame) writeBytes(v []byte) {
+	if v == nil {
+		f.writeInt(-1)
+		return
+	}
+	f.writeInt(int32(len(v)))
+	p := f.grow(len(v))
+	copy((*f)[p:], v)
+}
+
+func (f *frame) writeShortBytes(v []byte) {
+	f.writeShort(uint16(len(v)))
+	p := f.grow(len(v))
+	copy((*f)[p:], v)
+}
+
+func (f *frame) writeInet(ip net.IP, port int) {
+	p := f.grow(1 + len(ip))
+	(*f)[p] = byte(len(ip))
+	copy((*f)[p+1:], ip)
+	f.writeInt(int32(port))
+}
+
+func (f *frame) writeStringMap(v map[string]string) {
+	f.writeShort(uint16(len(v)))
+	for key, value := range v {
+		f.writeString(key)
+		f.writeString(value)
+	}
+}
+
+func (f *frame) writeStringMultimap(v map[string][]string) {
+	f.writeShort(uint16(len(v)))
+	for key, values := range v {
+		f.writeString(key)
+		f.writeStringList(values)
+	}
+}
+
+func (f *frame) setHeader(version, flags, stream, opcode uint8) {
+	(*f)[0] = version
+	(*f)[1] = flags
+	(*f)[2] = stream
+	(*f)[3] = opcode
+}
+
+func (f *frame) setLength(length int) {
+	(*f)[4] = byte(length >> 24)
+	(*f)[5] = byte(length >> 16)
+	(*f)[6] = byte(length >> 8)
+	(*f)[7] = byte(length)
+}
+
+func (f *frame) Length() int {
+	return int((*f)[4])<<24 | int((*f)[5])<<16 | int((*f)[6])<<8 | int((*f)[7])
+}
+
+func (f *frame) grow(n int) int {
+	if len(*f)+n >= cap(*f) {
+		buf := make(frame, len(*f), len(*f)*2+n)
+		copy(buf, *f)
+		*f = buf
+	}
+	p := len(*f)
+	*f = (*f)[:p+n]
+	return p
+}
+
+func (f *frame) skipHeader() {
+	*f = (*f)[headerSize:]
+}
+
+func (f *frame) readInt() int {
+	if len(*f) < 4 {
+		panic(ErrInvalid)
+	}
+	v := int((*f)[0])<<24 | int((*f)[1])<<16 | int((*f)[2])<<8 | int((*f)[3])
+	*f = (*f)[4:]
+	return v
+}
+
+func (f *frame) readShort() uint16 {
+	if len(*f) < 2 {
+		panic(ErrInvalid)
+	}
+	v := uint16((*f)[0])<<8 | uint16((*f)[1])
+	*f = (*f)[2:]
+	return v
+}
+
+func (f *frame) readString() string {
+	n := int(f.readShort())
+	if len(*f) < n {
+		panic(ErrInvalid)
+	}
+	v := string((*f)[:n])
+	*f = (*f)[n:]
+	return v
+}
+
+func (f *frame) readLongString() string {
+	n := f.readInt()
+	if len(*f) < n {
+		panic(ErrInvalid)
+	}
+	v := string((*f)[:n])
+	*f = (*f)[n:]
+	return v
+}
+
+func (f *frame) readBytes() []byte {
+	n := f.readInt()
+	if n < 0 {
+		return nil
+	}
+	if len(*f) < n {
+		panic(ErrInvalid)
+	}
+	v := (*f)[:n]
+	*f = (*f)[n:]
+	return v
+}
+
+func (f *frame) readShortBytes() []byte {
+	n := int(f.readShort())
+	if len(*f) < n {
+		panic(ErrInvalid)
+	}
+	v := (*f)[:n]
+	*f = (*f)[n:]
+	return v
+}
+
+func (f *frame) readTypeInfo() *TypeInfo {
+	x := f.readShort()
+	typ := &TypeInfo{Type: Type(x)}
+	switch typ.Type {
+	case TypeCustom:
+		typ.Custom = f.readString()
+	case TypeMap:
+		typ.Key = f.readTypeInfo()
+		fallthrough
+	case TypeList, TypeSet:
+		typ.Value = f.readTypeInfo()
+	}
+	return typ
+}
+
+func (f *frame) readMetaData() []columnInfo {
+	flags := f.readInt()
+	numColumns := f.readInt()
+	globalKeyspace := ""
+	globalTable := ""
+	if flags&1 != 0 {
+		globalKeyspace = f.readString()
+		globalTable = f.readString()
+	}
+	info := make([]columnInfo, numColumns)
+	for i := 0; i < numColumns; i++ {
+		info[i].Keyspace = globalKeyspace
+		info[i].Table = globalTable
+		if flags&1 == 0 {
+			info[i].Keyspace = f.readString()
+			info[i].Table = f.readString()
+		}
+		info[i].Name = f.readString()
+		info[i].TypeInfo = f.readTypeInfo()
+	}
+	return info
+}
+
+func (f *frame) readErrorFrame() Error {
+	f.skipHeader()
+	code := f.readInt()
+	desc := f.readString()
+	return Error{code, desc}
+}

+ 45 - 43
gocql.go

@@ -8,8 +8,6 @@ import (
 	"errors"
 	"fmt"
 	"strings"
-	"sync"
-	"sync/atomic"
 	"time"
 )
 
@@ -20,6 +18,8 @@ type Config struct {
 	Consistency Consistency
 	DefaultPort int
 	Timeout     time.Duration
+	NodePicker  NodePicker
+	Reconnector Reconnector
 }
 
 func (c *Config) normalize() {
@@ -32,6 +32,12 @@ func (c *Config) normalize() {
 	if c.Timeout <= 0 {
 		c.Timeout = 200 * time.Millisecond
 	}
+	if c.NodePicker == nil {
+		c.NodePicker = NewRoundRobinPicker()
+	}
+	if c.Reconnector == nil {
+		c.Reconnector = NewExponentialReconnector(1*time.Second, 10*time.Minute)
+	}
 	for i := 0; i < len(c.Nodes); i++ {
 		c.Nodes[i] = strings.TrimSpace(c.Nodes[i])
 		if strings.IndexByte(c.Nodes[i], ':') < 0 {
@@ -41,25 +47,25 @@ func (c *Config) normalize() {
 }
 
 type Session struct {
-	cfg      *Config
-	active   []*node
-	pos      uint32
-	mu       sync.RWMutex
-	keyspace string
+	cfg         *Config
+	pool        NodePicker
+	reconnector Reconnector
+	keyspace    string
+	nohosts     chan bool
 }
 
 func NewSession(cfg Config) *Session {
 	cfg.normalize()
-	active := make([]*node, 0, len(cfg.Nodes))
+	s := &Session{
+		cfg:         &cfg,
+		nohosts:     make(chan bool),
+		reconnector: cfg.Reconnector,
+		pool:        cfg.NodePicker,
+	}
 	for _, address := range cfg.Nodes {
-		con, err := connect(address, &cfg)
-		if err == nil {
-			active = append(active, &node{con})
-		} else {
-			fmt.Println("connect", err)
-		}
+		go s.reconnector.Reconnect(s, address)
 	}
-	return &Session{cfg: &cfg, active: active}
+	return s
 }
 
 func (s *Session) Query(stmt string, args ...interface{}) *Query {
@@ -75,19 +81,20 @@ func (s *Session) Close() {
 	return
 }
 
-func (s *Session) executeQuery(query *Query) (buffer, error) {
-	pos := atomic.AddUint32(&s.pos, 1)
-	var conn *connection
-	//var keyspace string
-	s.mu.RLock()
-	if len(s.active) == 0 {
-		s.mu.Unlock()
-		return nil, errors.New("no active nodes")
+func (s *Session) executeQuery(query *Query) (frame, error) {
+	node := s.pool.Pick(query)
+	if node == nil {
+		<-time.After(s.cfg.Timeout)
+		node = s.pool.Pick(query)
+	}
+	if node == nil {
+		return nil, ErrNoHostAvailable
 	}
-	conn = s.active[pos%uint32(len(s.active))].conn
-	//keyspace = s.keyspace
-	s.mu.RUnlock()
-	return conn.executeQuery(query)
+	return node.conn.executeQuery(query)
+}
+
+type Node struct {
+	conn *Conn
 }
 
 type Query struct {
@@ -152,7 +159,7 @@ func (q *Query) Consistency(cons Consistency) *Query {
 	return q
 }
 
-func (q *Query) request() (buffer, error) {
+func (q *Query) request() (frame, error) {
 	return q.ctx.executeQuery(q)
 }
 
@@ -162,10 +169,10 @@ type Iter struct {
 	numRows int
 	info    []columnInfo
 	flags   int
-	frame   buffer
+	frame   frame
 }
 
-func (iter *Iter) setFrame(frame buffer) {
+func (iter *Iter) setFrame(frame frame) {
 	info := frame.readMetaData()
 	iter.flags = 0
 	iter.info = info
@@ -199,7 +206,7 @@ func (iter *Iter) Close() error {
 }
 
 type queryContext interface {
-	executeQuery(query *Query) (buffer, error)
+	executeQuery(query *Query) (frame, error)
 }
 
 type columnInfo struct {
@@ -233,18 +240,13 @@ func (e Error) Error() string {
 	return e.Message
 }
 
-var ErrNotFound = errors.New("not found")
-
-var ErrQueryUnbound = errors.New("can not execute unbound query")
-
-// active (choose round robin)
-// connecting
-// down
-
-// getNode()
-// getNextNode() für failover
-// getNodeForShard(key) ...
+var (
+	ErrNotFound        = errors.New("not found")
+	ErrNoHostAvailable = errors.New("no host available")
+	ErrQueryUnbound    = errors.New("can not execute unbound query")
+	ErrProtocol        = errors.New("protocol error")
+)
 
 type node struct {
-	conn *connection
+	conn *Conn
 }

+ 9 - 8
gocql_test.go

@@ -33,6 +33,7 @@ func NewTestServer(t *testing.T, address string) *TestServer {
 }
 
 func (srv *TestServer) serve() {
+	defer srv.listen.Close()
 	for {
 		conn, err := srv.listen.Accept()
 		if err != nil {
@@ -53,7 +54,7 @@ func (srv *TestServer) Stop() {
 	srv.listen.Close()
 }
 
-func (srv *TestServer) process(frame buffer, conn net.Conn) {
+func (srv *TestServer) process(frame frame, conn net.Conn) {
 	switch frame[3] {
 	case opStartup:
 		frame = frame[:headerSize]
@@ -71,7 +72,7 @@ func (srv *TestServer) process(frame buffer, conn net.Conn) {
 		switch strings.ToLower(first) {
 		case "kill":
 			select {}
-		case "delay":
+		case "slow":
 			go func() {
 				<-time.After(1 * time.Second)
 				frame.writeInt(0)
@@ -101,8 +102,8 @@ func (srv *TestServer) process(frame buffer, conn net.Conn) {
 	}
 }
 
-func (srv *TestServer) readFrame(conn net.Conn) buffer {
-	frame := make(buffer, headerSize, headerSize+512)
+func (srv *TestServer) readFrame(conn net.Conn) frame {
+	frame := make(frame, headerSize, headerSize+512)
 	if _, err := io.ReadFull(conn, frame); err != nil {
 		srv.t.Fatal(err)
 	}
@@ -124,8 +125,7 @@ func TestSimple(t *testing.T) {
 		Consistency: ConQuorum,
 	})
 	if err := db.Query("void").Exec(); err != nil {
-		//t.Error("Query", err)
-		return
+		t.Error(err)
 	}
 }
 
@@ -148,7 +148,7 @@ func TestTimeout(t *testing.T) {
 	}
 }
 
-func TestLongQuery(t *testing.T) {
+func TestSlowQuery(t *testing.T) {
 	srv := NewTestServer(t, "127.0.0.1:9051")
 	defer srv.Stop()
 
@@ -157,7 +157,7 @@ func TestLongQuery(t *testing.T) {
 		Consistency: ConQuorum,
 	})
 
-	if err := db.Query("delay").Exec(); err != nil {
+	if err := db.Query("slow").Exec(); err != nil {
 		t.Fatal(err)
 	}
 }
@@ -174,6 +174,7 @@ func TestRoundRobin(t *testing.T) {
 		Nodes:       addrs,
 		Consistency: ConQuorum,
 	})
+	time.Sleep(1 * time.Second)
 
 	var wg sync.WaitGroup
 	wg.Add(5)

+ 88 - 0
session.go

@@ -0,0 +1,88 @@
+package gocql
+
+import (
+	"sync"
+	"sync/atomic"
+	"time"
+)
+
+type NodePicker interface {
+	AddNode(node *Node)
+	RemoveNode(node *Node)
+	Pick(query *Query) *Node
+}
+
+type RoundRobinPicker struct {
+	pool []*Node
+	pos  uint32
+	mu   sync.RWMutex
+}
+
+func NewRoundRobinPicker() *RoundRobinPicker {
+	return &RoundRobinPicker{}
+}
+
+func (r *RoundRobinPicker) AddNode(node *Node) {
+	r.mu.Lock()
+	r.pool = append(r.pool, node)
+	r.mu.Unlock()
+}
+
+func (r *RoundRobinPicker) RemoveNode(node *Node) {
+	r.mu.Lock()
+	n := len(r.pool)
+	for i := 0; i < n; i++ {
+		if r.pool[i] == node {
+			r.pool[i], r.pool[n-1] = r.pool[n-1], r.pool[i]
+			r.pool = r.pool[:n-1]
+			break
+		}
+	}
+	r.mu.Unlock()
+}
+
+func (r *RoundRobinPicker) Pick(query *Query) *Node {
+	pos := atomic.AddUint32(&r.pos, 1)
+	var node *Node
+	r.mu.RLock()
+	if len(r.pool) > 0 {
+		node = r.pool[pos%uint32(len(r.pool))]
+	}
+	r.mu.RUnlock()
+	return node
+}
+
+type Reconnector interface {
+	Reconnect(session *Session, address string)
+}
+
+type ExponentialReconnector struct {
+	baseDelay time.Duration
+	maxDelay  time.Duration
+}
+
+func NewExponentialReconnector(baseDelay, maxDelay time.Duration) *ExponentialReconnector {
+	return &ExponentialReconnector{baseDelay, maxDelay}
+}
+
+func (e *ExponentialReconnector) Reconnect(session *Session, address string) {
+	delay := e.baseDelay
+	for {
+		conn, err := Connect(address, session.cfg)
+		if err != nil {
+			<-time.After(delay)
+			if delay *= 2; delay > e.maxDelay {
+				delay = e.maxDelay
+			}
+			continue
+		}
+		node := &Node{conn}
+		go func() {
+			conn.Serve()
+			session.pool.RemoveNode(node)
+			e.Reconnect(session, address)
+		}()
+		session.pool.AddNode(node)
+		return
+	}
+}