فهرست منبع

choose port of the test server automatically, don't panic while reading frames

Christoph Hack 12 سال پیش
والد
کامیت
fd97f5cba2
4فایلهای تغییر یافته به همراه84 افزوده شده و 74 حذف شده
  1. 1 15
      conn.go
  2. 16 10
      frame.go
  3. 55 40
      gocql.go
  4. 12 9
      gocql_test.go

+ 1 - 15
conn.go

@@ -132,7 +132,7 @@ func (c *Conn) recv() (frame, error) {
 		}
 		if n == headerSize && len(resp) == headerSize {
 			if resp[0] != protoResponse {
-				return nil, ErrInvalid
+				return nil, ErrProtocol
 			}
 			resp.grow(resp.Length())
 		}
@@ -284,17 +284,3 @@ type callResp struct {
 	buf frame
 	err error
 }
-
-/*
-  conn := NewConn(addr, cfg)
-
-  querier := conn.Querier()
-
-
-  conn.Init(addr, cfg)
-  go func() {
-  	err := conn.Serve()
-
-  }
-*/
-var foo = 0

+ 16 - 10
frame.go

@@ -5,7 +5,6 @@
 package gocql
 
 import (
-	"errors"
 	"net"
 )
 
@@ -41,8 +40,6 @@ const (
 	headerSize = 8
 )
 
-var ErrInvalid = errors.New("invalid response")
-
 type frame []byte
 
 func (f *frame) writeInt(v int32) {
@@ -160,7 +157,7 @@ func (f *frame) skipHeader() {
 
 func (f *frame) readInt() int {
 	if len(*f) < 4 {
-		panic(ErrInvalid)
+		panic(ErrProtocol)
 	}
 	v := int((*f)[0])<<24 | int((*f)[1])<<16 | int((*f)[2])<<8 | int((*f)[3])
 	*f = (*f)[4:]
@@ -169,7 +166,7 @@ func (f *frame) readInt() int {
 
 func (f *frame) readShort() uint16 {
 	if len(*f) < 2 {
-		panic(ErrInvalid)
+		panic(ErrProtocol)
 	}
 	v := uint16((*f)[0])<<8 | uint16((*f)[1])
 	*f = (*f)[2:]
@@ -179,7 +176,7 @@ func (f *frame) readShort() uint16 {
 func (f *frame) readString() string {
 	n := int(f.readShort())
 	if len(*f) < n {
-		panic(ErrInvalid)
+		panic(ErrProtocol)
 	}
 	v := string((*f)[:n])
 	*f = (*f)[n:]
@@ -189,7 +186,7 @@ func (f *frame) readString() string {
 func (f *frame) readLongString() string {
 	n := f.readInt()
 	if len(*f) < n {
-		panic(ErrInvalid)
+		panic(ErrProtocol)
 	}
 	v := string((*f)[:n])
 	*f = (*f)[n:]
@@ -202,7 +199,7 @@ func (f *frame) readBytes() []byte {
 		return nil
 	}
 	if len(*f) < n {
-		panic(ErrInvalid)
+		panic(ErrProtocol)
 	}
 	v := (*f)[:n]
 	*f = (*f)[n:]
@@ -212,7 +209,7 @@ func (f *frame) readBytes() []byte {
 func (f *frame) readShortBytes() []byte {
 	n := int(f.readShort())
 	if len(*f) < n {
-		panic(ErrInvalid)
+		panic(ErrProtocol)
 	}
 	v := (*f)[:n]
 	*f = (*f)[n:]
@@ -257,7 +254,16 @@ func (f *frame) readMetaData() []columnInfo {
 	return info
 }
 
-func (f *frame) readErrorFrame() Error {
+func (f *frame) readErrorFrame() (err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			if e, ok := r.(error); ok && e == ErrProtocol {
+				err = e
+				return
+			}
+			panic(r)
+		}
+	}()
 	f.skipHeader()
 	code := f.readInt()
 	desc := f.readString()

+ 55 - 40
gocql.go

@@ -77,6 +77,12 @@ func (s *Session) Query(stmt string, args ...interface{}) *Query {
 	}
 }
 
+func (s *Session) Do(query *Query) *Query {
+	q := *query
+	q.ctx = s
+	return &q
+}
+
 func (s *Session) Close() {
 	return
 }
@@ -109,34 +115,34 @@ func NewQuery(stmt string) *Query {
 }
 
 func (q *Query) Exec() error {
-	frame, err := q.request()
+	if q.ctx == nil {
+		return ErrQueryUnbound
+	}
+	frame, err := q.ctx.executeQuery(q)
 	if err != nil {
 		return err
-	}
-	if frame[3] == opResult {
-		frame.skipHeader()
-		kind := frame.readInt()
-		if kind == 3 {
-			keyspace := frame.readString()
-			fmt.Println("set keyspace:", keyspace)
-		} else {
-		}
+	} else if frame[3] == opError {
+		return frame.readErrorFrame()
+	} else if frame[3] != opResult {
+		return ErrProtocol
 	}
 	return nil
 }
 
 func (q *Query) Iter() *Iter {
-	iter := new(Iter)
-	frame, err := q.request()
-	if err != nil {
-		iter.err = err
-		return iter
+	if q.ctx == nil {
+		return &Iter{err: ErrQueryUnbound}
 	}
-	frame.skipHeader()
-	kind := frame.readInt()
-	if kind == resultKindRows {
-		iter.setFrame(frame)
+	frame, err := q.ctx.executeQuery(q)
+	if err != nil {
+		return &Iter{err: err}
+	} else if frame[3] == opError {
+		return &Iter{err: frame.readErrorFrame()}
+	} else if frame[3] != opResult {
+		return &Iter{err: ErrProtocol}
 	}
+	iter := new(Iter)
+	iter.readFrame(frame)
 	return iter
 }
 
@@ -159,45 +165,54 @@ func (q *Query) Consistency(cons Consistency) *Query {
 	return q
 }
 
-func (q *Query) request() (frame, error) {
-	return q.ctx.executeQuery(q)
+type Iter struct {
+	err    error
+	pos    int
+	values [][]byte
+	info   []columnInfo
 }
 
-type Iter struct {
-	err     error
-	pos     int
-	numRows int
-	info    []columnInfo
-	flags   int
-	frame   frame
-}
-
-func (iter *Iter) setFrame(frame frame) {
-	info := frame.readMetaData()
-	iter.flags = 0
-	iter.info = info
-	iter.numRows = frame.readInt()
+func (iter *Iter) readFrame(frame frame) {
+	defer func() {
+		if r := recover(); r != nil {
+			if e, ok := r.(error); ok && e == ErrProtocol {
+				iter.err = e
+				return
+			}
+			panic(r)
+		}
+	}()
+	frame.skipHeader()
 	iter.pos = 0
 	iter.err = nil
-	iter.frame = frame
+	iter.values = nil
+	if frame.readInt() != resultKindRows {
+		return
+	}
+	iter.info = frame.readMetaData()
+	numRows := frame.readInt()
+	iter.values = make([][]byte, numRows*len(iter.info))
+	for i := 0; i < len(iter.values); i++ {
+		iter.values[i] = frame.readBytes()
+	}
 }
 
 func (iter *Iter) Scan(values ...interface{}) bool {
-	if iter.err != nil || iter.pos >= iter.numRows {
+	if iter.err != nil || iter.pos >= len(iter.values) {
 		return false
 	}
-	iter.pos++
 	if len(values) != len(iter.info) {
 		iter.err = errors.New("count mismatch")
 		return false
 	}
 	for i := 0; i < len(values); i++ {
-		data := iter.frame.readBytes()
-		if err := Unmarshal(iter.info[i].TypeInfo, data, values[i]); err != nil {
+		err := Unmarshal(iter.info[i].TypeInfo, iter.values[i+iter.pos], values[i])
+		if err != nil {
 			iter.err = err
 			return false
 		}
 	}
+	iter.pos += len(values)
 	return true
 }
 

+ 12 - 9
gocql_test.go

@@ -5,7 +5,6 @@
 package gocql
 
 import (
-	"fmt"
 	"io"
 	"net"
 	"strings"
@@ -22,12 +21,16 @@ type TestServer struct {
 	listen  net.Listener
 }
 
-func NewTestServer(t *testing.T, address string) *TestServer {
-	listen, err := net.Listen("tcp", address)
+func NewTestServer(t *testing.T) *TestServer {
+	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
 		t.Fatal(err)
 	}
-	srv := &TestServer{Address: address, listen: listen, t: t}
+	listen, err := net.ListenTCP("tcp", laddr)
+	if err != nil {
+		t.Fatal(err)
+	}
+	srv := &TestServer{Address: listen.Addr().String(), listen: listen, t: t}
 	go srv.serve()
 	return srv
 }
@@ -117,7 +120,7 @@ func (srv *TestServer) readFrame(conn net.Conn) frame {
 }
 
 func TestSimple(t *testing.T) {
-	srv := NewTestServer(t, "127.0.0.1:9051")
+	srv := NewTestServer(t)
 	defer srv.Stop()
 
 	db := NewSession(Config{
@@ -130,7 +133,7 @@ func TestSimple(t *testing.T) {
 }
 
 func TestTimeout(t *testing.T) {
-	srv := NewTestServer(t, "127.0.0.1:9051")
+	srv := NewTestServer(t)
 	defer srv.Stop()
 
 	db := NewSession(Config{
@@ -149,7 +152,7 @@ func TestTimeout(t *testing.T) {
 }
 
 func TestSlowQuery(t *testing.T) {
-	srv := NewTestServer(t, "127.0.0.1:9051")
+	srv := NewTestServer(t)
 	defer srv.Stop()
 
 	db := NewSession(Config{
@@ -166,8 +169,8 @@ func TestRoundRobin(t *testing.T) {
 	servers := make([]*TestServer, 5)
 	addrs := make([]string, len(servers))
 	for i := 0; i < len(servers); i++ {
-		addrs[i] = fmt.Sprintf("127.0.0.1:%d", 9051+i)
-		servers[i] = NewTestServer(t, addrs[i])
+		servers[i] = NewTestServer(t)
+		addrs[i] = servers[i].Address
 		defer servers[i].Stop()
 	}
 	db := NewSession(Config{