فهرست منبع

Merge pull request #770 from Zariel/go17

fix test timeout panic in go 1.7
Chris Bannister 9 سال پیش
والد
کامیت
a5da50294a
3فایلهای تغییر یافته به همراه130 افزوده شده و 64 حذف شده
  1. 8 2
      cassandra_test.go
  2. 121 61
      conn_test.go
  3. 1 1
      policies_test.go

+ 8 - 2
cassandra_test.go

@@ -545,9 +545,15 @@ func TestNotEnoughQueryArgs(t *testing.T) {
 // TestCreateSessionTimeout tests to make sure the CreateSession function timeouts out correctly
 // and prevents an infinite loop of connection retries.
 func TestCreateSessionTimeout(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
 	go func() {
-		<-time.After(2 * time.Second)
-		t.Error("no startup timeout")
+		select {
+		case <-time.After(2 * time.Second):
+			t.Error("no startup timeout")
+		case <-ctx.Done():
+		}
 	}()
 
 	cluster := createCluster()

+ 121 - 61
conn_test.go

@@ -9,7 +9,6 @@ import (
 	"crypto/tls"
 	"crypto/x509"
 	"fmt"
-	"golang.org/x/net/context"
 	"io"
 	"io/ioutil"
 	"net"
@@ -18,6 +17,8 @@ import (
 	"sync/atomic"
 	"testing"
 	"time"
+
+	"golang.org/x/net/context"
 )
 
 const (
@@ -59,7 +60,7 @@ func testCluster(addr string, proto protoVersion) *ClusterConfig {
 }
 
 func TestSimple(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	cluster := testCluster(srv.Address, defaultProto)
@@ -74,7 +75,7 @@ func TestSimple(t *testing.T) {
 }
 
 func TestSSLSimple(t *testing.T) {
-	srv := NewSSLTestServer(t, defaultProto)
+	srv := NewSSLTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	db, err := createTestSslCluster(srv.Address, defaultProto, true).CreateSession()
@@ -88,7 +89,7 @@ func TestSSLSimple(t *testing.T) {
 }
 
 func TestSSLSimpleNoClientCert(t *testing.T) {
-	srv := NewSSLTestServer(t, defaultProto)
+	srv := NewSSLTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	db, err := createTestSslCluster(srv.Address, defaultProto, false).CreateSession()
@@ -120,7 +121,7 @@ func createTestSslCluster(addr string, proto protoVersion, useClientCert bool) *
 func TestClosed(t *testing.T) {
 	t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis")
 
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	session, err := newTestSession(srv.Address, defaultProto)
@@ -140,7 +141,9 @@ func newTestSession(addr string, proto protoVersion) (*Session, error) {
 }
 
 func TestTimeout(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	ctx, cancel := context.WithCancel(context.Background())
+
+	srv := NewTestServer(t, defaultProto, ctx)
 	defer srv.Stop()
 
 	db, err := newTestSession(srv.Address, defaultProto)
@@ -149,20 +152,34 @@ func TestTimeout(t *testing.T) {
 	}
 	defer db.Close()
 
+	var wg sync.WaitGroup
+	wg.Add(1)
+
 	go func() {
-		<-time.After(2 * time.Second)
-		t.Errorf("no timeout")
+		defer wg.Done()
+
+		select {
+		case <-time.After(5 * time.Second):
+			t.Errorf("no timeout")
+		case <-ctx.Done():
+		}
 	}()
 
-	if err := db.Query("kill").Exec(); err == nil {
-		t.Errorf("expected error")
+	if err := db.Query("kill").WithContext(ctx).Exec(); err == nil {
+		t.Fatal("expected error got nil")
 	}
+	cancel()
+
+	wg.Wait()
 }
 
 // TestQueryRetry will test to make sure that gocql will execute
 // the exact amount of retry queries designated by the user.
 func TestQueryRetry(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	srv := NewTestServer(t, defaultProto, ctx)
 	defer srv.Stop()
 
 	db, err := newTestSession(srv.Address, defaultProto)
@@ -172,9 +189,14 @@ func TestQueryRetry(t *testing.T) {
 	defer db.Close()
 
 	go func() {
-		<-time.After(5 * time.Second)
-		t.Fatalf("no timeout")
+		select {
+		case <-ctx.Done():
+			return
+		case <-time.After(5 * time.Second):
+			t.Errorf("no timeout")
+		}
 	}()
+
 	rt := &SimpleRetryPolicy{NumRetries: 1}
 
 	qry := db.Query("kill").RetryPolicy(rt)
@@ -195,7 +217,7 @@ func TestQueryRetry(t *testing.T) {
 }
 
 func TestStreams_Protocol1(t *testing.T) {
-	srv := NewTestServer(t, protoVersion1)
+	srv := NewTestServer(t, protoVersion1, context.Background())
 	defer srv.Stop()
 
 	// TODO: these are more like session tests and should instead operate
@@ -227,7 +249,7 @@ func TestStreams_Protocol1(t *testing.T) {
 }
 
 func TestStreams_Protocol3(t *testing.T) {
-	srv := NewTestServer(t, protoVersion3)
+	srv := NewTestServer(t, protoVersion3, context.Background())
 	defer srv.Stop()
 
 	// TODO: these are more like session tests and should instead operate
@@ -254,7 +276,7 @@ func TestStreams_Protocol3(t *testing.T) {
 }
 
 func BenchmarkProtocolV3(b *testing.B) {
-	srv := NewTestServer(b, protoVersion3)
+	srv := NewTestServer(b, protoVersion3, context.Background())
 	defer srv.Stop()
 
 	// TODO: these are more like session tests and should instead operate
@@ -280,7 +302,7 @@ func BenchmarkProtocolV3(b *testing.B) {
 
 // This tests that the policy connection pool handles SSL correctly
 func TestPolicyConnPoolSSL(t *testing.T) {
-	srv := NewSSLTestServer(t, defaultProto)
+	srv := NewSSLTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	cluster := createTestSslCluster(srv.Address, defaultProto, true)
@@ -305,7 +327,7 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 }
 
 func TestQueryTimeout(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	cluster := testCluster(srv.Address, defaultProto)
@@ -341,33 +363,8 @@ func TestQueryTimeout(t *testing.T) {
 	}
 }
 
-func TestQueryTimeoutMany(t *testing.T) {
-	srv := NewTestServer(t, 3)
-	defer srv.Stop()
-
-	cluster := testCluster(srv.Address, 3)
-	// Set the timeout arbitrarily low so that the query hits the timeout in a
-	// timely manner.
-	cluster.Timeout = 5 * time.Millisecond
-	cluster.NumConns = 1
-
-	db, err := cluster.CreateSession()
-	if err != nil {
-		t.Fatalf("NewCluster: %v", err)
-	}
-	defer db.Close()
-
-	for i := 0; i < 128; i++ {
-		err := db.Query("void").Exec()
-		if err != nil {
-			t.Error(err)
-			return
-		}
-	}
-}
-
 func BenchmarkSingleConn(b *testing.B) {
-	srv := NewTestServer(b, 3)
+	srv := NewTestServer(b, 3, context.Background())
 	defer srv.Stop()
 
 	cluster := testCluster(srv.Address, 3)
@@ -398,7 +395,7 @@ func TestQueryTimeoutReuseStream(t *testing.T) {
 	// TODO(zariel): move this to conn test, we really just want to check what
 	// happens when a conn is
 
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	cluster := testCluster(srv.Address, defaultProto)
@@ -422,7 +419,7 @@ func TestQueryTimeoutReuseStream(t *testing.T) {
 }
 
 func TestQueryTimeoutClose(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	cluster := testCluster(srv.Address, defaultProto)
@@ -459,12 +456,20 @@ func TestQueryTimeoutClose(t *testing.T) {
 func TestStream0(t *testing.T) {
 	const expErr = "gocql: received frame on stream 0"
 
-	srv := NewTestServer(t, defaultProto)
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	srv := NewTestServer(t, defaultProto, ctx)
 	defer srv.Stop()
 
 	errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
 		if !srv.isClosed() && !strings.HasPrefix(err.Error(), expErr) {
-			t.Errorf("expected to get error prefix %q got %q", expErr, err.Error())
+			select {
+			case <-ctx.Done():
+				return
+			default:
+				t.Errorf("expected to get error prefix %q got %q", expErr, err.Error())
+			}
 		}
 	})
 
@@ -480,7 +485,7 @@ func TestStream0(t *testing.T) {
 	})
 
 	// need to write out an invalid frame, which we need a connection to do
-	framer, err := conn.exec(context.Background(), writer, nil)
+	framer, err := conn.exec(ctx, writer, nil)
 	if err == nil {
 		t.Fatal("expected to get an error on stream 0")
 	} else if !strings.HasPrefix(err.Error(), expErr) {
@@ -498,7 +503,7 @@ func TestConnClosedBlocked(t *testing.T) {
 	// issue 664
 	const proto = 3
 
-	srv := NewTestServer(t, proto)
+	srv := NewTestServer(t, proto, context.Background())
 	defer srv.Stop()
 	errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
 		t.Log(err)
@@ -522,7 +527,7 @@ func TestConnClosedBlocked(t *testing.T) {
 }
 
 func TestContext_Timeout(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	cluster := testCluster(srv.Address, defaultProto)
@@ -541,7 +546,7 @@ func TestContext_Timeout(t *testing.T) {
 	}
 }
 
-func NewTestServer(t testing.TB, protocol uint8) *TestServer {
+func NewTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
 		t.Fatal(err)
@@ -557,21 +562,24 @@ func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 		headerSize = 9
 	}
 
+	ctx, cancel := context.WithCancel(ctx)
 	srv := &TestServer{
 		Address:    listen.Addr().String(),
 		listen:     listen,
 		t:          t,
 		protocol:   protocol,
 		headerSize: headerSize,
-		quit:       make(chan struct{}),
+		ctx:        ctx,
+		cancel:     cancel,
 	}
 
+	go srv.closeWatch()
 	go srv.serve()
 
 	return srv
 }
 
-func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
+func NewSSLTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestServer {
 	pem, err := ioutil.ReadFile("testdata/pki/ca.crt")
 	certPool := x509.NewCertPool()
 	if !certPool.AppendCertsFromPEM(pem) {
@@ -595,14 +603,18 @@ func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
 		headerSize = 9
 	}
 
+	ctx, cancel := context.WithCancel(ctx)
 	srv := &TestServer{
 		Address:    listen.Addr().String(),
 		listen:     listen,
 		t:          t,
 		protocol:   protocol,
 		headerSize: headerSize,
-		quit:       make(chan struct{}),
+		ctx:        ctx,
+		cancel:     cancel,
 	}
+
+	go srv.closeWatch()
 	go srv.serve()
 	return srv
 }
@@ -617,28 +629,58 @@ type TestServer struct {
 
 	protocol   byte
 	headerSize int
+	ctx        context.Context
+	cancel     context.CancelFunc
 
 	quit   chan struct{}
 	mu     sync.Mutex
 	closed bool
 }
 
+func (srv *TestServer) closeWatch() {
+	<-srv.ctx.Done()
+
+	srv.mu.Lock()
+	defer srv.mu.Unlock()
+
+	srv.closeLocked()
+}
+
 func (srv *TestServer) serve() {
 	defer srv.listen.Close()
 	for {
+		select {
+		case <-srv.ctx.Done():
+			return
+		default:
+		}
+
 		conn, err := srv.listen.Accept()
 		if err != nil {
 			break
 		}
+
 		go func(conn net.Conn) {
 			defer conn.Close()
 			for {
+				select {
+				case <-srv.ctx.Done():
+					return
+				default:
+				}
+
 				framer, err := srv.readFrame(conn)
 				if err != nil {
 					if err == io.EOF {
 						return
 					}
 
+					select {
+					case <-srv.ctx.Done():
+						return
+					default:
+					}
+
 					srv.t.Error(err)
 					return
 				}
@@ -657,21 +699,32 @@ func (srv *TestServer) isClosed() bool {
 	return srv.closed
 }
 
-func (srv *TestServer) Stop() {
-	srv.mu.Lock()
-	defer srv.mu.Unlock()
+func (srv *TestServer) closeLocked() {
 	if srv.closed {
 		return
 	}
+
 	srv.closed = true
 
 	srv.listen.Close()
-	close(srv.quit)
+	srv.cancel()
+}
+
+func (srv *TestServer) Stop() {
+	srv.mu.Lock()
+	defer srv.mu.Unlock()
+	srv.closeLocked()
 }
 
 func (srv *TestServer) process(f *framer) {
 	head := f.header
 	if head == nil {
+		select {
+		case <-srv.ctx.Done():
+			return
+		default:
+		}
+
 		srv.t.Error("process frame with a nil header")
 		return
 	}
@@ -701,7 +754,7 @@ func (srv *TestServer) process(f *framer) {
 			f.writeHeader(0, opResult, head.stream)
 			f.writeInt(resultKindVoid)
 		case "timeout":
-			<-srv.quit
+			<-srv.ctx.Done()
 			return
 		case "slow":
 			go func() {
@@ -709,7 +762,8 @@ func (srv *TestServer) process(f *framer) {
 				f.writeInt(resultKindVoid)
 				f.wbuf[0] = srv.protocol | 0x80
 				select {
-				case <-srv.quit:
+				case <-srv.ctx.Done():
+					return
 				case <-time.After(50 * time.Millisecond):
 					f.finishWrite()
 				}
@@ -731,6 +785,12 @@ func (srv *TestServer) process(f *framer) {
 	f.wbuf[0] = srv.protocol | 0x80
 
 	if err := f.finishWrite(); err != nil {
+		select {
+		case <-srv.ctx.Done():
+			return
+		default:
+		}
+
 		srv.t.Error(err)
 	}
 }

+ 1 - 1
policies_test.go

@@ -161,7 +161,7 @@ func TestRoundRobinNilHostInfo(t *testing.T) {
 	} else if v := next.Info(); v == nil {
 		t.Fatal("got nil HostInfo")
 	} else if v.HostID() != host.HostID() {
-		t.Fatalf("expected host %v got %v", host, *v)
+		t.Fatalf("expected host %v got %v", host, v)
 	}
 
 	next = iter()