Преглед изворни кода

Add *WithTimeout methods

Add methods to set read timeout on call to Do and Receive.
Gary Burd пре 8 година
родитељ
комит
2ca15d09f9
8 измењених фајлова са 270 додато и 18 уклоњено
  1. 22 5
      redis/conn.go
  2. 50 6
      redis/conn_test.go
  3. 17 0
      redis/log.go
  4. 32 5
      redis/pool.go
  5. 15 2
      redis/pubsub.go
  6. 7 0
      redis/pubsub_test.go
  7. 56 0
      redis/redis.go
  8. 71 0
      redis/redis_test.go

+ 22 - 5
redis/conn.go

@@ -29,6 +29,10 @@ import (
 	"time"
 )
 
+var (
+	_ ConnWithTimeout = (*conn)(nil)
+)
+
 // conn is the low-level implementation of Conn
 type conn struct {
 	// Shared
@@ -571,10 +575,17 @@ func (c *conn) Flush() error {
 	return nil
 }
 
-func (c *conn) Receive() (reply interface{}, err error) {
-	if c.readTimeout != 0 {
-		c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
+func (c *conn) Receive() (interface{}, error) {
+	return c.ReceiveWithTimeout(c.readTimeout)
+}
+
+func (c *conn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
+	var deadline time.Time
+	if timeout != 0 {
+		deadline = time.Now().Add(timeout)
 	}
+	c.conn.SetReadDeadline(deadline)
+
 	if reply, err = c.readReply(); err != nil {
 		return nil, c.fatal(err)
 	}
@@ -597,6 +608,10 @@ func (c *conn) Receive() (reply interface{}, err error) {
 }
 
 func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
+	return c.DoWithTimeout(c.readTimeout, cmd, args...)
+}
+
+func (c *conn) DoWithTimeout(readTimeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
 	c.mu.Lock()
 	pending := c.pending
 	c.pending = 0
@@ -620,9 +635,11 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
 		return nil, c.fatal(err)
 	}
 
-	if c.readTimeout != 0 {
-		c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
+	var deadline time.Time
+	if readTimeout != 0 {
+		deadline = time.Now().Add(readTimeout)
 	}
+	c.conn.SetReadDeadline(deadline)
 
 	if cmd == "" {
 		reply := make([]interface{}, pending)

+ 50 - 6
redis/conn_test.go

@@ -34,14 +34,16 @@ import (
 type testConn struct {
 	io.Reader
 	io.Writer
+	readDeadline  time.Time
+	writeDeadline time.Time
 }
 
-func (*testConn) Close() error                       { return nil }
-func (*testConn) LocalAddr() net.Addr                { return nil }
-func (*testConn) RemoteAddr() net.Addr               { return nil }
-func (*testConn) SetDeadline(t time.Time) error      { return nil }
-func (*testConn) SetReadDeadline(t time.Time) error  { return nil }
-func (*testConn) SetWriteDeadline(t time.Time) error { return nil }
+func (*testConn) Close() error                         { return nil }
+func (*testConn) LocalAddr() net.Addr                  { return nil }
+func (*testConn) RemoteAddr() net.Addr                 { return nil }
+func (c *testConn) SetDeadline(t time.Time) error      { c.readDeadline = t; c.writeDeadline = t; return nil }
+func (c *testConn) SetReadDeadline(t time.Time) error  { c.readDeadline = t; return nil }
+func (c *testConn) SetWriteDeadline(t time.Time) error { c.writeDeadline = t; return nil }
 
 func dialTestConn(r string, w io.Writer) redis.DialOption {
 	return redis.DialNetDial(func(network, addr string) (net.Conn, error) {
@@ -821,3 +823,45 @@ Bjqn3yoLHaoZVvbWOi0C2TCN4FjXjaLNZGifQPbIcaA=
 	clientTLSConfig.RootCAs = x509.NewCertPool()
 	clientTLSConfig.RootCAs.AddCert(certificate)
 }
+
+func TestWithTimeout(t *testing.T) {
+	for _, recv := range []bool{true, false} {
+		for _, defaultTimout := range []time.Duration{0, time.Minute} {
+			var buf bytes.Buffer
+			nc := &testConn{Reader: strings.NewReader("+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n"), Writer: &buf}
+			c, _ := redis.Dial("", "", redis.DialReadTimeout(defaultTimout), redis.DialNetDial(func(network, addr string) (net.Conn, error) { return nc, nil }))
+			for i := 0; i < 4; i++ {
+				var minDeadline, maxDeadline time.Time
+
+				// Alternate between default and specified timeout.
+				if i%2 == 0 {
+					if defaultTimout != 0 {
+						minDeadline = time.Now().Add(defaultTimout)
+					}
+					if recv {
+						c.Receive()
+					} else {
+						c.Do("PING")
+					}
+					if defaultTimout != 0 {
+						maxDeadline = time.Now().Add(defaultTimout)
+					}
+				} else {
+					timeout := 10 * time.Minute
+					minDeadline = time.Now().Add(timeout)
+					if recv {
+						redis.ReceiveWithTimeout(c, timeout)
+					} else {
+						redis.DoWithTimeout(c, timeout, "PING")
+					}
+					maxDeadline = time.Now().Add(timeout)
+				}
+
+				// Expect set deadline in expected range.
+				if nc.readDeadline.Before(minDeadline) || nc.readDeadline.After(maxDeadline) {
+					t.Errorf("recv %v, %d: do deadline error: %v, %v, %v", recv, i, minDeadline, nc.readDeadline, maxDeadline)
+				}
+			}
+		}
+	}
+}

+ 17 - 0
redis/log.go

@@ -18,6 +18,11 @@ import (
 	"bytes"
 	"fmt"
 	"log"
+	"time"
+)
+
+var (
+	_ ConnWithTimeout = (*loggingConn)(nil)
 )
 
 // NewLoggingConn returns a logging wrapper around a connection.
@@ -104,6 +109,12 @@ func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{},
 	return reply, err
 }
 
+func (c *loggingConn) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (interface{}, error) {
+	reply, err := DoWithTimeout(c.Conn, timeout, commandName, args...)
+	c.print("DoWithTimeout", commandName, args, reply, err)
+	return reply, err
+}
+
 func (c *loggingConn) Send(commandName string, args ...interface{}) error {
 	err := c.Conn.Send(commandName, args...)
 	c.print("Send", commandName, args, nil, err)
@@ -115,3 +126,9 @@ func (c *loggingConn) Receive() (interface{}, error) {
 	c.print("Receive", "", nil, reply, err)
 	return reply, err
 }
+
+func (c *loggingConn) ReceiveWithTimeout(timeout time.Duration) (interface{}, error) {
+	reply, err := ReceiveWithTimeout(c.Conn, timeout)
+	c.print("ReceiveWithTimeout", "", nil, reply, err)
+	return reply, err
+}

+ 32 - 5
redis/pool.go

@@ -28,6 +28,11 @@ import (
 	"github.com/garyburd/redigo/internal"
 )
 
+var (
+	_ ConnWithTimeout = (*pooledConnection)(nil)
+	_ ConnWithTimeout = (*errorConnection)(nil)
+)
+
 var nowFunc = time.Now // for testing
 
 // ErrPoolExhausted is returned from a pool connection method (Do, Send,
@@ -418,6 +423,16 @@ func (pc *pooledConnection) Do(commandName string, args ...interface{}) (reply i
 	return pc.c.Do(commandName, args...)
 }
 
+func (pc *pooledConnection) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error) {
+	cwt, ok := pc.c.(ConnWithTimeout)
+	if !ok {
+		return nil, errTimeoutNotSupported
+	}
+	ci := internal.LookupCommandInfo(commandName)
+	pc.state = (pc.state | ci.Set) &^ ci.Clear
+	return cwt.DoWithTimeout(timeout, commandName, args...)
+}
+
 func (pc *pooledConnection) Send(commandName string, args ...interface{}) error {
 	ci := internal.LookupCommandInfo(commandName)
 	pc.state = (pc.state | ci.Set) &^ ci.Clear
@@ -432,11 +447,23 @@ func (pc *pooledConnection) Receive() (reply interface{}, err error) {
 	return pc.c.Receive()
 }
 
+func (pc *pooledConnection) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
+	cwt, ok := pc.c.(ConnWithTimeout)
+	if !ok {
+		return nil, errTimeoutNotSupported
+	}
+	return cwt.ReceiveWithTimeout(timeout)
+}
+
 type errorConnection struct{ err error }
 
 func (ec errorConnection) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err }
-func (ec errorConnection) Send(string, ...interface{}) error              { return ec.err }
-func (ec errorConnection) Err() error                                     { return ec.err }
-func (ec errorConnection) Close() error                                   { return ec.err }
-func (ec errorConnection) Flush() error                                   { return ec.err }
-func (ec errorConnection) Receive() (interface{}, error)                  { return nil, ec.err }
+func (ec errorConnection) DoWithTimeout(time.Duration, string, ...interface{}) (interface{}, error) {
+	return nil, ec.err
+}
+func (ec errorConnection) Send(string, ...interface{}) error                     { return ec.err }
+func (ec errorConnection) Err() error                                            { return ec.err }
+func (ec errorConnection) Close() error                                          { return nil }
+func (ec errorConnection) Flush() error                                          { return ec.err }
+func (ec errorConnection) Receive() (interface{}, error)                         { return nil, ec.err }
+func (ec errorConnection) ReceiveWithTimeout(time.Duration) (interface{}, error) { return nil, ec.err }

+ 15 - 2
redis/pubsub.go

@@ -14,7 +14,10 @@
 
 package redis
 
-import "errors"
+import (
+	"errors"
+	"time"
+)
 
 // Subscription represents a subscribe or unsubscribe notification.
 type Subscription struct {
@@ -103,7 +106,17 @@ func (c PubSubConn) Ping(data string) error {
 // or error. The return value is intended to be used directly in a type switch
 // as illustrated in the PubSubConn example.
 func (c PubSubConn) Receive() interface{} {
-	reply, err := Values(c.Conn.Receive())
+	return c.receiveInternal(c.Conn.Receive())
+}
+
+// ReceiveWithTimeout is like Receive, but it allows the application to
+// override the connection's default timeout.
+func (c PubSubConn) ReceiveWithTimeout(timeout time.Duration) interface{} {
+	return c.receiveInternal(ReceiveWithTimeout(c.Conn, timeout))
+}
+
+func (c PubSubConn) receiveInternal(replyArg interface{}, errArg error) interface{} {
+	reply, err := Values(replyArg, errArg)
 	if err != nil {
 		return err
 	}

+ 7 - 0
redis/pubsub_test.go

@@ -17,6 +17,7 @@ package redis_test
 import (
 	"reflect"
 	"testing"
+	"time"
 
 	"github.com/garyburd/redigo/redis"
 )
@@ -64,4 +65,10 @@ func TestPushed(t *testing.T) {
 	c.Conn.Send("PING")
 	c.Conn.Flush()
 	expectPushed(t, c, `Send("PING")`, redis.Pong{})
+
+	c.Ping("timeout")
+	got := c.ReceiveWithTimeout(time.Minute)
+	if want := (redis.Pong{Data: "timeout"}); want != got {
+		t.Errorf("recv /w timeout got %v, want %v", got, want)
+	}
 }

+ 56 - 0
redis/redis.go

@@ -14,6 +14,11 @@
 
 package redis
 
+import (
+	"errors"
+	"time"
+)
+
 // Error represents an error returned in a command reply.
 type Error string
 
@@ -59,3 +64,54 @@ type Scanner interface {
 	// loss of information.
 	RedisScan(src interface{}) error
 }
+
+// ConnWithTimeout is an optional interface that allows the caller to override
+// a connection's default read timeout. This interface is useful for executing
+// the BLPOP, BRPOP, BRPOPLPUSH, XREAD and other commands that block at the
+// server.
+//
+// A connection's default read timeout is set with the DialReadTimeout dial
+// option. Applications should rely on the default timeout for commands that do
+// not block at the server.
+//
+// All of the Conn implementations in this package satisfy the ConnWithTimeout
+// interface.
+//
+// Use the DoWithTimeout and ReceiveWithTimeout helper functions to simplify
+// use of this interface.
+type ConnWithTimeout interface {
+	Conn
+
+	// Do sends a command to the server and returns the received reply.
+	// The timeout overrides the read timeout set when dialing the
+	// connection.
+	DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error)
+
+	// Receive receives a single reply from the Redis server. The timeout
+	// overrides the read timeout set when dialing the connection.
+	ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error)
+}
+
+var errTimeoutNotSupported = errors.New("redis: connection does not support ConnWithTimeout")
+
+// DoWithTimeout executes a Redis command with the specified read timeout. If
+// the connection does not satisfy the ConnWithTimeout interface, then an error
+// is returned.
+func DoWithTimeout(c Conn, timeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
+	cwt, ok := c.(ConnWithTimeout)
+	if !ok {
+		return nil, errTimeoutNotSupported
+	}
+	return cwt.DoWithTimeout(timeout, cmd, args...)
+}
+
+// ReceiveWithTimeout receives a reply with the specified read timeout. If the
+// connection does not satisfy the ConnWithTimeout interface, then an error is
+// returned.
+func ReceiveWithTimeout(c Conn, timeout time.Duration) (interface{}, error) {
+	cwt, ok := c.(ConnWithTimeout)
+	if !ok {
+		return nil, errTimeoutNotSupported
+	}
+	return cwt.ReceiveWithTimeout(timeout)
+}

+ 71 - 0
redis/redis_test.go

@@ -0,0 +1,71 @@
+// Copyright 2017 Gary Burd
+//
+// Licensed under the Apache License, Version 2.0 (the "License"): you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations
+// under the License.
+
+package redis_test
+
+import (
+	"testing"
+	"time"
+
+	"github.com/garyburd/redigo/redis"
+)
+
+type timeoutTestConn int
+
+func (tc timeoutTestConn) Do(string, ...interface{}) (interface{}, error) {
+	return time.Duration(-1), nil
+}
+func (tc timeoutTestConn) DoWithTimeout(timeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
+	return timeout, nil
+}
+
+func (tc timeoutTestConn) Receive() (interface{}, error) {
+	return time.Duration(-1), nil
+}
+func (tc timeoutTestConn) ReceiveWithTimeout(timeout time.Duration) (interface{}, error) {
+	return timeout, nil
+}
+
+func (tc timeoutTestConn) Send(string, ...interface{}) error { return nil }
+func (tc timeoutTestConn) Err() error                        { return nil }
+func (tc timeoutTestConn) Close() error                      { return nil }
+func (tc timeoutTestConn) Flush() error                      { return nil }
+
+func testTimeout(t *testing.T, c redis.Conn) {
+	r, err := c.Do("PING")
+	if r != time.Duration(-1) || err != nil {
+		t.Errorf("Do() = %v, %v, want %v, %v", r, err, time.Duration(-1), nil)
+	}
+	r, err = redis.DoWithTimeout(c, time.Minute, "PING")
+	if r != time.Minute || err != nil {
+		t.Errorf("DoWithTimeout() = %v, %v, want %v, %v", r, err, time.Minute, nil)
+	}
+	r, err = c.Receive()
+	if r != time.Duration(-1) || err != nil {
+		t.Errorf("Receive() = %v, %v, want %v, %v", r, err, time.Duration(-1), nil)
+	}
+	r, err = redis.ReceiveWithTimeout(c, time.Minute)
+	if r != time.Minute || err != nil {
+		t.Errorf("ReceiveWithTimeout() = %v, %v, want %v, %v", r, err, time.Minute, nil)
+	}
+}
+
+func TestConnTimeout(t *testing.T) {
+	testTimeout(t, timeoutTestConn(0))
+}
+
+func TestPoolConnTimeout(t *testing.T) {
+	p := &redis.Pool{Dial: func() (redis.Conn, error) { return timeoutTestConn(0), nil }}
+	testTimeout(t, p.Get())
+}