소스 검색

Only setup a TLS connection when specified

Based on feedback in the PR.

Also, default to a more secure connection by setting ServerName by
default.
Edward Muller 9 년 전
부모
커밋
03b32d3ac8
1개의 변경된 파일21개의 추가작업 그리고 10개의 파일을 삭제
  1. 21 10
      redis/conn.go

+ 21 - 10
redis/conn.go

@@ -76,6 +76,7 @@ type dialOptions struct {
 	dial         func(network, addr string) (net.Conn, error)
 	db           int
 	password     string
+	dialTLS      bool
 	tlsConfig    *tls.Config
 }
 
@@ -125,9 +126,9 @@ func DialPassword(password string) DialOption {
 	}}
 }
 
-// DialTLS specifies that the connection to the Redis server should be
-// encrypted and use the provided tls.Config.
-func DialTLS(c *tls.Config) DialOption {
+// DialTLSConfig specifies the config to use when a TLS connection is dialed.
+// This has no effect when not dialing a TLS connection.
+func DialTLSConfig(c *tls.Config) DialOption {
 	return DialOption{func(do *dialOptions) {
 		do.tlsConfig = c
 	}}
@@ -148,8 +149,14 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
 		return nil, err
 	}
 
-	if do.tlsConfig != nil {
-		tlsConn := tls.Client(netConn, do.tlsConfig)
+	if do.dialTLS {
+		tlsConfig := do.tlsConfig
+		if tlsConfig == nil {
+			// https://golang.org/pkg/crypto/tls/#Client
+			// At this point we don't know the ServerName.
+			tlsConfig = &tls.Config{InsecureSkipVerify: true}
+		}
+		tlsConn := tls.Client(netConn, tlsConfig)
 		if err := tlsConn.Handshake(); err != nil {
 			return nil, err
 		}
@@ -233,12 +240,16 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) {
 	}
 
 	if u.Scheme == "rediss" {
-		// insert a default DlialTLS at position 0, so we don't override any
-		// user provided *tls.Config elsewhere in options
-		t := DialTLS(&tls.Config{InsecureSkipVerify: true})
-		options = append(options, t)
-		copy(options[1:], options[0:])
+		// insert the options at the front, so all user provided options come
+		// after and override what we set
+		t := DialOption{func(do *dialOptions) {
+			do.dialTLS = true
+		}}
+		c := DialTLSConfig(&tls.Config{ServerName: host})
+		options = append(options, t, c)
+		copy(options[2:], options[0:])
 		options[0] = t
+		options[1] = c
 	}
 
 	return Dial("tcp", address, options...)