Browse Source

rafthttp: a stopped stream does not accept any methods

Xiang Li 11 years ago
parent
commit
3319f716d9
1 changed files with 24 additions and 10 deletions
  1. 24 10
      rafthttp/streamer.go

+ 24 - 10
rafthttp/streamer.go

@@ -17,6 +17,7 @@
 package rafthttp
 package rafthttp
 
 
 import (
 import (
+	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"log"
 	"log"
@@ -41,24 +42,27 @@ const (
 
 
 // TODO: a stream might hava one stream server or one stream client, but not both.
 // TODO: a stream might hava one stream server or one stream client, but not both.
 type stream struct {
 type stream struct {
-	// the server might be attached asynchronously with the owner of the stream
-	// use a mutex to protect it
 	sync.Mutex
 	sync.Mutex
-	w *streamWriter
-
-	r *streamReader
+	w       *streamWriter
+	r       *streamReader
+	stopped bool
 }
 }
 
 
 func (s *stream) open(from, to, cid types.ID, term uint64, tr http.RoundTripper, u string, r Raft) error {
 func (s *stream) open(from, to, cid types.ID, term uint64, tr http.RoundTripper, u string, r Raft) error {
-	if s.r != nil {
-		panic("open: stream is open")
-	}
-
 	c, err := newStreamReader(from, to, cid, term, tr, u, r)
 	c, err := newStreamReader(from, to, cid, term, tr, u, r)
 	if err != nil {
 	if err != nil {
 		log.Printf("stream: error opening stream: %v", err)
 		log.Printf("stream: error opening stream: %v", err)
 		return err
 		return err
 	}
 	}
+
+	s.Lock()
+	defer s.Unlock()
+	if s.stopped {
+		return errors.New("stream: stopped")
+	}
+	if s.r != nil {
+		panic("open: stream is open")
+	}
 	s.r = c
 	s.r = c
 	return nil
 	return nil
 }
 }
@@ -66,6 +70,9 @@ func (s *stream) open(from, to, cid types.ID, term uint64, tr http.RoundTripper,
 func (s *stream) attach(sw *streamWriter) error {
 func (s *stream) attach(sw *streamWriter) error {
 	s.Lock()
 	s.Lock()
 	defer s.Unlock()
 	defer s.Unlock()
+	if s.stopped {
+		return errors.New("stream: stopped")
+	}
 	if s.w != nil {
 	if s.w != nil {
 		// ignore lower-term streaming request
 		// ignore lower-term streaming request
 		if sw.term < s.w.term {
 		if sw.term < s.w.term {
@@ -80,6 +87,9 @@ func (s *stream) attach(sw *streamWriter) error {
 func (s *stream) write(m raftpb.Message) bool {
 func (s *stream) write(m raftpb.Message) bool {
 	s.Lock()
 	s.Lock()
 	defer s.Unlock()
 	defer s.Unlock()
+	if s.stopped {
+		return false
+	}
 	if s.w == nil {
 	if s.w == nil {
 		return false
 		return false
 	}
 	}
@@ -105,7 +115,6 @@ func (s *stream) write(m raftpb.Message) bool {
 func (s *stream) invalidate(term uint64) {
 func (s *stream) invalidate(term uint64) {
 	s.Lock()
 	s.Lock()
 	defer s.Unlock()
 	defer s.Unlock()
-
 	if s.w != nil {
 	if s.w != nil {
 		if s.w.term < term {
 		if s.w.term < term {
 			s.w.stop()
 			s.w.stop()
@@ -118,6 +127,9 @@ func (s *stream) invalidate(term uint64) {
 			s.r = nil
 			s.r = nil
 		}
 		}
 	}
 	}
+	if term == math.MaxUint64 {
+		s.stopped = true
+	}
 }
 }
 
 
 func (s *stream) stop() {
 func (s *stream) stop() {
@@ -125,6 +137,8 @@ func (s *stream) stop() {
 }
 }
 
 
 func (s *stream) isOpen() bool {
 func (s *stream) isOpen() bool {
+	s.Lock()
+	defer s.Unlock()
 	if s.r != nil && s.r.isStopped() {
 	if s.r != nil && s.r.isStopped() {
 		s.r = nil
 		s.r = nil
 	}
 	}