浏览代码

test: use context to stop test server

Chris Bannister 9 年之前
父节点
当前提交
b3548bf565
共有 1 个文件被更改,包括 106 次插入35 次删除
  1. 106 35
      conn_test.go

+ 106 - 35
conn_test.go

@@ -60,7 +60,7 @@ func testCluster(addr string, proto protoVersion) *ClusterConfig {
 }
 
 func TestSimple(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	cluster := testCluster(srv.Address, defaultProto)
@@ -75,7 +75,7 @@ func TestSimple(t *testing.T) {
 }
 
 func TestSSLSimple(t *testing.T) {
-	srv := NewSSLTestServer(t, defaultProto)
+	srv := NewSSLTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	db, err := createTestSslCluster(srv.Address, defaultProto, true).CreateSession()
@@ -89,7 +89,7 @@ func TestSSLSimple(t *testing.T) {
 }
 
 func TestSSLSimpleNoClientCert(t *testing.T) {
-	srv := NewSSLTestServer(t, defaultProto)
+	srv := NewSSLTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	db, err := createTestSslCluster(srv.Address, defaultProto, false).CreateSession()
@@ -121,7 +121,7 @@ func createTestSslCluster(addr string, proto protoVersion, useClientCert bool) *
 func TestClosed(t *testing.T) {
 	t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis")
 
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	session, err := newTestSession(srv.Address, defaultProto)
@@ -141,7 +141,9 @@ func newTestSession(addr string, proto protoVersion) (*Session, error) {
 }
 
 func TestTimeout(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	ctx, cancel := context.WithCancel(context.Background())
+
+	srv := NewTestServer(t, defaultProto, ctx)
 	defer srv.Stop()
 
 	db, err := newTestSession(srv.Address, defaultProto)
@@ -153,8 +155,6 @@ func TestTimeout(t *testing.T) {
 	var wg sync.WaitGroup
 	wg.Add(1)
 
-	ctx, cancel := context.WithCancel(context.Background())
-
 	go func() {
 		defer wg.Done()
 
@@ -165,7 +165,7 @@ func TestTimeout(t *testing.T) {
 		}
 	}()
 
-	if err := db.Query("kill").Exec(); err == nil {
+	if err := db.Query("kill").WithContext(ctx).Exec(); err == nil {
 		t.Fatal("expected error got nil")
 	}
 	cancel()
@@ -176,7 +176,10 @@ func TestTimeout(t *testing.T) {
 // TestQueryRetry will test to make sure that gocql will execute
 // the exact amount of retry queries designated by the user.
 func TestQueryRetry(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	srv := NewTestServer(t, defaultProto, ctx)
 	defer srv.Stop()
 
 	db, err := newTestSession(srv.Address, defaultProto)
@@ -186,9 +189,14 @@ func TestQueryRetry(t *testing.T) {
 	defer db.Close()
 
 	go func() {
-		<-time.After(5 * time.Second)
-		t.Fatalf("no timeout")
+		select {
+		case <-ctx.Done():
+			return
+		case <-time.After(5 * time.Second):
+			t.Errorf("no timeout")
+		}
 	}()
+
 	rt := &SimpleRetryPolicy{NumRetries: 1}
 
 	qry := db.Query("kill").RetryPolicy(rt)
@@ -209,7 +217,7 @@ func TestQueryRetry(t *testing.T) {
 }
 
 func TestStreams_Protocol1(t *testing.T) {
-	srv := NewTestServer(t, protoVersion1)
+	srv := NewTestServer(t, protoVersion1, context.Background())
 	defer srv.Stop()
 
 	// TODO: these are more like session tests and should instead operate
@@ -241,7 +249,7 @@ func TestStreams_Protocol1(t *testing.T) {
 }
 
 func TestStreams_Protocol3(t *testing.T) {
-	srv := NewTestServer(t, protoVersion3)
+	srv := NewTestServer(t, protoVersion3, context.Background())
 	defer srv.Stop()
 
 	// TODO: these are more like session tests and should instead operate
@@ -268,7 +276,7 @@ func TestStreams_Protocol3(t *testing.T) {
 }
 
 func BenchmarkProtocolV3(b *testing.B) {
-	srv := NewTestServer(b, protoVersion3)
+	srv := NewTestServer(b, protoVersion3, context.Background())
 	defer srv.Stop()
 
 	// TODO: these are more like session tests and should instead operate
@@ -294,7 +302,7 @@ func BenchmarkProtocolV3(b *testing.B) {
 
 // This tests that the policy connection pool handles SSL correctly
 func TestPolicyConnPoolSSL(t *testing.T) {
-	srv := NewSSLTestServer(t, defaultProto)
+	srv := NewSSLTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	cluster := createTestSslCluster(srv.Address, defaultProto, true)
@@ -319,7 +327,7 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 }
 
 func TestQueryTimeout(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	cluster := testCluster(srv.Address, defaultProto)
@@ -356,7 +364,7 @@ func TestQueryTimeout(t *testing.T) {
 }
 
 func TestQueryTimeoutMany(t *testing.T) {
-	srv := NewTestServer(t, 3)
+	srv := NewTestServer(t, 3, context.Background())
 	defer srv.Stop()
 
 	cluster := testCluster(srv.Address, 3)
@@ -381,7 +389,7 @@ func TestQueryTimeoutMany(t *testing.T) {
 }
 
 func BenchmarkSingleConn(b *testing.B) {
-	srv := NewTestServer(b, 3)
+	srv := NewTestServer(b, 3, context.Background())
 	defer srv.Stop()
 
 	cluster := testCluster(srv.Address, 3)
@@ -412,7 +420,7 @@ func TestQueryTimeoutReuseStream(t *testing.T) {
 	// TODO(zariel): move this to conn test, we really just want to check what
 	// happens when a conn is
 
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	cluster := testCluster(srv.Address, defaultProto)
@@ -436,7 +444,7 @@ func TestQueryTimeoutReuseStream(t *testing.T) {
 }
 
 func TestQueryTimeoutClose(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	cluster := testCluster(srv.Address, defaultProto)
@@ -473,12 +481,20 @@ func TestQueryTimeoutClose(t *testing.T) {
 func TestStream0(t *testing.T) {
 	const expErr = "gocql: received frame on stream 0"
 
-	srv := NewTestServer(t, defaultProto)
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	srv := NewTestServer(t, defaultProto, ctx)
 	defer srv.Stop()
 
 	errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
 		if !srv.isClosed() && !strings.HasPrefix(err.Error(), expErr) {
-			t.Errorf("expected to get error prefix %q got %q", expErr, err.Error())
+			select {
+			case <-ctx.Done():
+				return
+			default:
+				t.Errorf("expected to get error prefix %q got %q", expErr, err.Error())
+			}
 		}
 	})
 
@@ -494,7 +510,7 @@ func TestStream0(t *testing.T) {
 	})
 
 	// need to write out an invalid frame, which we need a connection to do
-	framer, err := conn.exec(context.Background(), writer, nil)
+	framer, err := conn.exec(ctx, writer, nil)
 	if err == nil {
 		t.Fatal("expected to get an error on stream 0")
 	} else if !strings.HasPrefix(err.Error(), expErr) {
@@ -512,7 +528,7 @@ func TestConnClosedBlocked(t *testing.T) {
 	// issue 664
 	const proto = 3
 
-	srv := NewTestServer(t, proto)
+	srv := NewTestServer(t, proto, context.Background())
 	defer srv.Stop()
 	errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
 		t.Log(err)
@@ -536,7 +552,7 @@ func TestConnClosedBlocked(t *testing.T) {
 }
 
 func TestContext_Timeout(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 
 	cluster := testCluster(srv.Address, defaultProto)
@@ -555,7 +571,7 @@ func TestContext_Timeout(t *testing.T) {
 	}
 }
 
-func NewTestServer(t testing.TB, protocol uint8) *TestServer {
+func NewTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
 		t.Fatal(err)
@@ -571,21 +587,24 @@ func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 		headerSize = 9
 	}
 
+	ctx, cancel := context.WithCancel(ctx)
 	srv := &TestServer{
 		Address:    listen.Addr().String(),
 		listen:     listen,
 		t:          t,
 		protocol:   protocol,
 		headerSize: headerSize,
-		quit:       make(chan struct{}),
+		ctx:        ctx,
+		cancel:     cancel,
 	}
 
+	go srv.closeWatch()
 	go srv.serve()
 
 	return srv
 }
 
-func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
+func NewSSLTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestServer {
 	pem, err := ioutil.ReadFile("testdata/pki/ca.crt")
 	certPool := x509.NewCertPool()
 	if !certPool.AppendCertsFromPEM(pem) {
@@ -609,14 +628,18 @@ func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
 		headerSize = 9
 	}
 
+	ctx, cancel := context.WithCancel(ctx)
 	srv := &TestServer{
 		Address:    listen.Addr().String(),
 		listen:     listen,
 		t:          t,
 		protocol:   protocol,
 		headerSize: headerSize,
-		quit:       make(chan struct{}),
+		ctx:        ctx,
+		cancel:     cancel,
 	}
+
+	go srv.closeWatch()
 	go srv.serve()
 	return srv
 }
@@ -631,28 +654,58 @@ type TestServer struct {
 
 	protocol   byte
 	headerSize int
+	ctx        context.Context
+	cancel     context.CancelFunc
 
 	quit   chan struct{}
 	mu     sync.Mutex
 	closed bool
 }
 
+func (srv *TestServer) closeWatch() {
+	<-srv.ctx.Done()
+
+	srv.mu.Lock()
+	defer srv.mu.Unlock()
+
+	srv.closeLocked()
+}
+
 func (srv *TestServer) serve() {
 	defer srv.listen.Close()
 	for {
+		select {
+		case <-srv.ctx.Done():
+			return
+		default:
+		}
+
 		conn, err := srv.listen.Accept()
 		if err != nil {
 			break
 		}
+
 		go func(conn net.Conn) {
 			defer conn.Close()
 			for {
+				select {
+				case <-srv.ctx.Done():
+					return
+				default:
+				}
+
 				framer, err := srv.readFrame(conn)
 				if err != nil {
 					if err == io.EOF {
 						return
 					}
 
+					select {
+					case <-srv.ctx.Done():
+						return
+					default:
+					}
+
 					srv.t.Error(err)
 					return
 				}
@@ -671,21 +724,32 @@ func (srv *TestServer) isClosed() bool {
 	return srv.closed
 }
 
-func (srv *TestServer) Stop() {
-	srv.mu.Lock()
-	defer srv.mu.Unlock()
+func (srv *TestServer) closeLocked() {
 	if srv.closed {
 		return
 	}
+
 	srv.closed = true
 
 	srv.listen.Close()
-	close(srv.quit)
+	srv.cancel()
+}
+
+func (srv *TestServer) Stop() {
+	srv.mu.Lock()
+	defer srv.mu.Unlock()
+	srv.closeLocked()
 }
 
 func (srv *TestServer) process(f *framer) {
 	head := f.header
 	if head == nil {
+		select {
+		case <-srv.ctx.Done():
+			return
+		default:
+		}
+
 		srv.t.Error("process frame with a nil header")
 		return
 	}
@@ -715,7 +779,7 @@ func (srv *TestServer) process(f *framer) {
 			f.writeHeader(0, opResult, head.stream)
 			f.writeInt(resultKindVoid)
 		case "timeout":
-			<-srv.quit
+			<-srv.ctx.Done()
 			return
 		case "slow":
 			go func() {
@@ -723,7 +787,8 @@ func (srv *TestServer) process(f *framer) {
 				f.writeInt(resultKindVoid)
 				f.wbuf[0] = srv.protocol | 0x80
 				select {
-				case <-srv.quit:
+				case <-srv.ctx.Done():
+					return
 				case <-time.After(50 * time.Millisecond):
 					f.finishWrite()
 				}
@@ -745,6 +810,12 @@ func (srv *TestServer) process(f *framer) {
 	f.wbuf[0] = srv.protocol | 0x80
 
 	if err := f.finishWrite(); err != nil {
+		select {
+		case <-srv.ctx.Done():
+			return
+		default:
+		}
+
 		srv.t.Error(err)
 	}
 }