|
|
@@ -17,6 +17,7 @@ package redis
|
|
|
import (
|
|
|
"bufio"
|
|
|
"bytes"
|
|
|
+ "crypto/tls"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
@@ -75,6 +76,9 @@ type dialOptions struct {
|
|
|
dial func(network, addr string) (net.Conn, error)
|
|
|
db int
|
|
|
password string
|
|
|
+ dialTLS bool
|
|
|
+ skipVerify bool
|
|
|
+ tlsConfig *tls.Config
|
|
|
}
|
|
|
|
|
|
// DialReadTimeout specifies the timeout for reading a single command reply.
|
|
|
@@ -123,6 +127,22 @@ func DialPassword(password string) DialOption {
|
|
|
}}
|
|
|
}
|
|
|
|
|
|
+// DialTLSConfig specifies the config to use when a TLS connection is dialed.
|
|
|
+// Has no effect when not dialing a TLS connection.
|
|
|
+func DialTLSConfig(c *tls.Config) DialOption {
|
|
|
+ return DialOption{func(do *dialOptions) {
|
|
|
+ do.tlsConfig = c
|
|
|
+ }}
|
|
|
+}
|
|
|
+
|
|
|
+// DialTLSSkipVerify to disable server name verification when connecting
|
|
|
+// over TLS. Has no effect when not dialing a TLS connection.
|
|
|
+func DialTLSSkipVerify(skip bool) DialOption {
|
|
|
+ return DialOption{func(do *dialOptions) {
|
|
|
+ do.skipVerify = skip
|
|
|
+ }}
|
|
|
+}
|
|
|
+
|
|
|
// Dial connects to the Redis server at the given network and
|
|
|
// address using the specified options.
|
|
|
func Dial(network, address string, options ...DialOption) (Conn, error) {
|
|
|
@@ -137,6 +157,26 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
+
|
|
|
+ if do.dialTLS {
|
|
|
+ tlsConfig := cloneTLSClientConfig(do.tlsConfig, do.skipVerify)
|
|
|
+ if tlsConfig.ServerName == "" {
|
|
|
+ host, _, err := net.SplitHostPort(address)
|
|
|
+ if err != nil {
|
|
|
+ netConn.Close()
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ tlsConfig.ServerName = host
|
|
|
+ }
|
|
|
+
|
|
|
+ tlsConn := tls.Client(netConn, tlsConfig)
|
|
|
+ if err := tlsConn.Handshake(); err != nil {
|
|
|
+ netConn.Close()
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ netConn = tlsConn
|
|
|
+ }
|
|
|
+
|
|
|
c := &conn{
|
|
|
conn: netConn,
|
|
|
bw: bufio.NewWriter(netConn),
|
|
|
@@ -162,6 +202,10 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
|
|
|
return c, nil
|
|
|
}
|
|
|
|
|
|
+func dialTLS(do *dialOptions) {
|
|
|
+ do.dialTLS = true
|
|
|
+}
|
|
|
+
|
|
|
var pathDBRegexp = regexp.MustCompile(`/(\d*)\z`)
|
|
|
|
|
|
// DialURL connects to a Redis server at the given URL using the Redis
|
|
|
@@ -173,7 +217,7 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
- if u.Scheme != "redis" {
|
|
|
+ if u.Scheme != "redis" && u.Scheme != "rediss" {
|
|
|
return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme)
|
|
|
}
|
|
|
|
|
|
@@ -213,6 +257,10 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) {
|
|
|
return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
|
|
|
}
|
|
|
|
|
|
+ if u.Scheme == "rediss" {
|
|
|
+ options = append([]DialOption{{dialTLS}}, options...)
|
|
|
+ }
|
|
|
+
|
|
|
return Dial("tcp", address, options...)
|
|
|
}
|
|
|
|