瀏覽代碼

Add Pool.Wait

This closes #56.
Gary Burd 11 年之前
父節點
當前提交
54a2929be0
共有 2 個文件被更改,包括 243 次插入55 次删除
  1. 83 51
      redis/pool.go
  2. 160 4
      redis/pool_test.go

+ 83 - 51
redis/pool.go

@@ -116,8 +116,13 @@ type Pool struct {
 	// the timeout to a value less than the server's timeout.
 	// the timeout to a value less than the server's timeout.
 	IdleTimeout time.Duration
 	IdleTimeout time.Duration
 
 
+	// If Wait is true and the pool is at the MaxIdle limit, then Get() waits
+	// for a connection to be returned to the pool before returning.
+	Wait bool
+
 	// mu protects fields defined below.
 	// mu protects fields defined below.
 	mu     sync.Mutex
 	mu     sync.Mutex
+	cond   *sync.Cond
 	closed bool
 	closed bool
 	active int
 	active int
 
 
@@ -164,6 +169,9 @@ func (p *Pool) Close() error {
 	p.idle.Init()
 	p.idle.Init()
 	p.closed = true
 	p.closed = true
 	p.active -= idle.Len()
 	p.active -= idle.Len()
+	if p.cond != nil {
+		p.cond.Broadcast()
+	}
 	p.mu.Unlock()
 	p.mu.Unlock()
 	for e := idle.Front(); e != nil; e = e.Next() {
 	for e := idle.Front(); e != nil; e = e.Next() {
 		e.Value.(idleConn).c.Close()
 		e.Value.(idleConn).c.Close()
@@ -171,16 +179,20 @@ func (p *Pool) Close() error {
 	return nil
 	return nil
 }
 }
 
 
+// release decrements the active count and signals waiters. The caller must
+// hold p.mu during the call.
+func (p *Pool) release() {
+	p.active -= 1
+	if p.cond != nil {
+		p.cond.Signal()
+	}
+}
+
 // get prunes stale connections and returns a connection from the idle list or
 // get prunes stale connections and returns a connection from the idle list or
 // creates a new connection.
 // creates a new connection.
 func (p *Pool) get() (Conn, error) {
 func (p *Pool) get() (Conn, error) {
 	p.mu.Lock()
 	p.mu.Lock()
 
 
-	if p.closed {
-		p.mu.Unlock()
-		return nil, errors.New("redigo: get on closed pool")
-	}
-
 	// Prune stale connections.
 	// Prune stale connections.
 
 
 	if timeout := p.IdleTimeout; timeout > 0 {
 	if timeout := p.IdleTimeout; timeout > 0 {
@@ -194,72 +206,92 @@ func (p *Pool) get() (Conn, error) {
 				break
 				break
 			}
 			}
 			p.idle.Remove(e)
 			p.idle.Remove(e)
-			p.active -= 1
+			p.release()
 			p.mu.Unlock()
 			p.mu.Unlock()
 			ic.c.Close()
 			ic.c.Close()
 			p.mu.Lock()
 			p.mu.Lock()
 		}
 		}
 	}
 	}
 
 
-	// Get idle connection.
+	for {
+
+		// Get idle connection.
 
 
-	for i, n := 0, p.idle.Len(); i < n; i++ {
-		e := p.idle.Front()
-		if e == nil {
-			break
+		for i, n := 0, p.idle.Len(); i < n; i++ {
+			e := p.idle.Front()
+			if e == nil {
+				break
+			}
+			ic := e.Value.(idleConn)
+			p.idle.Remove(e)
+			test := p.TestOnBorrow
+			p.mu.Unlock()
+			if test == nil || test(ic.c, ic.t) == nil {
+				return ic.c, nil
+			}
+			ic.c.Close()
+			p.mu.Lock()
+			p.release()
 		}
 		}
-		ic := e.Value.(idleConn)
-		p.idle.Remove(e)
-		test := p.TestOnBorrow
-		p.mu.Unlock()
-		if test == nil || test(ic.c, ic.t) == nil {
-			return ic.c, nil
+
+		// Check for pool closed before dialing a new connection.
+
+		if p.closed {
+			p.mu.Unlock()
+			return nil, errors.New("redigo: get on closed pool")
 		}
 		}
-		ic.c.Close()
-		p.mu.Lock()
-		p.active -= 1
-	}
 
 
-	if p.MaxActive > 0 && p.active >= p.MaxActive {
-		p.mu.Unlock()
-		return nil, ErrPoolExhausted
-	}
+		// Dial new connection if under limit.
 
 
-	// No idle connection, create new.
+		if p.MaxActive == 0 || p.active < p.MaxActive {
+			dial := p.Dial
+			p.active += 1
+			p.mu.Unlock()
+			c, err := dial()
+			if err != nil {
+				p.mu.Lock()
+				p.release()
+				p.mu.Unlock()
+				c = nil
+			}
+			return c, err
+		}
 
 
-	dial := p.Dial
-	p.active += 1
-	p.mu.Unlock()
-	c, err := dial()
-	if err != nil {
-		p.mu.Lock()
-		p.active -= 1
-		p.mu.Unlock()
-		c = nil
+		if !p.Wait {
+			p.mu.Unlock()
+			return nil, ErrPoolExhausted
+		}
+
+		if p.cond == nil {
+			p.cond = sync.NewCond(&p.mu)
+		}
+		p.cond.Wait()
 	}
 	}
-	return c, err
 }
 }
 
 
 func (p *Pool) put(c Conn, forceClose bool) error {
 func (p *Pool) put(c Conn, forceClose bool) error {
-	if c.Err() == nil && !forceClose {
-		p.mu.Lock()
-		if !p.closed {
-			p.idle.PushFront(idleConn{t: nowFunc(), c: c})
-			if p.idle.Len() > p.MaxIdle {
-				c = p.idle.Remove(p.idle.Back()).(idleConn).c
-			} else {
-				c = nil
-			}
+	err := c.Err()
+	p.mu.Lock()
+	if !p.closed && err == nil && !forceClose {
+		p.idle.PushFront(idleConn{t: nowFunc(), c: c})
+		if p.idle.Len() > p.MaxIdle {
+			c = p.idle.Remove(p.idle.Back()).(idleConn).c
+		} else {
+			c = nil
 		}
 		}
-		p.mu.Unlock()
 	}
 	}
-	if c != nil {
-		p.mu.Lock()
-		p.active -= 1
+
+	if c == nil {
+		if p.cond != nil {
+			p.cond.Signal()
+		}
 		p.mu.Unlock()
 		p.mu.Unlock()
-		return c.Close()
+		return nil
 	}
 	}
-	return nil
+
+	p.release()
+	p.mu.Unlock()
+	return c.Close()
 }
 }
 
 
 type pooledConnection struct {
 type pooledConnection struct {

+ 160 - 4
redis/pool_test.go

@@ -15,6 +15,7 @@
 package redis_test
 package redis_test
 
 
 import (
 import (
+	"errors"
 	"io"
 	"io"
 	"reflect"
 	"reflect"
 	"testing"
 	"testing"
@@ -36,6 +37,7 @@ func (c *poolTestConn) Err() error   { return c.err }
 func (c *poolTestConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
 func (c *poolTestConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
 	if commandName == "ERR" {
 	if commandName == "ERR" {
 		c.err = args[0].(error)
 		c.err = args[0].(error)
+		commandName = "PING"
 	}
 	}
 	if commandName != "" {
 	if commandName != "" {
 		c.d.commands = append(c.d.commands, commandName)
 		c.d.commands = append(c.d.commands, commandName)
@@ -49,18 +51,23 @@ func (c *poolTestConn) Send(commandName string, args ...interface{}) error {
 }
 }
 
 
 type poolDialer struct {
 type poolDialer struct {
-	t            *testing.T
-	dialed, open int
-	commands     []string
+	t        *testing.T
+	dialed   int
+	open     int
+	commands []string
+	dialErr  error
 }
 }
 
 
 func (d *poolDialer) dial() (redis.Conn, error) {
 func (d *poolDialer) dial() (redis.Conn, error) {
-	d.open += 1
 	d.dialed += 1
 	d.dialed += 1
+	if d.dialErr != nil {
+		return nil, d.dialErr
+	}
 	c, err := redistest.Dial()
 	c, err := redistest.Dial()
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
+	d.open += 1
 	return &poolTestConn{d: d, Conn: c}, nil
 	return &poolTestConn{d: d, Conn: c}, nil
 }
 }
 
 
@@ -402,6 +409,155 @@ func TestPoolTransactionCleanup(t *testing.T) {
 	p.Close()
 	p.Close()
 }
 }
 
 
+func startGoroutines(p *redis.Pool, cmd string, args ...interface{}) chan error {
+	errs := make(chan error, 10)
+	for i := 0; i < cap(errs); i++ {
+		go func() {
+			c := p.Get()
+			_, err := c.Do(cmd, args...)
+			errs <- err
+			c.Close()
+		}()
+	}
+
+	// Wait for goroutines to block.
+	time.Sleep(time.Second / 4)
+
+	return errs
+}
+
+func TestWaitPool(t *testing.T) {
+	d := poolDialer{t: t}
+	p := &redis.Pool{
+		MaxIdle:   1,
+		MaxActive: 1,
+		Dial:      d.dial,
+		Wait:      true,
+	}
+	defer p.Close()
+	c := p.Get()
+	errs := startGoroutines(p, "PING")
+	d.check("before close", p, 1, 1)
+	c.Close()
+	timeout := time.After(2 * time.Second)
+	for i := 0; i < cap(errs); i++ {
+		select {
+		case err := <-errs:
+			if err != nil {
+				t.Fatal(err)
+			}
+		case <-timeout:
+			t.Fatalf("timeout waiting for blocked goroutine %d", i)
+		}
+	}
+	d.check("done", p, 1, 1)
+}
+
+func TestWaitPoolClose(t *testing.T) {
+	d := poolDialer{t: t}
+	p := &redis.Pool{
+		MaxIdle:   1,
+		MaxActive: 1,
+		Dial:      d.dial,
+		Wait:      true,
+	}
+	c := p.Get()
+	if _, err := c.Do("PING"); err != nil {
+		t.Fatal(err)
+	}
+	errs := startGoroutines(p, "PING")
+	d.check("before close", p, 1, 1)
+	p.Close()
+	timeout := time.After(2 * time.Second)
+	for i := 0; i < cap(errs); i++ {
+		select {
+		case err := <-errs:
+			switch err {
+			case nil:
+				t.Fatal("blocked goroutine did not get error")
+			case redis.ErrPoolExhausted:
+				t.Fatal("blocked goroutine got pool exhausted error")
+			}
+		case <-timeout:
+			t.Fatal("timeout waiting for blocked goroutine")
+		}
+	}
+	c.Close()
+	d.check("done", p, 1, 0)
+}
+
+func TestWaitPoolCommandError(t *testing.T) {
+	testErr := errors.New("test")
+	d := poolDialer{t: t}
+	p := &redis.Pool{
+		MaxIdle:   1,
+		MaxActive: 1,
+		Dial:      d.dial,
+		Wait:      true,
+	}
+	defer p.Close()
+	c := p.Get()
+	errs := startGoroutines(p, "ERR", testErr)
+	d.check("before close", p, 1, 1)
+	c.Close()
+	timeout := time.After(2 * time.Second)
+	for i := 0; i < cap(errs); i++ {
+		select {
+		case err := <-errs:
+			if err != nil {
+				t.Fatal(err)
+			}
+		case <-timeout:
+			t.Fatalf("timeout waiting for blocked goroutine %d", i)
+		}
+	}
+	d.check("done", p, cap(errs), 0)
+}
+
+func TestWaitPoolDialError(t *testing.T) {
+	testErr := errors.New("test")
+	d := poolDialer{t: t}
+	p := &redis.Pool{
+		MaxIdle:   1,
+		MaxActive: 1,
+		Dial:      d.dial,
+		Wait:      true,
+	}
+	defer p.Close()
+	c := p.Get()
+	errs := startGoroutines(p, "ERR", testErr)
+	d.check("before close", p, 1, 1)
+
+	d.dialErr = errors.New("dial")
+	c.Close()
+
+	nilCount := 0
+	errCount := 0
+	timeout := time.After(2 * time.Second)
+	for i := 0; i < cap(errs); i++ {
+		select {
+		case err := <-errs:
+			switch err {
+			case nil:
+				nilCount++
+			case d.dialErr:
+				errCount++
+			default:
+				t.Fatalf("expected dial error or nil, got %v", err)
+			}
+		case <-timeout:
+			t.Fatalf("timeout waiting for blocked goroutine %d", i)
+		}
+	}
+	if nilCount != 1 {
+		t.Errorf("expected one nil error, got %d", nilCount)
+	}
+	if errCount != cap(errs)-1 {
+		t.Errorf("expected %d dial erors, got %d", cap(errs)-1, errCount)
+	}
+	d.check("done", p, cap(errs), 0)
+}
+
 func BenchmarkPoolGet(b *testing.B) {
 func BenchmarkPoolGet(b *testing.B) {
 	b.StopTimer()
 	b.StopTimer()
 	p := redis.Pool{Dial: redistest.Dial, MaxIdle: 2}
 	p := redis.Pool{Dial: redistest.Dial, MaxIdle: 2}