Browse Source

Add max lifetime option to pool

Gary Burd 7 years ago
parent
commit
a7327d8ced
5 changed files with 207 additions and 98 deletions
  1. 14 14
      redis/list_test.go
  2. 115 80
      redis/pool.go
  3. 3 3
      redis/pool17.go
  4. 16 0
      redis/pool17_test.go
  5. 59 1
      redis/pool_test.go

+ 14 - 14
redis/list_test.go

@@ -20,13 +20,13 @@ import "testing"
 
 func TestPoolList(t *testing.T) {
 	var idle idleList
-	var a, b, c idleConn
+	var a, b, c poolConn
 
-	check := func(ics ...*idleConn) {
-		if idle.count != len(ics) {
-			t.Fatal("idle.count != len(ics)")
+	check := func(pcs ...*poolConn) {
+		if idle.count != len(pcs) {
+			t.Fatal("idle.count != len(pcs)")
 		}
-		if len(ics) == 0 {
+		if len(pcs) == 0 {
 			if idle.front != nil {
 				t.Fatalf("front not nil")
 			}
@@ -35,11 +35,11 @@ func TestPoolList(t *testing.T) {
 			}
 			return
 		}
-		if idle.front != ics[0] {
-			t.Fatal("front != ics[0]")
+		if idle.front != pcs[0] {
+			t.Fatal("front != pcs[0]")
 		}
-		if idle.back != ics[len(ics)-1] {
-			t.Fatal("back != ics[len(ics)-1]")
+		if idle.back != pcs[len(pcs)-1] {
+			t.Fatal("back != pcs[len(pcs)-1]")
 		}
 		if idle.front.prev != nil {
 			t.Fatal("front.prev != nil")
@@ -47,12 +47,12 @@ func TestPoolList(t *testing.T) {
 		if idle.back.next != nil {
 			t.Fatal("back.next != nil")
 		}
-		for i := 1; i < len(ics)-1; i++ {
-			if ics[i-1].next != ics[i] {
-				t.Fatal("ics[i-1].next != ics[i]")
+		for i := 1; i < len(pcs)-1; i++ {
+			if pcs[i-1].next != pcs[i] {
+				t.Fatal("pcs[i-1].next != pcs[i]")
 			}
-			if ics[i+1].prev != ics[i] {
-				t.Fatal("ics[i+1].prev != ics[i]")
+			if pcs[i+1].prev != pcs[i] {
+				t.Fatal("pcs[i+1].prev != pcs[i]")
 			}
 		}
 	}

+ 115 - 80
redis/pool.go

@@ -29,8 +29,8 @@ import (
 )
 
 var (
-	_ ConnWithTimeout = (*pooledConnection)(nil)
-	_ ConnWithTimeout = (*errorConnection)(nil)
+	_ ConnWithTimeout = (*activeConn)(nil)
+	_ ConnWithTimeout = (*errorConn)(nil)
 )
 
 var nowFunc = time.Now // for testing
@@ -150,6 +150,10 @@ type Pool struct {
 	// for a connection to be returned to the pool before returning.
 	Wait bool
 
+	// Close connections older than this duration. If the value is zero, then
+	// the pool does not close connections based on age.
+	MaxConnLifetime time.Duration
+
 	chInitialized uint32 // set to 1 when field ch is initialized
 
 	mu     sync.Mutex    // mu protects the following fields
@@ -172,11 +176,11 @@ func NewPool(newFn func() (Conn, error), maxIdle int) *Pool {
 // getting an underlying connection, then the connection Err, Do, Send, Flush
 // and Receive methods return that error.
 func (p *Pool) Get() Conn {
-	c, err := p.get(nil)
+	pc, err := p.get(nil)
 	if err != nil {
-		return errorConnection{err}
+		return errorConn{err}
 	}
-	return &pooledConnection{p: p, c: c}
+	return &activeConn{p: p, pc: pc}
 }
 
 // PoolStats contains pool statistics.
@@ -226,15 +230,15 @@ func (p *Pool) Close() error {
 	}
 	p.closed = true
 	p.active -= p.idle.count
-	ic := p.idle.front
+	pc := p.idle.front
 	p.idle.count = 0
 	p.idle.front, p.idle.back = nil, nil
 	if p.ch != nil {
 		close(p.ch)
 	}
 	p.mu.Unlock()
-	for ; ic != nil; ic = ic.next {
-		ic.c.Close()
+	for ; pc != nil; pc = pc.next {
+		pc.c.Close()
 	}
 	return nil
 }
@@ -265,7 +269,7 @@ func (p *Pool) lazyInit() {
 func (p *Pool) get(ctx interface {
 	Done() <-chan struct{}
 	Err() error
-}) (Conn, error) {
+}) (*poolConn, error) {
 
 	// Handle limit for p.Wait == true.
 	if p.Wait && p.MaxActive > 0 {
@@ -287,10 +291,10 @@ func (p *Pool) get(ctx interface {
 	if p.IdleTimeout > 0 {
 		n := p.idle.count
 		for i := 0; i < n && p.idle.back != nil && p.idle.back.t.Add(p.IdleTimeout).Before(nowFunc()); i++ {
-			c := p.idle.back.c
+			pc := p.idle.back
 			p.idle.popBack()
 			p.mu.Unlock()
-			c.Close()
+			pc.c.Close()
 			p.mu.Lock()
 			p.active--
 		}
@@ -298,13 +302,14 @@ func (p *Pool) get(ctx interface {
 
 	// Get idle connection from the front of idle list.
 	for p.idle.front != nil {
-		ic := p.idle.front
+		pc := p.idle.front
 		p.idle.popFront()
 		p.mu.Unlock()
-		if p.TestOnBorrow == nil || p.TestOnBorrow(ic.c, ic.t) == nil {
-			return ic.c, nil
+		if (p.TestOnBorrow == nil || p.TestOnBorrow(pc.c, pc.t) == nil) &&
+			(p.MaxConnLifetime == 0 || nowFunc().Sub(pc.created) < p.MaxConnLifetime) {
+			return pc, nil
 		}
-		ic.c.Close()
+		pc.c.Close()
 		p.mu.Lock()
 		p.active--
 	}
@@ -333,24 +338,25 @@ func (p *Pool) get(ctx interface {
 		}
 		p.mu.Unlock()
 	}
-	return c, err
+	return &poolConn{c: c, created: nowFunc()}, err
 }
 
-func (p *Pool) put(c Conn, forceClose bool) error {
+func (p *Pool) put(pc *poolConn, forceClose bool) error {
 	p.mu.Lock()
 	if !p.closed && !forceClose {
-		p.idle.pushFront(&idleConn{t: nowFunc(), c: c})
+		pc.t = nowFunc()
+		p.idle.pushFront(pc)
 		if p.idle.count > p.MaxIdle {
-			c = p.idle.back.c
+			pc = p.idle.back
 			p.idle.popBack()
 		} else {
-			c = nil
+			pc = nil
 		}
 	}
 
-	if c != nil {
+	if pc != nil {
 		p.mu.Unlock()
-		c.Close()
+		pc.c.Close()
 		p.mu.Lock()
 		p.active--
 	}
@@ -362,9 +368,9 @@ func (p *Pool) put(c Conn, forceClose bool) error {
 	return nil
 }
 
-type pooledConnection struct {
+type activeConn struct {
 	p     *Pool
-	c     Conn
+	pc    *poolConn
 	state int
 }
 
@@ -385,79 +391,107 @@ func initSentinel() {
 	}
 }
 
-func (pc *pooledConnection) Close() error {
-	c := pc.c
-	if _, ok := c.(errorConnection); ok {
+func (ac *activeConn) Close() error {
+	pc := ac.pc
+	if pc == nil {
 		return nil
 	}
-	pc.c = errorConnection{errConnClosed}
-
-	if pc.state&internal.MultiState != 0 {
-		c.Send("DISCARD")
-		pc.state &^= (internal.MultiState | internal.WatchState)
-	} else if pc.state&internal.WatchState != 0 {
-		c.Send("UNWATCH")
-		pc.state &^= internal.WatchState
+	ac.pc = nil
+
+	if ac.state&internal.MultiState != 0 {
+		pc.c.Send("DISCARD")
+		ac.state &^= (internal.MultiState | internal.WatchState)
+	} else if ac.state&internal.WatchState != 0 {
+		pc.c.Send("UNWATCH")
+		ac.state &^= internal.WatchState
 	}
-	if pc.state&internal.SubscribeState != 0 {
-		c.Send("UNSUBSCRIBE")
-		c.Send("PUNSUBSCRIBE")
+	if ac.state&internal.SubscribeState != 0 {
+		pc.c.Send("UNSUBSCRIBE")
+		pc.c.Send("PUNSUBSCRIBE")
 		// To detect the end of the message stream, ask the server to echo
 		// a sentinel value and read until we see that value.
 		sentinelOnce.Do(initSentinel)
-		c.Send("ECHO", sentinel)
-		c.Flush()
+		pc.c.Send("ECHO", sentinel)
+		pc.c.Flush()
 		for {
-			p, err := c.Receive()
+			p, err := pc.c.Receive()
 			if err != nil {
 				break
 			}
 			if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) {
-				pc.state &^= internal.SubscribeState
+				ac.state &^= internal.SubscribeState
 				break
 			}
 		}
 	}
-	c.Do("")
-	pc.p.put(c, pc.state != 0 || c.Err() != nil)
+	pc.c.Do("")
+	ac.p.put(pc, ac.state != 0 || pc.c.Err() != nil)
 	return nil
 }
 
-func (pc *pooledConnection) Err() error {
+func (ac *activeConn) Err() error {
+	pc := ac.pc
+	if pc == nil {
+		return errConnClosed
+	}
 	return pc.c.Err()
 }
 
-func (pc *pooledConnection) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
+func (ac *activeConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
+	pc := ac.pc
+	if pc == nil {
+		return nil, errConnClosed
+	}
 	ci := internal.LookupCommandInfo(commandName)
-	pc.state = (pc.state | ci.Set) &^ ci.Clear
+	ac.state = (ac.state | ci.Set) &^ ci.Clear
 	return pc.c.Do(commandName, args...)
 }
 
-func (pc *pooledConnection) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error) {
+func (ac *activeConn) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error) {
+	pc := ac.pc
+	if pc == nil {
+		return nil, errConnClosed
+	}
 	cwt, ok := pc.c.(ConnWithTimeout)
 	if !ok {
 		return nil, errTimeoutNotSupported
 	}
 	ci := internal.LookupCommandInfo(commandName)
-	pc.state = (pc.state | ci.Set) &^ ci.Clear
+	ac.state = (ac.state | ci.Set) &^ ci.Clear
 	return cwt.DoWithTimeout(timeout, commandName, args...)
 }
 
-func (pc *pooledConnection) Send(commandName string, args ...interface{}) error {
+func (ac *activeConn) Send(commandName string, args ...interface{}) error {
+	pc := ac.pc
+	if pc == nil {
+		return errConnClosed
+	}
 	ci := internal.LookupCommandInfo(commandName)
-	pc.state = (pc.state | ci.Set) &^ ci.Clear
+	ac.state = (ac.state | ci.Set) &^ ci.Clear
 	return pc.c.Send(commandName, args...)
 }
 
-func (pc *pooledConnection) Flush() error {
+func (ac *activeConn) Flush() error {
+	pc := ac.pc
+	if pc == nil {
+		return errConnClosed
+	}
 	return pc.c.Flush()
 }
 
-func (pc *pooledConnection) Receive() (reply interface{}, err error) {
+func (ac *activeConn) Receive() (reply interface{}, err error) {
+	pc := ac.pc
+	if pc == nil {
+		return nil, errConnClosed
+	}
 	return pc.c.Receive()
 }
 
-func (pc *pooledConnection) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
+func (ac *activeConn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
+	pc := ac.pc
+	if pc == nil {
+		return nil, errConnClosed
+	}
 	cwt, ok := pc.c.(ConnWithTimeout)
 	if !ok {
 		return nil, errTimeoutNotSupported
@@ -465,63 +499,64 @@ func (pc *pooledConnection) ReceiveWithTimeout(timeout time.Duration) (reply int
 	return cwt.ReceiveWithTimeout(timeout)
 }
 
-type errorConnection struct{ err error }
+type errorConn struct{ err error }
 
-func (ec errorConnection) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err }
-func (ec errorConnection) DoWithTimeout(time.Duration, string, ...interface{}) (interface{}, error) {
+func (ec errorConn) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err }
+func (ec errorConn) DoWithTimeout(time.Duration, string, ...interface{}) (interface{}, error) {
 	return nil, ec.err
 }
-func (ec errorConnection) Send(string, ...interface{}) error                     { return ec.err }
-func (ec errorConnection) Err() error                                            { return ec.err }
-func (ec errorConnection) Close() error                                          { return nil }
-func (ec errorConnection) Flush() error                                          { return ec.err }
-func (ec errorConnection) Receive() (interface{}, error)                         { return nil, ec.err }
-func (ec errorConnection) ReceiveWithTimeout(time.Duration) (interface{}, error) { return nil, ec.err }
+func (ec errorConn) Send(string, ...interface{}) error                     { return ec.err }
+func (ec errorConn) Err() error                                            { return ec.err }
+func (ec errorConn) Close() error                                          { return nil }
+func (ec errorConn) Flush() error                                          { return ec.err }
+func (ec errorConn) Receive() (interface{}, error)                         { return nil, ec.err }
+func (ec errorConn) ReceiveWithTimeout(time.Duration) (interface{}, error) { return nil, ec.err }
 
 type idleList struct {
 	count       int
-	front, back *idleConn
+	front, back *poolConn
 }
 
-type idleConn struct {
+type poolConn struct {
 	c          Conn
 	t          time.Time
-	next, prev *idleConn
+	created    time.Time
+	next, prev *poolConn
 }
 
-func (l *idleList) pushFront(ic *idleConn) {
-	ic.next = l.front
-	ic.prev = nil
+func (l *idleList) pushFront(pc *poolConn) {
+	pc.next = l.front
+	pc.prev = nil
 	if l.count == 0 {
-		l.back = ic
+		l.back = pc
 	} else {
-		l.front.prev = ic
+		l.front.prev = pc
 	}
-	l.front = ic
+	l.front = pc
 	l.count++
 	return
 }
 
 func (l *idleList) popFront() {
-	ic := l.front
+	pc := l.front
 	l.count--
 	if l.count == 0 {
 		l.front, l.back = nil, nil
 	} else {
-		ic.next.prev = nil
-		l.front = ic.next
+		pc.next.prev = nil
+		l.front = pc.next
 	}
-	ic.next, ic.prev = nil, nil
+	pc.next, pc.prev = nil, nil
 }
 
 func (l *idleList) popBack() {
-	ic := l.back
+	pc := l.back
 	l.count--
 	if l.count == 0 {
 		l.front, l.back = nil, nil
 	} else {
-		ic.prev.next = nil
-		l.back = ic.prev
+		pc.prev.next = nil
+		l.back = pc.prev
 	}
-	ic.next, ic.prev = nil, nil
+	pc.next, pc.prev = nil, nil
 }

+ 3 - 3
redis/pool17.go

@@ -27,9 +27,9 @@ import "context"
 // If the function completes without error, then the application must close the
 // returned connection.
 func (p *Pool) GetContext(ctx context.Context) (Conn, error) {
-	c, err := p.get(ctx)
+	pc, err := p.get(ctx)
 	if err != nil {
-		return errorConnection{err}, err
+		return errorConn{err}, err
 	}
-	return &pooledConnection{p: p, c: c}, nil
+	return &activeConn{p: p, pc: pc}, nil
 }

+ 16 - 0
redis/pool17_test.go

@@ -23,6 +23,22 @@ import (
 	"github.com/garyburd/redigo/redis"
 )
 
+func TestWaitPoolGetContext(t *testing.T) {
+	d := poolDialer{t: t}
+	p := &redis.Pool{
+		MaxIdle:   1,
+		MaxActive: 1,
+		Dial:      d.dial,
+		Wait:      true,
+	}
+	defer p.Close()
+	c, err := p.GetContext(context.Background())
+	if err != nil {
+		t.Fatalf("GetContext returned %v", err)
+	}
+	defer c.Close()
+}
+
 func TestWaitPoolGetAfterClose(t *testing.T) {
 	d := poolDialer{t: t}
 	p := &redis.Pool{

+ 59 - 1
redis/pool_test.go

@@ -212,7 +212,37 @@ func TestPoolClose(t *testing.T) {
 	}
 }
 
-func TestPoolTimeout(t *testing.T) {
+func TestPoolClosedConn(t *testing.T) {
+	d := poolDialer{t: t}
+	p := &redis.Pool{
+		MaxIdle:     2,
+		IdleTimeout: 300 * time.Second,
+		Dial:        d.dial,
+	}
+	defer p.Close()
+	c := p.Get()
+	if c.Err() != nil {
+		t.Fatal("get failed")
+	}
+	c.Close()
+	if err := c.Err(); err == nil {
+		t.Fatal("Err on closed connection did not return error")
+	}
+	if _, err := c.Do("PING"); err == nil {
+		t.Fatal("Do on closed connection did not return error")
+	}
+	if err := c.Send("PING"); err == nil {
+		t.Fatal("Send on closed connection did not return error")
+	}
+	if err := c.Flush(); err == nil {
+		t.Fatal("Flush on closed connection did not return error")
+	}
+	if _, err := c.Receive(); err == nil {
+		t.Fatal("Receive on closed connection did not return error")
+	}
+}
+
+func TestPoolIdleTimeout(t *testing.T) {
 	d := poolDialer{t: t}
 	p := &redis.Pool{
 		MaxIdle:     2,
@@ -240,6 +270,34 @@ func TestPoolTimeout(t *testing.T) {
 	d.check("2", p, 2, 1, 0)
 }
 
+func TestPoolMaxLifetime(t *testing.T) {
+	d := poolDialer{t: t}
+	p := &redis.Pool{
+		MaxIdle:         2,
+		MaxConnLifetime: 300 * time.Second,
+		Dial:            d.dial,
+	}
+	defer p.Close()
+
+	now := time.Now()
+	redis.SetNowFunc(func() time.Time { return now })
+	defer redis.SetNowFunc(time.Now)
+
+	c := p.Get()
+	c.Do("PING")
+	c.Close()
+
+	d.check("1", p, 1, 1, 0)
+
+	now = now.Add(p.MaxConnLifetime + 1)
+
+	c = p.Get()
+	c.Do("PING")
+	c.Close()
+
+	d.check("2", p, 2, 1, 0)
+}
+
 func TestPoolConcurrenSendReceive(t *testing.T) {
 	p := &redis.Pool{
 		Dial: redis.DialDefaultServer,