Browse Source

Support rediss:// urls

rediss:// is a valid scheme based on IANA :
https://www.iana.org/assignments/uri-schemes/prov/rediss

We (Heroku) would like to move secure redis connections forward, but
have this chicken and egg problem with some drivers. Drivers for other
languages already support rediss://
Edward Muller 9 years ago
parent
commit
8838467569
1 changed files with 29 additions and 1 deletions
  1. 29 1
      redis/conn.go

+ 29 - 1
redis/conn.go

@@ -17,6 +17,7 @@ package redis
 import (
 	"bufio"
 	"bytes"
+	"crypto/tls"
 	"errors"
 	"fmt"
 	"io"
@@ -75,6 +76,7 @@ type dialOptions struct {
 	dial         func(network, addr string) (net.Conn, error)
 	db           int
 	password     string
+	tlsConfig    *tls.Config
 }
 
 // DialReadTimeout specifies the timeout for reading a single command reply.
@@ -123,6 +125,14 @@ func DialPassword(password string) DialOption {
 	}}
 }
 
+// DialTLS specifies that the connection to the Redis server should be
+// encrypted and use the provided tls.Config.
+func DialTLS(c *tls.Config) DialOption {
+	return DialOption{func(do *dialOptions) {
+		do.tlsConfig = c
+	}}
+}
+
 // Dial connects to the Redis server at the given network and
 // address using the specified options.
 func Dial(network, address string, options ...DialOption) (Conn, error) {
@@ -137,6 +147,15 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
 	if err != nil {
 		return nil, err
 	}
+
+	if do.tlsConfig != nil {
+		tlsConn := tls.Client(netConn, do.tlsConfig)
+		if err := tlsConn.Handshake(); err != nil {
+			return nil, err
+		}
+		netConn = tlsConn
+	}
+
 	c := &conn{
 		conn:         netConn,
 		bw:           bufio.NewWriter(netConn),
@@ -173,7 +192,7 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) {
 		return nil, err
 	}
 
-	if u.Scheme != "redis" {
+	if u.Scheme != "redis" && u.Scheme != "rediss" {
 		return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme)
 	}
 
@@ -213,6 +232,15 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) {
 		return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
 	}
 
+	if u.Scheme == "rediss" {
+		// insert a default DlialTLS at position 0, so we don't override any
+		// user provided *tls.Config elsewhere in options
+		t := DialTLS(&tls.Config{InsecureSkipVerify: true})
+		options = append(options, t)
+		copy(options[1:], options[0:])
+		options[0] = t
+	}
+
 	return Dial("tcp", address, options...)
 }