Przeglądaj źródła

Add DialURL to allow connection using a Redis URL

Fixes #134
Emanuel Evans 10 lat temu
rodzic
commit
5e9d3af585
2 zmienionych plików z 166 dodań i 0 usunięć
  1. 61 0
      redis/conn.go
  2. 105 0
      redis/conn_test.go

+ 61 - 0
redis/conn.go

@@ -21,6 +21,8 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net"
 	"net"
+	"net/url"
+	"regexp"
 	"strconv"
 	"strconv"
 	"sync"
 	"sync"
 	"time"
 	"time"
@@ -113,6 +115,14 @@ func DialDatabase(db int) DialOption {
 	}}
 	}}
 }
 }
 
 
+// DialPassword specifies the password to use when connecting to
+// the Redis server.
+func DialPassword(password string) DialOption {
+	return DialOption{func(do *dialOptions) {
+		do.password = password
+	}}
+}
+
 // Dial connects to the Redis server at the given network and
 // Dial connects to the Redis server at the given network and
 // address using the specified options.
 // address using the specified options.
 func Dial(network, address string, options ...DialOption) (Conn, error) {
 func Dial(network, address string, options ...DialOption) (Conn, error) {
@@ -152,6 +162,57 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
 	return c, nil
 	return c, nil
 }
 }
 
 
+var pathDBRegexp = regexp.MustCompile(`/(\d)\z`)
+
+// DialURL connects to a Redis server at the given URL using the Redis
+// URI scheme. URLs should follow the draft IANA specification for the
+// scheme (https://www.iana.org/assignments/uri-schemes/prov/redis).
+func DialURL(rawurl string, options ...DialOption) (Conn, error) {
+	u, err := url.Parse(rawurl)
+	if err != nil {
+		return nil, err
+	}
+
+	if u.Scheme != "redis" {
+		return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme)
+	}
+
+	// As per the IANA draft spec, the host defaults to localhost and
+	// the port defaults to 6379.
+	host, port, err := net.SplitHostPort(u.Host)
+	if err != nil {
+		// assume port is missing
+		host = u.Host
+		port = "6379"
+	}
+	if host == "" {
+		host = "localhost"
+	}
+	address := net.JoinHostPort(host, port)
+
+	if u.User != nil {
+		password, isSet := u.User.Password()
+		if isSet {
+			options = append(options, DialPassword(password))
+		}
+	}
+
+	match := pathDBRegexp.FindStringSubmatch(u.Path)
+	if len(match) == 2 {
+		db, err := strconv.Atoi(match[1])
+		if err != nil {
+			return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
+		}
+		if db != 0 {
+			options = append(options, DialDatabase(db))
+		}
+	} else if u.Path != "" {
+		return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
+	}
+
+	return Dial("tcp", address, options...)
+}
+
 // NewConn returns a new Redigo connection for the given net connection.
 // NewConn returns a new Redigo connection for the given net connection.
 func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
 func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
 	return &conn{
 	return &conn{

+ 105 - 0
redis/conn_test.go

@@ -19,6 +19,7 @@ import (
 	"bytes"
 	"bytes"
 	"math"
 	"math"
 	"net"
 	"net"
+	"os"
 	"reflect"
 	"reflect"
 	"strings"
 	"strings"
 	"testing"
 	"testing"
@@ -422,6 +423,101 @@ func TestReadDeadline(t *testing.T) {
 	}
 	}
 }
 }
 
 
+var dialErrors = []struct {
+	rawurl        string
+	expectedError string
+}{
+	{
+		"localhost",
+		"invalid redis URL scheme",
+	},
+	// The error message for invalid hosts is diffferent in different
+	// versions of Go, so just check that there is an error message.
+	{
+		"redis://weird url",
+		"",
+	},
+	{
+		"redis://foo:bar:baz",
+		"",
+	},
+	{
+		"http://www.google.com",
+		"invalid redis URL scheme: http",
+	},
+	{
+		"redis://x:abc123@localhost",
+		"no password is set",
+	},
+	{
+		"redis://localhost:6379/abc123",
+		"invalid database: abc123",
+	},
+}
+
+func TestDialURL(t *testing.T) {
+	for _, d := range dialErrors {
+		_, err := redis.DialURL(d.rawurl)
+		if err == nil || !strings.Contains(err.Error(), d.expectedError) {
+			t.Errorf("DialURL did not return expected error (expected %v to contain %s)", err, d.expectedError)
+		}
+	}
+
+	checkPort := func(network, address string) (net.Conn, error) {
+		if address != "localhost:6379" {
+			t.Errorf("DialURL did not set port to 6379 by default (got %v)", address)
+		}
+		return net.Dial(network, address)
+	}
+	c, err := redis.DialURL("redis://localhost", redis.DialNetDial(checkPort))
+	if err != nil {
+		t.Error("dial error:", err)
+	}
+	c.Close()
+
+	checkHost := func(network, address string) (net.Conn, error) {
+		if address != "localhost:6379" {
+			t.Errorf("DialURL did not set host to localhost by default (got %v)", address)
+		}
+		return net.Dial(network, address)
+	}
+	c, err = redis.DialURL("redis://:6379", redis.DialNetDial(checkHost))
+	if err != nil {
+		t.Error("dial error:", err)
+	}
+	c.Close()
+
+	// Check that the database is set correctly
+	c1, err := redis.DialURL("redis://:6379/8")
+	defer c1.Close()
+	if err != nil {
+		t.Error("Dial error:", err)
+	}
+	dbSize, _ := redis.Int(c1.Do("DBSIZE"))
+	if dbSize > 0 {
+		t.Fatal("DB 8 has existing keys; aborting test to avoid overwriting data")
+	}
+	c1.Do("SET", "var", "val")
+
+	c2, err := redis.Dial("tcp", ":6379")
+	defer c2.Close()
+	if err != nil {
+		t.Error("dial error:", err)
+	}
+	_, err = c2.Do("SELECT", "8")
+	if err != nil {
+		t.Error(err)
+	}
+	got, err := redis.String(c2.Do("GET", "var"))
+	if err != nil {
+		t.Error(err)
+	}
+	if got != "val" {
+		t.Error("DialURL did not correctly set the db.")
+	}
+	_, err = c2.Do("DEL", "var")
+}
+
 // Connect to local instance of Redis running on the default port.
 // Connect to local instance of Redis running on the default port.
 func ExampleDial(x int) {
 func ExampleDial(x int) {
 	c, err := redis.Dial("tcp", ":6379")
 	c, err := redis.Dial("tcp", ":6379")
@@ -431,6 +527,15 @@ func ExampleDial(x int) {
 	defer c.Close()
 	defer c.Close()
 }
 }
 
 
+// Connect to remote instance of Redis using a URL.
+func ExampleDialURL() {
+	c, err := redis.DialURL(os.Getenv("REDIS_URL"))
+	if err != nil {
+		// handle connection error
+	}
+	defer c.Close()
+}
+
 // TextExecError tests handling of errors in a transaction. See
 // TextExecError tests handling of errors in a transaction. See
 // http://redis.io/topics/transactions for information on how Redis handles
 // http://redis.io/topics/transactions for information on how Redis handles
 // errors in a transaction.
 // errors in a transaction.