Bläddra i källkod

Cleanup pub/sub state on pool close.

Gary Burd 11 år sedan
förälder
incheckning
19f7ce87fe
7 ändrade filer med 147 tillägg och 75 borttagningar
  1. 17 53
      redis/conn_test.go
  2. 42 3
      redis/pool.go
  3. 35 17
      redis/pool_test.go
  4. 4 1
      redis/pubsub_test.go
  5. 4 0
      redis/reply_test.go
  6. 4 1
      redis/script_test.go
  7. 41 0
      redis/test_test.go

+ 17 - 53
redis/conn_test.go

@@ -17,7 +17,6 @@ package redis_test
 import (
 	"bufio"
 	"bytes"
-	"errors"
 	"math"
 	"net"
 	"reflect"
@@ -160,53 +159,6 @@ func TestRead(t *testing.T) {
 	}
 }
 
-type testConn struct {
-	redis.Conn
-}
-
-func (t testConn) Close() error {
-	_, err := t.Conn.Do("SELECT", "9")
-	if err != nil {
-		return nil
-	}
-	_, err = t.Conn.Do("FLUSHDB")
-	if err != nil {
-		return err
-	}
-	return t.Conn.Close()
-}
-
-func dial() (redis.Conn, error) {
-	c, err := redis.DialTimeout("tcp", ":6379", 0, 1*time.Second, 1*time.Second)
-	if err != nil {
-		return nil, err
-	}
-
-	_, err = c.Do("SELECT", "9")
-	if err != nil {
-		return nil, err
-	}
-
-	n, err := redis.Int(c.Do("DBSIZE"))
-	if err != nil {
-		return nil, err
-	}
-
-	if n != 0 {
-		return nil, errors.New("database #9 is not empty, test can not continue")
-	}
-
-	return testConn{c}, nil
-}
-
-func dialt(t *testing.T) redis.Conn {
-	c, err := dial()
-	if err != nil {
-		t.Fatalf("error connection to database, %v", err)
-	}
-	return c
-}
-
 var testCommands = []struct {
 	args     []interface{}
 	expected interface{}
@@ -269,7 +221,10 @@ var testCommands = []struct {
 }
 
 func TestDoCommands(t *testing.T) {
-	c := dialt(t)
+	c, err := redis.DialTest()
+	if err != nil {
+		t.Fatalf("error connection to database, %v", err)
+	}
 	defer c.Close()
 
 	for _, cmd := range testCommands {
@@ -285,7 +240,10 @@ func TestDoCommands(t *testing.T) {
 }
 
 func TestPipelineCommands(t *testing.T) {
-	c := dialt(t)
+	c, err := redis.DialTest()
+	if err != nil {
+		t.Fatalf("error connection to database, %v", err)
+	}
 	defer c.Close()
 
 	for _, cmd := range testCommands {
@@ -308,7 +266,10 @@ func TestPipelineCommands(t *testing.T) {
 }
 
 func TestBlankCommmand(t *testing.T) {
-	c := dialt(t)
+	c, err := redis.DialTest()
+	if err != nil {
+		t.Fatalf("error connection to database, %v", err)
+	}
 	defer c.Close()
 
 	for _, cmd := range testCommands {
@@ -332,11 +293,14 @@ func TestBlankCommmand(t *testing.T) {
 }
 
 func TestError(t *testing.T) {
-	c := dialt(t)
+	c, err := redis.DialTest()
+	if err != nil {
+		t.Fatalf("error connection to database, %v", err)
+	}
 	defer c.Close()
 
 	c.Do("SET", "key", "val")
-	_, err := c.Do("HSET", "key", "fld", "val")
+	_, err = c.Do("HSET", "key", "fld", "val")
 	if err == nil {
 		t.Errorf("Expected err for HSET on string key.")
 	}

+ 42 - 3
redis/pool.go

@@ -15,8 +15,13 @@
 package redis
 
 import (
+	"bytes"
 	"container/list"
+	"crypto/rand"
+	"crypto/sha1"
 	"errors"
+	"io"
+	"strconv"
 	"sync"
 	"time"
 )
@@ -260,6 +265,23 @@ func (c *pooledConnection) get() error {
 	return c.err
 }
 
+var (
+	sentinel     []byte
+	sentinelOnce sync.Once
+)
+
+func initSentinel() {
+	p := make([]byte, 64)
+	if _, err := rand.Read(p); err == nil {
+		sentinel = p
+	} else {
+		h := sha1.New()
+		io.WriteString(h, "Oops, rand failed. Use time instead.")
+		io.WriteString(h, strconv.FormatInt(time.Now().UnixNano(), 10))
+		sentinel = h.Sum(nil)
+	}
+}
+
 func (c *pooledConnection) Close() (err error) {
 	if c.c != nil {
 		if c.state&multiState != 0 {
@@ -269,9 +291,26 @@ func (c *pooledConnection) Close() (err error) {
 			c.c.Send("UNWATCH")
 			c.state &^= watchState
 		}
-		// TODO: Clear subscription state by executing PUNSUBSCRIBE,
-		// UNSUBSCRIBE and ECHO sentinel and receiving until the sentinel is
-		// found. The sentinel is a random string generated once at runtime.
+		if c.state&subscribeState != 0 {
+			c.c.Send("UNSUBSCRIBE")
+			c.c.Send("PUNSUBSCRIBE")
+			// To detect the end of the message stream, ask the server to echo
+			// a sentinel value and read until we see that value.
+			sentinelOnce.Do(initSentinel)
+			c.c.Send("ECHO", sentinel)
+			c.c.Flush()
+			//for i := 0; i < 10; i++ {
+			for {
+				p, err := c.c.Receive()
+				if err != nil {
+					break
+				}
+				if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) {
+					c.state &^= subscribeState
+					break
+				}
+			}
+		}
 		c.c.Do("")
 		c.p.put(c.c, c.state != 0)
 		c.c = nil

+ 35 - 17
redis/pool_test.go

@@ -24,6 +24,7 @@ import (
 type poolTestConn struct {
 	d   *poolDialer
 	err error
+	Conn
 }
 
 func (c *poolTestConn) Close() error { c.d.open -= 1; return nil }
@@ -36,20 +37,12 @@ func (c *poolTestConn) Do(commandName string, args ...interface{}) (reply interf
 	if commandName != "" {
 		c.d.commands = append(c.d.commands, commandName)
 	}
-	return nil, nil
+	return c.Conn.Do(commandName, args...)
 }
 
 func (c *poolTestConn) Send(commandName string, args ...interface{}) error {
 	c.d.commands = append(c.d.commands, commandName)
-	return nil
-}
-
-func (c *poolTestConn) Flush() error {
-	return nil
-}
-
-func (c *poolTestConn) Receive() (reply interface{}, err error) {
-	return nil, nil
+	return c.Conn.Send(commandName, args...)
 }
 
 type poolDialer struct {
@@ -61,7 +54,11 @@ type poolDialer struct {
 func (d *poolDialer) dial() (Conn, error) {
 	d.open += 1
 	d.dialed += 1
-	return &poolTestConn{d: d}, nil
+	c, err := DialTest()
+	if err != nil {
+		return nil, err
+	}
+	return &poolTestConn{d: d, Conn: c}, nil
 }
 
 func (d *poolDialer) check(message string, p *Pool, dialed, open int) {
@@ -255,7 +252,7 @@ func TestMaxActive(t *testing.T) {
 	d.check("4", p, 2, 2)
 }
 
-func TestPoolPubSubMonitorCleanup(t *testing.T) {
+func TestMonitorCleanup(t *testing.T) {
 	d := poolDialer{t: t}
 	p := &Pool{
 		MaxIdle:   2,
@@ -263,18 +260,39 @@ func TestPoolPubSubMonitorCleanup(t *testing.T) {
 		Dial:      d.dial,
 	}
 	c := p.Get()
-	c.Send("SUBSCRIBE", "x")
+	c.Send("MONITOR")
 	c.Close()
 
-	c = p.Get()
-	c.Send("PSUBSCRIBE", "x")
+	d.check("", p, 1, 0)
+}
+
+func TestPubSubCleanup(t *testing.T) {
+	d := poolDialer{t: t}
+	p := &Pool{
+		MaxIdle:   2,
+		MaxActive: 2,
+		Dial:      d.dial,
+	}
+
+	c := p.Get()
+	c.Send("SUBSCRIBE", "x")
 	c.Close()
 
+	want := []string{"SUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"}
+	if !reflect.DeepEqual(d.commands, want) {
+		t.Errorf("got commands %v, want %v", d.commands, want)
+	}
+	d.commands = nil
+
 	c = p.Get()
-	c.Send("MONITOR")
+	c.Send("PSUBSCRIBE", "x*")
 	c.Close()
 
-	d.check("", p, 3, 0)
+	want = []string{"PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"}
+	if !reflect.DeepEqual(d.commands, want) {
+		t.Errorf("got commands %v, want %v", d.commands, want)
+	}
+	d.commands = nil
 }
 
 func TestTransactionCleanup(t *testing.T) {

+ 4 - 1
redis/pubsub_test.go

@@ -109,7 +109,10 @@ func expectPushed(t *testing.T, c redis.PubSubConn, message string, expected int
 }
 
 func TestPushed(t *testing.T) {
-	pc := dialt(t)
+	pc, err := redis.DialTest()
+	if err != nil {
+		t.Fatalf("error connection to database, %v", err)
+	}
 	defer pc.Close()
 
 	nc, err := net.Dial("tcp", ":6379")

+ 4 - 0
redis/reply_test.go

@@ -80,6 +80,10 @@ func TestReply(t *testing.T) {
 	}
 }
 
+func dial() (redis.Conn, error) {
+	return redis.DialTest()
+}
+
 func ExampleBool() {
 	c, err := dial()
 	if err != nil {

+ 4 - 1
redis/script_test.go

@@ -33,7 +33,10 @@ func ExampleScript(c redis.Conn, reply interface{}, err error) {
 }
 
 func TestScript(t *testing.T) {
-	c := dialt(t)
+	c, err := redis.DialTest()
+	if err != nil {
+		t.Fatalf("error connection to database, %v", err)
+	}
 	defer c.Close()
 
 	// To test fall back in Do, we make script unique by adding comment with current time.

+ 41 - 0
redis/test_test.go

@@ -16,9 +16,50 @@ package redis
 
 import (
 	"bufio"
+	"errors"
 	"net"
+	"time"
 )
 
+type testConn struct {
+	Conn
+}
+
+func (t testConn) Close() error {
+	_, err := t.Conn.Do("SELECT", "9")
+	if err != nil {
+		return nil
+	}
+	_, err = t.Conn.Do("FLUSHDB")
+	if err != nil {
+		return err
+	}
+	return t.Conn.Close()
+}
+
+func DialTest() (Conn, error) {
+	c, err := DialTimeout("tcp", ":6379", 0, 1*time.Second, 1*time.Second)
+	if err != nil {
+		return nil, err
+	}
+
+	_, err = c.Do("SELECT", "9")
+	if err != nil {
+		return nil, err
+	}
+
+	n, err := Int(c.Do("DBSIZE"))
+	if err != nil {
+		return nil, err
+	}
+
+	if n != 0 {
+		return nil, errors.New("database #9 is not empty, test can not continue")
+	}
+
+	return testConn{c}, nil
+}
+
 type dummyClose struct{ net.Conn }
 
 func (dummyClose) Close() error { return nil }