Selaa lähdekoodia

Always pass ServerName, provide DialTLSSkipVerify

Based on PR feedback.
Edward Muller 9 vuotta sitten
vanhempi
commit
a08f884ee1
1 muutettua tiedostoa jossa 53 lisäystä ja 16 poistoa
  1. 53 16
      redis/conn.go

+ 53 - 16
redis/conn.go

@@ -77,6 +77,7 @@ type dialOptions struct {
 	db           int
 	password     string
 	dialTLS      bool
+	skipVerify   bool
 	tlsConfig    *tls.Config
 }
 
@@ -127,13 +128,49 @@ func DialPassword(password string) DialOption {
 }
 
 // DialTLSConfig specifies the config to use when a TLS connection is dialed.
-// This has no effect when not dialing a TLS connection.
+//  Has no effect when not dialing a TLS connection.
 func DialTLSConfig(c *tls.Config) DialOption {
 	return DialOption{func(do *dialOptions) {
 		do.tlsConfig = c
 	}}
 }
 
+// DialTLSSkipVerify to disable server name verification when connecting
+// over TLS. Has no effect when not dialing a TLS connection.
+func DialTLSSkipVerify(skip bool) DialOption {
+	return DialOption{func(do *dialOptions) {
+		do.skipVerify = skip
+	}}
+}
+
+// similar cloneTLSClientConfig in the stdlib, but also honor skipVerify for the nil case
+func cloneTLSClientConfig(cfg *tls.Config, skipVerify bool) *tls.Config {
+	if cfg == nil {
+		return &tls.Config{InsecureSkipVerify: skipVerify}
+	}
+	return &tls.Config{
+		Rand:                        cfg.Rand,
+		Time:                        cfg.Time,
+		Certificates:                cfg.Certificates,
+		NameToCertificate:           cfg.NameToCertificate,
+		GetCertificate:              cfg.GetCertificate,
+		RootCAs:                     cfg.RootCAs,
+		NextProtos:                  cfg.NextProtos,
+		ServerName:                  cfg.ServerName,
+		ClientAuth:                  cfg.ClientAuth,
+		ClientCAs:                   cfg.ClientCAs,
+		InsecureSkipVerify:          cfg.InsecureSkipVerify,
+		CipherSuites:                cfg.CipherSuites,
+		PreferServerCipherSuites:    cfg.PreferServerCipherSuites,
+		ClientSessionCache:          cfg.ClientSessionCache,
+		MinVersion:                  cfg.MinVersion,
+		MaxVersion:                  cfg.MaxVersion,
+		CurvePreferences:            cfg.CurvePreferences,
+		DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
+		Renegotiation:               cfg.Renegotiation,
+	}
+}
+
 // Dial connects to the Redis server at the given network and
 // address using the specified options.
 func Dial(network, address string, options ...DialOption) (Conn, error) {
@@ -150,14 +187,19 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
 	}
 
 	if do.dialTLS {
-		tlsConfig := do.tlsConfig
-		if tlsConfig == nil {
-			// https://golang.org/pkg/crypto/tls/#Client
-			// At this point we don't know the ServerName.
-			tlsConfig = &tls.Config{InsecureSkipVerify: true}
+		tlsConfig := cloneTLSClientConfig(do.tlsConfig, do.skipVerify)
+		if tlsConfig.ServerName == "" {
+			host, _, err := net.SplitHostPort(address)
+			if err != nil {
+				netConn.Close()
+				return nil, err
+			}
+			tlsConfig.ServerName = host
 		}
+
 		tlsConn := tls.Client(netConn, tlsConfig)
 		if err := tlsConn.Handshake(); err != nil {
+			netConn.Close()
 			return nil, err
 		}
 		netConn = tlsConn
@@ -188,6 +230,10 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
 	return c, nil
 }
 
+func dialTLS(do *dialOptions) {
+	do.dialTLS = true
+}
+
 var pathDBRegexp = regexp.MustCompile(`/(\d*)\z`)
 
 // DialURL connects to a Redis server at the given URL using the Redis
@@ -240,16 +286,7 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) {
 	}
 
 	if u.Scheme == "rediss" {
-		// insert the options at the front, so all user provided options come
-		// after and override what we set
-		t := DialOption{func(do *dialOptions) {
-			do.dialTLS = true
-		}}
-		c := DialTLSConfig(&tls.Config{ServerName: host})
-		options = append(options, t, c)
-		copy(options[2:], options[0:])
-		options[0] = t
-		options[1] = c
+		options = append([]DialOption{{dialTLS}}, options...)
 	}
 
 	return Dial("tcp", address, options...)