Browse Source

Merge pull request #1026 from coreos/fix_node

Fix node
Xiang Li 11 years ago
parent
commit
1a677164be
4 changed files with 48 additions and 15 deletions
  1. 28 14
      raft/node.go
  2. 1 1
      raft/node_test.go
  3. 4 0
      wal/wal.go
  4. 15 0
      wal/wal_test.go

+ 28 - 14
raft/node.go

@@ -9,7 +9,10 @@ import (
 	"github.com/coreos/etcd/third_party/code.google.com/p/go.net/context"
 	"github.com/coreos/etcd/third_party/code.google.com/p/go.net/context"
 )
 )
 
 
-var ErrStopped = errors.New("raft: stopped")
+var (
+	emptyState = pb.State{}
+	ErrStopped = errors.New("raft: stopped")
+)
 
 
 // Ready encapsulates the entries and messages that are ready to be saved to
 // Ready encapsulates the entries and messages that are ready to be saved to
 // stable storage, committed or sent to other peers.
 // stable storage, committed or sent to other peers.
@@ -35,8 +38,12 @@ func isStateEqual(a, b pb.State) bool {
 	return a.Term == b.Term && a.Vote == b.Vote && a.LastIndex == b.LastIndex
 	return a.Term == b.Term && a.Vote == b.Vote && a.LastIndex == b.LastIndex
 }
 }
 
 
-func (rd Ready) containsUpdates(prev Ready) bool {
-	return !isStateEqual(prev.State, rd.State) || len(rd.Entries) > 0 || len(rd.CommittedEntries) > 0 || len(rd.Messages) > 0
+func IsEmptyState(st pb.State) bool {
+	return isStateEqual(st, emptyState)
+}
+
+func (rd Ready) containsUpdates() bool {
+	return !IsEmptyState(rd.State) || len(rd.Entries) > 0 || len(rd.CommittedEntries) > 0 || len(rd.Messages) > 0
 }
 }
 
 
 type Node struct {
 type Node struct {
@@ -83,8 +90,7 @@ func (n *Node) run(r *raft) {
 	readyc := n.readyc
 	readyc := n.readyc
 
 
 	var lead int64
 	var lead int64
-	var prev Ready
-	prev.State = r.State
+	prevSt := r.State
 
 
 	for {
 	for {
 		if lead != r.lead {
 		if lead != r.lead {
@@ -97,16 +103,9 @@ func (n *Node) run(r *raft) {
 			}
 			}
 		}
 		}
 
 
-		rd := Ready{
-			r.State,
-			r.raftLog.unstableEnts(),
-			r.raftLog.nextEnts(),
-			r.msgs,
-		}
-
-		if rd.containsUpdates(prev) {
+		rd := newReady(r, prevSt)
+		if rd.containsUpdates() {
 			readyc = n.readyc
 			readyc = n.readyc
-			prev = rd
 		} else {
 		} else {
 			readyc = nil
 			readyc = nil
 		}
 		}
@@ -122,6 +121,9 @@ func (n *Node) run(r *raft) {
 		case readyc <- rd:
 		case readyc <- rd:
 			r.raftLog.resetNextEnts()
 			r.raftLog.resetNextEnts()
 			r.raftLog.resetUnstable()
 			r.raftLog.resetUnstable()
+			if !IsEmptyState(rd.State) {
+				prevSt = rd.State
+			}
 			r.msgs = nil
 			r.msgs = nil
 		case <-n.done:
 		case <-n.done:
 			return
 			return
@@ -169,3 +171,15 @@ func (n *Node) Step(ctx context.Context, m pb.Message) error {
 func (n *Node) Ready() <-chan Ready {
 func (n *Node) Ready() <-chan Ready {
 	return n.readyc
 	return n.readyc
 }
 }
+
+func newReady(r *raft, prev pb.State) Ready {
+	rd := Ready{
+		Entries:          r.raftLog.unstableEnts(),
+		CommittedEntries: r.raftLog.nextEnts(),
+		Messages:         r.msgs,
+	}
+	if !isStateEqual(r.State, prev) {
+		rd.State = r.State
+	}
+	return rd
+}

+ 1 - 1
raft/node_test.go

@@ -51,7 +51,7 @@ func TestNodeRestart(t *testing.T) {
 	st := raftpb.State{Term: 1, Vote: -1, Commit: 1, LastIndex: 2}
 	st := raftpb.State{Term: 1, Vote: -1, Commit: 1, LastIndex: 2}
 
 
 	want := Ready{
 	want := Ready{
-		State: st,
+		State: emptyState,
 		// commit upto index commit index in st
 		// commit upto index commit index in st
 		CommittedEntries: entries[:st.Commit],
 		CommittedEntries: entries[:st.Commit],
 	}
 	}

+ 4 - 0
wal/wal.go

@@ -26,6 +26,7 @@ import (
 	"path"
 	"path"
 	"sort"
 	"sort"
 
 
+	"github.com/coreos/etcd/raft"
 	"github.com/coreos/etcd/raft/raftpb"
 	"github.com/coreos/etcd/raft/raftpb"
 	"github.com/coreos/etcd/wal/walpb"
 	"github.com/coreos/etcd/wal/walpb"
 )
 )
@@ -253,6 +254,9 @@ func (w *WAL) SaveEntry(e *raftpb.Entry) error {
 }
 }
 
 
 func (w *WAL) SaveState(s *raftpb.State) error {
 func (w *WAL) SaveState(s *raftpb.State) error {
+	if raft.IsEmptyState(*s) {
+		return nil
+	}
 	log.Printf("path=%s wal.saveState state=\"%+v\"", w.f.Name(), s)
 	log.Printf("path=%s wal.saveState state=\"%+v\"", w.f.Name(), s)
 	b, err := s.Marshal()
 	b, err := s.Marshal()
 	if err != nil {
 	if err != nil {

+ 15 - 0
wal/wal_test.go

@@ -17,6 +17,7 @@ limitations under the License.
 package wal
 package wal
 
 
 import (
 import (
+	"bytes"
 	"fmt"
 	"fmt"
 	"io/ioutil"
 	"io/ioutil"
 	"os"
 	"os"
@@ -322,3 +323,17 @@ func TestRecoverAfterCut(t *testing.T) {
 		}
 		}
 	}
 	}
 }
 }
+
+func TestSaveEmpty(t *testing.T) {
+	var buf bytes.Buffer
+	var est raftpb.State
+	w := WAL{
+		encoder: newEncoder(&buf, 0),
+	}
+	if err := w.SaveState(&est); err != nil {
+		t.Errorf("err = %v, want nil", err)
+	}
+	if len(buf.Bytes()) != 0 {
+		t.Errorf("buf.Bytes = %d, want 0", len(buf.Bytes()))
+	}
+}