Browse Source

contrib/raftexample: shutdown rafthttp on closed proposal channel

Otherwise listening ports leak across unit tests and ports won't bind.
chz 10 years ago
parent
commit
63bc804253
3 changed files with 152 additions and 22 deletions
  1. 59 0
      contrib/raftexample/listener.go
  2. 49 14
      contrib/raftexample/raft.go
  3. 44 8
      contrib/raftexample/raftexample_test.go

+ 59 - 0
contrib/raftexample/listener.go

@@ -0,0 +1,59 @@
+// Copyright 2015 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+	"errors"
+	"net"
+	"time"
+)
+
+// stoppableListener sets TCP keep-alive timeouts on accepted
+// connections and waits on stopc message
+type stoppableListener struct {
+	*net.TCPListener
+	stopc <-chan struct{}
+}
+
+func newStoppableListener(addr string, stopc <-chan struct{}) (*stoppableListener, error) {
+	ln, err := net.Listen("tcp", addr)
+	if err != nil {
+		return nil, err
+	}
+	return &stoppableListener{ln.(*net.TCPListener), stopc}, nil
+}
+
+func (ln stoppableListener) Accept() (c net.Conn, err error) {
+	connc := make(chan *net.TCPConn, 1)
+	errc := make(chan error, 1)
+	go func() {
+		tc, err := ln.AcceptTCP()
+		if err != nil {
+			errc <- err
+		} else {
+			connc <- tc
+		}
+	}()
+	select {
+	case <-ln.stopc:
+		return nil, errors.New("server stopped")
+	case err := <-errc:
+		return nil, err
+	case tc := <-connc:
+		tc.SetKeepAlive(true)
+		tc.SetKeepAlivePeriod(3 * time.Minute)
+		return tc, nil
+	}
+}

+ 49 - 14
contrib/raftexample/raft.go

@@ -49,13 +49,16 @@ type raftNode struct {
 	raftStorage *raft.MemoryStorage
 	wal         *wal.WAL
 	transport   *rafthttp.Transport
+	stopc       chan struct{} // signals proposal channel closed
+	httpstopc   chan struct{} // signals http server to shutdown
+	httpdonec   chan struct{} // signals http server shutdown complete
 }
 
 // newRaftNode initiates a raft instance and returns a committed log entry
 // channel and error channel. Proposals for log updates are sent over the
 // provided the proposal channel. All log entries are replayed over the
 // commit channel, followed by a nil message (to indicate the channel is
-// current), then new log entries.
+// current), then new log entries. To shutdown, close proposeC and read errorC.
 func newRaftNode(id int, peers []string, proposeC <-chan string) (<-chan *string, <-chan error) {
 	rc := &raftNode{
 		proposeC:    proposeC,
@@ -65,22 +68,31 @@ func newRaftNode(id int, peers []string, proposeC <-chan string) (<-chan *string
 		peers:       peers,
 		waldir:      fmt.Sprintf("raftexample-%d", id),
 		raftStorage: raft.NewMemoryStorage(),
+		stopc:       make(chan struct{}),
+		httpstopc:   make(chan struct{}),
+		httpdonec:   make(chan struct{}),
 		// rest of structure populated after WAL replay
 	}
 	go rc.startRaft()
 	return rc.commitC, rc.errorC
 }
 
-// publishEntries writes committed log entries to commit channel.
-func (rc *raftNode) publishEntries(ents []raftpb.Entry) {
+// publishEntries writes committed log entries to commit channel and returns
+// whether all entries could be published.
+func (rc *raftNode) publishEntries(ents []raftpb.Entry) bool {
 	for i := range ents {
 		if ents[i].Type != raftpb.EntryNormal || len(ents[i].Data) == 0 {
 			// ignore conf changes and empty messages
 			continue
 		}
 		s := string(ents[i].Data)
-		rc.commitC <- &s
+		select {
+		case rc.commitC <- &s:
+		case <-rc.stopc:
+			return false
+		}
 	}
+	return true
 }
 
 // openWAL returns a WAL ready for reading.
@@ -122,6 +134,7 @@ func (rc *raftNode) replayWAL() *wal.WAL {
 }
 
 func (rc *raftNode) writeError(err error) {
+	rc.stopHTTP()
 	close(rc.commitC)
 	rc.errorC <- err
 	close(rc.errorC)
@@ -174,6 +187,20 @@ func (rc *raftNode) startRaft() {
 	go rc.serveChannels()
 }
 
+// stop closes http, closes all channels, and stops raft.
+func (rc *raftNode) stop() {
+	rc.stopHTTP()
+	close(rc.commitC)
+	close(rc.errorC)
+	rc.node.Stop()
+}
+
+func (rc *raftNode) stopHTTP() {
+	rc.transport.Stop()
+	close(rc.httpstopc)
+	<-rc.httpdonec
+}
+
 func (rc *raftNode) serveChannels() {
 	defer rc.wal.Close()
 
@@ -181,14 +208,13 @@ func (rc *raftNode) serveChannels() {
 	defer ticker.Stop()
 
 	// send proposals over raft
-	stopc := make(chan struct{}, 1)
 	go func() {
 		for prop := range rc.proposeC {
 			// blocks until accepted by raft state machine
 			rc.node.Propose(context.TODO(), []byte(prop))
 		}
 		// client closed channel; shutdown raft if not already
-		stopc <- struct{}{}
+		close(rc.stopc)
 	}()
 
 	// event loop on raft state machine updates
@@ -202,17 +228,18 @@ func (rc *raftNode) serveChannels() {
 			rc.wal.Save(rd.HardState, rd.Entries)
 			rc.raftStorage.Append(rd.Entries)
 			rc.transport.Send(rd.Messages)
-			rc.publishEntries(rd.Entries)
+			if ok := rc.publishEntries(rd.Entries); !ok {
+				rc.stop()
+				return
+			}
 			rc.node.Advance()
 
 		case err := <-rc.transport.ErrorC:
 			rc.writeError(err)
 			return
 
-		case <-stopc:
-			close(rc.commitC)
-			close(rc.errorC)
-			rc.node.Stop()
+		case <-rc.stopc:
+			rc.stop()
 			return
 		}
 	}
@@ -224,10 +251,18 @@ func (rc *raftNode) serveRaft() {
 		log.Fatalf("raftexample: Failed parsing URL (%v)", err)
 	}
 
-	srv := http.Server{Addr: url.Host, Handler: rc.transport.Handler()}
-	if err := srv.ListenAndServe(); err != nil {
-		log.Fatalf("raftexample: Failed serving rafthttp (%v)", err)
+	ln, err := newStoppableListener(url.Host, rc.httpstopc)
+	if err != nil {
+		log.Fatalf("raftexample: Failed to listen rafthttp (%v)", err)
+	}
+
+	err = (&http.Server{Handler: rc.transport.Handler()}).Serve(ln)
+	select {
+	case <-rc.httpstopc:
+	default:
+		log.Fatalf("raftexample: Failed to serve rafthttp (%v)", err)
 	}
+	close(rc.httpdonec)
 }
 
 func (rc *raftNode) Process(ctx context.Context, m raftpb.Message) error {

+ 44 - 8
contrib/raftexample/raftexample_test.go

@@ -44,15 +44,20 @@ func newCluster(n int) *cluster {
 		os.RemoveAll(fmt.Sprintf("raftexample-%d", i+1))
 		clus.proposeC[i] = make(chan string, 1)
 		clus.commitC[i], clus.errorC[i] = newRaftNode(i+1, clus.peers, clus.proposeC[i])
-		// replay local log
+	}
+
+	return clus
+}
+
+// sinkReplay reads all commits in each node's local log.
+func (clus *cluster) sinkReplay() {
+	for i := range clus.peers {
 		for s := range clus.commitC[i] {
 			if s == nil {
 				break
 			}
 		}
 	}
-
-	return clus
 }
 
 // Close closes all cluster nodes and returns an error if any failed.
@@ -72,15 +77,19 @@ func (clus *cluster) Close() (err error) {
 	return err
 }
 
+func (clus *cluster) closeNoErrors(t *testing.T) {
+	if err := clus.Close(); err != nil {
+		t.Fatal(err)
+	}
+}
+
 // TestProposeOnCommit starts three nodes and feeds commits back into the proposal
 // channel. The intent is to ensure blocking on a proposal won't block raft progress.
 func TestProposeOnCommit(t *testing.T) {
 	clus := newCluster(3)
-	defer func() {
-		if err := clus.Close(); err != nil {
-			t.Fatal(err)
-		}
-	}()
+	defer clus.closeNoErrors(t)
+
+	clus.sinkReplay()
 
 	donec := make(chan struct{})
 	for i := range clus.peers {
@@ -109,3 +118,30 @@ func TestProposeOnCommit(t *testing.T) {
 		<-donec
 	}
 }
+
+// TestCloseBeforeReplay tests closing the producer before raft starts.
+func TestCloseProposerBeforeReplay(t *testing.T) {
+	clus := newCluster(1)
+	// close before replay so raft never starts
+	defer clus.closeNoErrors(t)
+}
+
+// TestCloseProposerInflight tests closing the producer while
+// committed messages are being published to the client.
+func TestCloseProposerInflight(t *testing.T) {
+	clus := newCluster(1)
+	defer clus.closeNoErrors(t)
+
+	clus.sinkReplay()
+
+	// some inflight ops
+	go func() {
+		clus.proposeC[0] <- "foo"
+		clus.proposeC[0] <- "bar"
+	}()
+
+	// wait for one message
+	if c, ok := <-clus.commitC[0]; *c != "foo" || !ok {
+		t.Fatalf("Commit failed")
+	}
+}