Browse Source

Update tests to support multiple protocol versions

Allow specifiying the protocol version for the server to run in
and expect queries.
Chris Bannister 10 years ago
parent
commit
eca02a950c
1 changed files with 261 additions and 76 deletions
  1. 261 76
      conn_test.go

+ 261 - 76
conn_test.go

@@ -5,6 +5,7 @@ package gocql
 import (
 	"crypto/tls"
 	"crypto/x509"
+	"fmt"
 	"io"
 	"io/ioutil"
 	"net"
@@ -15,6 +16,10 @@ import (
 	"time"
 )
 
+const (
+	defaultProto = protoVersion2
+)
+
 func TestJoinHostPort(t *testing.T) {
 	tests := map[string]string{
 		"127.0.0.1:0":                                 JoinHostPort("127.0.0.1", 0),
@@ -29,43 +34,38 @@ func TestJoinHostPort(t *testing.T) {
 	}
 }
 
-type TestServer struct {
-	Address  string
-	t        *testing.T
-	nreq     uint64
-	listen   net.Listener
-	nKillReq uint64
-}
-
 func TestSimple(t *testing.T) {
-	srv := NewTestServer(t)
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	cluster := NewCluster(srv.Address)
+	cluster.ProtoVersion = int(defaultProto)
+	db, err := cluster.CreateSession()
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Errorf("0x%x: NewCluster: %v", defaultProto, err)
+		return
 	}
 
 	if err := db.Query("void").Exec(); err != nil {
-		t.Error(err)
+		t.Errorf("0x%x: %v", defaultProto, err)
 	}
 }
 
 func TestSSLSimple(t *testing.T) {
-	srv := NewSSLTestServer(t)
+	srv := NewSSLTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	db, err := createTestSslCluster(srv.Address).CreateSession()
+	db, err := createTestSslCluster(srv.Address, defaultProto).CreateSession()
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
 	}
 
 	if err := db.Query("void").Exec(); err != nil {
-		t.Error(err)
+		t.Fatalf("0x%x: %v", defaultProto, err)
 	}
 }
 
-func createTestSslCluster(hosts string) *ClusterConfig {
+func createTestSslCluster(hosts string, proto uint8) *ClusterConfig {
 	cluster := NewCluster(hosts)
 	cluster.SslOpts = &SslOptions{
 		CertPath:               "testdata/pki/gocql.crt",
@@ -73,82 +73,103 @@ func createTestSslCluster(hosts string) *ClusterConfig {
 		CaPath:                 "testdata/pki/ca.crt",
 		EnableHostVerification: false,
 	}
+	cluster.ProtoVersion = int(proto)
 	return cluster
 }
 
 func TestClosed(t *testing.T) {
 	t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis")
-	srv := NewTestServer(t)
+
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	session, err := NewCluster(srv.Address).CreateSession()
+	cluster := NewCluster(srv.Address)
+	cluster.ProtoVersion = int(defaultProto)
+
+	session, err := cluster.CreateSession()
+	defer session.Close()
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Errorf("0x%x: NewCluster: %v", defaultProto, err)
+		return
 	}
-	session.Close()
 
 	if err := session.Query("void").Exec(); err != ErrSessionClosed {
-		t.Errorf("expected %#v, got %#v", ErrSessionClosed, err)
+		t.Errorf("0x%x: expected %#v, got %#v", defaultProto, ErrSessionClosed, err)
+		return
 	}
 }
 
+func newTestSession(addr string, proto uint8) (*Session, error) {
+	cluster := NewCluster(addr)
+	cluster.ProtoVersion = int(proto)
+	return cluster.CreateSession()
+}
+
 func TestTimeout(t *testing.T) {
-	srv := NewTestServer(t)
+
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	db, err := newTestSession(srv.Address, defaultProto)
 	if err != nil {
 		t.Errorf("NewCluster: %v", err)
+		return
 	}
+	defer db.Close()
 
 	go func() {
 		<-time.After(2 * time.Second)
-		t.Fatal("no timeout")
+		t.Errorf("no timeout")
 	}()
 
 	if err := db.Query("kill").Exec(); err == nil {
-		t.Fatal("expected error")
+		t.Errorf("expected error")
 	}
 }
 
 // TestQueryRetry will test to make sure that gocql will execute
 // the exact amount of retry queries designated by the user.
 func TestQueryRetry(t *testing.T) {
-	srv := NewTestServer(t)
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	db, err := newTestSession(srv.Address, defaultProto)
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Fatalf("NewCluster: %v", err)
 	}
+	defer db.Close()
 
 	go func() {
 		<-time.After(5 * time.Second)
-		t.Fatal("no timeout")
+		t.Fatalf("no timeout")
 	}()
 	rt := &SimpleRetryPolicy{NumRetries: 1}
 
 	qry := db.Query("kill").RetryPolicy(rt)
 	if err := qry.Exec(); err == nil {
-		t.Fatal("expected error")
+		t.Fatalf("expected error")
 	}
-	requests := srv.nKillReq
-	if requests != uint64(qry.Attempts()) {
-		t.Fatalf("expected requests %v to match query attemps %v", requests, qry.Attempts())
+
+	requests := atomic.LoadInt64(&srv.nKillReq)
+	attempts := qry.Attempts()
+	if requests != int64(attempts) {
+		t.Fatalf("expected requests %v to match query attemps %v", requests, attempts)
 	}
+
 	//Minus 1 from the requests variable since there is the initial query attempt
-	if requests-1 != uint64(rt.NumRetries) {
+	if requests-1 != int64(rt.NumRetries) {
 		t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, requests-1)
 	}
 }
 
 func TestSlowQuery(t *testing.T) {
-	srv := NewTestServer(t)
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	db, err := newTestSession(srv.Address, defaultProto)
 	if err != nil {
 		t.Errorf("NewCluster: %v", err)
+		return
 	}
 
 	if err := db.Query("slow").Exec(); err != nil {
@@ -159,22 +180,24 @@ func TestSlowQuery(t *testing.T) {
 func TestRoundRobin(t *testing.T) {
 	servers := make([]*TestServer, 5)
 	addrs := make([]string, len(servers))
-	for i := 0; i < len(servers); i++ {
-		servers[i] = NewTestServer(t)
-		addrs[i] = servers[i].Address
-		defer servers[i].Stop()
+	for n := 0; n < len(servers); n++ {
+		servers[n] = NewTestServer(t, defaultProto)
+		addrs[n] = servers[n].Address
+		defer servers[n].Stop()
 	}
 	cluster := NewCluster(addrs...)
+	cluster.ProtoVersion = defaultProto
+
 	db, err := cluster.CreateSession()
-	time.Sleep(1 * time.Second) //Sleep to allow the Cluster.fillPool to complete
+	time.Sleep(1 * time.Second) // Sleep to allow the Cluster.fillPool to complete
 
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Fatalf("NewCluster: %v", err)
 	}
 
 	var wg sync.WaitGroup
 	wg.Add(5)
-	for i := 0; i < 5; i++ {
+	for n := 0; n < 5; n++ {
 		go func() {
 			for j := 0; j < 5; j++ {
 				if err := db.Query("void").Exec(); err != nil {
@@ -187,12 +210,12 @@ func TestRoundRobin(t *testing.T) {
 	wg.Wait()
 
 	diff := 0
-	for i := 1; i < len(servers); i++ {
+	for n := 1; n < len(servers); n++ {
 		d := 0
-		if servers[i].nreq > servers[i-1].nreq {
-			d = int(servers[i].nreq - servers[i-1].nreq)
+		if servers[n].nreq > servers[n-1].nreq {
+			d = int(servers[n].nreq - servers[n-1].nreq)
 		} else {
-			d = int(servers[i-1].nreq - servers[i].nreq)
+			d = int(servers[n-1].nreq - servers[n].nreq)
 		}
 		if d > diff {
 			diff = d
@@ -206,7 +229,8 @@ func TestRoundRobin(t *testing.T) {
 
 func TestConnClosing(t *testing.T) {
 	t.Skip("Skipping until test can be ran reliably")
-	srv := NewTestServer(t)
+
+	srv := NewTestServer(t, protoVersion2)
 	defer srv.Stop()
 
 	db, err := NewCluster(srv.Address).CreateSession()
@@ -238,21 +262,147 @@ func TestConnClosing(t *testing.T) {
 	}
 }
 
-func NewTestServer(t *testing.T) *TestServer {
+func TestStreams_Protocol1(t *testing.T) {
+	srv := NewTestServer(t, protoVersion1)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 1
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer db.Close()
+
+	var wg sync.WaitGroup
+	for i := 0; i < db.cfg.NumStreams; i++ {
+		// here were just validating that if we send NumStream request we get
+		// a response for every stream and the lengths for the queries are set
+		// correctly.
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			if err := db.Query("void").Exec(); err != nil {
+				t.Error(err)
+			}
+		}()
+	}
+	wg.Wait()
+}
+
+func TestStreams_Protocol2(t *testing.T) {
+	srv := NewTestServer(t, protoVersion2)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 2
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer db.Close()
+
+	for i := 0; i < db.cfg.NumStreams; i++ {
+		// the test server processes each conn synchronously
+		// here were just validating that if we send NumStream request we get
+		// a response for every stream and the lengths for the queries are set
+		// correctly.
+		if err = db.Query("void").Exec(); err != nil {
+			t.Fatal(err)
+		}
+	}
+}
+
+func TestStreams_Protocol3(t *testing.T) {
+	srv := NewTestServer(t, protoVersion3)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 3
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer db.Close()
+
+	for i := 0; i < db.cfg.NumStreams; i++ {
+		// the test server processes each conn synchronously
+		// here were just validating that if we send NumStream request we get
+		// a response for every stream and the lengths for the queries are set
+		// correctly.
+		if err = db.Query("void").Exec(); err != nil {
+			t.Fatal(err)
+		}
+	}
+}
+
+func BenchmarkProtocolV3(b *testing.B) {
+	srv := NewTestServer(b, protoVersion3)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 3
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		b.Fatal(err)
+	}
+	defer db.Close()
+
+	b.ResetTimer()
+	b.ReportAllocs()
+	for i := 0; i < b.N; i++ {
+		if err = db.Query("void").Exec(); err != nil {
+			b.Fatal(err)
+		}
+	}
+}
+
+func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
 		t.Fatal(err)
 	}
+
 	listen, err := net.ListenTCP("tcp", laddr)
 	if err != nil {
 		t.Fatal(err)
 	}
-	srv := &TestServer{Address: listen.Addr().String(), listen: listen, t: t}
+
+	headerSize := 8
+	if protocol > protoVersion2 {
+		headerSize = 9
+	}
+
+	srv := &TestServer{
+		Address:    listen.Addr().String(),
+		listen:     listen,
+		t:          t,
+		protocol:   protocol,
+		headerSize: headerSize,
+	}
+
 	go srv.serve()
+
 	return srv
 }
 
-func NewSSLTestServer(t *testing.T) *TestServer {
+func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
 	pem, err := ioutil.ReadFile("testdata/pki/ca.crt")
 	certPool := x509.NewCertPool()
 	if !certPool.AppendCertsFromPEM(pem) {
@@ -270,11 +420,34 @@ func NewSSLTestServer(t *testing.T) *TestServer {
 	if err != nil {
 		t.Fatal(err)
 	}
-	srv := &TestServer{Address: listen.Addr().String(), listen: listen, t: t}
+
+	headerSize := 8
+	if protocol > protoVersion2 {
+		headerSize = 9
+	}
+
+	srv := &TestServer{
+		Address:    listen.Addr().String(),
+		listen:     listen,
+		t:          t,
+		protocol:   protocol,
+		headerSize: headerSize,
+	}
 	go srv.serve()
 	return srv
 }
 
+type TestServer struct {
+	Address  string
+	t        testing.TB
+	nreq     uint64
+	listen   net.Listener
+	nKillReq int64
+
+	protocol   uint8
+	headerSize int
+}
+
 func (srv *TestServer) serve() {
 	defer srv.listen.Close()
 	for {
@@ -297,51 +470,63 @@ func (srv *TestServer) Stop() {
 	srv.listen.Close()
 }
 
-func (srv *TestServer) process(frame frame, conn net.Conn) {
-	switch frame[3] {
+func (srv *TestServer) process(f frame, conn net.Conn) {
+	headerSize := headerProtoSize[srv.protocol]
+	stream := f.Stream(srv.protocol)
+
+	switch f.Op(srv.protocol) {
 	case opStartup:
-		frame = frame[:headerSize]
-		frame.setHeader(protoResponse, 0, frame[2], opReady)
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opReady)
+	case opOptions:
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opSupported)
+		f.writeShort(0)
 	case opQuery:
-		input := frame
-		input.skipHeader()
+		input := f
+		input.skipHeader(srv.protocol)
 		query := strings.TrimSpace(input.readLongString())
-		frame = frame[:headerSize]
-		frame.setHeader(protoResponse, 0, frame[2], opResult)
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opResult)
 		first := query
 		if n := strings.Index(query, " "); n > 0 {
 			first = first[:n]
 		}
 		switch strings.ToLower(first) {
 		case "kill":
-			atomic.AddUint64(&srv.nKillReq, 1)
-			select {}
+			atomic.AddInt64(&srv.nKillReq, 1)
+			f = f[:headerSize]
+			f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opError)
+			f.writeInt(0x1001)
+			f.writeString("query killed")
 		case "slow":
 			go func() {
 				<-time.After(1 * time.Second)
-				frame.writeInt(resultKindVoid)
-				frame.setLength(len(frame) - headerSize)
-				if _, err := conn.Write(frame); err != nil {
+				f.writeInt(resultKindVoid)
+				f.setLength(len(f)-headerSize, srv.protocol)
+				if _, err := conn.Write(f); err != nil {
 					return
 				}
 			}()
 			return
 		case "use":
-			frame.writeInt(3)
-			frame.writeString(strings.TrimSpace(query[3:]))
+			f.writeInt(3)
+			f.writeString(strings.TrimSpace(query[3:]))
 		case "void":
-			frame.writeInt(resultKindVoid)
+			f.writeInt(resultKindVoid)
 		default:
-			frame.writeInt(resultKindVoid)
+			f.writeInt(resultKindVoid)
 		}
 	default:
-		frame = frame[:headerSize]
-		frame.setHeader(protoResponse, 0, frame[2], opError)
-		frame.writeInt(0)
-		frame.writeString("not supported")
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opError)
+		f.writeInt(0)
+		f.writeString("not supported")
 	}
-	frame.setLength(len(frame) - headerSize)
-	if _, err := conn.Write(frame); err != nil {
+
+	f.setLength(len(f)-headerSize, srv.protocol)
+	if _, err := conn.Write(f); err != nil {
+		srv.t.Log(err)
 		return
 	}
 }