Ver código fonte

Improve pubsub example

Gary Burd 8 anos atrás
pai
commit
1e086faa00
4 arquivos alterados com 178 adições e 87 exclusões
  1. 164 0
      redis/pubsub_example_test.go
  2. 0 81
      redis/pubsub_test.go
  3. 5 0
      redis/reply_test.go
  4. 9 6
      redis/test_test.go

+ 164 - 0
redis/pubsub_example_test.go

@@ -0,0 +1,164 @@
+// Copyright 2012 Gary Burd
+//
+// Licensed under the Apache License, Version 2.0 (the "License"): you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations
+// under the License.
+
+package redis_test
+
+import (
+	"context"
+	"fmt"
+	"time"
+
+	"github.com/garyburd/redigo/redis"
+)
+
+// listenPubSubChannels listens for messages on Redis pubsub channels. The
+// onStart function is called after the channels are subscribed. The onMessage
+// function is called for each message.
+func listenPubSubChannels(ctx context.Context, redisServerAddr string,
+	onStart func() error,
+	onMessage func(channel string, data []byte) error,
+	channels ...string) error {
+
+	// A ping is set to the server with this period to test for the health of
+	// the connection and server.
+	const healthCheckPeriod = time.Minute
+
+	c, err := redis.Dial("tcp", redisServerAddr,
+		// Read timeout on server should be greater than ping period.
+		redis.DialReadTimeout(healthCheckPeriod+10*time.Second),
+		redis.DialWriteTimeout(10*time.Second))
+	if err != nil {
+		return err
+	}
+	defer c.Close()
+
+	psc := redis.PubSubConn{Conn: c}
+
+	if err := psc.Subscribe(redis.Args{}.AddFlat(channels)...); err != nil {
+		return err
+	}
+
+	done := make(chan error, 1)
+
+	// Start a goroutine to receive notifications from the server.
+	go func() {
+		for {
+			switch n := psc.Receive().(type) {
+			case error:
+				done <- n
+				return
+			case redis.Message:
+				if err := onMessage(n.Channel, n.Data); err != nil {
+					done <- err
+					return
+				}
+			case redis.Subscription:
+				switch n.Count {
+				case len(channels):
+					// Notify application when all channels are subscribed.
+					if err := onStart(); err != nil {
+						done <- err
+						return
+					}
+				case 0:
+					// Return from the goroutine when all channels are unsubscribed.
+					done <- nil
+					return
+				}
+			}
+		}
+	}()
+
+	ticker := time.NewTicker(healthCheckPeriod)
+	defer ticker.Stop()
+loop:
+	for err == nil {
+		select {
+		case <-ticker.C:
+			// Send ping to test health of connection and server. If
+			// corresponding pong is not received, then receive on the
+			// connection will timeout and the receive goroutine will exit.
+			if err = psc.Ping(""); err != nil {
+				break loop
+			}
+		case <-ctx.Done():
+			break loop
+		case err := <-done:
+			// Return error from the receive goroutine.
+			return err
+		}
+	}
+
+	// Signal the receiving goroutine to exit by unsubscribing from all channels.
+	psc.Unsubscribe()
+
+	// Wait for goroutine to complete.
+	return <-done
+}
+
+func publish() {
+	c, err := dial()
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+	defer c.Close()
+
+	c.Do("PUBLISH", "c1", "hello")
+	c.Do("PUBLISH", "c2", "world")
+	c.Do("PUBLISH", "c1", "goodbye")
+}
+
+// This example shows how receive pubsub notifications with cancelation and
+// health checks.
+func ExamplePubSubConn() {
+	redisServerAddr, err := serverAddr()
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+
+	ctx, cancel := context.WithCancel(context.Background())
+
+	err = listenPubSubChannels(ctx,
+		redisServerAddr,
+		func() error {
+			// The start callback is a good place to backfill missed
+			// notifications. For the purpose of this example, a goroutine is
+			// started to send notifications.
+			go publish()
+			return nil
+		},
+		func(channel string, message []byte) error {
+			fmt.Printf("channel: %s, message: %s\n", channel, message)
+
+			// For the purpose of this example, cancel the listener's context
+			// after receiving last message sent by publish().
+			if string(message) == "goodbye" {
+				cancel()
+			}
+			return nil
+		},
+		"c1", "c2")
+
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+
+	// Output:
+	// channel: c1, message: hello
+	// channel: c2, message: world
+	// channel: c1, message: goodbye
+}

+ 0 - 81
redis/pubsub_test.go

@@ -15,93 +15,12 @@
 package redis_test
 
 import (
-	"fmt"
 	"reflect"
-	"sync"
 	"testing"
 
 	"github.com/garyburd/redigo/redis"
 )
 
-func publish(channel, value interface{}) {
-	c, err := dial()
-	if err != nil {
-		fmt.Println(err)
-		return
-	}
-	defer c.Close()
-	c.Do("PUBLISH", channel, value)
-}
-
-// Applications can receive pushed messages from one goroutine and manage subscriptions from another goroutine.
-func ExamplePubSubConn() {
-	c, err := dial()
-	if err != nil {
-		fmt.Println(err)
-		return
-	}
-	defer c.Close()
-	var wg sync.WaitGroup
-	wg.Add(2)
-
-	psc := redis.PubSubConn{Conn: c}
-
-	// This goroutine receives and prints pushed notifications from the server.
-	// The goroutine exits when the connection is unsubscribed from all
-	// channels or there is an error.
-	go func() {
-		defer wg.Done()
-		for {
-			switch n := psc.Receive().(type) {
-			case redis.Message:
-				fmt.Printf("Message: %s %s\n", n.Channel, n.Data)
-			case redis.PMessage:
-				fmt.Printf("PMessage: %s %s %s\n", n.Pattern, n.Channel, n.Data)
-			case redis.Subscription:
-				fmt.Printf("Subscription: %s %s %d\n", n.Kind, n.Channel, n.Count)
-				if n.Count == 0 {
-					return
-				}
-			case error:
-				fmt.Printf("error: %v\n", n)
-				return
-			}
-		}
-	}()
-
-	// This goroutine manages subscriptions for the connection.
-	go func() {
-		defer wg.Done()
-
-		psc.Subscribe("example")
-		psc.PSubscribe("p*")
-
-		// The following function calls publish a message using another
-		// connection to the Redis server.
-		publish("example", "hello")
-		publish("example", "world")
-		publish("pexample", "foo")
-		publish("pexample", "bar")
-
-		// Unsubscribe from all connections. This will cause the receiving
-		// goroutine to exit.
-		psc.Unsubscribe()
-		psc.PUnsubscribe()
-	}()
-
-	wg.Wait()
-
-	// Output:
-	// Subscription: subscribe example 1
-	// Subscription: psubscribe p* 2
-	// Message: example hello
-	// Message: example world
-	// PMessage: p* pexample foo
-	// PMessage: p* pexample bar
-	// Subscription: unsubscribe example 1
-	// Subscription: punsubscribe p* 0
-}
-
 func expectPushed(t *testing.T, c redis.PubSubConn, message string, expected interface{}) {
 	actual := c.Receive()
 	if !reflect.DeepEqual(actual, expected) {

+ 5 - 0
redis/reply_test.go

@@ -140,6 +140,11 @@ func dial() (redis.Conn, error) {
 	return redis.DialDefaultServer()
 }
 
+// serverAddr wraps DefaultServerAddr() with a more suitable function name for examples.
+func serverAddr() (string, error) {
+	return redis.DefaultServerAddr()
+}
+
 func ExampleBool() {
 	c, err := dial()
 	if err != nil {

+ 9 - 6
redis/test_test.go

@@ -127,12 +127,14 @@ func stopDefaultServer() {
 	}
 }
 
-// startDefaultServer starts the default server if not already running.
-func startDefaultServer() error {
+// DefaultServerAddr starts the test server if not already started and returns
+// the address of that server.
+func DefaultServerAddr() (string, error) {
 	defaultServerMu.Lock()
 	defer defaultServerMu.Unlock()
+	addr := fmt.Sprintf("%v:%d", *serverAddress, *serverBasePort)
 	if defaultServer != nil || defaultServerErr != nil {
-		return defaultServerErr
+		return addr, defaultServerErr
 	}
 	defaultServer, defaultServerErr = NewServer(
 		"default",
@@ -140,16 +142,17 @@ func startDefaultServer() error {
 		"--bind", *serverAddress,
 		"--save", "",
 		"--appendonly", "no")
-	return defaultServerErr
+	return addr, defaultServerErr
 }
 
 // DialDefaultServer starts the test server if not already started and dials a
 // connection to the server.
 func DialDefaultServer() (Conn, error) {
-	if err := startDefaultServer(); err != nil {
+	addr, err := DefaultServerAddr()
+	if err != nil {
 		return nil, err
 	}
-	c, err := Dial("tcp", fmt.Sprintf("%v:%d", *serverAddress, *serverBasePort), DialReadTimeout(1*time.Second), DialWriteTimeout(1*time.Second))
+	c, err := Dial("tcp", addr, DialReadTimeout(1*time.Second), DialWriteTimeout(1*time.Second))
 	if err != nil {
 		return nil, err
 	}