瀏覽代碼

Support DialContext on Pool

Support Pool.DialContext option

Fixes #406
Masayuki Izumi 6 年之前
父節點
當前提交
39e2c31b7c
共有 2 個文件被更改,包括 40 次插入5 次删除
  1. 20 5
      redis/pool.go
  2. 20 0
      redis/pool_test.go

+ 20 - 5
redis/pool.go

@@ -57,6 +57,7 @@ var (
 //    return &redis.Pool{
 //      MaxIdle: 3,
 //      IdleTimeout: 240 * time.Second,
+//      // Dial or DialContext must be set. When both are set, DialContext takes precedence over Dial.
 //      Dial: func () (redis.Conn, error) { return redis.Dial("tcp", addr) },
 //    }
 //  }
@@ -126,6 +127,13 @@ type Pool struct {
 	// (subscribed to pubsub channel, transaction started, ...).
 	Dial func() (Conn, error)
 
+	// DialContext is an application supplied function for creating and configuring a
+	// connection with the given context.
+	//
+	// The connection returned from Dial must not be in a special state
+	// (subscribed to pubsub channel, transaction started, ...).
+	DialContext func(ctx context.Context) (Conn, error)
+
 	// TestOnBorrow is an optional application supplied function for checking
 	// the health of an idle connection before the connection is used again by
 	// the application. Argument t is the time that the connection was returned
@@ -293,10 +301,7 @@ func (p *Pool) lazyInit() {
 
 // get prunes stale connections and returns a connection from the idle list or
 // creates a new connection.
-func (p *Pool) get(ctx interface {
-	Done() <-chan struct{}
-	Err() error
-}) (*poolConn, error) {
+func (p *Pool) get(ctx context.Context) (*poolConn, error) {
 
 	// Handle limit for p.Wait == true.
 	var waited time.Duration
@@ -372,7 +377,7 @@ func (p *Pool) get(ctx interface {
 
 	p.active++
 	p.mu.Unlock()
-	c, err := p.Dial()
+	c, err := p.dial(ctx)
 	if err != nil {
 		c = nil
 		p.mu.Lock()
@@ -385,6 +390,16 @@ func (p *Pool) get(ctx interface {
 	return &poolConn{c: c, created: nowFunc()}, err
 }
 
+func (p *Pool) dial(ctx context.Context) (Conn, error) {
+	if p.DialContext != nil {
+		return p.DialContext(ctx)
+	}
+	if p.Dial != nil {
+		return p.Dial()
+	}
+	return nil, errors.New("redigo: must pass Dial or DialContext to pool")
+}
+
 func (p *Pool) put(pc *poolConn, forceClose bool) error {
 	p.mu.Lock()
 	if !p.closed && !forceClose {

+ 20 - 0
redis/pool_test.go

@@ -88,6 +88,10 @@ func (d *poolDialer) dial() (redis.Conn, error) {
 	return &poolTestConn{d: d, Conn: c}, nil
 }
 
+func (d *poolDialer) dialContext(ctx context.Context) (redis.Conn, error) {
+	return d.dial()
+}
+
 func (d *poolDialer) check(message string, p *redis.Pool, dialed, open, inuse int) {
 	d.checkAll(message, p, dialed, open, inuse, 0, 0)
 }
@@ -820,6 +824,22 @@ func TestWaitPoolGetContext(t *testing.T) {
 	defer c.Close()
 }
 
+func TestWaitPoolGetContextWithDialContext(t *testing.T) {
+	d := poolDialer{t: t}
+	p := &redis.Pool{
+		MaxIdle:     1,
+		MaxActive:   1,
+		DialContext: d.dialContext,
+		Wait:        true,
+	}
+	defer p.Close()
+	c, err := p.GetContext(context.Background())
+	if err != nil {
+		t.Fatalf("GetContext returned %v", err)
+	}
+	defer c.Close()
+}
+
 func TestWaitPoolGetAfterClose(t *testing.T) {
 	d := poolDialer{t: t}
 	p := &redis.Pool{