Browse Source

Add dial options

Gary Burd 10 years ago
parent
commit
5f01fdde26
2 changed files with 77 additions and 34 deletions
  1. 3 0
      internal/redistest/testdb.go
  2. 74 34
      redis/conn.go

+ 3 - 0
internal/redistest/testdb.go

@@ -49,15 +49,18 @@ func Dial() (redis.Conn, error) {
 
 	_, err = c.Do("SELECT", "9")
 	if err != nil {
+		c.Close()
 		return nil, err
 	}
 
 	n, err := redis.Int(c.Do("DBSIZE"))
 	if err != nil {
+		c.Close()
 		return nil, err
 	}
 
 	if n != 0 {
+		c.Close()
 		return nil, errors.New("database #9 is not empty, test can not continue")
 	}
 

+ 74 - 34
redis/conn.go

@@ -51,56 +51,96 @@ type conn struct {
 	numScratch [40]byte
 }
 
-// Dial connects to the Redis server at the given network and address.
-func Dial(network, address string) (Conn, error) {
-	dialer := xDialer{}
-	return dialer.Dial(network, address)
-}
-
 // DialTimeout acts like Dial but takes timeouts for establishing the
 // connection to the server, writing a command and reading a reply.
+//
+// DialTimeout is deprecated.
 func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) {
-	netDialer := net.Dialer{Timeout: connectTimeout}
-	dialer := xDialer{
-		NetDial:      netDialer.Dial,
-		ReadTimeout:  readTimeout,
-		WriteTimeout: writeTimeout,
-	}
-	return dialer.Dial(network, address)
+	return Dial(network, address,
+		DialConnectTimeout(connectTimeout),
+		DialReadTimeout(readTimeout),
+		DialWriteTimeout(writeTimeout))
 }
 
-// A Dialer specifies options for connecting to a Redis server.
-type xDialer struct {
-	// NetDial specifies the dial function for creating TCP connections. If
-	// NetDial is nil, then net.Dial is used.
-	NetDial func(network, addr string) (net.Conn, error)
+// DialOption specifies an option for dialing a Redis server.
+type DialOption struct {
+	f func(*dialOptions)
+}
 
-	// ReadTimeout specifies the timeout for reading a single command
-	// reply. If ReadTimeout is zero, then no timeout is used.
-	ReadTimeout time.Duration
+type dialOptions struct {
+	readTimeout  time.Duration
+	writeTimeout time.Duration
+	dial         func(network, addr string) (net.Conn, error)
+	db           int
+	password     string
+}
+
+// DialReadTimeout specifies the timeout for reading a single command reply.
+func DialReadTimeout(d time.Duration) DialOption {
+	return DialOption{func(do *dialOptions) {
+		do.readTimeout = d
+	}}
+}
+
+// DialWriteTimeout specifies the timeout for writing a single command.
+func DialWriteTimeout(d time.Duration) DialOption {
+	return DialOption{func(do *dialOptions) {
+		do.writeTimeout = d
+	}}
+}
+
+// DialConnectTimeout specifies the timeout for connecting to the Redis server.
+func DialConnectTimeout(d time.Duration) DialOption {
+	return DialOption{func(do *dialOptions) {
+		dialer := net.Dialer{Timeout: d}
+		do.dial = dialer.Dial
+	}}
+}
 
-	// WriteTimeout specifies the timeout for writing a single command.  If
-	// WriteTimeout is zero, then no timeout is used.
-	WriteTimeout time.Duration
+// DialDatabase specifies the database to select when dialing a connection.
+func DialDatabase(db int) DialOption {
+	return DialOption{func(do *dialOptions) {
+		do.db = db
+	}}
 }
 
-// Dial connects to the Redis server at address on the named network.
-func (d *xDialer) Dial(network, address string) (Conn, error) {
-	dial := d.NetDial
-	if dial == nil {
-		dial = net.Dial
+// Dial connects to the Redis server at the given network and
+// address using the specified options.
+func Dial(network, address string, options ...DialOption) (Conn, error) {
+	do := dialOptions{
+		dial: net.Dial,
 	}
-	netConn, err := dial(network, address)
+	for _, option := range options {
+		option.f(&do)
+	}
+
+	netConn, err := do.dial(network, address)
 	if err != nil {
 		return nil, err
 	}
-	return &conn{
+	c := &conn{
 		conn:         netConn,
 		bw:           bufio.NewWriter(netConn),
 		br:           bufio.NewReader(netConn),
-		readTimeout:  d.ReadTimeout,
-		writeTimeout: d.WriteTimeout,
-	}, nil
+		readTimeout:  do.readTimeout,
+		writeTimeout: do.writeTimeout,
+	}
+
+	if do.password != "" {
+		if _, err := c.Do("AUTH", do.password); err != nil {
+			netConn.Close()
+			return nil, err
+		}
+	}
+
+	if do.db != 0 {
+		if _, err := c.Do("SELECT", do.db); err != nil {
+			netConn.Close()
+			return nil, err
+		}
+	}
+
+	return c, nil
 }
 
 // NewConn returns a new Redigo connection for the given net connection.