Procházet zdrojové kódy

Results, scripts, concurrency.

- Add functions for parsing results of various types.
- Add helpers for scripts.
- Allow some concurrent use of connections.
Gary Burd před 13 roky
rodič
revize
719c2a9830
9 změnil soubory, kde provedl 475 přidání a 43 odebrání
  1. 0 2
      README.markdown
  2. 47 31
      redis/conn.go
  3. 15 8
      redis/conn_test.go
  4. 12 2
      redis/redis.go
  5. 145 0
      redis/result.go
  6. 85 0
      redis/result_test.go
  7. 76 0
      redis/script.go
  8. 71 0
      redis/script_test.go
  9. 24 0
      redis/test_test.go

+ 0 - 2
README.markdown

@@ -4,8 +4,6 @@ Redigo
 Redigo is a [Go](http://golang.org/) client for the [Redis](http://redis.io/)
 database.
 
-Redigo is a work in progress. 
-
 The Redigo API reference is available on
 [GoPkgDoc](http://gopkgdoc.appspot.com/pkg/github.com/garyburd/redigo/redis).
 

+ 47 - 31
redis/conn.go

@@ -22,6 +22,7 @@ import (
 	"io"
 	"net"
 	"strconv"
+	"sync"
 	"time"
 )
 
@@ -29,20 +30,19 @@ import (
 type conn struct {
 	rw      bufio.ReadWriter
 	conn    net.Conn
-	err     error
 	scratch []byte
 	pending int
+	mu      sync.Mutex
+	err     error
 }
 
 // Dial connects to the Redis server at the given network and address.
-// 
-// The returned connection is not thread-safe.
 func Dial(network, address string) (Conn, error) {
 	netConn, err := net.Dial(network, address)
 	if err != nil {
 		return nil, err
 	}
-	return newConn(netConn), nil
+	return NewConn(netConn), nil
 }
 
 // DialTimeout acts like Dial but takes a timeout. The timeout includes name
@@ -52,10 +52,11 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
 	if err != nil {
 		return nil, err
 	}
-	return newConn(netConn), nil
+	return NewConn(netConn), nil
 }
 
-func newConn(netConn net.Conn) Conn {
+// NewConn returns a new Redigo connection for the given net connection.
+func NewConn(netConn net.Conn) Conn {
 	return &conn{
 		conn: netConn,
 		rw: bufio.ReadWriter{
@@ -66,11 +67,29 @@ func newConn(netConn net.Conn) Conn {
 }
 
 func (c *conn) Close() error {
-	return c.conn.Close()
+	err := c.conn.Close()
+	if err != nil {
+		c.fatal(err)
+	} else {
+		c.fatal(errors.New("redigo: closed"))
+	}
+	return err
 }
 
 func (c *conn) Err() error {
-	return c.err
+	c.mu.Lock()
+	err := c.err
+	c.mu.Unlock()
+	return err
+}
+
+func (c *conn) fatal(err error) error {
+	c.mu.Lock()
+	if c.err != nil {
+		c.err = err
+	}
+	c.mu.Unlock()
+	return err
 }
 
 func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
@@ -190,35 +209,30 @@ func (c *conn) parseReply() (interface{}, error) {
 
 // Send sends a command for the server without waiting for a reply.
 func (c *conn) Send(cmd string, args ...interface{}) error {
-	if c.err != nil {
-		return c.err
-	}
-
-	c.err = c.writeN('*', 1+len(args))
-	if c.err != nil {
-		return c.err
+	if err := c.writeN('*', 1+len(args)); err != nil {
+		return c.fatal(err)
 	}
 
-	c.err = c.writeString(cmd)
-	if c.err != nil {
-		return c.err
+	if err := c.writeString(cmd); err != nil {
+		return c.fatal(err)
 	}
 
 	for _, arg := range args {
+		var err error
 		switch arg := arg.(type) {
 		case string:
-			c.err = c.writeString(arg)
+			err = c.writeString(arg)
 		case []byte:
-			c.err = c.writeBytes(arg)
+			err = c.writeBytes(arg)
 		case nil:
-			c.err = c.writeString("")
+			err = c.writeString("")
 		default:
 			var buf bytes.Buffer
 			fmt.Fprint(&buf, arg)
-			c.err = c.writeBytes(buf.Bytes())
+			err = c.writeBytes(buf.Bytes())
 		}
-		if c.err != nil {
-			return c.err
+		if err != nil {
+			return c.fatal(err)
 		}
 	}
 	c.pending += 1
@@ -227,15 +241,17 @@ func (c *conn) Send(cmd string, args ...interface{}) error {
 
 func (c *conn) Receive() (interface{}, error) {
 	c.pending -= 1
-	c.err = c.rw.Flush()
-	if c.err != nil {
-		return nil, c.err
+	if err := c.rw.Flush(); err != nil {
+		return nil, c.fatal(err)
 	}
 	v, err := c.parseReply()
+	if err == nil {
+		if e, ok := v.(Error); ok {
+			err = e
+		}
+	}
 	if err != nil {
-		c.err = err
-	} else if e, ok := v.(Error); ok {
-		err = e
+		return nil, c.fatal(err)
 	}
-	return v, err
+	return v, nil
 }

+ 15 - 8
redis/conn_test.go

@@ -12,12 +12,13 @@
 // License for the specific language governing permissions and limitations
 // under the License.
 
-package redis
+package redis_test
 
 import (
 	"bufio"
 	"bytes"
 	"errors"
+	"github.com/garyburd/redigo/redis"
 	"reflect"
 	"strings"
 	"testing"
@@ -52,13 +53,14 @@ var sendTests = []struct {
 func TestSend(t *testing.T) {
 	for _, tt := range sendTests {
 		var buf bytes.Buffer
-		c := conn{rw: bufio.ReadWriter{Writer: bufio.NewWriter(&buf)}}
+		rw := bufio.ReadWriter{Writer: bufio.NewWriter(&buf)}
+		c := redis.NewConnBufio(rw)
 		err := c.Send(tt.args[0].(string), tt.args[1:]...)
 		if err != nil {
 			t.Errorf("Send(%v) returned error %v", tt.args, err)
 			continue
 		}
-		c.rw.Flush()
+		rw.Flush()
 		actual := buf.String()
 		if actual != tt.expected {
 			t.Errorf("Send(%v) = %q, want %q", tt.args, actual, tt.expected)
@@ -112,10 +114,11 @@ var receiveTests = []struct {
 
 func TestReceive(t *testing.T) {
 	for _, tt := range receiveTests {
-		c := conn{rw: bufio.ReadWriter{
+		rw := bufio.ReadWriter{
 			Reader: bufio.NewReader(strings.NewReader(tt.reply)),
 			Writer: bufio.NewWriter(nil), // writer need to support Flush
-		}}
+		}
+		c := redis.NewConnBufio(rw)
 		actual, err := c.Receive()
 		if tt.expected == errorSentinel {
 			if err == nil {
@@ -133,8 +136,8 @@ func TestReceive(t *testing.T) {
 	}
 }
 
-func connect() (Conn, error) {
-	c, err := Dial("tcp", ":6379")
+func connect() (redis.Conn, error) {
+	c, err := redis.Dial("tcp", ":6379")
 	if err != nil {
 		return nil, err
 	}
@@ -156,7 +159,7 @@ func connect() (Conn, error) {
 	return c, nil
 }
 
-func disconnect(c Conn) error {
+func disconnect(c redis.Conn) error {
 	_, err := c.Do("SELECT", "9")
 	if err != nil {
 		return nil
@@ -185,6 +188,10 @@ var testCommands = []struct {
 		[]interface{}{"GET", "nokey"},
 		nil,
 	},
+	{
+		[]interface{}{"MGET", "nokey", "foo"},
+		[]interface{}{nil, []byte("bar")},
+	},
 	{
 		[]interface{}{"INCR", "mycounter"},
 		int64(1),

+ 12 - 2
redis/redis.go

@@ -89,7 +89,7 @@
 //  
 // The connection Receive method is used to implement blocking subscribers: 
 //
-//  c.Do("SUBSCRIBE", "foo")
+//  c.Send("SUBSCRIBE", "foo")
 //  for {
 //      reply, err := c.Receive()
 //      if err != nil {
@@ -97,9 +97,19 @@
 //      }
 //      // consume message
 //  }
+//
+// Thread Safety
+//
+// The Send method cannot be called concurrently with other calls to Send. The
+// Receive method cannot be called concurrently  with other calls to Receive.
+// Because the Do method invokes Send and Receive, the Do method cannot be
+// called concurrently  with Send, Receive or Do. All other concurrent access is
+// allowed.
 package redis
 
-// Error represets an error returned in a command reply.
+import ()
+
+// Error represents an error returned in a command reply.
 type Error string
 
 func (err Error) Error() string { return string(err) }

+ 145 - 0
redis/result.go

@@ -0,0 +1,145 @@
+// Copyright 2012 Gary Burd
+//
+// Licensed under the Apache License, Version 2.0 (the "License"): you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations
+// under the License.
+
+package redis
+
+import (
+	"errors"
+	"strconv"
+)
+
+var (
+	errUnexpectedResultType = errors.New("redigo: unexpected result type")
+)
+
+// Int is a helper that wraps a call to the Conn Do and Receive methods and
+// returns the result as an integer. If the result is an integer, then Int
+// returns the integer. If the result is a bulk response, then Int parses the
+// result as a signed decimal value.  Otherwise, Int returns an error.
+func Int(v interface{}, err error) (int, error) {
+	if err != nil {
+		return 0, err
+	}
+	switch v := v.(type) {
+	case int64:
+		return int(v), nil
+	case []byte:
+		n, err := strconv.ParseInt(string(v), 10, 0)
+		return int(n), err
+	case Error:
+		return 0, v
+	}
+	return 0, errUnexpectedResultType
+}
+
+// String is a helper that wraps a call to the Conn Do and Receive methods and
+// returns the result as a string. If the result is a bulk response, then
+// String returns the bytes converted to a string. If the result is an integer,
+// then String formats the result as a decimal string. Otherwise, String returns
+// an error.
+func String(v interface{}, err error) (string, error) {
+	if err != nil {
+		return "", err
+	}
+	switch v := v.(type) {
+	case int64:
+		return strconv.FormatInt(v, 10), nil
+	case []byte:
+		return string(v), nil
+	case Error:
+		return "", v
+	}
+	return "", errUnexpectedResultType
+}
+
+// Bytes is a helper that wraps a call to the Conn Do or Receive methods and
+// returns the result as a []byte. If the result is a bulk response, then Bytes
+// returns the result as is.  If the result is an integer, then Bytes formats
+// the result as a decimal string. Otherwise, Bytes returns an error.
+func Bytes(v interface{}, err error) ([]byte, error) {
+	if err != nil {
+		return nil, err
+	}
+	switch v := v.(type) {
+	case int64:
+		return strconv.AppendInt(nil, v, 10), nil
+	case []byte:
+		return v, nil
+	case Error:
+		return nil, v
+	}
+	return nil, errUnexpectedResultType
+}
+
+// Subscribe represents a subscribe or unsubscribe notification.
+type Subscription struct {
+
+	// Kind is "subscribe", "unsubscribe", "psubscribe" or "punsubscribe"
+	Kind string
+
+	// The channel that was changed.
+	Channel string
+
+	// The current number of subscriptions for connection.
+	Count int
+}
+
+// Message represents a message notification.
+type Message struct {
+
+	// The originating channel.
+	Channel string
+
+	// The message data.
+	Data []byte
+}
+
+// Notification returns the result from the Conn Receive method as a
+// Subscription or a Message.
+func Notification(v interface{}, err error) (interface{}, error) {
+	if err != nil {
+		return nil, err
+	}
+	err = errUnexpectedResultType
+	s, ok := v.([]interface{})
+	if !ok || len(s) != 3 {
+		return nil, errUnexpectedResultType
+	}
+	b, ok := s[0].([]byte)
+	if !ok {
+		return nil, errUnexpectedResultType
+	}
+	kind := string(b)
+
+	b, ok = s[1].([]byte)
+	if !ok {
+		return nil, errUnexpectedResultType
+	}
+	channel := string(b)
+
+	if kind == "message" {
+		data, ok := s[2].([]byte)
+		if !ok {
+			return nil, errUnexpectedResultType
+		}
+		return Message{channel, data}, nil
+	}
+
+	count, ok := s[2].(int64)
+	if !ok {
+		return nil, errUnexpectedResultType
+	}
+
+	return Subscription{kind, channel, int(count)}, nil
+}

+ 85 - 0
redis/result_test.go

@@ -0,0 +1,85 @@
+// Copyright 2012 Gary Burd
+//
+// Licensed under the Apache License, Version 2.0 (the "License"): you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations
+// under the License.
+
+package redis_test
+
+import (
+	"fmt"
+	"github.com/garyburd/redigo/redis"
+	"net"
+	"reflect"
+	"testing"
+	"time"
+)
+
+func ExampleNotification(c redis.Conn) {
+	c.Send("SUBSCRIBE", "mychannel")
+	for {
+		n, err := redis.Notification(c.Receive())
+		if err != nil {
+			break
+		}
+		switch n := n.(type) {
+		case redis.Message:
+			fmt.Printf("%s: message: %s", n.Channel, n.Data)
+		case redis.Subscription:
+			fmt.Printf("%s: %s %d", n.Channel, n.Kind, n.Count)
+		default:
+			panic("unexpected")
+		}
+	}
+}
+
+func expectNotification(t *testing.T, c redis.Conn, message string, expected interface{}) {
+	actual, err := redis.Notification(c.Receive())
+	if err != nil {
+		t.Errorf("%s returned error %v", message, err)
+		return
+	}
+	if !reflect.DeepEqual(actual, expected) {
+		t.Errorf("%s = %v, want %v", message, actual, expected)
+	}
+}
+
+func TestNotification(t *testing.T) {
+	pc, err := connect()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer disconnect(pc)
+
+	nc, err := net.Dial("tcp", ":6379")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer nc.Close()
+	nc.SetReadDeadline(time.Now().Add(4 * time.Second))
+
+	c := redis.NewConn(nc)
+
+	c.Send("SUBSCRIBE", "c1")
+	expectNotification(t, c, "Subscribe(c1)", redis.Subscription{"subscribe", "c1", 1})
+	c.Send("SUBSCRIBE", "c2")
+	expectNotification(t, c, "Subscribe(c2)", redis.Subscription{"subscribe", "c2", 2})
+	c.Send("PSUBSCRIBE", "p1")
+	expectNotification(t, c, "PSubscribe(p1)", redis.Subscription{"psubscribe", "p1", 3})
+	c.Send("PSUBSCRIBE", "p2")
+	expectNotification(t, c, "PSubscribe(p2)", redis.Subscription{"psubscribe", "p2", 4})
+	c.Send("PUNSUBSCRIBE")
+	expectNotification(t, c, "Punsubscribe(p1)", redis.Subscription{"punsubscribe", "p1", 3})
+	expectNotification(t, c, "Punsubscribe()", redis.Subscription{"punsubscribe", "p2", 2})
+
+	pc.Do("PUBLISH", "c1", "hello")
+	expectNotification(t, c, "PUBLISH c1 hello", redis.Message{"c1", []byte("hello")})
+}

+ 76 - 0
redis/script.go

@@ -0,0 +1,76 @@
+// Copyright 2012 Gary Burd
+//
+// Licensed under the Apache License, Version 2.0 (the "License"): you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations
+// under the License.
+
+package redis
+
+import (
+	"crypto/sha1"
+	"encoding/hex"
+	"io"
+	"strings"
+)
+
+// Script encapsulates the source, hash and key count for a Lua script. See
+// http://redis.io/commands/eval for information on scripts in Redis.
+type Script struct {
+	keyCount int
+	src      string
+	hash     string
+}
+
+// NewScript returns a new script object initialized with the specified number
+// of keys and source code.
+func NewScript(keyCount int, src string) *Script {
+	h := sha1.New()
+	io.WriteString(h, src)
+	return &Script{keyCount, src, hex.EncodeToString(h.Sum(nil))}
+}
+
+func (s *Script) args(spec string, keysAndArgs []interface{}) []interface{} {
+	args := make([]interface{}, 2+len(keysAndArgs))
+	args[0] = spec
+	args[1] = s.keyCount
+	copy(args[2:], keysAndArgs)
+	return args
+}
+
+// Do evaluates the script and returns the result. Under the covers, Do
+// attempts to evaluate the script using the EVALSHA command. If the command
+// fails because the script is not loaded, then Do evaluates the script using
+// the EVAL command (thus causing the script to load).
+func (s *Script) Do(c Conn, keysAndArgs ...interface{}) (interface{}, error) {
+	v, err := c.Do("EVALSHA", s.args(s.hash, keysAndArgs)...)
+	if e, ok := err.(Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") {
+		v, err = c.Do("EVAL", s.args(s.src, keysAndArgs)...)
+	}
+	return v, err
+}
+
+// SendHash evaluates the script without waiting for the result. The script is
+// evaluated with the EVALSHA command. The application must ensure that the
+// script is loaded by a previous call to Send, Do or Load methods.
+func (s *Script) SendHash(c Conn, keysAndArgs ...interface{}) error {
+	return c.Send("EVALSHA", s.args(s.hash, keysAndArgs)...)
+}
+
+// Send evaluates the script without waiting for the result. 
+func (s *Script) Send(c Conn, keysAndArgs ...interface{}) error {
+	return c.Send("EVAL", s.args(s.src, keysAndArgs)...)
+}
+
+// Load loads the script without evaluating it.
+func (s *Script) Load(c Conn) error {
+	_, err := c.Do("SCRIPT", "LOAD", s.src)
+	return err
+}

+ 71 - 0
redis/script_test.go

@@ -0,0 +1,71 @@
+// Copyright 2012 Gary Burd
+//
+// Licensed under the Apache License, Version 2.0 (the "License"): you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations
+// under the License.
+
+package redis_test
+
+import (
+	"fmt"
+	"github.com/garyburd/redigo/redis"
+	"reflect"
+	"testing"
+	"time"
+)
+
+func TestScript(t *testing.T) {
+	c, err := connect()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer disconnect(c)
+
+	// To test fallback in Do, we make script unique by adding comment with current time.
+	script := fmt.Sprintf("--%d\nreturn {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", time.Now().UnixNano())
+	s := redis.NewScript(2, script)
+	result := []interface{}{[]byte("key1"), []byte("key2"), []byte("arg1"), []byte("arg2")}
+
+	v, err := s.Do(c, "key1", "key2", "arg1", "arg2")
+	if err != nil {
+		t.Errorf("s.Do(c, ...) returned %v", err)
+	}
+
+	if !reflect.DeepEqual(v, result) {
+		t.Errorf("s.Do(c, ..); = %v, want %v", v, result)
+	}
+
+	err = s.Load(c)
+	if err != nil {
+		t.Errorf("s.Load(c) returned %v", err)
+	}
+
+	err = s.SendHash(c, "key1", "key2", "arg1", "arg2")
+	if err != nil {
+		t.Errorf("s.SendHash(c, ...) returned %v", err)
+	}
+
+	v, err = c.Receive()
+	if !reflect.DeepEqual(v, result) {
+		t.Errorf("s.SendHash(c, ..); s.Recevie() = %v, want %v", v, result)
+	}
+
+	err = s.Send(c, "key1", "key2", "arg1", "arg2")
+	if err != nil {
+		t.Errorf("s.Send(c, ...) returned %v", err)
+	}
+
+	v, err = c.Receive()
+	if !reflect.DeepEqual(v, result) {
+		t.Errorf("s.Send(c, ..); s.Recevie() = %v, want %v", v, result)
+	}
+
+}

+ 24 - 0
redis/test_test.go

@@ -0,0 +1,24 @@
+// Copyright 2012 Gary Burd
+//
+// Licensed under the Apache License, Version 2.0 (the "License"): you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations
+// under the License.
+
+package redis
+
+import (
+	"bufio"
+)
+
+// NewConnBufio is a hook for tests.
+func NewConnBufio(rw bufio.ReadWriter) Conn {
+	return &conn{rw: rw}
+}