Browse Source

Require a non-nil Storage parameter in newLog.

Callers must in general have a reference to their Storage objects to
transfer entries from Ready to Storage, so it doesn't make sense to
create a hidden Storage for them.

By explicitly creating Storage objects in tests we can remove a
few casts of raftLog's storage field.
Ben Darnell 11 years ago
parent
commit
76a3de9a33
6 changed files with 77 additions and 70 deletions
  1. 1 1
      raft/log.go
  2. 7 7
      raft/log_test.go
  3. 7 4
      raft/node_test.go
  4. 30 27
      raft/raft_paper_test.go
  5. 31 28
      raft/raft_test.go
  6. 1 3
      raft/storage.go

+ 1 - 1
raft/log.go

@@ -47,7 +47,7 @@ type raftLog struct {
 
 func newLog(storage Storage) *raftLog {
 	if storage == nil {
-		storage = NewMemoryStorage()
+		panic("storage must not be nil")
 	}
 	lastIndex, err := storage.GetLastIndex()
 	if err != nil {

+ 7 - 7
raft/log_test.go

@@ -49,7 +49,7 @@ func TestFindConflict(t *testing.T) {
 	}
 
 	for i, tt := range tests {
-		raftLog := newLog(nil)
+		raftLog := newLog(NewMemoryStorage())
 		raftLog.append(raftLog.lastIndex(), previousEnts...)
 
 		gconflict := raftLog.findConflict(tt.from, tt.ents)
@@ -61,7 +61,7 @@ func TestFindConflict(t *testing.T) {
 
 func TestIsUpToDate(t *testing.T) {
 	previousEnts := []pb.Entry{{Term: 1}, {Term: 2}, {Term: 3}}
-	raftLog := newLog(nil)
+	raftLog := newLog(NewMemoryStorage())
 	raftLog.append(raftLog.lastIndex(), previousEnts...)
 	tests := []struct {
 		lastIndex uint64
@@ -241,7 +241,7 @@ func TestLogMaybeAppend(t *testing.T) {
 	}
 
 	for i, tt := range tests {
-		raftLog := newLog(nil)
+		raftLog := newLog(NewMemoryStorage())
 		raftLog.append(raftLog.lastIndex(), previousEnts...)
 		raftLog.committed = commit
 		func() {
@@ -410,7 +410,7 @@ func TestCompaction(t *testing.T) {
 
 func TestLogRestore(t *testing.T) {
 	var i uint64
-	raftLog := newLog(nil)
+	raftLog := newLog(NewMemoryStorage())
 	for i = 0; i < 100; i++ {
 		raftLog.append(i, pb.Entry{Term: i + 1})
 	}
@@ -443,7 +443,7 @@ func TestLogRestore(t *testing.T) {
 func TestIsOutOfBounds(t *testing.T) {
 	offset := uint64(100)
 	num := uint64(100)
-	l := newLog(nil)
+	l := newLog(NewMemoryStorage())
 	l.restore(pb.Snapshot{Index: offset})
 	l.append(offset, make([]pb.Entry, num)...)
 
@@ -471,7 +471,7 @@ func TestAt(t *testing.T) {
 	offset := uint64(100)
 	num := uint64(100)
 
-	l := newLog(nil)
+	l := newLog(NewMemoryStorage())
 	l.restore(pb.Snapshot{Index: offset})
 	for i = 0; i < num; i++ {
 		l.append(offset+i-1, pb.Entry{Term: i})
@@ -501,7 +501,7 @@ func TestSlice(t *testing.T) {
 	offset := uint64(100)
 	num := uint64(100)
 
-	l := newLog(nil)
+	l := newLog(NewMemoryStorage())
 	l.restore(pb.Snapshot{Index: offset})
 	for i = 0; i < num; i++ {
 		l.append(offset+i-1, pb.Entry{Term: i})

+ 7 - 4
raft/node_test.go

@@ -112,7 +112,7 @@ func TestNodeStepUnblock(t *testing.T) {
 // who is the current leader.
 func TestBlockProposal(t *testing.T) {
 	n := newNode()
-	r := newRaft(1, []uint64{1}, 10, 1, nil)
+	r := newRaft(1, []uint64{1}, 10, 1, NewMemoryStorage())
 	go n.run(r)
 	defer n.Stop()
 
@@ -230,7 +230,7 @@ func TestNodeRestart(t *testing.T) {
 		CommittedEntries: entries[1 : st.Commit+1],
 	}
 
-	n := RestartNode(1, 10, 1, nil, st, entries, nil)
+	n := RestartNode(1, 10, 1, nil, st, entries, NewMemoryStorage())
 	if g := <-n.Ready(); !reflect.DeepEqual(g, want) {
 		t.Errorf("g = %+v,\n             w   %+v", g, want)
 	} else {
@@ -303,16 +303,19 @@ func TestNodeAdvance(t *testing.T) {
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
 
-	n := StartNode(1, []Peer{{ID: 1}}, 10, 1, nil)
+	storage := NewMemoryStorage()
+	n := StartNode(1, []Peer{{ID: 1}}, 10, 1, storage)
 	n.ApplyConfChange(raftpb.ConfChange{Type: raftpb.ConfChangeAddNode, NodeID: 1})
 	n.Campaign(ctx)
 	<-n.Ready()
 	n.Propose(ctx, []byte("foo"))
+	var rd Ready
 	select {
-	case rd := <-n.Ready():
+	case rd = <-n.Ready():
 		t.Fatalf("unexpected Ready before Advance: %+v", rd)
 	default:
 	}
+	storage.Append(rd.Entries)
 	n.Advance()
 	select {
 	case <-n.Ready():

+ 30 - 27
raft/raft_paper_test.go

@@ -52,7 +52,7 @@ func TestLeaderUpdateTermFromMessage(t *testing.T) {
 // it immediately reverts to follower state.
 // Reference: section 5.1
 func testUpdateTermFromMessage(t *testing.T, state StateType) {
-	r := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+	r := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 	switch state {
 	case StateFollower:
 		r.becomeFollower(1, 2)
@@ -82,7 +82,7 @@ func TestRejectStaleTermMessage(t *testing.T) {
 	fakeStep := func(r *raft, m pb.Message) {
 		called = true
 	}
-	r := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+	r := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 	r.step = fakeStep
 	r.loadState(pb.HardState{Term: 2})
 
@@ -96,7 +96,7 @@ func TestRejectStaleTermMessage(t *testing.T) {
 // TestStartAsFollower tests that when servers start up, they begin as followers.
 // Reference: section 5.2
 func TestStartAsFollower(t *testing.T) {
-	r := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+	r := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 	if r.state != StateFollower {
 		t.Errorf("state = %s, want %s", r.state, StateFollower)
 	}
@@ -109,7 +109,7 @@ func TestStartAsFollower(t *testing.T) {
 func TestLeaderBcastBeat(t *testing.T) {
 	// heartbeat interval
 	hi := 1
-	r := newRaft(1, []uint64{1, 2, 3}, 10, hi, nil)
+	r := newRaft(1, []uint64{1, 2, 3}, 10, hi, NewMemoryStorage())
 	r.becomeCandidate()
 	r.becomeLeader()
 	for i := 0; i < 10; i++ {
@@ -151,7 +151,7 @@ func TestCandidateStartNewElection(t *testing.T) {
 func testNonleaderStartElection(t *testing.T, state StateType) {
 	// election timeout
 	et := 10
-	r := newRaft(1, []uint64{1, 2, 3}, et, 1, nil)
+	r := newRaft(1, []uint64{1, 2, 3}, et, 1, NewMemoryStorage())
 	switch state {
 	case StateFollower:
 		r.becomeFollower(1, 2)
@@ -215,7 +215,7 @@ func TestLeaderElectionInOneRoundRPC(t *testing.T) {
 		{5, map[uint64]bool{}, StateCandidate},
 	}
 	for i, tt := range tests {
-		r := newRaft(1, idsBySize(tt.size), 10, 1, nil)
+		r := newRaft(1, idsBySize(tt.size), 10, 1, NewMemoryStorage())
 
 		r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgHup})
 		for id, vote := range tt.votes {
@@ -248,7 +248,7 @@ func TestFollowerVote(t *testing.T) {
 		{2, 1, true},
 	}
 	for i, tt := range tests {
-		r := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+		r := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 		r.loadState(pb.HardState{Term: 1, Vote: tt.vote})
 
 		r.Step(pb.Message{From: tt.nvote, To: 1, Term: 1, Type: pb.MsgVote})
@@ -274,7 +274,7 @@ func TestCandidateFallback(t *testing.T) {
 		{From: 2, To: 1, Term: 2, Type: pb.MsgApp},
 	}
 	for i, tt := range tests {
-		r := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+		r := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 		r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgHup})
 		if r.state != StateCandidate {
 			t.Fatalf("unexpected state = %s, want %s", r.state, StateCandidate)
@@ -303,7 +303,7 @@ func TestCandidateElectionTimeoutRandomized(t *testing.T) {
 // Reference: section 5.2
 func testNonleaderElectionTimeoutRandomized(t *testing.T, state StateType) {
 	et := 10
-	r := newRaft(1, []uint64{1, 2, 3}, et, 1, nil)
+	r := newRaft(1, []uint64{1, 2, 3}, et, 1, NewMemoryStorage())
 	timeouts := make(map[int]bool)
 	for round := 0; round < 50*et; round++ {
 		switch state {
@@ -345,7 +345,7 @@ func testNonleadersElectionTimeoutNonconflict(t *testing.T, state StateType) {
 	rs := make([]*raft, size)
 	ids := idsBySize(size)
 	for k := range rs {
-		rs[k] = newRaft(ids[k], ids, et, 1, nil)
+		rs[k] = newRaft(ids[k], ids, et, 1, NewMemoryStorage())
 	}
 	conflicts := 0
 	for round := 0; round < 1000; round++ {
@@ -387,10 +387,11 @@ func testNonleadersElectionTimeoutNonconflict(t *testing.T, state StateType) {
 // Also, it writes the new entry into stable storage.
 // Reference: section 5.3
 func TestLeaderStartReplication(t *testing.T) {
-	r := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+	s := NewMemoryStorage()
+	r := newRaft(1, []uint64{1, 2, 3}, 10, 1, s)
 	r.becomeCandidate()
 	r.becomeLeader()
-	commitNoopEntry(r)
+	commitNoopEntry(r, s)
 	li := r.raftLog.lastIndex()
 
 	ents := []pb.Entry{{Data: []byte("some data")}}
@@ -425,10 +426,11 @@ func TestLeaderStartReplication(t *testing.T) {
 // servers eventually find out.
 // Reference: section 5.3
 func TestLeaderCommitEntry(t *testing.T) {
-	r := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+	s := NewMemoryStorage()
+	r := newRaft(1, []uint64{1, 2, 3}, 10, 1, s)
 	r.becomeCandidate()
 	r.becomeLeader()
-	commitNoopEntry(r)
+	commitNoopEntry(r, s)
 	li := r.raftLog.lastIndex()
 	r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("some data")}}})
 
@@ -478,10 +480,11 @@ func TestLeaderAcknowledgeCommit(t *testing.T) {
 		{5, map[uint64]bool{2: true, 3: true, 4: true, 5: true}, true},
 	}
 	for i, tt := range tests {
-		r := newRaft(1, idsBySize(tt.size), 10, 1, nil)
+		s := NewMemoryStorage()
+		r := newRaft(1, idsBySize(tt.size), 10, 1, s)
 		r.becomeCandidate()
 		r.becomeLeader()
-		commitNoopEntry(r)
+		commitNoopEntry(r, s)
 		li := r.raftLog.lastIndex()
 		r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("some data")}}})
 
@@ -510,7 +513,7 @@ func TestLeaderCommitPrecedingEntries(t *testing.T) {
 		{{Term: 1, Index: 1}},
 	}
 	for i, tt := range tests {
-		r := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+		r := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 		r.loadEnts(append([]pb.Entry{{}}, tt...))
 		r.loadState(pb.HardState{Term: 2})
 		r.becomeCandidate()
@@ -566,7 +569,7 @@ func TestFollowerCommitEntry(t *testing.T) {
 		},
 	}
 	for i, tt := range tests {
-		r := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+		r := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 		r.becomeFollower(1, 2)
 
 		r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgApp, Term: 1, Entries: tt.ents, Commit: tt.commit})
@@ -601,7 +604,7 @@ func TestFollowerCheckMsgApp(t *testing.T) {
 		{3, 3, true},
 	}
 	for i, tt := range tests {
-		r := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+		r := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 		r.loadEnts(ents)
 		r.loadState(pb.HardState{Commit: 2})
 		r.becomeFollower(2, 2)
@@ -656,7 +659,7 @@ func TestFollowerAppendEntries(t *testing.T) {
 		},
 	}
 	for i, tt := range tests {
-		r := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+		r := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 		r.loadEnts([]pb.Entry{{}, {Term: 1, Index: 1}, {Term: 2, Index: 2}})
 		r.becomeFollower(2, 2)
 
@@ -724,10 +727,10 @@ func TestLeaderSyncFollowerLog(t *testing.T) {
 		},
 	}
 	for i, tt := range tests {
-		lead := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+		lead := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 		lead.loadEnts(ents)
 		lead.loadState(pb.HardState{Commit: lead.raftLog.lastIndex(), Term: term})
-		follower := newRaft(2, []uint64{1, 2, 3}, 10, 1, nil)
+		follower := newRaft(2, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 		follower.loadEnts(tt)
 		follower.loadState(pb.HardState{Term: term - 1})
 		// It is necessary to have a three-node cluster.
@@ -757,7 +760,7 @@ func TestVoteRequest(t *testing.T) {
 		{[]pb.Entry{{Term: 1, Index: 1}, {Term: 2, Index: 2}}, 3},
 	}
 	for i, tt := range tests {
-		r := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+		r := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 		r.Step(pb.Message{
 			From: 2, To: 1, Type: pb.MsgApp, Term: tt.wterm - 1, LogTerm: 0, Index: 0, Entries: tt.ents,
 		})
@@ -818,7 +821,7 @@ func TestVoter(t *testing.T) {
 		{[]pb.Entry{{}, {Term: 2, Index: 1}, {Term: 1, Index: 2}}, 1, 1, true},
 	}
 	for i, tt := range tests {
-		r := newRaft(1, []uint64{1, 2}, 10, 1, nil)
+		r := newRaft(1, []uint64{1, 2}, 10, 1, NewMemoryStorage())
 		r.loadEnts(tt.ents)
 
 		r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgVote, Term: 3, LogTerm: tt.logterm, Index: tt.index})
@@ -853,7 +856,7 @@ func TestLeaderOnlyCommitsLogFromCurrentTerm(t *testing.T) {
 		{3, 3},
 	}
 	for i, tt := range tests {
-		r := newRaft(1, []uint64{1, 2}, 10, 1, nil)
+		r := newRaft(1, []uint64{1, 2}, 10, 1, NewMemoryStorage())
 		r.loadEnts(ents)
 		r.loadState(pb.HardState{Term: 2})
 		// become leader at term 3
@@ -876,7 +879,7 @@ func (s messageSlice) Len() int           { return len(s) }
 func (s messageSlice) Less(i, j int) bool { return fmt.Sprint(s[i]) < fmt.Sprint(s[j]) }
 func (s messageSlice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
 
-func commitNoopEntry(r *raft) {
+func commitNoopEntry(r *raft, s *MemoryStorage) {
 	if r.state != StateLeader {
 		panic("it should only be used when it is the leader")
 	}
@@ -891,7 +894,7 @@ func commitNoopEntry(r *raft) {
 	}
 	// ignore further messages to refresh followers' commmit index
 	r.readMessages()
-	r.raftLog.storage.(*MemoryStorage).Append(r.raftLog.unstableEntries())
+	s.Append(r.raftLog.unstableEntries())
 	r.raftLog.appliedTo(r.raftLog.committed)
 	r.raftLog.stableTo(r.raftLog.lastIndex())
 }

+ 31 - 28
raft/raft_test.go

@@ -29,10 +29,9 @@ import (
 )
 
 // nextEnts returns the appliable entries and updates the applied index
-func nextEnts(r *raft) (ents []pb.Entry) {
+func nextEnts(r *raft, s *MemoryStorage) (ents []pb.Entry) {
 	// Transfer all unstable entries to "stable" storage.
-	memStorage := r.raftLog.storage.(*MemoryStorage)
-	memStorage.Append(r.raftLog.unstableEntries())
+	s.Append(r.raftLog.unstableEntries())
 	r.raftLog.stableTo(r.raftLog.lastIndex())
 
 	ents = r.raftLog.nextEnts()
@@ -176,7 +175,7 @@ func TestLogReplication(t *testing.T) {
 			}
 
 			ents := []pb.Entry{}
-			for _, e := range nextEnts(sm) {
+			for _, e := range nextEnts(sm, tt.network.storage[j]) {
 				if e.Data != nil {
 					ents = append(ents, e)
 				}
@@ -285,9 +284,9 @@ func TestCommitWithoutNewTermEntry(t *testing.T) {
 }
 
 func TestDuelingCandidates(t *testing.T) {
-	a := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
-	b := newRaft(2, []uint64{1, 2, 3}, 10, 1, nil)
-	c := newRaft(3, []uint64{1, 2, 3}, 10, 1, nil)
+	a := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
+	b := newRaft(2, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
+	c := newRaft(3, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 
 	nt := newNetwork(a, b, c)
 	nt.cut(1, 3)
@@ -311,7 +310,7 @@ func TestDuelingCandidates(t *testing.T) {
 	}{
 		{a, StateFollower, 2, wlog},
 		{b, StateFollower, 2, wlog},
-		{c, StateFollower, 2, newLog(nil)},
+		{c, StateFollower, 2, newLog(NewMemoryStorage())},
 	}
 
 	for i, tt := range tests {
@@ -450,7 +449,7 @@ func TestProposal(t *testing.T) {
 		send(pb.Message{From: 1, To: 1, Type: pb.MsgHup})
 		send(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: data}}})
 
-		wantLog := newLog(nil)
+		wantLog := newLog(NewMemoryStorage())
 		if tt.success {
 			wantLog = &raftLog{
 				storage: &MemoryStorage{
@@ -620,7 +619,7 @@ func TestIsElectionTimeout(t *testing.T) {
 	}
 
 	for i, tt := range tests {
-		sm := newRaft(1, []uint64{1}, 10, 1, nil)
+		sm := newRaft(1, []uint64{1}, 10, 1, NewMemoryStorage())
 		sm.elapsed = tt.elapse
 		c := 0
 		for j := 0; j < 10000; j++ {
@@ -645,7 +644,7 @@ func TestStepIgnoreOldTermMsg(t *testing.T) {
 	fakeStep := func(r *raft, m pb.Message) {
 		called = true
 	}
-	sm := newRaft(1, []uint64{1}, 10, 1, nil)
+	sm := newRaft(1, []uint64{1}, 10, 1, NewMemoryStorage())
 	sm.step = fakeStep
 	sm.Term = 2
 	sm.Step(pb.Message{Type: pb.MsgApp, Term: sm.Term - 1})
@@ -747,7 +746,7 @@ func TestRecvMsgVote(t *testing.T) {
 	}
 
 	for i, tt := range tests {
-		sm := newRaft(1, []uint64{1}, 10, 1, nil)
+		sm := newRaft(1, []uint64{1}, 10, 1, NewMemoryStorage())
 		sm.state = tt.state
 		switch tt.state {
 		case StateFollower:
@@ -807,7 +806,7 @@ func TestStateTransition(t *testing.T) {
 				}
 			}()
 
-			sm := newRaft(1, []uint64{1}, 10, 1, nil)
+			sm := newRaft(1, []uint64{1}, 10, 1, NewMemoryStorage())
 			sm.state = tt.from
 
 			switch tt.to {
@@ -846,7 +845,7 @@ func TestAllServerStepdown(t *testing.T) {
 	tterm := uint64(3)
 
 	for i, tt := range tests {
-		sm := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+		sm := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 		switch tt.state {
 		case StateFollower:
 			sm.becomeFollower(1, None)
@@ -902,7 +901,7 @@ func TestLeaderAppResp(t *testing.T) {
 	for i, tt := range tests {
 		// sm term is 1 after it becomes the leader.
 		// thus the last log term must be 1 to be committed.
-		sm := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+		sm := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 		sm.raftLog = &raftLog{
 			storage:  &MemoryStorage{ents: []pb.Entry{{}, {Term: 0}, {Term: 1}}},
 			unstable: 3,
@@ -946,7 +945,7 @@ func TestBcastBeat(t *testing.T) {
 		Term:  1,
 		Nodes: []uint64{1, 2, 3},
 	}
-	sm := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+	sm := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 	sm.Term = 1
 	sm.restore(s)
 
@@ -996,7 +995,7 @@ func TestRecvMsgBeat(t *testing.T) {
 	}
 
 	for i, tt := range tests {
-		sm := newRaft(1, []uint64{1, 2, 3}, 10, 1, nil)
+		sm := newRaft(1, []uint64{1, 2, 3}, 10, 1, NewMemoryStorage())
 		sm.raftLog = &raftLog{storage: &MemoryStorage{ents: []pb.Entry{{}, {Term: 0}, {Term: 1}}}}
 		sm.Term = 1
 		sm.state = tt.state
@@ -1029,7 +1028,7 @@ func TestRestore(t *testing.T) {
 		Nodes: []uint64{1, 2, 3},
 	}
 
-	sm := newRaft(1, []uint64{1, 2}, 10, 1, nil)
+	sm := newRaft(1, []uint64{1, 2}, 10, 1, NewMemoryStorage())
 	if ok := sm.restore(s); !ok {
 		t.Fatal("restore fail, want succeed")
 	}
@@ -1060,7 +1059,7 @@ func TestProvideSnap(t *testing.T) {
 		Term:  11, // magic number
 		Nodes: []uint64{1, 2},
 	}
-	sm := newRaft(1, []uint64{1}, 10, 1, nil)
+	sm := newRaft(1, []uint64{1}, 10, 1, NewMemoryStorage())
 	// restore the statemachin from a snapshot
 	// so it has a compacted log and a snapshot
 	sm.restore(s)
@@ -1091,7 +1090,7 @@ func TestRestoreFromSnapMsg(t *testing.T) {
 	}
 	m := pb.Message{Type: pb.MsgSnap, From: 1, Term: 2, Snapshot: s}
 
-	sm := newRaft(2, []uint64{1, 2}, 10, 1, nil)
+	sm := newRaft(2, []uint64{1, 2}, 10, 1, NewMemoryStorage())
 	sm.Step(m)
 
 	if !reflect.DeepEqual(sm.raftLog.snapshot, s) {
@@ -1108,7 +1107,7 @@ func TestSlowNodeRestore(t *testing.T) {
 		nt.send(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{}}})
 	}
 	lead := nt.peers[1].(*raft)
-	nextEnts(lead)
+	nextEnts(lead, nt.storage[1])
 	lead.compact(lead.raftLog.applied, lead.nodes(), nil)
 
 	nt.recover()
@@ -1130,7 +1129,7 @@ func TestSlowNodeRestore(t *testing.T) {
 // it appends the entry to log and sets pendingConf to be true.
 func TestStepConfig(t *testing.T) {
 	// a raft that cannot make progress
-	r := newRaft(1, []uint64{1, 2}, 10, 1, nil)
+	r := newRaft(1, []uint64{1, 2}, 10, 1, NewMemoryStorage())
 	r.becomeCandidate()
 	r.becomeLeader()
 	index := r.raftLog.lastIndex()
@@ -1148,7 +1147,7 @@ func TestStepConfig(t *testing.T) {
 // the proposal and keep its original state.
 func TestStepIgnoreConfig(t *testing.T) {
 	// a raft that cannot make progress
-	r := newRaft(1, []uint64{1, 2}, 10, 1, nil)
+	r := newRaft(1, []uint64{1, 2}, 10, 1, NewMemoryStorage())
 	r.becomeCandidate()
 	r.becomeLeader()
 	r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Type: pb.EntryConfChange}}})
@@ -1174,7 +1173,7 @@ func TestRecoverPendingConfig(t *testing.T) {
 		{pb.EntryConfChange, true},
 	}
 	for i, tt := range tests {
-		r := newRaft(1, []uint64{1, 2}, 10, 1, nil)
+		r := newRaft(1, []uint64{1, 2}, 10, 1, NewMemoryStorage())
 		r.appendEntry(pb.Entry{Type: tt.entType})
 		r.becomeCandidate()
 		r.becomeLeader()
@@ -1193,7 +1192,7 @@ func TestRecoverDoublePendingConfig(t *testing.T) {
 				t.Errorf("expect panic, but nothing happens")
 			}
 		}()
-		r := newRaft(1, []uint64{1, 2}, 10, 1, nil)
+		r := newRaft(1, []uint64{1, 2}, 10, 1, NewMemoryStorage())
 		r.appendEntry(pb.Entry{Type: pb.EntryConfChange})
 		r.appendEntry(pb.Entry{Type: pb.EntryConfChange})
 		r.becomeCandidate()
@@ -1203,7 +1202,7 @@ func TestRecoverDoublePendingConfig(t *testing.T) {
 
 // TestAddNode tests that addNode could update pendingConf and nodes correctly.
 func TestAddNode(t *testing.T) {
-	r := newRaft(1, []uint64{1}, 10, 1, nil)
+	r := newRaft(1, []uint64{1}, 10, 1, NewMemoryStorage())
 	r.pendingConf = true
 	r.addNode(2)
 	if r.pendingConf != false {
@@ -1220,7 +1219,7 @@ func TestAddNode(t *testing.T) {
 // TestRemoveNode tests that removeNode could update pendingConf, nodes and
 // and removed list correctly.
 func TestRemoveNode(t *testing.T) {
-	r := newRaft(1, []uint64{1, 2}, 10, 1, nil)
+	r := newRaft(1, []uint64{1, 2}, 10, 1, NewMemoryStorage())
 	r.pendingConf = true
 	r.removeNode(2)
 	if r.pendingConf != false {
@@ -1272,6 +1271,7 @@ func ents(terms ...uint64) *raft {
 
 type network struct {
 	peers   map[uint64]Interface
+	storage map[uint64]*MemoryStorage
 	dropm   map[connem]float64
 	ignorem map[pb.MessageType]bool
 }
@@ -1285,12 +1285,14 @@ func newNetwork(peers ...Interface) *network {
 	peerAddrs := idsBySize(size)
 
 	npeers := make(map[uint64]Interface, size)
+	nstorage := make(map[uint64]*MemoryStorage, size)
 
 	for i, p := range peers {
 		id := peerAddrs[i]
 		switch v := p.(type) {
 		case nil:
-			sm := newRaft(id, peerAddrs, 10, 1, nil)
+			nstorage[id] = NewMemoryStorage()
+			sm := newRaft(id, peerAddrs, 10, 1, nstorage[id])
 			npeers[id] = sm
 		case *raft:
 			v.id = id
@@ -1308,6 +1310,7 @@ func newNetwork(peers ...Interface) *network {
 	}
 	return &network{
 		peers:   npeers,
+		storage: nstorage,
 		dropm:   make(map[connem]float64),
 		ignorem: make(map[pb.MessageType]bool),
 	}

+ 1 - 3
raft/storage.go

@@ -23,9 +23,7 @@ import (
 )
 
 // Storage is an interface that may be implemented by the application
-// to retrieve log entries from storage. If no storage implementation
-// is supplied by the application, a MemoryStorage will be used, which
-// retains all log entries in memory.
+// to retrieve log entries from storage.
 //
 // If any Storage method returns an error, the raft instance will
 // become inoperable and refuse to participate in elections; the