|
|
@@ -5,6 +5,7 @@ package gocql
|
|
|
import (
|
|
|
"crypto/tls"
|
|
|
"crypto/x509"
|
|
|
+ "fmt"
|
|
|
"io"
|
|
|
"io/ioutil"
|
|
|
"net"
|
|
|
@@ -15,6 +16,10 @@ import (
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
+const (
|
|
|
+ defaultProto = protoVersion2
|
|
|
+)
|
|
|
+
|
|
|
func TestJoinHostPort(t *testing.T) {
|
|
|
tests := map[string]string{
|
|
|
"127.0.0.1:0": JoinHostPort("127.0.0.1", 0),
|
|
|
@@ -29,43 +34,38 @@ func TestJoinHostPort(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-type TestServer struct {
|
|
|
- Address string
|
|
|
- t *testing.T
|
|
|
- nreq uint64
|
|
|
- listen net.Listener
|
|
|
- nKillReq uint64
|
|
|
-}
|
|
|
-
|
|
|
func TestSimple(t *testing.T) {
|
|
|
- srv := NewTestServer(t)
|
|
|
+ srv := NewTestServer(t, defaultProto)
|
|
|
defer srv.Stop()
|
|
|
|
|
|
- db, err := NewCluster(srv.Address).CreateSession()
|
|
|
+ cluster := NewCluster(srv.Address)
|
|
|
+ cluster.ProtoVersion = int(defaultProto)
|
|
|
+ db, err := cluster.CreateSession()
|
|
|
if err != nil {
|
|
|
- t.Errorf("NewCluster: %v", err)
|
|
|
+ t.Errorf("0x%x: NewCluster: %v", defaultProto, err)
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
if err := db.Query("void").Exec(); err != nil {
|
|
|
- t.Error(err)
|
|
|
+ t.Errorf("0x%x: %v", defaultProto, err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func TestSSLSimple(t *testing.T) {
|
|
|
- srv := NewSSLTestServer(t)
|
|
|
+ srv := NewSSLTestServer(t, defaultProto)
|
|
|
defer srv.Stop()
|
|
|
|
|
|
- db, err := createTestSslCluster(srv.Address).CreateSession()
|
|
|
+ db, err := createTestSslCluster(srv.Address, defaultProto).CreateSession()
|
|
|
if err != nil {
|
|
|
- t.Errorf("NewCluster: %v", err)
|
|
|
+ t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
|
|
|
}
|
|
|
|
|
|
if err := db.Query("void").Exec(); err != nil {
|
|
|
- t.Error(err)
|
|
|
+ t.Fatalf("0x%x: %v", defaultProto, err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func createTestSslCluster(hosts string) *ClusterConfig {
|
|
|
+func createTestSslCluster(hosts string, proto uint8) *ClusterConfig {
|
|
|
cluster := NewCluster(hosts)
|
|
|
cluster.SslOpts = &SslOptions{
|
|
|
CertPath: "testdata/pki/gocql.crt",
|
|
|
@@ -73,82 +73,103 @@ func createTestSslCluster(hosts string) *ClusterConfig {
|
|
|
CaPath: "testdata/pki/ca.crt",
|
|
|
EnableHostVerification: false,
|
|
|
}
|
|
|
+ cluster.ProtoVersion = int(proto)
|
|
|
return cluster
|
|
|
}
|
|
|
|
|
|
func TestClosed(t *testing.T) {
|
|
|
t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis")
|
|
|
- srv := NewTestServer(t)
|
|
|
+
|
|
|
+ srv := NewTestServer(t, defaultProto)
|
|
|
defer srv.Stop()
|
|
|
|
|
|
- session, err := NewCluster(srv.Address).CreateSession()
|
|
|
+ cluster := NewCluster(srv.Address)
|
|
|
+ cluster.ProtoVersion = int(defaultProto)
|
|
|
+
|
|
|
+ session, err := cluster.CreateSession()
|
|
|
+ defer session.Close()
|
|
|
if err != nil {
|
|
|
- t.Errorf("NewCluster: %v", err)
|
|
|
+ t.Errorf("0x%x: NewCluster: %v", defaultProto, err)
|
|
|
+ return
|
|
|
}
|
|
|
- session.Close()
|
|
|
|
|
|
if err := session.Query("void").Exec(); err != ErrSessionClosed {
|
|
|
- t.Errorf("expected %#v, got %#v", ErrSessionClosed, err)
|
|
|
+ t.Errorf("0x%x: expected %#v, got %#v", defaultProto, ErrSessionClosed, err)
|
|
|
+ return
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func newTestSession(addr string, proto uint8) (*Session, error) {
|
|
|
+ cluster := NewCluster(addr)
|
|
|
+ cluster.ProtoVersion = int(proto)
|
|
|
+ return cluster.CreateSession()
|
|
|
+}
|
|
|
+
|
|
|
func TestTimeout(t *testing.T) {
|
|
|
- srv := NewTestServer(t)
|
|
|
+
|
|
|
+ srv := NewTestServer(t, defaultProto)
|
|
|
defer srv.Stop()
|
|
|
|
|
|
- db, err := NewCluster(srv.Address).CreateSession()
|
|
|
+ db, err := newTestSession(srv.Address, defaultProto)
|
|
|
if err != nil {
|
|
|
t.Errorf("NewCluster: %v", err)
|
|
|
+ return
|
|
|
}
|
|
|
+ defer db.Close()
|
|
|
|
|
|
go func() {
|
|
|
<-time.After(2 * time.Second)
|
|
|
- t.Fatal("no timeout")
|
|
|
+ t.Errorf("no timeout")
|
|
|
}()
|
|
|
|
|
|
if err := db.Query("kill").Exec(); err == nil {
|
|
|
- t.Fatal("expected error")
|
|
|
+ t.Errorf("expected error")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// TestQueryRetry will test to make sure that gocql will execute
|
|
|
// the exact amount of retry queries designated by the user.
|
|
|
func TestQueryRetry(t *testing.T) {
|
|
|
- srv := NewTestServer(t)
|
|
|
+ srv := NewTestServer(t, defaultProto)
|
|
|
defer srv.Stop()
|
|
|
|
|
|
- db, err := NewCluster(srv.Address).CreateSession()
|
|
|
+ db, err := newTestSession(srv.Address, defaultProto)
|
|
|
if err != nil {
|
|
|
- t.Errorf("NewCluster: %v", err)
|
|
|
+ t.Fatalf("NewCluster: %v", err)
|
|
|
}
|
|
|
+ defer db.Close()
|
|
|
|
|
|
go func() {
|
|
|
<-time.After(5 * time.Second)
|
|
|
- t.Fatal("no timeout")
|
|
|
+ t.Fatalf("no timeout")
|
|
|
}()
|
|
|
rt := &SimpleRetryPolicy{NumRetries: 1}
|
|
|
|
|
|
qry := db.Query("kill").RetryPolicy(rt)
|
|
|
if err := qry.Exec(); err == nil {
|
|
|
- t.Fatal("expected error")
|
|
|
+ t.Fatalf("expected error")
|
|
|
}
|
|
|
- requests := srv.nKillReq
|
|
|
- if requests != uint64(qry.Attempts()) {
|
|
|
- t.Fatalf("expected requests %v to match query attemps %v", requests, qry.Attempts())
|
|
|
+
|
|
|
+ requests := atomic.LoadInt64(&srv.nKillReq)
|
|
|
+ attempts := qry.Attempts()
|
|
|
+ if requests != int64(attempts) {
|
|
|
+ t.Fatalf("expected requests %v to match query attemps %v", requests, attempts)
|
|
|
}
|
|
|
+
|
|
|
//Minus 1 from the requests variable since there is the initial query attempt
|
|
|
- if requests-1 != uint64(rt.NumRetries) {
|
|
|
+ if requests-1 != int64(rt.NumRetries) {
|
|
|
t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, requests-1)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func TestSlowQuery(t *testing.T) {
|
|
|
- srv := NewTestServer(t)
|
|
|
+ srv := NewTestServer(t, defaultProto)
|
|
|
defer srv.Stop()
|
|
|
|
|
|
- db, err := NewCluster(srv.Address).CreateSession()
|
|
|
+ db, err := newTestSession(srv.Address, defaultProto)
|
|
|
if err != nil {
|
|
|
t.Errorf("NewCluster: %v", err)
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
if err := db.Query("slow").Exec(); err != nil {
|
|
|
@@ -159,22 +180,24 @@ func TestSlowQuery(t *testing.T) {
|
|
|
func TestRoundRobin(t *testing.T) {
|
|
|
servers := make([]*TestServer, 5)
|
|
|
addrs := make([]string, len(servers))
|
|
|
- for i := 0; i < len(servers); i++ {
|
|
|
- servers[i] = NewTestServer(t)
|
|
|
- addrs[i] = servers[i].Address
|
|
|
- defer servers[i].Stop()
|
|
|
+ for n := 0; n < len(servers); n++ {
|
|
|
+ servers[n] = NewTestServer(t, defaultProto)
|
|
|
+ addrs[n] = servers[n].Address
|
|
|
+ defer servers[n].Stop()
|
|
|
}
|
|
|
cluster := NewCluster(addrs...)
|
|
|
+ cluster.ProtoVersion = defaultProto
|
|
|
+
|
|
|
db, err := cluster.CreateSession()
|
|
|
- time.Sleep(1 * time.Second) //Sleep to allow the Cluster.fillPool to complete
|
|
|
+ time.Sleep(1 * time.Second) // Sleep to allow the Cluster.fillPool to complete
|
|
|
|
|
|
if err != nil {
|
|
|
- t.Errorf("NewCluster: %v", err)
|
|
|
+ t.Fatalf("NewCluster: %v", err)
|
|
|
}
|
|
|
|
|
|
var wg sync.WaitGroup
|
|
|
wg.Add(5)
|
|
|
- for i := 0; i < 5; i++ {
|
|
|
+ for n := 0; n < 5; n++ {
|
|
|
go func() {
|
|
|
for j := 0; j < 5; j++ {
|
|
|
if err := db.Query("void").Exec(); err != nil {
|
|
|
@@ -187,12 +210,12 @@ func TestRoundRobin(t *testing.T) {
|
|
|
wg.Wait()
|
|
|
|
|
|
diff := 0
|
|
|
- for i := 1; i < len(servers); i++ {
|
|
|
+ for n := 1; n < len(servers); n++ {
|
|
|
d := 0
|
|
|
- if servers[i].nreq > servers[i-1].nreq {
|
|
|
- d = int(servers[i].nreq - servers[i-1].nreq)
|
|
|
+ if servers[n].nreq > servers[n-1].nreq {
|
|
|
+ d = int(servers[n].nreq - servers[n-1].nreq)
|
|
|
} else {
|
|
|
- d = int(servers[i-1].nreq - servers[i].nreq)
|
|
|
+ d = int(servers[n-1].nreq - servers[n].nreq)
|
|
|
}
|
|
|
if d > diff {
|
|
|
diff = d
|
|
|
@@ -206,7 +229,8 @@ func TestRoundRobin(t *testing.T) {
|
|
|
|
|
|
func TestConnClosing(t *testing.T) {
|
|
|
t.Skip("Skipping until test can be ran reliably")
|
|
|
- srv := NewTestServer(t)
|
|
|
+
|
|
|
+ srv := NewTestServer(t, protoVersion2)
|
|
|
defer srv.Stop()
|
|
|
|
|
|
db, err := NewCluster(srv.Address).CreateSession()
|
|
|
@@ -238,21 +262,147 @@ func TestConnClosing(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func NewTestServer(t *testing.T) *TestServer {
|
|
|
+func TestStreams_Protocol1(t *testing.T) {
|
|
|
+ srv := NewTestServer(t, protoVersion1)
|
|
|
+ defer srv.Stop()
|
|
|
+
|
|
|
+ // TODO: these are more like session tests and should instead operate
|
|
|
+ // on a single Conn
|
|
|
+ cluster := NewCluster(srv.Address)
|
|
|
+ cluster.NumConns = 1
|
|
|
+ cluster.ProtoVersion = 1
|
|
|
+
|
|
|
+ db, err := cluster.CreateSession()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ defer db.Close()
|
|
|
+
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ for i := 0; i < db.cfg.NumStreams; i++ {
|
|
|
+ // here were just validating that if we send NumStream request we get
|
|
|
+ // a response for every stream and the lengths for the queries are set
|
|
|
+ // correctly.
|
|
|
+ wg.Add(1)
|
|
|
+ go func() {
|
|
|
+ defer wg.Done()
|
|
|
+ if err := db.Query("void").Exec(); err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ }
|
|
|
+ wg.Wait()
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreams_Protocol2(t *testing.T) {
|
|
|
+ srv := NewTestServer(t, protoVersion2)
|
|
|
+ defer srv.Stop()
|
|
|
+
|
|
|
+ // TODO: these are more like session tests and should instead operate
|
|
|
+ // on a single Conn
|
|
|
+ cluster := NewCluster(srv.Address)
|
|
|
+ cluster.NumConns = 1
|
|
|
+ cluster.ProtoVersion = 2
|
|
|
+
|
|
|
+ db, err := cluster.CreateSession()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ defer db.Close()
|
|
|
+
|
|
|
+ for i := 0; i < db.cfg.NumStreams; i++ {
|
|
|
+ // the test server processes each conn synchronously
|
|
|
+ // here were just validating that if we send NumStream request we get
|
|
|
+ // a response for every stream and the lengths for the queries are set
|
|
|
+ // correctly.
|
|
|
+ if err = db.Query("void").Exec(); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreams_Protocol3(t *testing.T) {
|
|
|
+ srv := NewTestServer(t, protoVersion3)
|
|
|
+ defer srv.Stop()
|
|
|
+
|
|
|
+ // TODO: these are more like session tests and should instead operate
|
|
|
+ // on a single Conn
|
|
|
+ cluster := NewCluster(srv.Address)
|
|
|
+ cluster.NumConns = 1
|
|
|
+ cluster.ProtoVersion = 3
|
|
|
+
|
|
|
+ db, err := cluster.CreateSession()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ defer db.Close()
|
|
|
+
|
|
|
+ for i := 0; i < db.cfg.NumStreams; i++ {
|
|
|
+ // the test server processes each conn synchronously
|
|
|
+ // here were just validating that if we send NumStream request we get
|
|
|
+ // a response for every stream and the lengths for the queries are set
|
|
|
+ // correctly.
|
|
|
+ if err = db.Query("void").Exec(); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func BenchmarkProtocolV3(b *testing.B) {
|
|
|
+ srv := NewTestServer(b, protoVersion3)
|
|
|
+ defer srv.Stop()
|
|
|
+
|
|
|
+ // TODO: these are more like session tests and should instead operate
|
|
|
+ // on a single Conn
|
|
|
+ cluster := NewCluster(srv.Address)
|
|
|
+ cluster.NumConns = 1
|
|
|
+ cluster.ProtoVersion = 3
|
|
|
+
|
|
|
+ db, err := cluster.CreateSession()
|
|
|
+ if err != nil {
|
|
|
+ b.Fatal(err)
|
|
|
+ }
|
|
|
+ defer db.Close()
|
|
|
+
|
|
|
+ b.ResetTimer()
|
|
|
+ b.ReportAllocs()
|
|
|
+ for i := 0; i < b.N; i++ {
|
|
|
+ if err = db.Query("void").Exec(); err != nil {
|
|
|
+ b.Fatal(err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func NewTestServer(t testing.TB, protocol uint8) *TestServer {
|
|
|
laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
+
|
|
|
listen, err := net.ListenTCP("tcp", laddr)
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
- srv := &TestServer{Address: listen.Addr().String(), listen: listen, t: t}
|
|
|
+
|
|
|
+ headerSize := 8
|
|
|
+ if protocol > protoVersion2 {
|
|
|
+ headerSize = 9
|
|
|
+ }
|
|
|
+
|
|
|
+ srv := &TestServer{
|
|
|
+ Address: listen.Addr().String(),
|
|
|
+ listen: listen,
|
|
|
+ t: t,
|
|
|
+ protocol: protocol,
|
|
|
+ headerSize: headerSize,
|
|
|
+ }
|
|
|
+
|
|
|
go srv.serve()
|
|
|
+
|
|
|
return srv
|
|
|
}
|
|
|
|
|
|
-func NewSSLTestServer(t *testing.T) *TestServer {
|
|
|
+func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
|
|
|
pem, err := ioutil.ReadFile("testdata/pki/ca.crt")
|
|
|
certPool := x509.NewCertPool()
|
|
|
if !certPool.AppendCertsFromPEM(pem) {
|
|
|
@@ -270,11 +420,34 @@ func NewSSLTestServer(t *testing.T) *TestServer {
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
- srv := &TestServer{Address: listen.Addr().String(), listen: listen, t: t}
|
|
|
+
|
|
|
+ headerSize := 8
|
|
|
+ if protocol > protoVersion2 {
|
|
|
+ headerSize = 9
|
|
|
+ }
|
|
|
+
|
|
|
+ srv := &TestServer{
|
|
|
+ Address: listen.Addr().String(),
|
|
|
+ listen: listen,
|
|
|
+ t: t,
|
|
|
+ protocol: protocol,
|
|
|
+ headerSize: headerSize,
|
|
|
+ }
|
|
|
go srv.serve()
|
|
|
return srv
|
|
|
}
|
|
|
|
|
|
+type TestServer struct {
|
|
|
+ Address string
|
|
|
+ t testing.TB
|
|
|
+ nreq uint64
|
|
|
+ listen net.Listener
|
|
|
+ nKillReq int64
|
|
|
+
|
|
|
+ protocol uint8
|
|
|
+ headerSize int
|
|
|
+}
|
|
|
+
|
|
|
func (srv *TestServer) serve() {
|
|
|
defer srv.listen.Close()
|
|
|
for {
|
|
|
@@ -297,51 +470,63 @@ func (srv *TestServer) Stop() {
|
|
|
srv.listen.Close()
|
|
|
}
|
|
|
|
|
|
-func (srv *TestServer) process(frame frame, conn net.Conn) {
|
|
|
- switch frame[3] {
|
|
|
+func (srv *TestServer) process(f frame, conn net.Conn) {
|
|
|
+ headerSize := headerProtoSize[srv.protocol]
|
|
|
+ stream := f.Stream(srv.protocol)
|
|
|
+
|
|
|
+ switch f.Op(srv.protocol) {
|
|
|
case opStartup:
|
|
|
- frame = frame[:headerSize]
|
|
|
- frame.setHeader(protoResponse, 0, frame[2], opReady)
|
|
|
+ f = f[:headerSize]
|
|
|
+ f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opReady)
|
|
|
+ case opOptions:
|
|
|
+ f = f[:headerSize]
|
|
|
+ f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opSupported)
|
|
|
+ f.writeShort(0)
|
|
|
case opQuery:
|
|
|
- input := frame
|
|
|
- input.skipHeader()
|
|
|
+ input := f
|
|
|
+ input.skipHeader(srv.protocol)
|
|
|
query := strings.TrimSpace(input.readLongString())
|
|
|
- frame = frame[:headerSize]
|
|
|
- frame.setHeader(protoResponse, 0, frame[2], opResult)
|
|
|
+ f = f[:headerSize]
|
|
|
+ f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opResult)
|
|
|
first := query
|
|
|
if n := strings.Index(query, " "); n > 0 {
|
|
|
first = first[:n]
|
|
|
}
|
|
|
switch strings.ToLower(first) {
|
|
|
case "kill":
|
|
|
- atomic.AddUint64(&srv.nKillReq, 1)
|
|
|
- select {}
|
|
|
+ atomic.AddInt64(&srv.nKillReq, 1)
|
|
|
+ f = f[:headerSize]
|
|
|
+ f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opError)
|
|
|
+ f.writeInt(0x1001)
|
|
|
+ f.writeString("query killed")
|
|
|
case "slow":
|
|
|
go func() {
|
|
|
<-time.After(1 * time.Second)
|
|
|
- frame.writeInt(resultKindVoid)
|
|
|
- frame.setLength(len(frame) - headerSize)
|
|
|
- if _, err := conn.Write(frame); err != nil {
|
|
|
+ f.writeInt(resultKindVoid)
|
|
|
+ f.setLength(len(f)-headerSize, srv.protocol)
|
|
|
+ if _, err := conn.Write(f); err != nil {
|
|
|
return
|
|
|
}
|
|
|
}()
|
|
|
return
|
|
|
case "use":
|
|
|
- frame.writeInt(3)
|
|
|
- frame.writeString(strings.TrimSpace(query[3:]))
|
|
|
+ f.writeInt(3)
|
|
|
+ f.writeString(strings.TrimSpace(query[3:]))
|
|
|
case "void":
|
|
|
- frame.writeInt(resultKindVoid)
|
|
|
+ f.writeInt(resultKindVoid)
|
|
|
default:
|
|
|
- frame.writeInt(resultKindVoid)
|
|
|
+ f.writeInt(resultKindVoid)
|
|
|
}
|
|
|
default:
|
|
|
- frame = frame[:headerSize]
|
|
|
- frame.setHeader(protoResponse, 0, frame[2], opError)
|
|
|
- frame.writeInt(0)
|
|
|
- frame.writeString("not supported")
|
|
|
+ f = f[:headerSize]
|
|
|
+ f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opError)
|
|
|
+ f.writeInt(0)
|
|
|
+ f.writeString("not supported")
|
|
|
}
|
|
|
- frame.setLength(len(frame) - headerSize)
|
|
|
- if _, err := conn.Write(frame); err != nil {
|
|
|
+
|
|
|
+ f.setLength(len(f)-headerSize, srv.protocol)
|
|
|
+ if _, err := conn.Write(f); err != nil {
|
|
|
+ srv.t.Log(err)
|
|
|
return
|
|
|
}
|
|
|
}
|