Kaynağa Gözat

Improve pooled connection concurrency support.

Add support for one concurrent reader and one concurrent writer on
pooled connections.

Fixes #73
Gary Burd 11 yıl önce
ebeveyn
işleme
7eff00e9b3
2 değiştirilmiş dosya ile 82 ekleme ve 62 silme
  1. 56 61
      redis/pool.go
  2. 26 1
      redis/pool_test.go

+ 56 - 61
redis/pool.go

@@ -33,7 +33,10 @@ var nowFunc = time.Now // for testing
 // pool has been reached.
 var ErrPoolExhausted = errors.New("redigo: connection pool exhausted")
 
-var errPoolClosed = errors.New("redigo: connection pool closed")
+var (
+	errPoolClosed = errors.New("redigo: connection pool closed")
+	errConnClosed = errors.New("redigo: connection closed")
+)
 
 // Pool maintains a pool of connections. The application calls the Get method
 // to get a connection from the pool and the connection's Close method to
@@ -131,12 +134,16 @@ func NewPool(newFn func() (Conn, error), maxIdle int) *Pool {
 }
 
 // Get gets a connection. The application must close the returned connection.
-// The connection acquires an underlying connection on the first call to the
-// connection Do, Send, Receive, Flush or Err methods. An application can force
-// the connection to acquire an underlying connection without executing a Redis
-// command by calling the Err method.
+// This method always returns a valid connection so that applications can defer
+// error handling to the first use of the connection. If there is an error
+// getting an underlying connection, then the connection Err, Do, Send, Flush
+// and Receive methods return that error.
 func (p *Pool) Get() Conn {
-	return &pooledConnection{p: p}
+	c, err := p.get()
+	if err != nil {
+		return errorConnection{err}
+	}
+	return &pooledConnection{p: p, c: c}
 }
 
 // ActiveCount returns the number of active connections in the pool.
@@ -253,19 +260,11 @@ func (p *Pool) put(c Conn, forceClose bool) error {
 }
 
 type pooledConnection struct {
-	c     Conn
-	err   error
 	p     *Pool
+	c     Conn
 	state int
 }
 
-func (pc *pooledConnection) get() error {
-	if pc.err == nil && pc.c == nil {
-		pc.c, pc.err = pc.p.get()
-	}
-	return pc.err
-}
-
 var (
 	sentinel     []byte
 	sentinelOnce sync.Once
@@ -283,77 +282,73 @@ func initSentinel() {
 	}
 }
 
-func (pc *pooledConnection) Close() (err error) {
-	if pc.c != nil {
-		if pc.state&multiState != 0 {
-			pc.c.Send("DISCARD")
-			pc.state &^= (multiState | watchState)
-		} else if pc.state&watchState != 0 {
-			pc.c.Send("UNWATCH")
-			pc.state &^= watchState
-		}
-		if pc.state&subscribeState != 0 {
-			pc.c.Send("UNSUBSCRIBE")
-			pc.c.Send("PUNSUBSCRIBE")
-			// To detect the end of the message stream, ask the server to echo
-			// a sentinel value and read until we see that value.
-			sentinelOnce.Do(initSentinel)
-			pc.c.Send("ECHO", sentinel)
-			pc.c.Flush()
-			for {
-				p, err := pc.c.Receive()
-				if err != nil {
-					break
-				}
-				if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) {
-					pc.state &^= subscribeState
-					break
-				}
+func (pc *pooledConnection) Close() error {
+	c := pc.c
+	if _, ok := c.(errorConnection); ok {
+		return nil
+	}
+	pc.c = errorConnection{errConnClosed}
+
+	if pc.state&multiState != 0 {
+		c.Send("DISCARD")
+		pc.state &^= (multiState | watchState)
+	} else if pc.state&watchState != 0 {
+		c.Send("UNWATCH")
+		pc.state &^= watchState
+	}
+	if pc.state&subscribeState != 0 {
+		c.Send("UNSUBSCRIBE")
+		c.Send("PUNSUBSCRIBE")
+		// To detect the end of the message stream, ask the server to echo
+		// a sentinel value and read until we see that value.
+		sentinelOnce.Do(initSentinel)
+		c.Send("ECHO", sentinel)
+		c.Flush()
+		for {
+			p, err := c.Receive()
+			if err != nil {
+				break
+			}
+			if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) {
+				pc.state &^= subscribeState
+				break
 			}
 		}
-		pc.c.Do("")
-		pc.p.put(pc.c, pc.state != 0)
-		pc.c = nil
-		pc.err = errPoolClosed
 	}
-	return err
+	c.Do("")
+	pc.p.put(c, pc.state != 0)
+	return nil
 }
 
 func (pc *pooledConnection) Err() error {
-	if err := pc.get(); err != nil {
-		return err
-	}
 	return pc.c.Err()
 }
 
 func (pc *pooledConnection) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
-	if err := pc.get(); err != nil {
-		return nil, err
-	}
 	ci := lookupCommandInfo(commandName)
 	pc.state = (pc.state | ci.set) &^ ci.clear
 	return pc.c.Do(commandName, args...)
 }
 
 func (pc *pooledConnection) Send(commandName string, args ...interface{}) error {
-	if err := pc.get(); err != nil {
-		return err
-	}
 	ci := lookupCommandInfo(commandName)
 	pc.state = (pc.state | ci.set) &^ ci.clear
 	return pc.c.Send(commandName, args...)
 }
 
 func (pc *pooledConnection) Flush() error {
-	if err := pc.get(); err != nil {
-		return err
-	}
 	return pc.c.Flush()
 }
 
 func (pc *pooledConnection) Receive() (reply interface{}, err error) {
-	if err := pc.get(); err != nil {
-		return nil, err
-	}
 	return pc.c.Receive()
 }
+
+type errorConnection struct{ err error }
+
+func (ec errorConnection) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err }
+func (ec errorConnection) Send(string, ...interface{}) error              { return ec.err }
+func (ec errorConnection) Err() error                                     { return ec.err }
+func (ec errorConnection) Close() error                                   { return ec.err }
+func (ec errorConnection) Flush() error                                   { return ec.err }
+func (ec errorConnection) Receive() (interface{}, error)                  { return nil, ec.err }

+ 26 - 1
redis/pool_test.go

@@ -156,6 +156,7 @@ func TestPoolClose(t *testing.T) {
 		t.Errorf("expected error after connection closed")
 	}
 
+	c2.Close()
 	c2.Close()
 
 	p.Close()
@@ -168,7 +169,7 @@ func TestPoolClose(t *testing.T) {
 
 	c3.Close()
 
-	d.check("after channel close", p, 3, 0)
+	d.check("after conn close", p, 3, 0)
 
 	c1 = p.Get()
 	if _, err := c1.Do("PING"); err == nil {
@@ -205,6 +206,30 @@ func TestPoolTimeout(t *testing.T) {
 	p.Close()
 }
 
+func TestConcurrenSendReceive(t *testing.T) {
+	p := &Pool{
+		Dial: DialTestDB,
+	}
+	c := p.Get()
+	done := make(chan error, 1)
+	go func() {
+		_, err := c.Receive()
+		done <- err
+	}()
+	c.Send("PING")
+	c.Flush()
+	err := <-done
+	if err != nil {
+		t.Fatalf("Receive() returned error %v", err)
+	}
+	_, err = c.Do("")
+	if err != nil {
+		t.Fatalf("Do() returned error %v", err)
+	}
+	c.Close()
+	p.Close()
+}
+
 func TestBorrowCheck(t *testing.T) {
 	d := poolDialer{t: t}
 	p := &Pool{