Quellcode durchsuchen

pool: handle server connection state

Add server state tracking to the pool connection. The connection Close
method clears watch and transaction state. If a subscription or monitor
command was used, the underlying connection is closed.

Fixes issue #53.
Gary Burd vor 11 Jahren
Ursprung
Commit
32bf3361d3
3 geänderte Dateien mit 186 neuen und 28 gelöschten Zeilen
  1. 45 0
      redis/commandinfo.go
  2. 21 6
      redis/pool.go
  3. 120 22
      redis/pool_test.go

+ 45 - 0
redis/commandinfo.go

@@ -0,0 +1,45 @@
+// Copyright 2014 Gary Burd
+//
+// Licensed under the Apache License, Version 2.0 (the "License"): you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations
+// under the License.
+
+package redis
+
+import (
+	"strings"
+)
+
+const (
+	watchState = 1 << iota
+	multiState
+	subscribeState
+	monitorState
+)
+
+type commandInfo struct {
+	set, clear int
+}
+
+var commandInfos = map[string]commandInfo{
+	"WATCH":      commandInfo{set: watchState},
+	"UNWATCH":    commandInfo{clear: watchState},
+	"MULTI":      commandInfo{set: multiState},
+	"EXEC":       commandInfo{clear: watchState | multiState},
+	"DISCARD":    commandInfo{clear: watchState | multiState},
+	"PSUBSCRIBE": commandInfo{set: subscribeState},
+	"SUBSCRIBE":  commandInfo{set: subscribeState},
+	"MONITOR":    commandInfo{set: monitorState},
+}
+
+func lookupCommandInfo(commandName string) commandInfo {
+	return commandInfos[strings.ToUpper(commandName)]
+}

+ 21 - 6
redis/pool.go

@@ -210,8 +210,8 @@ func (p *Pool) get() (Conn, error) {
 	return c, err
 }
 
-func (p *Pool) put(c Conn) error {
-	if c.Err() == nil {
+func (p *Pool) put(c Conn, forceClose bool) error {
+	if c.Err() == nil && !forceClose {
 		p.mu.Lock()
 		if !p.closed {
 			p.idle.PushFront(idleConn{t: nowFunc(), c: c})
@@ -233,9 +233,10 @@ func (p *Pool) put(c Conn) error {
 }
 
 type pooledConnection struct {
-	c   Conn
-	err error
-	p   *Pool
+	c     Conn
+	err   error
+	p     *Pool
+	state int
 }
 
 func (c *pooledConnection) get() error {
@@ -247,8 +248,18 @@ func (c *pooledConnection) get() error {
 
 func (c *pooledConnection) Close() (err error) {
 	if c.c != nil {
+		if c.state&multiState != 0 {
+			c.c.Send("DISCARD")
+			c.state &^= (multiState | watchState)
+		} else if c.state&watchState != 0 {
+			c.c.Send("UNWATCH")
+			c.state &^= watchState
+		}
+		// TODO: Clear subscription state by executing PUNSUBSCRIBE,
+		// UNSUBSCRIBE and ECHO sentinel and receiving until the sentinel is
+		// found. The sentinel is a random string generated once at runtime.
 		c.c.Do("")
-		c.p.put(c.c)
+		c.p.put(c.c, c.state != 0)
 		c.c = nil
 		c.err = errPoolClosed
 	}
@@ -266,6 +277,8 @@ func (c *pooledConnection) Do(commandName string, args ...interface{}) (reply in
 	if err := c.get(); err != nil {
 		return nil, err
 	}
+	ci := lookupCommandInfo(commandName)
+	c.state = (c.state | ci.set) &^ ci.clear
 	return c.c.Do(commandName, args...)
 }
 
@@ -273,6 +286,8 @@ func (c *pooledConnection) Send(commandName string, args ...interface{}) error {
 	if err := c.get(); err != nil {
 		return err
 	}
+	ci := lookupCommandInfo(commandName)
+	c.state = (c.state | ci.set) &^ ci.clear
 	return c.c.Send(commandName, args...)
 }
 

+ 120 - 22
redis/pool_test.go

@@ -16,49 +16,55 @@ package redis
 
 import (
 	"io"
+	"reflect"
 	"testing"
 	"time"
 )
 
-type fakeConn struct {
-	open *int
-	err  error
+type poolTestConn struct {
+	d   *poolDialer
+	err error
 }
 
-func (c *fakeConn) Close() error { *c.open -= 1; return nil }
-func (c *fakeConn) Err() error   { return c.err }
+func (c *poolTestConn) Close() error { c.d.open -= 1; return nil }
+func (c *poolTestConn) Err() error   { return c.err }
 
-func (c *fakeConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
+func (c *poolTestConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
 	if commandName == "ERR" {
 		c.err = args[0].(error)
 	}
+	if commandName != "" {
+		c.d.commands = append(c.d.commands, commandName)
+	}
 	return nil, nil
 }
 
-func (c *fakeConn) Send(commandName string, args ...interface{}) error {
+func (c *poolTestConn) Send(commandName string, args ...interface{}) error {
+	c.d.commands = append(c.d.commands, commandName)
 	return nil
 }
 
-func (c *fakeConn) Flush() error {
+func (c *poolTestConn) Flush() error {
 	return nil
 }
 
-func (c *fakeConn) Receive() (reply interface{}, err error) {
+func (c *poolTestConn) Receive() (reply interface{}, err error) {
 	return nil, nil
 }
 
-type dialer struct {
+type poolDialer struct {
 	t            *testing.T
 	dialed, open int
+	commands     []string
 }
 
-func (d *dialer) dial() (Conn, error) {
+func (d *poolDialer) dial() (Conn, error) {
 	d.open += 1
 	d.dialed += 1
-	return &fakeConn{open: &d.open}, nil
+	return &poolTestConn{d: d}, nil
 }
 
-func (d *dialer) check(message string, p *Pool, dialed, open int) {
+func (d *poolDialer) check(message string, p *Pool, dialed, open int) {
 	if d.dialed != dialed {
 		d.t.Errorf("%s: dialed=%d, want %d", message, d.dialed, dialed)
 	}
@@ -71,7 +77,7 @@ func (d *dialer) check(message string, p *Pool, dialed, open int) {
 }
 
 func TestPoolReuse(t *testing.T) {
-	d := dialer{t: t}
+	d := poolDialer{t: t}
 	p := &Pool{
 		MaxIdle: 2,
 		Dial:    d.dial,
@@ -92,7 +98,7 @@ func TestPoolReuse(t *testing.T) {
 }
 
 func TestPoolMaxIdle(t *testing.T) {
-	d := dialer{t: t}
+	d := poolDialer{t: t}
 	p := &Pool{
 		MaxIdle: 2,
 		Dial:    d.dial,
@@ -114,7 +120,7 @@ func TestPoolMaxIdle(t *testing.T) {
 }
 
 func TestPoolError(t *testing.T) {
-	d := dialer{t: t}
+	d := poolDialer{t: t}
 	p := &Pool{
 		MaxIdle: 2,
 		Dial:    d.dial,
@@ -135,7 +141,7 @@ func TestPoolError(t *testing.T) {
 }
 
 func TestPoolClose(t *testing.T) {
-	d := dialer{t: t}
+	d := poolDialer{t: t}
 	p := &Pool{
 		MaxIdle: 2,
 		Dial:    d.dial,
@@ -174,7 +180,7 @@ func TestPoolClose(t *testing.T) {
 }
 
 func TestPoolTimeout(t *testing.T) {
-	d := dialer{t: t}
+	d := poolDialer{t: t}
 	p := &Pool{
 		MaxIdle:     2,
 		IdleTimeout: 300 * time.Second,
@@ -201,7 +207,7 @@ func TestPoolTimeout(t *testing.T) {
 }
 
 func TestBorrowCheck(t *testing.T) {
-	d := dialer{t: t}
+	d := poolDialer{t: t}
 	p := &Pool{
 		MaxIdle:      2,
 		Dial:         d.dial,
@@ -217,7 +223,7 @@ func TestBorrowCheck(t *testing.T) {
 }
 
 func TestMaxActive(t *testing.T) {
-	d := dialer{t: t}
+	d := poolDialer{t: t}
 	p := &Pool{
 		MaxIdle:   2,
 		MaxActive: 2,
@@ -238,7 +244,7 @@ func TestMaxActive(t *testing.T) {
 	c3.Close()
 	d.check("2", p, 2, 2)
 	c2.Close()
-	d.check("2", p, 2, 2)
+	d.check("3", p, 2, 2)
 
 	c3 = p.Get()
 	if _, err := c3.Do("PING"); err != nil {
@@ -246,5 +252,97 @@ func TestMaxActive(t *testing.T) {
 	}
 	c3.Close()
 
-	d.check("2", p, 2, 2)
+	d.check("4", p, 2, 2)
+}
+
+func TestPoolPubSubMonitorCleanup(t *testing.T) {
+	d := poolDialer{t: t}
+	p := &Pool{
+		MaxIdle:   2,
+		MaxActive: 2,
+		Dial:      d.dial,
+	}
+	c := p.Get()
+	c.Send("SUBSCRIBE", "x")
+	c.Close()
+
+	c = p.Get()
+	c.Send("PSUBSCRIBE", "x")
+	c.Close()
+
+	c = p.Get()
+	c.Send("MONITOR")
+	c.Close()
+
+	d.check("", p, 3, 0)
+}
+
+func TestTransactionCleanup(t *testing.T) {
+	d := poolDialer{t: t}
+	p := &Pool{
+		MaxIdle:   2,
+		MaxActive: 2,
+		Dial:      d.dial,
+	}
+
+	c := p.Get()
+	c.Do("WATCH", "key")
+	c.Do("PING")
+	c.Close()
+
+	want := []string{"WATCH", "PING", "UNWATCH"}
+	if !reflect.DeepEqual(d.commands, want) {
+		t.Errorf("got commands %v, want %v", d.commands, want)
+	}
+	d.commands = nil
+
+	c = p.Get()
+	c.Do("WATCH", "key")
+	c.Do("UNWATCH")
+	c.Do("PING")
+	c.Close()
+
+	want = []string{"WATCH", "UNWATCH", "PING"}
+	if !reflect.DeepEqual(d.commands, want) {
+		t.Errorf("got commands %v, want %v", d.commands, want)
+	}
+	d.commands = nil
+
+	c = p.Get()
+	c.Do("WATCH", "key")
+	c.Do("MULTI")
+	c.Do("PING")
+	c.Close()
+
+	want = []string{"WATCH", "MULTI", "PING", "DISCARD"}
+	if !reflect.DeepEqual(d.commands, want) {
+		t.Errorf("got commands %v, want %v", d.commands, want)
+	}
+	d.commands = nil
+
+	c = p.Get()
+	c.Do("WATCH", "key")
+	c.Do("MULTI")
+	c.Do("DISCARD")
+	c.Do("PING")
+	c.Close()
+
+	want = []string{"WATCH", "MULTI", "DISCARD", "PING"}
+	if !reflect.DeepEqual(d.commands, want) {
+		t.Errorf("got commands %v, want %v", d.commands, want)
+	}
+	d.commands = nil
+
+	c = p.Get()
+	c.Do("WATCH", "key")
+	c.Do("MULTI")
+	c.Do("EXEC")
+	c.Do("PING")
+	c.Close()
+
+	want = []string{"WATCH", "MULTI", "EXEC", "PING"}
+	if !reflect.DeepEqual(d.commands, want) {
+		t.Errorf("got commands %v, want %v", d.commands, want)
+	}
+	d.commands = nil
 }