Browse Source

test: use context to stop test server

Chris Bannister 9 years ago
parent
commit
b3548bf565
1 changed files with 106 additions and 35 deletions
  1. 106 35
      conn_test.go

+ 106 - 35
conn_test.go

@@ -60,7 +60,7 @@ func testCluster(addr string, proto protoVersion) *ClusterConfig {
 }
 }
 
 
 func TestSimple(t *testing.T) {
 func TestSimple(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := testCluster(srv.Address, defaultProto)
 	cluster := testCluster(srv.Address, defaultProto)
@@ -75,7 +75,7 @@ func TestSimple(t *testing.T) {
 }
 }
 
 
 func TestSSLSimple(t *testing.T) {
 func TestSSLSimple(t *testing.T) {
-	srv := NewSSLTestServer(t, defaultProto)
+	srv := NewSSLTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	db, err := createTestSslCluster(srv.Address, defaultProto, true).CreateSession()
 	db, err := createTestSslCluster(srv.Address, defaultProto, true).CreateSession()
@@ -89,7 +89,7 @@ func TestSSLSimple(t *testing.T) {
 }
 }
 
 
 func TestSSLSimpleNoClientCert(t *testing.T) {
 func TestSSLSimpleNoClientCert(t *testing.T) {
-	srv := NewSSLTestServer(t, defaultProto)
+	srv := NewSSLTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	db, err := createTestSslCluster(srv.Address, defaultProto, false).CreateSession()
 	db, err := createTestSslCluster(srv.Address, defaultProto, false).CreateSession()
@@ -121,7 +121,7 @@ func createTestSslCluster(addr string, proto protoVersion, useClientCert bool) *
 func TestClosed(t *testing.T) {
 func TestClosed(t *testing.T) {
 	t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis")
 	t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis")
 
 
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	session, err := newTestSession(srv.Address, defaultProto)
 	session, err := newTestSession(srv.Address, defaultProto)
@@ -141,7 +141,9 @@ func newTestSession(addr string, proto protoVersion) (*Session, error) {
 }
 }
 
 
 func TestTimeout(t *testing.T) {
 func TestTimeout(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	ctx, cancel := context.WithCancel(context.Background())
+
+	srv := NewTestServer(t, defaultProto, ctx)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	db, err := newTestSession(srv.Address, defaultProto)
 	db, err := newTestSession(srv.Address, defaultProto)
@@ -153,8 +155,6 @@ func TestTimeout(t *testing.T) {
 	var wg sync.WaitGroup
 	var wg sync.WaitGroup
 	wg.Add(1)
 	wg.Add(1)
 
 
-	ctx, cancel := context.WithCancel(context.Background())
-
 	go func() {
 	go func() {
 		defer wg.Done()
 		defer wg.Done()
 
 
@@ -165,7 +165,7 @@ func TestTimeout(t *testing.T) {
 		}
 		}
 	}()
 	}()
 
 
-	if err := db.Query("kill").Exec(); err == nil {
+	if err := db.Query("kill").WithContext(ctx).Exec(); err == nil {
 		t.Fatal("expected error got nil")
 		t.Fatal("expected error got nil")
 	}
 	}
 	cancel()
 	cancel()
@@ -176,7 +176,10 @@ func TestTimeout(t *testing.T) {
 // TestQueryRetry will test to make sure that gocql will execute
 // TestQueryRetry will test to make sure that gocql will execute
 // the exact amount of retry queries designated by the user.
 // the exact amount of retry queries designated by the user.
 func TestQueryRetry(t *testing.T) {
 func TestQueryRetry(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	srv := NewTestServer(t, defaultProto, ctx)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	db, err := newTestSession(srv.Address, defaultProto)
 	db, err := newTestSession(srv.Address, defaultProto)
@@ -186,9 +189,14 @@ func TestQueryRetry(t *testing.T) {
 	defer db.Close()
 	defer db.Close()
 
 
 	go func() {
 	go func() {
-		<-time.After(5 * time.Second)
-		t.Fatalf("no timeout")
+		select {
+		case <-ctx.Done():
+			return
+		case <-time.After(5 * time.Second):
+			t.Errorf("no timeout")
+		}
 	}()
 	}()
+
 	rt := &SimpleRetryPolicy{NumRetries: 1}
 	rt := &SimpleRetryPolicy{NumRetries: 1}
 
 
 	qry := db.Query("kill").RetryPolicy(rt)
 	qry := db.Query("kill").RetryPolicy(rt)
@@ -209,7 +217,7 @@ func TestQueryRetry(t *testing.T) {
 }
 }
 
 
 func TestStreams_Protocol1(t *testing.T) {
 func TestStreams_Protocol1(t *testing.T) {
-	srv := NewTestServer(t, protoVersion1)
+	srv := NewTestServer(t, protoVersion1, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	// TODO: these are more like session tests and should instead operate
 	// TODO: these are more like session tests and should instead operate
@@ -241,7 +249,7 @@ func TestStreams_Protocol1(t *testing.T) {
 }
 }
 
 
 func TestStreams_Protocol3(t *testing.T) {
 func TestStreams_Protocol3(t *testing.T) {
-	srv := NewTestServer(t, protoVersion3)
+	srv := NewTestServer(t, protoVersion3, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	// TODO: these are more like session tests and should instead operate
 	// TODO: these are more like session tests and should instead operate
@@ -268,7 +276,7 @@ func TestStreams_Protocol3(t *testing.T) {
 }
 }
 
 
 func BenchmarkProtocolV3(b *testing.B) {
 func BenchmarkProtocolV3(b *testing.B) {
-	srv := NewTestServer(b, protoVersion3)
+	srv := NewTestServer(b, protoVersion3, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	// TODO: these are more like session tests and should instead operate
 	// TODO: these are more like session tests and should instead operate
@@ -294,7 +302,7 @@ func BenchmarkProtocolV3(b *testing.B) {
 
 
 // This tests that the policy connection pool handles SSL correctly
 // This tests that the policy connection pool handles SSL correctly
 func TestPolicyConnPoolSSL(t *testing.T) {
 func TestPolicyConnPoolSSL(t *testing.T) {
-	srv := NewSSLTestServer(t, defaultProto)
+	srv := NewSSLTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := createTestSslCluster(srv.Address, defaultProto, true)
 	cluster := createTestSslCluster(srv.Address, defaultProto, true)
@@ -319,7 +327,7 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 }
 }
 
 
 func TestQueryTimeout(t *testing.T) {
 func TestQueryTimeout(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := testCluster(srv.Address, defaultProto)
 	cluster := testCluster(srv.Address, defaultProto)
@@ -356,7 +364,7 @@ func TestQueryTimeout(t *testing.T) {
 }
 }
 
 
 func TestQueryTimeoutMany(t *testing.T) {
 func TestQueryTimeoutMany(t *testing.T) {
-	srv := NewTestServer(t, 3)
+	srv := NewTestServer(t, 3, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := testCluster(srv.Address, 3)
 	cluster := testCluster(srv.Address, 3)
@@ -381,7 +389,7 @@ func TestQueryTimeoutMany(t *testing.T) {
 }
 }
 
 
 func BenchmarkSingleConn(b *testing.B) {
 func BenchmarkSingleConn(b *testing.B) {
-	srv := NewTestServer(b, 3)
+	srv := NewTestServer(b, 3, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := testCluster(srv.Address, 3)
 	cluster := testCluster(srv.Address, 3)
@@ -412,7 +420,7 @@ func TestQueryTimeoutReuseStream(t *testing.T) {
 	// TODO(zariel): move this to conn test, we really just want to check what
 	// TODO(zariel): move this to conn test, we really just want to check what
 	// happens when a conn is
 	// happens when a conn is
 
 
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := testCluster(srv.Address, defaultProto)
 	cluster := testCluster(srv.Address, defaultProto)
@@ -436,7 +444,7 @@ func TestQueryTimeoutReuseStream(t *testing.T) {
 }
 }
 
 
 func TestQueryTimeoutClose(t *testing.T) {
 func TestQueryTimeoutClose(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := testCluster(srv.Address, defaultProto)
 	cluster := testCluster(srv.Address, defaultProto)
@@ -473,12 +481,20 @@ func TestQueryTimeoutClose(t *testing.T) {
 func TestStream0(t *testing.T) {
 func TestStream0(t *testing.T) {
 	const expErr = "gocql: received frame on stream 0"
 	const expErr = "gocql: received frame on stream 0"
 
 
-	srv := NewTestServer(t, defaultProto)
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	srv := NewTestServer(t, defaultProto, ctx)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
 	errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
 		if !srv.isClosed() && !strings.HasPrefix(err.Error(), expErr) {
 		if !srv.isClosed() && !strings.HasPrefix(err.Error(), expErr) {
-			t.Errorf("expected to get error prefix %q got %q", expErr, err.Error())
+			select {
+			case <-ctx.Done():
+				return
+			default:
+				t.Errorf("expected to get error prefix %q got %q", expErr, err.Error())
+			}
 		}
 		}
 	})
 	})
 
 
@@ -494,7 +510,7 @@ func TestStream0(t *testing.T) {
 	})
 	})
 
 
 	// need to write out an invalid frame, which we need a connection to do
 	// need to write out an invalid frame, which we need a connection to do
-	framer, err := conn.exec(context.Background(), writer, nil)
+	framer, err := conn.exec(ctx, writer, nil)
 	if err == nil {
 	if err == nil {
 		t.Fatal("expected to get an error on stream 0")
 		t.Fatal("expected to get an error on stream 0")
 	} else if !strings.HasPrefix(err.Error(), expErr) {
 	} else if !strings.HasPrefix(err.Error(), expErr) {
@@ -512,7 +528,7 @@ func TestConnClosedBlocked(t *testing.T) {
 	// issue 664
 	// issue 664
 	const proto = 3
 	const proto = 3
 
 
-	srv := NewTestServer(t, proto)
+	srv := NewTestServer(t, proto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 	errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
 	errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
 		t.Log(err)
 		t.Log(err)
@@ -536,7 +552,7 @@ func TestConnClosedBlocked(t *testing.T) {
 }
 }
 
 
 func TestContext_Timeout(t *testing.T) {
 func TestContext_Timeout(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := testCluster(srv.Address, defaultProto)
 	cluster := testCluster(srv.Address, defaultProto)
@@ -555,7 +571,7 @@ func TestContext_Timeout(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func NewTestServer(t testing.TB, protocol uint8) *TestServer {
+func NewTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
@@ -571,21 +587,24 @@ func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 		headerSize = 9
 		headerSize = 9
 	}
 	}
 
 
+	ctx, cancel := context.WithCancel(ctx)
 	srv := &TestServer{
 	srv := &TestServer{
 		Address:    listen.Addr().String(),
 		Address:    listen.Addr().String(),
 		listen:     listen,
 		listen:     listen,
 		t:          t,
 		t:          t,
 		protocol:   protocol,
 		protocol:   protocol,
 		headerSize: headerSize,
 		headerSize: headerSize,
-		quit:       make(chan struct{}),
+		ctx:        ctx,
+		cancel:     cancel,
 	}
 	}
 
 
+	go srv.closeWatch()
 	go srv.serve()
 	go srv.serve()
 
 
 	return srv
 	return srv
 }
 }
 
 
-func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
+func NewSSLTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestServer {
 	pem, err := ioutil.ReadFile("testdata/pki/ca.crt")
 	pem, err := ioutil.ReadFile("testdata/pki/ca.crt")
 	certPool := x509.NewCertPool()
 	certPool := x509.NewCertPool()
 	if !certPool.AppendCertsFromPEM(pem) {
 	if !certPool.AppendCertsFromPEM(pem) {
@@ -609,14 +628,18 @@ func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
 		headerSize = 9
 		headerSize = 9
 	}
 	}
 
 
+	ctx, cancel := context.WithCancel(ctx)
 	srv := &TestServer{
 	srv := &TestServer{
 		Address:    listen.Addr().String(),
 		Address:    listen.Addr().String(),
 		listen:     listen,
 		listen:     listen,
 		t:          t,
 		t:          t,
 		protocol:   protocol,
 		protocol:   protocol,
 		headerSize: headerSize,
 		headerSize: headerSize,
-		quit:       make(chan struct{}),
+		ctx:        ctx,
+		cancel:     cancel,
 	}
 	}
+
+	go srv.closeWatch()
 	go srv.serve()
 	go srv.serve()
 	return srv
 	return srv
 }
 }
@@ -631,28 +654,58 @@ type TestServer struct {
 
 
 	protocol   byte
 	protocol   byte
 	headerSize int
 	headerSize int
+	ctx        context.Context
+	cancel     context.CancelFunc
 
 
 	quit   chan struct{}
 	quit   chan struct{}
 	mu     sync.Mutex
 	mu     sync.Mutex
 	closed bool
 	closed bool
 }
 }
 
 
+func (srv *TestServer) closeWatch() {
+	<-srv.ctx.Done()
+
+	srv.mu.Lock()
+	defer srv.mu.Unlock()
+
+	srv.closeLocked()
+}
+
 func (srv *TestServer) serve() {
 func (srv *TestServer) serve() {
 	defer srv.listen.Close()
 	defer srv.listen.Close()
 	for {
 	for {
+		select {
+		case <-srv.ctx.Done():
+			return
+		default:
+		}
+
 		conn, err := srv.listen.Accept()
 		conn, err := srv.listen.Accept()
 		if err != nil {
 		if err != nil {
 			break
 			break
 		}
 		}
+
 		go func(conn net.Conn) {
 		go func(conn net.Conn) {
 			defer conn.Close()
 			defer conn.Close()
 			for {
 			for {
+				select {
+				case <-srv.ctx.Done():
+					return
+				default:
+				}
+
 				framer, err := srv.readFrame(conn)
 				framer, err := srv.readFrame(conn)
 				if err != nil {
 				if err != nil {
 					if err == io.EOF {
 					if err == io.EOF {
 						return
 						return
 					}
 					}
 
 
+					select {
+					case <-srv.ctx.Done():
+						return
+					default:
+					}
+
 					srv.t.Error(err)
 					srv.t.Error(err)
 					return
 					return
 				}
 				}
@@ -671,21 +724,32 @@ func (srv *TestServer) isClosed() bool {
 	return srv.closed
 	return srv.closed
 }
 }
 
 
-func (srv *TestServer) Stop() {
-	srv.mu.Lock()
-	defer srv.mu.Unlock()
+func (srv *TestServer) closeLocked() {
 	if srv.closed {
 	if srv.closed {
 		return
 		return
 	}
 	}
+
 	srv.closed = true
 	srv.closed = true
 
 
 	srv.listen.Close()
 	srv.listen.Close()
-	close(srv.quit)
+	srv.cancel()
+}
+
+func (srv *TestServer) Stop() {
+	srv.mu.Lock()
+	defer srv.mu.Unlock()
+	srv.closeLocked()
 }
 }
 
 
 func (srv *TestServer) process(f *framer) {
 func (srv *TestServer) process(f *framer) {
 	head := f.header
 	head := f.header
 	if head == nil {
 	if head == nil {
+		select {
+		case <-srv.ctx.Done():
+			return
+		default:
+		}
+
 		srv.t.Error("process frame with a nil header")
 		srv.t.Error("process frame with a nil header")
 		return
 		return
 	}
 	}
@@ -715,7 +779,7 @@ func (srv *TestServer) process(f *framer) {
 			f.writeHeader(0, opResult, head.stream)
 			f.writeHeader(0, opResult, head.stream)
 			f.writeInt(resultKindVoid)
 			f.writeInt(resultKindVoid)
 		case "timeout":
 		case "timeout":
-			<-srv.quit
+			<-srv.ctx.Done()
 			return
 			return
 		case "slow":
 		case "slow":
 			go func() {
 			go func() {
@@ -723,7 +787,8 @@ func (srv *TestServer) process(f *framer) {
 				f.writeInt(resultKindVoid)
 				f.writeInt(resultKindVoid)
 				f.wbuf[0] = srv.protocol | 0x80
 				f.wbuf[0] = srv.protocol | 0x80
 				select {
 				select {
-				case <-srv.quit:
+				case <-srv.ctx.Done():
+					return
 				case <-time.After(50 * time.Millisecond):
 				case <-time.After(50 * time.Millisecond):
 					f.finishWrite()
 					f.finishWrite()
 				}
 				}
@@ -745,6 +810,12 @@ func (srv *TestServer) process(f *framer) {
 	f.wbuf[0] = srv.protocol | 0x80
 	f.wbuf[0] = srv.protocol | 0x80
 
 
 	if err := f.finishWrite(); err != nil {
 	if err := f.finishWrite(); err != nil {
+		select {
+		case <-srv.ctx.Done():
+			return
+		default:
+		}
+
 		srv.t.Error(err)
 		srv.t.Error(err)
 	}
 	}
 }
 }