Browse Source

Merge pull request #3035 from yichengq/update-term

rafthttp: update term when AddPeer
Yicheng Qin 10 years ago
parent
commit
2d426b518a
2 changed files with 15 additions and 3 deletions
  1. 6 2
      rafthttp/transport.go
  2. 9 1
      rafthttp/transport_test.go

+ 6 - 2
rafthttp/transport.go

@@ -79,8 +79,8 @@ type transport struct {
 	serverStats  *stats.ServerStats
 	leaderStats  *stats.LeaderStats
 
+	mu      sync.RWMutex         // protect the term, remote and peer map
 	term    uint64               // the latest term that has been observed
-	mu      sync.RWMutex         // protect the remote and peer map
 	remotes map[types.ID]*remote // remotes map that helps newly joined member to catch up
 	peers   map[types.ID]Peer    // peers map
 	errorc  chan error
@@ -116,6 +116,8 @@ func (t *transport) Get(id types.ID) Peer {
 }
 
 func (t *transport) maybeUpdatePeersTerm(term uint64) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
 	if t.term >= term {
 		return
 	}
@@ -192,7 +194,9 @@ func (t *transport) AddPeer(id types.ID, us []string) {
 		plog.Panicf("newURLs %+v should never fail: %+v", us, err)
 	}
 	fs := t.leaderStats.Follower(id.String())
-	t.peers[id] = startPeer(t.roundTripper, urls, t.id, id, t.clusterID, t.raft, fs, t.errorc)
+	p := startPeer(t.roundTripper, urls, t.id, id, t.clusterID, t.raft, fs, t.errorc)
+	p.setTerm(t.term)
+	t.peers[id] = p
 }
 
 func (t *transport) RemovePeer(id types.ID) {

+ 9 - 1
rafthttp/transport_test.go

@@ -67,19 +67,21 @@ func TestTransportSend(t *testing.T) {
 
 func TestTransportAdd(t *testing.T) {
 	ls := stats.NewLeaderStats("")
+	term := uint64(10)
 	tr := &transport{
 		roundTripper: &roundTripperRecorder{},
 		leaderStats:  ls,
+		term:         term,
 		peers:        make(map[types.ID]Peer),
 	}
 	tr.AddPeer(1, []string{"http://localhost:2380"})
-	defer tr.Stop()
 
 	if _, ok := ls.Followers["1"]; !ok {
 		t.Errorf("FollowerStats[1] is nil, want exists")
 	}
 	s, ok := tr.peers[types.ID(1)]
 	if !ok {
+		tr.Stop()
 		t.Fatalf("senders[1] is nil, want exists")
 	}
 
@@ -89,6 +91,12 @@ func TestTransportAdd(t *testing.T) {
 	if s != ns {
 		t.Errorf("sender = %v, want %v", ns, s)
 	}
+
+	tr.Stop()
+
+	if g := s.(*peer).msgAppReader.msgAppTerm; g != term {
+		t.Errorf("peer.term = %d, want %d", g, term)
+	}
 }
 
 func TestTransportRemove(t *testing.T) {