|
|
@@ -16,6 +16,9 @@ package redis_test
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
+ "crypto/tls"
|
|
|
+ "crypto/x509"
|
|
|
+ "fmt"
|
|
|
"io"
|
|
|
"math"
|
|
|
"net"
|
|
|
@@ -41,11 +44,21 @@ func (*testConn) SetReadDeadline(t time.Time) error { return nil }
|
|
|
func (*testConn) SetWriteDeadline(t time.Time) error { return nil }
|
|
|
|
|
|
func dialTestConn(r io.Reader, w io.Writer) redis.DialOption {
|
|
|
- return redis.DialNetDial(func(net, addr string) (net.Conn, error) {
|
|
|
+ return redis.DialNetDial(func(network, addr string) (net.Conn, error) {
|
|
|
return &testConn{Reader: r, Writer: w}, nil
|
|
|
})
|
|
|
}
|
|
|
|
|
|
+func dialTestConnTLS(r io.Reader, w io.Writer) redis.DialOption {
|
|
|
+ return redis.DialNetDial(func(network, addr string) (net.Conn, error) {
|
|
|
+ client, server := net.Pipe()
|
|
|
+ tlsServer := tls.Server(server, &serverTLSConfig)
|
|
|
+ go io.Copy(tlsServer, r)
|
|
|
+ go io.Copy(w, tlsServer)
|
|
|
+ return client, nil
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
type durationArg struct {
|
|
|
time.Duration
|
|
|
}
|
|
|
@@ -551,6 +564,74 @@ func TestDialURLDatabase(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func checkPingPong(t *testing.T, buf *bytes.Buffer, c redis.Conn) {
|
|
|
+ resp, err := c.Do("PING")
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal("ping error:", err)
|
|
|
+ }
|
|
|
+ expected := "*1\r\n$4\r\nPING\r\n"
|
|
|
+ actual := buf.String()
|
|
|
+ if actual != expected {
|
|
|
+ t.Errorf("commands = %q, want %q", actual, expected)
|
|
|
+ }
|
|
|
+ if resp != "PONG" {
|
|
|
+ t.Errorf("resp = %v, want %v", resp, "PONG")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func pingRespReader() io.Reader { return strings.NewReader("+PONG\r\n") }
|
|
|
+
|
|
|
+func TestDialURLTLS(t *testing.T) {
|
|
|
+ var buf bytes.Buffer
|
|
|
+ c, err := redis.DialURL("rediss://example.com/",
|
|
|
+ redis.DialTLSConfig(&clientTLSConfig),
|
|
|
+ dialTestConnTLS(pingRespReader(), &buf))
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal("dial error:", err)
|
|
|
+ }
|
|
|
+ defer c.Close()
|
|
|
+ checkPingPong(t, &buf, c)
|
|
|
+}
|
|
|
+
|
|
|
+func TestDialURLIgnoreUseTLS(t *testing.T) {
|
|
|
+ var buf bytes.Buffer
|
|
|
+ c, err := redis.DialURL("redis://example.com/",
|
|
|
+ redis.DialTLSConfig(&clientTLSConfig),
|
|
|
+ dialTestConn(pingRespReader(), &buf),
|
|
|
+ redis.DialUseTLS(true))
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal("dial error:", err)
|
|
|
+ }
|
|
|
+ defer c.Close()
|
|
|
+ checkPingPong(t, &buf, c)
|
|
|
+}
|
|
|
+
|
|
|
+func TestDialUseTLS(t *testing.T) {
|
|
|
+ var buf bytes.Buffer
|
|
|
+ c, err := redis.Dial("tcp", "example.com:6379",
|
|
|
+ redis.DialTLSConfig(&clientTLSConfig),
|
|
|
+ dialTestConnTLS(pingRespReader(), &buf),
|
|
|
+ redis.DialUseTLS(true))
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal("dial error:", err)
|
|
|
+ }
|
|
|
+ defer c.Close()
|
|
|
+ checkPingPong(t, &buf, c)
|
|
|
+}
|
|
|
+
|
|
|
+func TestDialTLSSKipVerify(t *testing.T) {
|
|
|
+ var buf bytes.Buffer
|
|
|
+ c, err := redis.Dial("tcp", "example.com:6379",
|
|
|
+ dialTestConnTLS(pingRespReader(), &buf),
|
|
|
+ redis.DialTLSSkipVerify(true),
|
|
|
+ redis.DialUseTLS(true))
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal("dial error:", err)
|
|
|
+ }
|
|
|
+ defer c.Close()
|
|
|
+ checkPingPong(t, &buf, c)
|
|
|
+}
|
|
|
+
|
|
|
// Connect to local instance of Redis running on the default port.
|
|
|
func ExampleDial() {
|
|
|
c, err := redis.Dial("tcp", ":6379")
|
|
|
@@ -680,3 +761,65 @@ func BenchmarkDoPing(b *testing.B) {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+var clientTLSConfig, serverTLSConfig tls.Config
|
|
|
+
|
|
|
+func init() {
|
|
|
+
|
|
|
+ // The certificate and key for testing TLS dial options was created
|
|
|
+ // using the command
|
|
|
+ //
|
|
|
+ // go run GOROOT/src/crypto/tls/generate_cert.go \
|
|
|
+ // --rsa-bits 1024 \
|
|
|
+ // --host 127.0.0.1,::1,example.com --ca \
|
|
|
+ // --start-date "Jan 1 00:00:00 1970" \
|
|
|
+ // --duration=1000000h
|
|
|
+ //
|
|
|
+ // where GOROOT is the value of GOROOT reported by go env.
|
|
|
+ localhostCert := []byte(`
|
|
|
+-----BEGIN CERTIFICATE-----
|
|
|
+MIICFDCCAX2gAwIBAgIRAJfBL4CUxkXcdlFurb3K+iowDQYJKoZIhvcNAQELBQAw
|
|
|
+EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2
|
|
|
+MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAw
|
|
|
+gYkCgYEArizw8WxMUQ3bGHLeuJ4fDrEpy+L2pqrbYRlKk1DasJ/VkB8bImzIpe6+
|
|
|
+LGjiYIxvnDCOJ3f3QplcQuiuMyl6f2irJlJsbFT8Lo/3obnuTKAIaqUdJUqBg6y+
|
|
|
+JaL8Auk97FvunfKFv8U1AIhgiLzAfQ/3Eaq1yi87Ra6pMjGbTtcCAwEAAaNoMGYw
|
|
|
+DgYDVR0PAQH/BAQDAgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQF
|
|
|
+MAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAA
|
|
|
+AAAAAAEwDQYJKoZIhvcNAQELBQADgYEAdZ8daIVkyhVwflt5I19m0oq1TycbGO1+
|
|
|
+ach7T6cZiBQeNR/SJtxr/wKPEpmvUgbv2BfFrKJ8QoIHYsbNSURTWSEa02pfw4k9
|
|
|
+6RQhij3ZkG79Ituj5OYRORV6Z0HUW32r670BtcuHuAhq7YA6Nxy4FtSt7bAlVdRt
|
|
|
+rrKgNsltzMk=
|
|
|
+-----END CERTIFICATE-----`)
|
|
|
+
|
|
|
+ localhostKey := []byte(`
|
|
|
+-----BEGIN RSA PRIVATE KEY-----
|
|
|
+MIICXAIBAAKBgQCuLPDxbExRDdsYct64nh8OsSnL4vamqtthGUqTUNqwn9WQHxsi
|
|
|
+bMil7r4saOJgjG+cMI4nd/dCmVxC6K4zKXp/aKsmUmxsVPwuj/ehue5MoAhqpR0l
|
|
|
+SoGDrL4lovwC6T3sW+6d8oW/xTUAiGCIvMB9D/cRqrXKLztFrqkyMZtO1wIDAQAB
|
|
|
+AoGACrc5G6FOEK6JjDeE/Fa+EmlT6PdNtXNNi+vCas3Opo8u1G8VfEi1D4BgstrB
|
|
|
+Eq+RLkrOdB8tVyuYQYWPMhabMqF+hhKJN72j0OwfuPlVvTInwb/cKjo/zbH1IA+Y
|
|
|
+HenHNK4ywv7/p/9/MvQPJ3I32cQBCgGUW5chVSH5M1sj5gECQQDabQAI1X0uDqCm
|
|
|
+KbX9gXVkAgxkFddrt6LBHt57xujFcqEKFE7nwKhDh7DweVs/VEJ+kpid4z+UnLOw
|
|
|
+KjtP9JolAkEAzCNBphQ//IsbH5rNs10wIUw3Ks/Oepicvr6kUFbIv+neRzi1iJHa
|
|
|
+m6H7EayK3PWgax6BAsR/t0Jc9XV7r2muSwJAVzN09BHnK+ADGtNEKLTqXMbEk6B0
|
|
|
+pDhn7ZmZUOkUPN+Kky+QYM11X6Bob1jDqQDGmymDbGUxGO+GfSofC8inUQJAGfci
|
|
|
+Eo3g1a6b9JksMPRZeuLG4ZstGErxJRH6tH1Va5PDwitka8qhk8o2tTjNMO3NSdLH
|
|
|
+diKoXBcE2/Pll5pJoQJBAIMiiMIzXJhnN4mX8may44J/HvMlMf2xuVH2gNMwmZuc
|
|
|
+Bjqn3yoLHaoZVvbWOi0C2TCN4FjXjaLNZGifQPbIcaA=
|
|
|
+-----END RSA PRIVATE KEY-----`)
|
|
|
+
|
|
|
+ cert, err := tls.X509KeyPair(localhostCert, localhostKey)
|
|
|
+ if err != nil {
|
|
|
+ panic(fmt.Sprintf("error creating key pair: %v", err))
|
|
|
+ }
|
|
|
+ serverTLSConfig.Certificates = []tls.Certificate{cert}
|
|
|
+
|
|
|
+ certificate, err := x509.ParseCertificate(serverTLSConfig.Certificates[0].Certificate[0])
|
|
|
+ if err != nil {
|
|
|
+ panic(fmt.Sprintf("error parsing x509 certificate: %v", err))
|
|
|
+ }
|
|
|
+
|
|
|
+ clientTLSConfig.RootCAs = x509.NewCertPool()
|
|
|
+ clientTLSConfig.RootCAs.AddCert(certificate)
|
|
|
+}
|