Explorar o código

Add DialUseTLS and TLS tests

Gary Burd %!s(int64=8) %!d(string=hai) anos
pai
achega
8bed3b5cdc
Modificáronse 2 ficheiros con 157 adicións e 12 borrados
  1. 13 11
      redis/conn.go
  2. 144 1
      redis/conn_test.go

+ 13 - 11
redis/conn.go

@@ -76,7 +76,7 @@ type dialOptions struct {
 	dial         func(network, addr string) (net.Conn, error)
 	db           int
 	password     string
-	dialTLS      bool
+	useTLS       bool
 	skipVerify   bool
 	tlsConfig    *tls.Config
 }
@@ -135,14 +135,22 @@ func DialTLSConfig(c *tls.Config) DialOption {
 	}}
 }
 
-// DialTLSSkipVerify to disable server name verification when connecting
-// over TLS. Has no effect when not dialing a TLS connection.
+// DialTLSSkipVerify disables server name verification when connecting over
+// TLS. Has no effect when not dialing a TLS connection.
 func DialTLSSkipVerify(skip bool) DialOption {
 	return DialOption{func(do *dialOptions) {
 		do.skipVerify = skip
 	}}
 }
 
+// DialUseTLS specifies whether TLS should be used when connecting to the
+// server. This option is ignore by DialURL.
+func DialUseTLS(useTLS bool) DialOption {
+	return DialOption{func(do *dialOptions) {
+		do.useTLS = useTLS
+	}}
+}
+
 // Dial connects to the Redis server at the given network and
 // address using the specified options.
 func Dial(network, address string, options ...DialOption) (Conn, error) {
@@ -158,7 +166,7 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
 		return nil, err
 	}
 
-	if do.dialTLS {
+	if do.useTLS {
 		tlsConfig := cloneTLSClientConfig(do.tlsConfig, do.skipVerify)
 		if tlsConfig.ServerName == "" {
 			host, _, err := net.SplitHostPort(address)
@@ -202,10 +210,6 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
 	return c, nil
 }
 
-func dialTLS(do *dialOptions) {
-	do.dialTLS = true
-}
-
 var pathDBRegexp = regexp.MustCompile(`/(\d*)\z`)
 
 // DialURL connects to a Redis server at the given URL using the Redis
@@ -257,9 +261,7 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) {
 		return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
 	}
 
-	if u.Scheme == "rediss" {
-		options = append([]DialOption{{dialTLS}}, options...)
-	}
+	options = append(options, DialUseTLS(u.Scheme == "rediss"))
 
 	return Dial("tcp", address, options...)
 }

+ 144 - 1
redis/conn_test.go

@@ -16,6 +16,9 @@ package redis_test
 
 import (
 	"bytes"
+	"crypto/tls"
+	"crypto/x509"
+	"fmt"
 	"io"
 	"math"
 	"net"
@@ -41,11 +44,21 @@ func (*testConn) SetReadDeadline(t time.Time) error  { return nil }
 func (*testConn) SetWriteDeadline(t time.Time) error { return nil }
 
 func dialTestConn(r io.Reader, w io.Writer) redis.DialOption {
-	return redis.DialNetDial(func(net, addr string) (net.Conn, error) {
+	return redis.DialNetDial(func(network, addr string) (net.Conn, error) {
 		return &testConn{Reader: r, Writer: w}, nil
 	})
 }
 
+func dialTestConnTLS(r io.Reader, w io.Writer) redis.DialOption {
+	return redis.DialNetDial(func(network, addr string) (net.Conn, error) {
+		client, server := net.Pipe()
+		tlsServer := tls.Server(server, &serverTLSConfig)
+		go io.Copy(tlsServer, r)
+		go io.Copy(w, tlsServer)
+		return client, nil
+	})
+}
+
 type durationArg struct {
 	time.Duration
 }
@@ -551,6 +564,74 @@ func TestDialURLDatabase(t *testing.T) {
 	}
 }
 
+func checkPingPong(t *testing.T, buf *bytes.Buffer, c redis.Conn) {
+	resp, err := c.Do("PING")
+	if err != nil {
+		t.Fatal("ping error:", err)
+	}
+	expected := "*1\r\n$4\r\nPING\r\n"
+	actual := buf.String()
+	if actual != expected {
+		t.Errorf("commands = %q, want %q", actual, expected)
+	}
+	if resp != "PONG" {
+		t.Errorf("resp = %v, want %v", resp, "PONG")
+	}
+}
+
+func pingRespReader() io.Reader { return strings.NewReader("+PONG\r\n") }
+
+func TestDialURLTLS(t *testing.T) {
+	var buf bytes.Buffer
+	c, err := redis.DialURL("rediss://example.com/",
+		redis.DialTLSConfig(&clientTLSConfig),
+		dialTestConnTLS(pingRespReader(), &buf))
+	if err != nil {
+		t.Fatal("dial error:", err)
+	}
+	defer c.Close()
+	checkPingPong(t, &buf, c)
+}
+
+func TestDialURLIgnoreUseTLS(t *testing.T) {
+	var buf bytes.Buffer
+	c, err := redis.DialURL("redis://example.com/",
+		redis.DialTLSConfig(&clientTLSConfig),
+		dialTestConn(pingRespReader(), &buf),
+		redis.DialUseTLS(true))
+	if err != nil {
+		t.Fatal("dial error:", err)
+	}
+	defer c.Close()
+	checkPingPong(t, &buf, c)
+}
+
+func TestDialUseTLS(t *testing.T) {
+	var buf bytes.Buffer
+	c, err := redis.Dial("tcp", "example.com:6379",
+		redis.DialTLSConfig(&clientTLSConfig),
+		dialTestConnTLS(pingRespReader(), &buf),
+		redis.DialUseTLS(true))
+	if err != nil {
+		t.Fatal("dial error:", err)
+	}
+	defer c.Close()
+	checkPingPong(t, &buf, c)
+}
+
+func TestDialTLSSKipVerify(t *testing.T) {
+	var buf bytes.Buffer
+	c, err := redis.Dial("tcp", "example.com:6379",
+		dialTestConnTLS(pingRespReader(), &buf),
+		redis.DialTLSSkipVerify(true),
+		redis.DialUseTLS(true))
+	if err != nil {
+		t.Fatal("dial error:", err)
+	}
+	defer c.Close()
+	checkPingPong(t, &buf, c)
+}
+
 // Connect to local instance of Redis running on the default port.
 func ExampleDial() {
 	c, err := redis.Dial("tcp", ":6379")
@@ -680,3 +761,65 @@ func BenchmarkDoPing(b *testing.B) {
 		}
 	}
 }
+
+var clientTLSConfig, serverTLSConfig tls.Config
+
+func init() {
+
+	// The certificate and key for testing TLS dial options was created
+	// using the command
+	//
+	//   go run GOROOT/src/crypto/tls/generate_cert.go  \
+	//      --rsa-bits 1024 \
+	//      --host 127.0.0.1,::1,example.com --ca \
+	//      --start-date "Jan 1 00:00:00 1970" \
+	//      --duration=1000000h
+	//
+	// where GOROOT is the value of GOROOT reported by go env.
+	localhostCert := []byte(`
+-----BEGIN CERTIFICATE-----
+MIICFDCCAX2gAwIBAgIRAJfBL4CUxkXcdlFurb3K+iowDQYJKoZIhvcNAQELBQAw
+EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2
+MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAw
+gYkCgYEArizw8WxMUQ3bGHLeuJ4fDrEpy+L2pqrbYRlKk1DasJ/VkB8bImzIpe6+
+LGjiYIxvnDCOJ3f3QplcQuiuMyl6f2irJlJsbFT8Lo/3obnuTKAIaqUdJUqBg6y+
+JaL8Auk97FvunfKFv8U1AIhgiLzAfQ/3Eaq1yi87Ra6pMjGbTtcCAwEAAaNoMGYw
+DgYDVR0PAQH/BAQDAgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQF
+MAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAA
+AAAAAAEwDQYJKoZIhvcNAQELBQADgYEAdZ8daIVkyhVwflt5I19m0oq1TycbGO1+
+ach7T6cZiBQeNR/SJtxr/wKPEpmvUgbv2BfFrKJ8QoIHYsbNSURTWSEa02pfw4k9
+6RQhij3ZkG79Ituj5OYRORV6Z0HUW32r670BtcuHuAhq7YA6Nxy4FtSt7bAlVdRt
+rrKgNsltzMk=
+-----END CERTIFICATE-----`)
+
+	localhostKey := []byte(`
+-----BEGIN RSA PRIVATE KEY-----
+MIICXAIBAAKBgQCuLPDxbExRDdsYct64nh8OsSnL4vamqtthGUqTUNqwn9WQHxsi
+bMil7r4saOJgjG+cMI4nd/dCmVxC6K4zKXp/aKsmUmxsVPwuj/ehue5MoAhqpR0l
+SoGDrL4lovwC6T3sW+6d8oW/xTUAiGCIvMB9D/cRqrXKLztFrqkyMZtO1wIDAQAB
+AoGACrc5G6FOEK6JjDeE/Fa+EmlT6PdNtXNNi+vCas3Opo8u1G8VfEi1D4BgstrB
+Eq+RLkrOdB8tVyuYQYWPMhabMqF+hhKJN72j0OwfuPlVvTInwb/cKjo/zbH1IA+Y
+HenHNK4ywv7/p/9/MvQPJ3I32cQBCgGUW5chVSH5M1sj5gECQQDabQAI1X0uDqCm
+KbX9gXVkAgxkFddrt6LBHt57xujFcqEKFE7nwKhDh7DweVs/VEJ+kpid4z+UnLOw
+KjtP9JolAkEAzCNBphQ//IsbH5rNs10wIUw3Ks/Oepicvr6kUFbIv+neRzi1iJHa
+m6H7EayK3PWgax6BAsR/t0Jc9XV7r2muSwJAVzN09BHnK+ADGtNEKLTqXMbEk6B0
+pDhn7ZmZUOkUPN+Kky+QYM11X6Bob1jDqQDGmymDbGUxGO+GfSofC8inUQJAGfci
+Eo3g1a6b9JksMPRZeuLG4ZstGErxJRH6tH1Va5PDwitka8qhk8o2tTjNMO3NSdLH
+diKoXBcE2/Pll5pJoQJBAIMiiMIzXJhnN4mX8may44J/HvMlMf2xuVH2gNMwmZuc
+Bjqn3yoLHaoZVvbWOi0C2TCN4FjXjaLNZGifQPbIcaA=
+-----END RSA PRIVATE KEY-----`)
+
+	cert, err := tls.X509KeyPair(localhostCert, localhostKey)
+	if err != nil {
+		panic(fmt.Sprintf("error creating key pair: %v", err))
+	}
+	serverTLSConfig.Certificates = []tls.Certificate{cert}
+
+	certificate, err := x509.ParseCertificate(serverTLSConfig.Certificates[0].Certificate[0])
+	if err != nil {
+		panic(fmt.Sprintf("error parsing x509 certificate: %v", err))
+	}
+
+	clientTLSConfig.RootCAs = x509.NewCertPool()
+	clientTLSConfig.RootCAs.AddCert(certificate)
+}