Browse Source

Fix connection pending count.

To handle the case where Receive is called before Send, decrement the
pending count after the reply is read.

Fixes #78
Gary Burd 11 years ago
parent
commit
917dc959b1
2 changed files with 33 additions and 8 deletions
  1. 12 8
      redis/conn.go
  2. 21 0
      redis/conn_test.go

+ 12 - 8
redis/conn.go

@@ -345,20 +345,24 @@ func (c *conn) Flush() error {
 }
 }
 
 
 func (c *conn) Receive() (reply interface{}, err error) {
 func (c *conn) Receive() (reply interface{}, err error) {
-	c.mu.Lock()
-	// There can be more receives than sends when using pub/sub. To allow
-	// normal use of the connection after unsubscribe from all channels, do not
-	// decrement pending to a negative value.
-	if c.pending > 0 {
-		c.pending -= 1
-	}
-	c.mu.Unlock()
 	if c.readTimeout != 0 {
 	if c.readTimeout != 0 {
 		c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
 		c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
 	}
 	}
 	if reply, err = c.readReply(); err != nil {
 	if reply, err = c.readReply(); err != nil {
 		return nil, c.fatal(err)
 		return nil, c.fatal(err)
 	}
 	}
+	// When using pub/sub, the number of receives can be greater than the
+	// number of sends. To enable normal use of the connection after
+	// unsubscribing from all channels, we do not decrement pending to a
+	// negative value.
+	//
+	// The pending field is decremented after the reply is read to handle the
+	// case where Receive is called before Send.
+	c.mu.Lock()
+	if c.pending > 0 {
+		c.pending -= 1
+	}
+	c.mu.Unlock()
 	if err, ok := reply.(Error); ok {
 	if err, ok := reply.(Error); ok {
 		return nil, err
 		return nil, err
 	}
 	}

+ 21 - 0
redis/conn_test.go

@@ -292,6 +292,27 @@ func TestBlankCommmand(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestRecvBeforeSend(t *testing.T) {
+	c, err := redis.DialTestDB()
+	if err != nil {
+		t.Fatalf("error connection to database, %v", err)
+	}
+	defer c.Close()
+	done := make(chan struct{})
+	go func() {
+		c.Receive()
+		close(done)
+	}()
+	time.Sleep(time.Millisecond)
+	c.Send("PING")
+	c.Flush()
+	<-done
+	_, err = c.Do("")
+	if err != nil {
+		t.Fatalf("error=%v", err)
+	}
+}
+
 func TestError(t *testing.T) {
 func TestError(t *testing.T) {
 	c, err := redis.DialTestDB()
 	c, err := redis.DialTestDB()
 	if err != nil {
 	if err != nil {