Browse Source

Improve pool

- Replace container/list with custom list
- Use channel instead of sync.Cond
- Add GetContext method
Gary Burd 7 years ago
parent
commit
f5d5b3d22e
5 changed files with 306 additions and 116 deletions
  1. 6 5
      .travis.yml
  2. 85 0
      redis/list_test.go
  3. 176 104
      redis/pool.go
  4. 36 4
      redis/pool_test.go
  5. 3 3
      redis/scan_test.go

+ 6 - 5
.travis.yml

@@ -5,11 +5,12 @@ services:
 
 
 go:
 go:
   - 1.4
   - 1.4
-  - 1.5
-  - 1.6
-  - 1.7
-  - 1.8
-  - 1.9
+  - 1.5.x
+  - 1.6.x
+  - 1.7.x
+  - 1.8.x
+  - 1.9.x
+  - 1.10.x
   - tip
   - tip
 
 
 script:
 script:

+ 85 - 0
redis/list_test.go

@@ -0,0 +1,85 @@
+// Copyright 2018 Gary Burd
+//
+// Licensed under the Apache License, Version 2.0 (the "License"): you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations
+// under the License.
+
+// +build go1.9
+
+package redis
+
+import "testing"
+
+func TestPoolList(t *testing.T) {
+	var idle idleList
+	var a, b, c idleConn
+
+	check := func(ics ...*idleConn) {
+		if idle.count != len(ics) {
+			t.Fatal("idle.count != len(ics)")
+		}
+		if len(ics) == 0 {
+			if idle.front != nil {
+				t.Fatalf("front not nil")
+			}
+			if idle.back != nil {
+				t.Fatalf("back not nil")
+			}
+			return
+		}
+		if idle.front != ics[0] {
+			t.Fatal("front != ics[0]")
+		}
+		if idle.back != ics[len(ics)-1] {
+			t.Fatal("back != ics[len(ics)-1]")
+		}
+		if idle.front.prev != nil {
+			t.Fatal("front.prev != nil")
+		}
+		if idle.back.next != nil {
+			t.Fatal("back.next != nil")
+		}
+		for i := 1; i < len(ics)-1; i++ {
+			if ics[i-1].next != ics[i] {
+				t.Fatal("ics[i-1].next != ics[i]")
+			}
+			if ics[i+1].prev != ics[i] {
+				t.Fatal("ics[i+1].prev != ics[i]")
+			}
+		}
+	}
+
+	idle.pushFront(&c)
+	check(&c)
+	idle.pushFront(&b)
+	check(&b, &c)
+	idle.pushFront(&a)
+	check(&a, &b, &c)
+	idle.popFront()
+	check(&b, &c)
+	idle.popFront()
+	check(&c)
+	idle.popFront()
+	check()
+
+	idle.pushFront(&c)
+	check(&c)
+	idle.pushFront(&b)
+	check(&b, &c)
+	idle.pushFront(&a)
+	check(&a, &b, &c)
+	idle.popBack()
+	check(&a, &b)
+	idle.popBack()
+	check(&a)
+	idle.popBack()
+	check()
+}

+ 176 - 104
redis/pool.go

@@ -16,16 +16,17 @@ package redis
 
 
 import (
 import (
 	"bytes"
 	"bytes"
-	"container/list"
 	"crypto/rand"
 	"crypto/rand"
 	"crypto/sha1"
 	"crypto/sha1"
 	"errors"
 	"errors"
 	"io"
 	"io"
 	"strconv"
 	"strconv"
 	"sync"
 	"sync"
+	"sync/atomic"
 	"time"
 	"time"
 
 
 	"github.com/garyburd/redigo/internal"
 	"github.com/garyburd/redigo/internal"
+	"golang.org/x/net/context"
 )
 )
 
 
 var (
 var (
@@ -150,19 +151,13 @@ type Pool struct {
 	// for a connection to be returned to the pool before returning.
 	// for a connection to be returned to the pool before returning.
 	Wait bool
 	Wait bool
 
 
-	// mu protects fields defined below.
-	mu     sync.Mutex
-	cond   *sync.Cond
-	closed bool
-	active int
+	chInitialized uint32 // set to 1 when field ch is initialized
 
 
-	// Stack of idleConn with most recently used at the front.
-	idle list.List
-}
-
-type idleConn struct {
-	c Conn
-	t time.Time
+	mu     sync.Mutex    // mu protects the following fields
+	closed bool          // set to true when the pool is closed.
+	active int           // the number of open connections in the pool
+	ch     chan struct{} // limits open connections when p.Wait is true
+	idle   idleList      // idle connections
 }
 }
 
 
 // NewPool creates a new pool.
 // NewPool creates a new pool.
@@ -178,16 +173,33 @@ func NewPool(newFn func() (Conn, error), maxIdle int) *Pool {
 // getting an underlying connection, then the connection Err, Do, Send, Flush
 // getting an underlying connection, then the connection Err, Do, Send, Flush
 // and Receive methods return that error.
 // and Receive methods return that error.
 func (p *Pool) Get() Conn {
 func (p *Pool) Get() Conn {
-	c, err := p.get()
+	c, err := p.get(nil)
 	if err != nil {
 	if err != nil {
 		return errorConnection{err}
 		return errorConnection{err}
 	}
 	}
 	return &pooledConnection{p: p, c: c}
 	return &pooledConnection{p: p, c: c}
 }
 }
 
 
+// GetContext gets a connection using the provided context.
+//
+// The provided Context must be non-nil. If the context expires before the
+// connection is complete, an error is returned. Any expiration on the context
+// will not affect the returned connection.
+//
+// If the function completes without error, then the application must close the
+// returned connection.
+func (p *Pool) GetContext(ctx context.Context) (Conn, error) {
+	c, err := p.get(ctx)
+	if err != nil {
+		return errorConnection{err}, err
+	}
+	return &pooledConnection{p: p, c: c}, nil
+}
+
 // PoolStats contains pool statistics.
 // PoolStats contains pool statistics.
 type PoolStats struct {
 type PoolStats struct {
-	// ActiveCount is the number of connections in the pool. The count includes idle connections and connections in use.
+	// ActiveCount is the number of connections in the pool. The count includes
+	// idle connections and connections in use.
 	ActiveCount int
 	ActiveCount int
 	// IdleCount is the number of idle connections in the pool.
 	// IdleCount is the number of idle connections in the pool.
 	IdleCount int
 	IdleCount int
@@ -198,14 +210,15 @@ func (p *Pool) Stats() PoolStats {
 	p.mu.Lock()
 	p.mu.Lock()
 	stats := PoolStats{
 	stats := PoolStats{
 		ActiveCount: p.active,
 		ActiveCount: p.active,
-		IdleCount:   p.idle.Len(),
+		IdleCount:   p.idle.count,
 	}
 	}
 	p.mu.Unlock()
 	p.mu.Unlock()
 
 
 	return stats
 	return stats
 }
 }
 
 
-// ActiveCount returns the number of connections in the pool. The count includes idle connections and connections in use.
+// ActiveCount returns the number of connections in the pool. The count
+// includes idle connections and connections in use.
 func (p *Pool) ActiveCount() int {
 func (p *Pool) ActiveCount() int {
 	p.mu.Lock()
 	p.mu.Lock()
 	active := p.active
 	active := p.active
@@ -216,7 +229,7 @@ func (p *Pool) ActiveCount() int {
 // IdleCount returns the number of idle connections in the pool.
 // IdleCount returns the number of idle connections in the pool.
 func (p *Pool) IdleCount() int {
 func (p *Pool) IdleCount() int {
 	p.mu.Lock()
 	p.mu.Lock()
-	idle := p.idle.Len()
+	idle := p.idle.count
 	p.mu.Unlock()
 	p.mu.Unlock()
 	return idle
 	return idle
 }
 }
@@ -224,132 +237,143 @@ func (p *Pool) IdleCount() int {
 // Close releases the resources used by the pool.
 // Close releases the resources used by the pool.
 func (p *Pool) Close() error {
 func (p *Pool) Close() error {
 	p.mu.Lock()
 	p.mu.Lock()
-	idle := p.idle
-	p.idle.Init()
+	if p.closed {
+		p.mu.Unlock()
+		return nil
+	}
 	p.closed = true
 	p.closed = true
-	p.active -= idle.Len()
-	if p.cond != nil {
-		p.cond.Broadcast()
+	p.active -= p.idle.count
+	ic := p.idle.front
+	p.idle.count = 0
+	p.idle.front, p.idle.back = nil, nil
+	if p.ch != nil {
+		close(p.ch)
 	}
 	}
 	p.mu.Unlock()
 	p.mu.Unlock()
-	for e := idle.Front(); e != nil; e = e.Next() {
-		e.Value.(idleConn).c.Close()
+	for ; ic != nil; ic = ic.next {
+		ic.c.Close()
 	}
 	}
 	return nil
 	return nil
 }
 }
 
 
-// release decrements the active count and signals waiters. The caller must
-// hold p.mu during the call.
-func (p *Pool) release() {
-	p.active -= 1
-	if p.cond != nil {
-		p.cond.Signal()
+func (p *Pool) lazyInit() {
+	// Fast path.
+	if atomic.LoadUint32(&p.chInitialized) == 1 {
+		return
+	}
+	// Slow path.
+	p.mu.Lock()
+	if p.chInitialized == 0 {
+		p.ch = make(chan struct{}, p.MaxActive)
+		if p.closed {
+			close(p.ch)
+		} else {
+			for i := 0; i < p.MaxActive; i++ {
+				p.ch <- struct{}{}
+			}
+		}
+		atomic.StoreUint32(&p.chInitialized, 1)
 	}
 	}
+	p.mu.Unlock()
 }
 }
 
 
 // get prunes stale connections and returns a connection from the idle list or
 // get prunes stale connections and returns a connection from the idle list or
 // creates a new connection.
 // creates a new connection.
-func (p *Pool) get() (Conn, error) {
-	p.mu.Lock()
-
-	// Prune stale connections.
+func (p *Pool) get(ctx context.Context) (Conn, error) {
 
 
-	if timeout := p.IdleTimeout; timeout > 0 {
-		for i, n := 0, p.idle.Len(); i < n; i++ {
-			e := p.idle.Back()
-			if e == nil {
-				break
-			}
-			ic := e.Value.(idleConn)
-			if ic.t.Add(timeout).After(nowFunc()) {
-				break
+	// Handle limit for p.Wait == true.
+	if p.Wait && p.MaxActive > 0 {
+		p.lazyInit()
+		if ctx == nil {
+			<-p.ch
+		} else {
+			select {
+			case <-p.ch:
+			case <-ctx.Done():
+				return nil, ctx.Err()
 			}
 			}
-			p.idle.Remove(e)
-			p.release()
-			p.mu.Unlock()
-			ic.c.Close()
-			p.mu.Lock()
 		}
 		}
 	}
 	}
 
 
-	for {
-		// Get idle connection.
+	p.mu.Lock()
 
 
-		for i, n := 0, p.idle.Len(); i < n; i++ {
-			e := p.idle.Front()
-			if e == nil {
-				break
-			}
-			ic := e.Value.(idleConn)
-			p.idle.Remove(e)
-			test := p.TestOnBorrow
+	// Prune stale connections at the back of the idle list.
+	if p.IdleTimeout > 0 {
+		n := p.idle.count
+		for i := 0; i < n && p.idle.back != nil && p.idle.back.t.Add(p.IdleTimeout).Before(nowFunc()); i++ {
+			c := p.idle.back.c
+			p.idle.popBack()
 			p.mu.Unlock()
 			p.mu.Unlock()
-			if test == nil || test(ic.c, ic.t) == nil {
-				return ic.c, nil
-			}
-			ic.c.Close()
+			c.Close()
 			p.mu.Lock()
 			p.mu.Lock()
-			p.release()
+			p.active--
 		}
 		}
+	}
 
 
-		// Check for pool closed before dialing a new connection.
-
-		if p.closed {
-			p.mu.Unlock()
-			return nil, errors.New("redigo: get on closed pool")
+	// Get idle connection from the front of idle list.
+	for p.idle.front != nil {
+		ic := p.idle.front
+		p.idle.popFront()
+		p.mu.Unlock()
+		if p.TestOnBorrow == nil || p.TestOnBorrow(ic.c, ic.t) == nil {
+			return ic.c, nil
 		}
 		}
+		ic.c.Close()
+		p.mu.Lock()
+		p.active--
+	}
 
 
-		// Dial new connection if under limit.
-
-		if p.MaxActive == 0 || p.active < p.MaxActive {
-			dial := p.Dial
-			p.active += 1
-			p.mu.Unlock()
-			c, err := dial()
-			if err != nil {
-				p.mu.Lock()
-				p.release()
-				p.mu.Unlock()
-				c = nil
-			}
-			return c, err
-		}
+	// Check for pool closed before dialing a new connection.
+	if p.closed {
+		p.mu.Unlock()
+		return nil, errors.New("redigo: get on closed pool")
+	}
 
 
-		if !p.Wait {
-			p.mu.Unlock()
-			return nil, ErrPoolExhausted
-		}
+	// Handle limit for p.Wait == false.
+	if !p.Wait && p.MaxActive > 0 && p.active >= p.MaxActive {
+		p.mu.Unlock()
+		return nil, ErrPoolExhausted
+	}
 
 
-		if p.cond == nil {
-			p.cond = sync.NewCond(&p.mu)
+	p.active++
+	p.mu.Unlock()
+	c, err := p.Dial()
+	if err != nil {
+		c = nil
+		p.mu.Lock()
+		p.active--
+		if p.ch != nil && !p.closed {
+			p.ch <- struct{}{}
 		}
 		}
-		p.cond.Wait()
+		p.mu.Unlock()
 	}
 	}
+	return c, err
 }
 }
 
 
 func (p *Pool) put(c Conn, forceClose bool) error {
 func (p *Pool) put(c Conn, forceClose bool) error {
-	err := c.Err()
 	p.mu.Lock()
 	p.mu.Lock()
-	if !p.closed && err == nil && !forceClose {
-		p.idle.PushFront(idleConn{t: nowFunc(), c: c})
-		if p.idle.Len() > p.MaxIdle {
-			c = p.idle.Remove(p.idle.Back()).(idleConn).c
+	if !p.closed && !forceClose {
+		p.idle.pushFront(&idleConn{t: nowFunc(), c: c})
+		if p.idle.count > p.MaxIdle {
+			c = p.idle.back.c
+			p.idle.popBack()
 		} else {
 		} else {
 			c = nil
 			c = nil
 		}
 		}
 	}
 	}
 
 
-	if c == nil {
-		if p.cond != nil {
-			p.cond.Signal()
-		}
+	if c != nil {
 		p.mu.Unlock()
 		p.mu.Unlock()
-		return nil
+		c.Close()
+		p.mu.Lock()
+		p.active--
 	}
 	}
 
 
-	p.release()
+	if p.ch != nil && !p.closed {
+		p.ch <- struct{}{}
+	}
 	p.mu.Unlock()
 	p.mu.Unlock()
-	return c.Close()
+	return nil
 }
 }
 
 
 type pooledConnection struct {
 type pooledConnection struct {
@@ -409,7 +433,7 @@ func (pc *pooledConnection) Close() error {
 		}
 		}
 	}
 	}
 	c.Do("")
 	c.Do("")
-	pc.p.put(c, pc.state != 0)
+	pc.p.put(c, pc.state != 0 || c.Err() != nil)
 	return nil
 	return nil
 }
 }
 
 
@@ -467,3 +491,51 @@ func (ec errorConnection) Close() error
 func (ec errorConnection) Flush() error                                          { return ec.err }
 func (ec errorConnection) Flush() error                                          { return ec.err }
 func (ec errorConnection) Receive() (interface{}, error)                         { return nil, ec.err }
 func (ec errorConnection) Receive() (interface{}, error)                         { return nil, ec.err }
 func (ec errorConnection) ReceiveWithTimeout(time.Duration) (interface{}, error) { return nil, ec.err }
 func (ec errorConnection) ReceiveWithTimeout(time.Duration) (interface{}, error) { return nil, ec.err }
+
+type idleList struct {
+	count       int
+	front, back *idleConn
+}
+
+type idleConn struct {
+	c          Conn
+	t          time.Time
+	next, prev *idleConn
+}
+
+func (l *idleList) pushFront(ic *idleConn) {
+	ic.next = l.front
+	ic.prev = nil
+	if l.count == 0 {
+		l.back = ic
+	} else {
+		l.front.prev = ic
+	}
+	l.front = ic
+	l.count++
+	return
+}
+
+func (l *idleList) popFront() {
+	ic := l.front
+	l.count--
+	if l.count == 0 {
+		l.front, l.back = nil, nil
+	} else {
+		ic.next.prev = nil
+		l.front = ic.next
+	}
+	ic.next, ic.prev = nil, nil
+}
+
+func (l *idleList) popBack() {
+	ic := l.back
+	l.count--
+	if l.count == 0 {
+		l.front, l.back = nil, nil
+	} else {
+		ic.prev.next = nil
+		l.back = ic.prev
+	}
+	ic.next, ic.prev = nil, nil
+}

+ 36 - 4
redis/pool_test.go

@@ -23,6 +23,7 @@ import (
 	"time"
 	"time"
 
 
 	"github.com/garyburd/redigo/redis"
 	"github.com/garyburd/redigo/redis"
+	"golang.org/x/net/context"
 )
 )
 
 
 type poolTestConn struct {
 type poolTestConn struct {
@@ -231,7 +232,7 @@ func TestPoolTimeout(t *testing.T) {
 
 
 	d.check("1", p, 1, 1, 0)
 	d.check("1", p, 1, 1, 0)
 
 
-	now = now.Add(p.IdleTimeout)
+	now = now.Add(p.IdleTimeout + 1)
 
 
 	c = p.Get()
 	c = p.Get()
 	c.Do("PING")
 	c.Do("PING")
@@ -445,9 +446,6 @@ func startGoroutines(p *redis.Pool, cmd string, args ...interface{}) chan error
 		}()
 		}()
 	}
 	}
 
 
-	// Wait for goroutines to block.
-	time.Sleep(time.Second / 4)
-
 	return errs
 	return errs
 }
 }
 
 
@@ -588,6 +586,40 @@ func TestWaitPoolDialError(t *testing.T) {
 	d.check("done", p, cap(errs), 0, 0)
 	d.check("done", p, cap(errs), 0, 0)
 }
 }
 
 
+func TestWaitPoolGetAfterClose(t *testing.T) {
+	d := poolDialer{t: t}
+	p := &redis.Pool{
+		MaxIdle:   1,
+		MaxActive: 1,
+		Dial:      d.dial,
+		Wait:      true,
+	}
+	p.Close()
+	_, err := p.GetContext(context.Background())
+	if err == nil {
+		t.Fatal("expected error")
+	}
+}
+
+func TestWaitPoolGetCanceledContext(t *testing.T) {
+	d := poolDialer{t: t}
+	p := &redis.Pool{
+		MaxIdle:   1,
+		MaxActive: 1,
+		Dial:      d.dial,
+		Wait:      true,
+	}
+	defer p.Close()
+	ctx, f := context.WithCancel(context.Background())
+	f()
+	c := p.Get()
+	defer c.Close()
+	_, err := p.GetContext(ctx)
+	if err != context.Canceled {
+		t.Fatalf("got error %v, want %v", err, context.Canceled)
+	}
+}
+
 // Borrowing requires us to iterate over the idle connections, unlock the pool,
 // Borrowing requires us to iterate over the idle connections, unlock the pool,
 // and perform a blocking operation to check the connection still works. If
 // and perform a blocking operation to check the connection still works. If
 // TestOnBorrow fails, we must reacquire the lock and continue iteration. This
 // TestOnBorrow fails, we must reacquire the lock and continue iteration. This

+ 3 - 3
redis/scan_test.go

@@ -84,8 +84,8 @@ var scanConversionTests = []struct {
 	{"1m", durationScan{Duration: time.Minute}},
 	{"1m", durationScan{Duration: time.Minute}},
 	{[]byte("1m"), durationScan{Duration: time.Minute}},
 	{[]byte("1m"), durationScan{Duration: time.Minute}},
 	{time.Minute.Nanoseconds(), durationScan{Duration: time.Minute}},
 	{time.Minute.Nanoseconds(), durationScan{Duration: time.Minute}},
-	{[]interface{}{[]byte("1m")}, []durationScan{durationScan{Duration: time.Minute}}},
-	{[]interface{}{[]byte("1m")}, []*durationScan{&durationScan{Duration: time.Minute}}},
+	{[]interface{}{[]byte("1m")}, []durationScan{{Duration: time.Minute}}},
+	{[]interface{}{[]byte("1m")}, []*durationScan{{Duration: time.Minute}}},
 }
 }
 
 
 func TestScanConversion(t *testing.T) {
 func TestScanConversion(t *testing.T) {
@@ -318,7 +318,7 @@ var scanSliceTests = []struct {
 		[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
 		[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
 		nil,
 		nil,
 		true,
 		true,
-		[]*struct{ A, B string }{{"a1", "b1"}, {"a2", "b2"}},
+		[]*struct{ A, B string }{{A: "a1", B: "b1"}, {A: "a2", B: "b2"}},
 	},
 	},
 	{
 	{
 		[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
 		[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},