Переглянути джерело

Add CLIENT SETNAME to dial options

Marco Bersani 6 роки тому
батько
коміт
aefe8f574a
4 змінених файлів з 59 додано та 6 видалено
  1. 16 0
      redis/conn.go
  2. 37 0
      redis/conn_test.go
  3. 4 4
      redis/pool_test.go
  4. 2 2
      redis/test_test.go

+ 16 - 0
redis/conn.go

@@ -80,6 +80,7 @@ type dialOptions struct {
 	dial         func(network, addr string) (net.Conn, error)
 	db           int
 	password     string
+	clientName   string
 	useTLS       bool
 	skipVerify   bool
 	tlsConfig    *tls.Config
@@ -141,6 +142,14 @@ func DialPassword(password string) DialOption {
 	}}
 }
 
+// DialClientName specifies a client name to be used
+// by the Redis server connection.
+func DialClientName(name string) DialOption {
+	return DialOption{func(do *dialOptions) {
+		do.clientName = name
+	}}
+}
+
 // DialTLSConfig specifies the config to use when a TLS connection is dialed.
 // Has no effect when not dialing a TLS connection.
 func DialTLSConfig(c *tls.Config) DialOption {
@@ -224,6 +233,13 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
 		}
 	}
 
+	if do.clientName != "" {
+		if _, err := c.Do("CLIENT", "SETNAME", do.clientName); err != nil {
+			netConn.Close()
+			return nil, err
+		}
+	}
+
 	if do.db != 0 {
 		if _, err := c.Do("SELECT", do.db); err != nil {
 			netConn.Close()

+ 37 - 0
redis/conn_test.go

@@ -671,6 +671,43 @@ func TestDialTLSSKipVerify(t *testing.T) {
 	checkPingPong(t, &buf, c)
 }
 
+func TestDialClientName(t *testing.T) {
+	var buf bytes.Buffer
+	_, err := redis.Dial("tcp", ":6379",
+		dialTestConn(pingResponse, &buf),
+		redis.DialClientName("redis-connection"),
+	)
+	if err != nil {
+		t.Fatal("dial error:", err)
+	}
+	expected := "*3\r\n$6\r\nCLIENT\r\n$7\r\nSETNAME\r\n$16\r\nredis-connection\r\n"
+	if w := buf.String(); w != expected {
+		t.Errorf("got %q, want %q", w, expected)
+	}
+
+	// testing against a real server
+	connectionName := "test-connection"
+	c, err := redis.DialDefaultServer(redis.DialClientName(connectionName))
+	if err != nil {
+		t.Fatalf("error connection to database, %v", err)
+	}
+	defer c.Close()
+
+	v, err := c.Do("CLIENT", "GETNAME")
+	if err != nil {
+		t.Fatalf("CLIENT GETNAME returned error %v", err)
+	}
+
+	vs, err := redis.String(v, nil)
+	if err != nil {
+		t.Fatalf("String(v) returned error %v", err)
+	}
+
+	if vs != connectionName {
+		t.Fatalf("wrong connection name. Got '%s', expected '%s'", vs, connectionName)
+	}
+}
+
 // Connect to local instance of Redis running on the default port.
 func ExampleDial() {
 	c, err := redis.Dial("tcp", ":6379")

+ 4 - 4
redis/pool_test.go

@@ -300,7 +300,7 @@ func TestPoolMaxLifetime(t *testing.T) {
 
 func TestPoolConcurrenSendReceive(t *testing.T) {
 	p := &redis.Pool{
-		Dial: redis.DialDefaultServer,
+		Dial: func() (redis.Conn, error) { return redis.DialDefaultServer() },
 	}
 	defer p.Close()
 
@@ -693,7 +693,7 @@ func TestLocking_TestOnBorrowFails_PoolDoesntCrash(t *testing.T) {
 
 func BenchmarkPoolGet(b *testing.B) {
 	b.StopTimer()
-	p := redis.Pool{Dial: redis.DialDefaultServer, MaxIdle: 2}
+	p := redis.Pool{Dial: func() (redis.Conn, error) { return redis.DialDefaultServer() }, MaxIdle: 2}
 	c := p.Get()
 	if err := c.Err(); err != nil {
 		b.Fatal(err)
@@ -709,7 +709,7 @@ func BenchmarkPoolGet(b *testing.B) {
 
 func BenchmarkPoolGetErr(b *testing.B) {
 	b.StopTimer()
-	p := redis.Pool{Dial: redis.DialDefaultServer, MaxIdle: 2}
+	p := redis.Pool{Dial: func() (redis.Conn, error) { return redis.DialDefaultServer() }, MaxIdle: 2}
 	c := p.Get()
 	if err := c.Err(); err != nil {
 		b.Fatal(err)
@@ -728,7 +728,7 @@ func BenchmarkPoolGetErr(b *testing.B) {
 
 func BenchmarkPoolGetPing(b *testing.B) {
 	b.StopTimer()
-	p := redis.Pool{Dial: redis.DialDefaultServer, MaxIdle: 2}
+	p := redis.Pool{Dial: func() (redis.Conn, error) { return redis.DialDefaultServer() }, MaxIdle: 2}
 	c := p.Get()
 	if err := c.Err(); err != nil {
 		b.Fatal(err)

+ 2 - 2
redis/test_test.go

@@ -147,12 +147,12 @@ func DefaultServerAddr() (string, error) {
 
 // DialDefaultServer starts the test server if not already started and dials a
 // connection to the server.
-func DialDefaultServer() (Conn, error) {
+func DialDefaultServer(options ...DialOption) (Conn, error) {
 	addr, err := DefaultServerAddr()
 	if err != nil {
 		return nil, err
 	}
-	c, err := Dial("tcp", addr, DialReadTimeout(1*time.Second), DialWriteTimeout(1*time.Second))
+	c, err := Dial("tcp", addr, append([]DialOption{DialReadTimeout(1 * time.Second), DialWriteTimeout(1 * time.Second)}, options...)...)
 	if err != nil {
 		return nil, err
 	}