Pārlūkot izejas kodu

Merge DialURL tests into single table driven test

Gary Burd 8 gadi atpakaļ
vecāks
revīzija
e3f9a55835
1 mainītis faili ar 32 papildinājumiem un 54 dzēšanām
  1. 32 54
      redis/conn_test.go

+ 32 - 54
redis/conn_test.go

@@ -43,9 +43,9 @@ func (*testConn) SetDeadline(t time.Time) error      { return nil }
 func (*testConn) SetReadDeadline(t time.Time) error  { return nil }
 func (*testConn) SetWriteDeadline(t time.Time) error { return nil }
 
-func dialTestConn(r io.Reader, w io.Writer) redis.DialOption {
+func dialTestConn(r string, w io.Writer) redis.DialOption {
 	return redis.DialNetDial(func(network, addr string) (net.Conn, error) {
-		return &testConn{Reader: r, Writer: w}, nil
+		return &testConn{Reader: strings.NewReader(r), Writer: w}, nil
 	})
 }
 
@@ -60,11 +60,11 @@ func (c *tlsTestConn) Close() error {
 	return nil
 }
 
-func dialTestConnTLS(r io.Reader, w io.Writer) redis.DialOption {
+func dialTestConnTLS(r string, w io.Writer) redis.DialOption {
 	return redis.DialNetDial(func(network, addr string) (net.Conn, error) {
 		client, server := net.Pipe()
 		tlsServer := tls.Server(server, &serverTLSConfig)
-		go io.Copy(tlsServer, r)
+		go io.Copy(tlsServer, strings.NewReader(r))
 		done := make(chan struct{})
 		go func() {
 			io.Copy(w, tlsServer)
@@ -131,7 +131,7 @@ var writeTests = []struct {
 func TestWrite(t *testing.T) {
 	for _, tt := range writeTests {
 		var buf bytes.Buffer
-		c, _ := redis.Dial("", "", dialTestConn(nil, &buf))
+		c, _ := redis.Dial("", "", dialTestConn("", &buf))
 		err := c.Send(tt.args[0].(string), tt.args[1:]...)
 		if err != nil {
 			t.Errorf("Send(%v) returned error %v", tt.args, err)
@@ -230,7 +230,7 @@ var readTests = []struct {
 
 func TestRead(t *testing.T) {
 	for _, tt := range readTests {
-		c, _ := redis.Dial("", "", dialTestConn(strings.NewReader(tt.reply), nil))
+		c, _ := redis.Dial("", "", dialTestConn(tt.reply, nil))
 		actual, err := c.Receive()
 		if tt.expected == errorSentinel {
 			if err == nil {
@@ -542,40 +542,30 @@ func TestDialURLHost(t *testing.T) {
 	}
 }
 
-func TestDialURLPassword(t *testing.T) {
-	var buf bytes.Buffer
-	_, err := redis.DialURL("redis://x:abc123@localhost", dialTestConn(strings.NewReader("+OK\r\n"), &buf))
-	if err != nil {
-		t.Error("dial error:", err)
-	}
-	expected := "*2\r\n$4\r\nAUTH\r\n$6\r\nabc123\r\n"
-	actual := buf.String()
-	if actual != expected {
-		t.Errorf("commands = %q, want %q", actual, expected)
-	}
+var dialURLTests = []struct {
+	description string
+	url         string
+	r           string
+	w           string
+}{
+	{"password", "redis://x:abc123@localhost", "+OK\r\n", "*2\r\n$4\r\nAUTH\r\n$6\r\nabc123\r\n"},
+	{"database 3", "redis://localhost/3", "+OK\r\n", "*2\r\n$6\r\nSELECT\r\n$1\r\n3\r\n"},
+	{"database 99", "redis://localhost/99", "+OK\r\n", "*2\r\n$6\r\nSELECT\r\n$2\r\n99\r\n"},
+	{"no database", "redis://localhost/", "+OK\r\n", ""},
 }
 
-func TestDialURLDatabase(t *testing.T) {
-	var buf3 bytes.Buffer
-	_, err3 := redis.DialURL("redis://localhost/3", dialTestConn(strings.NewReader("+OK\r\n"), &buf3))
-	if err3 != nil {
-		t.Error("dial error:", err3)
-	}
-	expected3 := "*2\r\n$6\r\nSELECT\r\n$1\r\n3\r\n"
-	actual3 := buf3.String()
-	if actual3 != expected3 {
-		t.Errorf("commands = %q, want %q", actual3, expected3)
-	}
-	// empty DB means 0
-	var buf0 bytes.Buffer
-	_, err0 := redis.DialURL("redis://localhost/", dialTestConn(strings.NewReader("+OK\r\n"), &buf0))
-	if err0 != nil {
-		t.Error("dial error:", err0)
-	}
-	expected0 := ""
-	actual0 := buf0.String()
-	if actual0 != expected0 {
-		t.Errorf("commands = %q, want %q", actual0, expected0)
+func TestDialURL(t *testing.T) {
+	for _, tt := range dialURLTests {
+		var buf bytes.Buffer
+		// UseTLS should be ignored in all of these tests.
+		_, err := redis.DialURL(tt.url, dialTestConn(tt.r, &buf), redis.DialUseTLS(true))
+		if err != nil {
+			t.Errorf("%s dial error: %v", tt.description, err)
+			continue
+		}
+		if w := buf.String(); w != tt.w {
+			t.Errorf("%s commands = %q, want %q", tt.description, w, tt.w)
+		}
 	}
 }
 
@@ -596,25 +586,13 @@ func checkPingPong(t *testing.T, buf *bytes.Buffer, c redis.Conn) {
 	}
 }
 
-func pingRespReader() io.Reader { return strings.NewReader("+PONG\r\n") }
+const pingResponse = "+PONG\r\n"
 
 func TestDialURLTLS(t *testing.T) {
 	var buf bytes.Buffer
 	c, err := redis.DialURL("rediss://example.com/",
 		redis.DialTLSConfig(&clientTLSConfig),
-		dialTestConnTLS(pingRespReader(), &buf))
-	if err != nil {
-		t.Fatal("dial error:", err)
-	}
-	checkPingPong(t, &buf, c)
-}
-
-func TestDialURLIgnoreUseTLS(t *testing.T) {
-	var buf bytes.Buffer
-	c, err := redis.DialURL("redis://example.com/",
-		redis.DialTLSConfig(&clientTLSConfig),
-		dialTestConn(pingRespReader(), &buf),
-		redis.DialUseTLS(true))
+		dialTestConnTLS(pingResponse, &buf))
 	if err != nil {
 		t.Fatal("dial error:", err)
 	}
@@ -625,7 +603,7 @@ func TestDialUseTLS(t *testing.T) {
 	var buf bytes.Buffer
 	c, err := redis.Dial("tcp", "example.com:6379",
 		redis.DialTLSConfig(&clientTLSConfig),
-		dialTestConnTLS(pingRespReader(), &buf),
+		dialTestConnTLS(pingResponse, &buf),
 		redis.DialUseTLS(true))
 	if err != nil {
 		t.Fatal("dial error:", err)
@@ -636,7 +614,7 @@ func TestDialUseTLS(t *testing.T) {
 func TestDialTLSSKipVerify(t *testing.T) {
 	var buf bytes.Buffer
 	c, err := redis.Dial("tcp", "example.com:6379",
-		dialTestConnTLS(pingRespReader(), &buf),
+		dialTestConnTLS(pingResponse, &buf),
 		redis.DialTLSSkipVerify(true),
 		redis.DialUseTLS(true))
 	if err != nil {