|
|
@@ -157,6 +157,47 @@ func TestRoundRobin(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func TestConnClosing(t *testing.T) {
|
|
|
+ srv := NewTestServer(t)
|
|
|
+ defer srv.Stop()
|
|
|
+
|
|
|
+ db, err := NewCluster(srv.Address).CreateSession()
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("NewCluster: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ numConns := db.cfg.NumConns
|
|
|
+ count := db.cfg.NumStreams * numConns
|
|
|
+
|
|
|
+ wg := &sync.WaitGroup{}
|
|
|
+ wg.Add(count)
|
|
|
+ for i := 0; i < count; i++ {
|
|
|
+ go func(wg *sync.WaitGroup) {
|
|
|
+ wg.Done()
|
|
|
+ db.Query("kill").Exec()
|
|
|
+ }(wg)
|
|
|
+ }
|
|
|
+
|
|
|
+ wg.Wait()
|
|
|
+
|
|
|
+ cluster := db.Node.(*clusterImpl)
|
|
|
+ cluster.mu.Lock()
|
|
|
+ for conn := range cluster.conns {
|
|
|
+ conn.conn.Close()
|
|
|
+ }
|
|
|
+
|
|
|
+ cluster.mu.Unlock()
|
|
|
+
|
|
|
+ time.Sleep(20 * time.Millisecond)
|
|
|
+ cluster.mu.Lock()
|
|
|
+ conns := len(cluster.conns)
|
|
|
+ cluster.mu.Unlock()
|
|
|
+
|
|
|
+ if conns != numConns {
|
|
|
+ t.Fatalf("Expected to have %d connections but have %d", numConns, conns)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func NewTestServer(t *testing.T) *TestServer {
|
|
|
laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
|
|
|
if err != nil {
|