Browse Source

Merge pull request #208 from freeformz/rediss

Support rediss:// urls
Gary Burd 9 năm trước cách đây
mục cha
commit
d7de57ceb5
3 tập tin đã thay đổi với 113 bổ sung1 xóa
  1. 49 1
      redis/conn.go
  2. 33 0
      redis/go17.go
  3. 31 0
      redis/pre_go17.go

+ 49 - 1
redis/conn.go

@@ -17,6 +17,7 @@ package redis
 import (
 	"bufio"
 	"bytes"
+	"crypto/tls"
 	"errors"
 	"fmt"
 	"io"
@@ -75,6 +76,9 @@ type dialOptions struct {
 	dial         func(network, addr string) (net.Conn, error)
 	db           int
 	password     string
+	dialTLS      bool
+	skipVerify   bool
+	tlsConfig    *tls.Config
 }
 
 // DialReadTimeout specifies the timeout for reading a single command reply.
@@ -123,6 +127,22 @@ func DialPassword(password string) DialOption {
 	}}
 }
 
+// DialTLSConfig specifies the config to use when a TLS connection is dialed.
+//  Has no effect when not dialing a TLS connection.
+func DialTLSConfig(c *tls.Config) DialOption {
+	return DialOption{func(do *dialOptions) {
+		do.tlsConfig = c
+	}}
+}
+
+// DialTLSSkipVerify to disable server name verification when connecting
+// over TLS. Has no effect when not dialing a TLS connection.
+func DialTLSSkipVerify(skip bool) DialOption {
+	return DialOption{func(do *dialOptions) {
+		do.skipVerify = skip
+	}}
+}
+
 // Dial connects to the Redis server at the given network and
 // address using the specified options.
 func Dial(network, address string, options ...DialOption) (Conn, error) {
@@ -137,6 +157,26 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
 	if err != nil {
 		return nil, err
 	}
+
+	if do.dialTLS {
+		tlsConfig := cloneTLSClientConfig(do.tlsConfig, do.skipVerify)
+		if tlsConfig.ServerName == "" {
+			host, _, err := net.SplitHostPort(address)
+			if err != nil {
+				netConn.Close()
+				return nil, err
+			}
+			tlsConfig.ServerName = host
+		}
+
+		tlsConn := tls.Client(netConn, tlsConfig)
+		if err := tlsConn.Handshake(); err != nil {
+			netConn.Close()
+			return nil, err
+		}
+		netConn = tlsConn
+	}
+
 	c := &conn{
 		conn:         netConn,
 		bw:           bufio.NewWriter(netConn),
@@ -162,6 +202,10 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
 	return c, nil
 }
 
+func dialTLS(do *dialOptions) {
+	do.dialTLS = true
+}
+
 var pathDBRegexp = regexp.MustCompile(`/(\d*)\z`)
 
 // DialURL connects to a Redis server at the given URL using the Redis
@@ -173,7 +217,7 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) {
 		return nil, err
 	}
 
-	if u.Scheme != "redis" {
+	if u.Scheme != "redis" && u.Scheme != "rediss" {
 		return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme)
 	}
 
@@ -213,6 +257,10 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) {
 		return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
 	}
 
+	if u.Scheme == "rediss" {
+		options = append([]DialOption{{dialTLS}}, options...)
+	}
+
 	return Dial("tcp", address, options...)
 }
 

+ 33 - 0
redis/go17.go

@@ -0,0 +1,33 @@
+// +build go1.7
+
+package redis
+
+import "crypto/tls"
+
+// similar cloneTLSClientConfig in the stdlib, but also honor skipVerify for the nil case
+func cloneTLSClientConfig(cfg *tls.Config, skipVerify bool) *tls.Config {
+	if cfg == nil {
+		return &tls.Config{InsecureSkipVerify: skipVerify}
+	}
+	return &tls.Config{
+		Rand:                        cfg.Rand,
+		Time:                        cfg.Time,
+		Certificates:                cfg.Certificates,
+		NameToCertificate:           cfg.NameToCertificate,
+		GetCertificate:              cfg.GetCertificate,
+		RootCAs:                     cfg.RootCAs,
+		NextProtos:                  cfg.NextProtos,
+		ServerName:                  cfg.ServerName,
+		ClientAuth:                  cfg.ClientAuth,
+		ClientCAs:                   cfg.ClientCAs,
+		InsecureSkipVerify:          cfg.InsecureSkipVerify,
+		CipherSuites:                cfg.CipherSuites,
+		PreferServerCipherSuites:    cfg.PreferServerCipherSuites,
+		ClientSessionCache:          cfg.ClientSessionCache,
+		MinVersion:                  cfg.MinVersion,
+		MaxVersion:                  cfg.MaxVersion,
+		CurvePreferences:            cfg.CurvePreferences,
+		DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
+		Renegotiation:               cfg.Renegotiation,
+	}
+}

+ 31 - 0
redis/pre_go17.go

@@ -0,0 +1,31 @@
+// +build !go1.7
+
+package redis
+
+import "crypto/tls"
+
+// similar cloneTLSClientConfig in the stdlib, but also honor skipVerify for the nil case
+func cloneTLSClientConfig(cfg *tls.Config, skipVerify bool) *tls.Config {
+	if cfg == nil {
+		return &tls.Config{InsecureSkipVerify: skipVerify}
+	}
+	return &tls.Config{
+		Rand:                     cfg.Rand,
+		Time:                     cfg.Time,
+		Certificates:             cfg.Certificates,
+		NameToCertificate:        cfg.NameToCertificate,
+		GetCertificate:           cfg.GetCertificate,
+		RootCAs:                  cfg.RootCAs,
+		NextProtos:               cfg.NextProtos,
+		ServerName:               cfg.ServerName,
+		ClientAuth:               cfg.ClientAuth,
+		ClientCAs:                cfg.ClientCAs,
+		InsecureSkipVerify:       cfg.InsecureSkipVerify,
+		CipherSuites:             cfg.CipherSuites,
+		PreferServerCipherSuites: cfg.PreferServerCipherSuites,
+		ClientSessionCache:       cfg.ClientSessionCache,
+		MinVersion:               cfg.MinVersion,
+		MaxVersion:               cfg.MaxVersion,
+		CurvePreferences:         cfg.CurvePreferences,
+	}
+}