Explorar o código

Cleanup pool tests

- Fix race.
- Fix connection leak.
Gary Burd %!s(int64=10) %!d(string=hai) anos
pai
achega
dd7accd139
Modificáronse 1 ficheiros con 46 adicións e 35 borrados
  1. 46 35
      redis/pool_test.go

+ 46 - 35
redis/pool_test.go

@@ -32,8 +32,14 @@ type poolTestConn struct {
 	redis.Conn
 }
 
-func (c *poolTestConn) Close() error { c.d.open -= 1; return nil }
-func (c *poolTestConn) Err() error   { return c.err }
+func (c *poolTestConn) Close() error {
+	c.d.mu.Lock()
+	c.d.open -= 1
+	c.d.mu.Unlock()
+	return c.Conn.Close()
+}
+
+func (c *poolTestConn) Err() error { return c.err }
 
 func (c *poolTestConn) Do(commandName string, args ...interface{}) (interface{}, error) {
 	if commandName == "ERR" {
@@ -52,6 +58,7 @@ func (c *poolTestConn) Send(commandName string, args ...interface{}) error {
 }
 
 type poolDialer struct {
+	mu       sync.Mutex
 	t        *testing.T
 	dialed   int
 	open     int
@@ -60,19 +67,25 @@ type poolDialer struct {
 }
 
 func (d *poolDialer) dial() (redis.Conn, error) {
+	d.mu.Lock()
 	d.dialed += 1
-	if d.dialErr != nil {
+	dialErr := d.dialErr
+	d.mu.Unlock()
+	if dialErr != nil {
 		return nil, d.dialErr
 	}
 	c, err := redistest.Dial()
 	if err != nil {
 		return nil, err
 	}
+	d.mu.Lock()
 	d.open += 1
+	d.mu.Unlock()
 	return &poolTestConn{d: d, Conn: c}, nil
 }
 
 func (d *poolDialer) check(message string, p *redis.Pool, dialed, open int) {
+	d.mu.Lock()
 	if d.dialed != dialed {
 		d.t.Errorf("%s: dialed=%d, want %d", message, d.dialed, dialed)
 	}
@@ -82,6 +95,7 @@ func (d *poolDialer) check(message string, p *redis.Pool, dialed, open int) {
 	if active := p.ActiveCount(); active != open {
 		d.t.Errorf("%s: active=%d, want %d", message, active, open)
 	}
+	d.mu.Unlock()
 }
 
 func TestPoolReuse(t *testing.T) {
@@ -111,6 +125,8 @@ func TestPoolMaxIdle(t *testing.T) {
 		MaxIdle: 2,
 		Dial:    d.dial,
 	}
+	defer p.Close()
+
 	for i := 0; i < 10; i++ {
 		c1 := p.Get()
 		c1.Do("PING")
@@ -133,6 +149,7 @@ func TestPoolError(t *testing.T) {
 		MaxIdle: 2,
 		Dial:    d.dial,
 	}
+	defer p.Close()
 
 	c := p.Get()
 	c.Do("ERR", io.EOF)
@@ -154,6 +171,7 @@ func TestPoolClose(t *testing.T) {
 		MaxIdle: 2,
 		Dial:    d.dial,
 	}
+	defer p.Close()
 
 	c1 := p.Get()
 	c1.Do("PING")
@@ -195,6 +213,7 @@ func TestPoolTimeout(t *testing.T) {
 		IdleTimeout: 300 * time.Second,
 		Dial:        d.dial,
 	}
+	defer p.Close()
 
 	now := time.Now()
 	redis.SetNowFunc(func() time.Time { return now })
@@ -213,14 +232,14 @@ func TestPoolTimeout(t *testing.T) {
 	c.Close()
 
 	d.check("2", p, 2, 1)
-
-	p.Close()
 }
 
 func TestPoolConcurrenSendReceive(t *testing.T) {
 	p := &redis.Pool{
 		Dial: redistest.Dial,
 	}
+	defer p.Close()
+
 	c := p.Get()
 	done := make(chan error, 1)
 	go func() {
@@ -238,7 +257,6 @@ func TestPoolConcurrenSendReceive(t *testing.T) {
 		t.Fatalf("Do() returned error %v", err)
 	}
 	c.Close()
-	p.Close()
 }
 
 func TestPoolBorrowCheck(t *testing.T) {
@@ -248,6 +266,7 @@ func TestPoolBorrowCheck(t *testing.T) {
 		Dial:         d.dial,
 		TestOnBorrow: func(redis.Conn, time.Time) error { return redis.Error("BLAH") },
 	}
+	defer p.Close()
 
 	for i := 0; i < 10; i++ {
 		c := p.Get()
@@ -255,7 +274,6 @@ func TestPoolBorrowCheck(t *testing.T) {
 		c.Close()
 	}
 	d.check("1", p, 10, 1)
-	p.Close()
 }
 
 func TestPoolMaxActive(t *testing.T) {
@@ -265,6 +283,8 @@ func TestPoolMaxActive(t *testing.T) {
 		MaxActive: 2,
 		Dial:      d.dial,
 	}
+	defer p.Close()
+
 	c1 := p.Get()
 	c1.Do("PING")
 	c2 := p.Get()
@@ -289,7 +309,6 @@ func TestPoolMaxActive(t *testing.T) {
 	c3.Close()
 
 	d.check("4", p, 2, 2)
-	p.Close()
 }
 
 func TestPoolMonitorCleanup(t *testing.T) {
@@ -299,12 +318,13 @@ func TestPoolMonitorCleanup(t *testing.T) {
 		MaxActive: 2,
 		Dial:      d.dial,
 	}
+	defer p.Close()
+
 	c := p.Get()
 	c.Send("MONITOR")
 	c.Close()
 
 	d.check("", p, 1, 0)
-	p.Close()
 }
 
 func TestPoolPubSubCleanup(t *testing.T) {
@@ -314,6 +334,7 @@ func TestPoolPubSubCleanup(t *testing.T) {
 		MaxActive: 2,
 		Dial:      d.dial,
 	}
+	defer p.Close()
 
 	c := p.Get()
 	c.Send("SUBSCRIBE", "x")
@@ -334,8 +355,6 @@ func TestPoolPubSubCleanup(t *testing.T) {
 		t.Errorf("got commands %v, want %v", d.commands, want)
 	}
 	d.commands = nil
-
-	p.Close()
 }
 
 func TestPoolTransactionCleanup(t *testing.T) {
@@ -345,6 +364,7 @@ func TestPoolTransactionCleanup(t *testing.T) {
 		MaxActive: 2,
 		Dial:      d.dial,
 	}
+	defer p.Close()
 
 	c := p.Get()
 	c.Do("WATCH", "key")
@@ -406,8 +426,6 @@ func TestPoolTransactionCleanup(t *testing.T) {
 		t.Errorf("got commands %v, want %v", d.commands, want)
 	}
 	d.commands = nil
-
-	p.Close()
 }
 
 func startGoroutines(p *redis.Pool, cmd string, args ...interface{}) chan error {
@@ -436,6 +454,7 @@ func TestWaitPool(t *testing.T) {
 		Wait:      true,
 	}
 	defer p.Close()
+
 	c := p.Get()
 	errs := startGoroutines(p, "PING")
 	d.check("before close", p, 1, 1)
@@ -462,6 +481,8 @@ func TestWaitPoolClose(t *testing.T) {
 		Dial:      d.dial,
 		Wait:      true,
 	}
+	defer p.Close()
+
 	c := p.Get()
 	if _, err := c.Do("PING"); err != nil {
 		t.Fatal(err)
@@ -497,6 +518,7 @@ func TestWaitPoolCommandError(t *testing.T) {
 		Wait:      true,
 	}
 	defer p.Close()
+
 	c := p.Get()
 	errs := startGoroutines(p, "ERR", testErr)
 	d.check("before close", p, 1, 1)
@@ -525,6 +547,7 @@ func TestWaitPoolDialError(t *testing.T) {
 		Wait:      true,
 	}
 	defer p.Close()
+
 	c := p.Get()
 	errs := startGoroutines(p, "ERR", testErr)
 	d.check("before close", p, 1, 1)
@@ -565,7 +588,7 @@ func TestWaitPoolDialError(t *testing.T) {
 // test ensures that iteration will work correctly if multiple threads are
 // iterating simultaneously.
 func TestLocking_TestOnBorrowFails_PoolDoesntCrash(t *testing.T) {
-	count := 100
+	const count = 100
 
 	// First we'll Create a pool where the pilfering of idle connections fails.
 	d := poolDialer{t: t}
@@ -580,29 +603,17 @@ func TestLocking_TestOnBorrowFails_PoolDoesntCrash(t *testing.T) {
 	defer p.Close()
 
 	// Fill the pool with idle connections.
-	b1 := sync.WaitGroup{}
-	b1.Add(count)
-	b2 := sync.WaitGroup{}
-	b2.Add(count)
-	for i := 0; i < count; i++ {
-		go func() {
-			c := p.Get()
-			if c.Err() != nil {
-				t.Errorf("pool get failed: %v", c.Err())
-			}
-			b1.Done()
-			b1.Wait()
-			c.Close()
-			b2.Done()
-		}()
+	conns := make([]redis.Conn, count)
+	for i := range conns {
+		conns[i] = p.Get()
 	}
-	b2.Wait()
-	if d.dialed != count {
-		t.Errorf("Expected %d dials, got %d", count, d.dialed)
+	for i := range conns {
+		conns[i].Close()
 	}
 
 	// Spawn a bunch of goroutines to thrash the pool.
-	b2.Add(count)
+	var wg sync.WaitGroup
+	wg.Add(count)
 	for i := 0; i < count; i++ {
 		go func() {
 			c := p.Get()
@@ -610,10 +621,10 @@ func TestLocking_TestOnBorrowFails_PoolDoesntCrash(t *testing.T) {
 				t.Errorf("pool get failed: %v", c.Err())
 			}
 			c.Close()
-			b2.Done()
+			wg.Done()
 		}()
 	}
-	b2.Wait()
+	wg.Wait()
 	if d.dialed != count*2 {
 		t.Errorf("Expected %d dials, got %d", count*2, d.dialed)
 	}