Browse Source

raft: new log struct

Xiang Li 11 years ago
parent
commit
4c609ec59c
3 changed files with 114 additions and 63 deletions
  1. 74 0
      raft/log.go
  2. 16 41
      raft/raft.go
  3. 24 22
      raft/raft_test.go

+ 74 - 0
raft/log.go

@@ -0,0 +1,74 @@
+package raft
+
+type Entry struct {
+	Term int
+	Data []byte
+}
+
+type log struct {
+	ents    []Entry
+	commit  int
+	applied int
+}
+
+func newLog() *log {
+	return &log{
+		ents:    make([]Entry, 1, 1024),
+		commit:  0,
+		applied: 0,
+	}
+}
+
+func (l *log) maybeAppend(index, logTerm int, ents ...Entry) bool {
+	if l.isOk(index, logTerm) {
+		l.append(index, ents...)
+		return true
+	}
+	return false
+}
+
+func (l *log) append(after int, ents ...Entry) int {
+	l.ents = append(l.ents[:after+1], ents...)
+	return len(l.ents) - 1
+}
+
+func (l *log) len() int {
+	return len(l.ents) - 1
+}
+
+func (l *log) term(i int) int {
+	if i > l.len() {
+		return -1
+	}
+	return l.ents[i].Term
+}
+
+func (l *log) entries(i int) []Entry {
+	if i > l.len() {
+		return nil
+	}
+	return l.ents[i:]
+}
+
+func (l *log) isUpToDate(i, term int) bool {
+	// LET upToDate == \/ m.mlastLogTerm > LastTerm(log[i])
+	//              \/ /\ m.mlastLogTerm = LastTerm(log[i])
+	//                 /\ m.mlastLogIndex >= Len(log[i])
+	e := l.ents[l.len()]
+	return term > e.Term || (term == e.Term && i >= l.len())
+}
+
+func (l *log) isOk(i, term int) bool {
+	if i > l.len() {
+		return false
+	}
+	return l.ents[i].Term == term
+}
+
+func (l *log) nextEnts() (ents []Entry) {
+	if l.commit > l.applied {
+		ents = l.ents[l.applied+1 : l.commit+1]
+		l.applied = l.commit
+	}
+	return ents
+}

+ 16 - 41
raft/raft.go

@@ -51,11 +51,6 @@ func (st stateType) String() string {
 	return stmap[int(st)]
 }
 
-type Entry struct {
-	Term int
-	Data []byte
-}
-
 type Message struct {
 	Type     messageType
 	To       int
@@ -98,15 +93,12 @@ type stateMachine struct {
 	vote int
 
 	// the log
-	log []Entry
+	log *log
 
 	ins []*index
 
 	state stateType
 
-	commit  int
-	applied int
-
 	votes map[int]bool
 
 	msgs []Message
@@ -116,8 +108,7 @@ type stateMachine struct {
 }
 
 func newStateMachine(k, addr int) *stateMachine {
-	log := make([]Entry, 1, 1024)
-	sm := &stateMachine{k: k, addr: addr, log: log}
+	sm := &stateMachine{k: k, addr: addr, log: newLog()}
 	sm.reset()
 	return sm
 }
@@ -141,16 +132,8 @@ func (sm *stateMachine) poll(addr int, v bool) (granted int) {
 	return granted
 }
 
-func (sm *stateMachine) append(after int, ents ...Entry) int {
-	sm.log = append(sm.log[:after+1], ents...)
-	return len(sm.log) - 1
-}
-
 func (sm *stateMachine) isLogOk(i, term int) bool {
-	if i > sm.li() {
-		return false
-	}
-	return sm.log[i].Term == term
+	return sm.log.isOk(i, term)
 }
 
 // send persists state to stable storage and then sends to its mailbox
@@ -171,9 +154,9 @@ func (sm *stateMachine) sendAppend() {
 		m.Type = msgApp
 		m.To = i
 		m.Index = in.next - 1
-		m.LogTerm = sm.log[in.next-1].Term
-		m.Entries = sm.log[in.next:]
-		m.Commit = sm.commit
+		m.LogTerm = sm.log.term(in.next - 1)
+		m.Entries = sm.log.entries(in.next)
+		m.Commit = sm.log.commit
 		sm.send(m)
 	}
 }
@@ -187,8 +170,8 @@ func (sm *stateMachine) maybeCommit() bool {
 	sort.Sort(sort.Reverse(sort.IntSlice(mis)))
 	mci := mis[sm.q()-1]
 
-	if mci > sm.commit && sm.log[mci].Term == sm.term {
-		sm.commit = mci
+	if mci > sm.log.commit && sm.log.term(mci) == sm.term {
+		sm.log.commit = mci
 		return true
 	}
 
@@ -197,11 +180,7 @@ func (sm *stateMachine) maybeCommit() bool {
 
 // nextEnts returns the appliable entries and updates the applied index
 func (sm *stateMachine) nextEnts() (ents []Entry) {
-	if sm.commit > sm.applied {
-		ents = sm.log[sm.applied+1 : sm.commit+1]
-		sm.applied = sm.commit
-	}
-	return ents
+	return sm.log.nextEnts()
 }
 
 func (sm *stateMachine) reset() {
@@ -210,7 +189,7 @@ func (sm *stateMachine) reset() {
 	sm.votes = make(map[int]bool)
 	sm.ins = make([]*index, sm.k)
 	for i := range sm.ins {
-		sm.ins[i] = &index{next: len(sm.log)}
+		sm.ins[i] = &index{next: sm.log.len() + 1}
 	}
 }
 
@@ -219,15 +198,11 @@ func (sm *stateMachine) q() int {
 }
 
 func (sm *stateMachine) voteWorthy(i, term int) bool {
-	// LET logOk == \/ m.mlastLogTerm > LastTerm(log[i])
-	//              \/ /\ m.mlastLogTerm = LastTerm(log[i])
-	//                 /\ m.mlastLogIndex >= Len(log[i])
-	e := sm.log[sm.li()]
-	return term > e.Term || (term == e.Term && i >= sm.li())
+	return sm.log.isUpToDate(i, term)
 }
 
 func (sm *stateMachine) li() int {
-	return len(sm.log) - 1
+	return sm.log.len()
 }
 
 func (sm *stateMachine) becomeFollower(term, lead int) {
@@ -275,13 +250,13 @@ func (sm *stateMachine) Step(m Message) {
 				continue
 			}
 			lasti := sm.li()
-			sm.send(Message{To: i, Type: msgVote, Index: lasti, LogTerm: sm.log[lasti].Term})
+			sm.send(Message{To: i, Type: msgVote, Index: lasti, LogTerm: sm.log.term(lasti)})
 		}
 		return
 	case msgProp:
 		switch sm.lead {
 		case sm.addr:
-			sm.append(sm.li(), Entry{Term: sm.term, Data: m.Data})
+			sm.log.append(sm.log.len(), Entry{Term: sm.term, Data: m.Data})
 			sm.sendAppend()
 		case none:
 			panic("msgProp given without leader")
@@ -302,8 +277,8 @@ func (sm *stateMachine) Step(m Message) {
 
 	handleAppendEntries := func() {
 		if sm.isLogOk(m.Index, m.LogTerm) {
-			sm.commit = m.Commit
-			sm.append(m.Index, m.Entries...)
+			sm.log.commit = m.Commit
+			sm.log.append(m.Index, m.Entries...)
 			sm.send(Message{To: m.From, Type: msgAppResp, Index: sm.li()})
 		} else {
 			sm.send(Message{To: m.From, Type: msgAppResp, Index: -1})

+ 24 - 22
raft/raft_test.go

@@ -7,8 +7,6 @@ import (
 	"testing"
 )
 
-var defaultLog = []Entry{{}}
-
 func TestLeaderElection(t *testing.T) {
 	tests := []struct {
 		*network
@@ -24,9 +22,9 @@ func TestLeaderElection(t *testing.T) {
 		{
 			newNetwork(
 				nil,
-				&nsm{stateMachine{log: []Entry{{}, {Term: 1}}}, nil},
-				&nsm{stateMachine{log: []Entry{{}, {Term: 2}}}, nil},
-				&nsm{stateMachine{log: []Entry{{}, {Term: 1}, {Term: 3}}}, nil},
+				&nsm{stateMachine{log: &log{ents: []Entry{{}, {Term: 1}}}}, nil},
+				&nsm{stateMachine{log: &log{ents: []Entry{{}, {Term: 2}}}}, nil},
+				&nsm{stateMachine{log: &log{ents: []Entry{{}, {Term: 1}, {Term: 3}}}}, nil},
 				nil,
 			),
 			stateFollower,
@@ -35,10 +33,10 @@ func TestLeaderElection(t *testing.T) {
 		// logs converge
 		{
 			newNetwork(
-				&nsm{stateMachine{log: []Entry{{}, {Term: 1}}}, nil},
+				&nsm{stateMachine{log: &log{ents: []Entry{{}, {Term: 1}}}}, nil},
 				nil,
-				&nsm{stateMachine{log: []Entry{{}, {Term: 2}}}, nil},
-				&nsm{stateMachine{log: []Entry{{}, {Term: 1}}}, nil},
+				&nsm{stateMachine{log: &log{ents: []Entry{{}, {Term: 2}}}}, nil},
+				&nsm{stateMachine{log: &log{ents: []Entry{{}, {Term: 1}}}}, nil},
 				nil,
 			),
 			stateLeader,
@@ -94,8 +92,8 @@ func TestLogReplication(t *testing.T) {
 		for j, ism := range tt.ss {
 			sm := ism.(*nsm)
 
-			if sm.commit != tt.wcommit {
-				t.Errorf("#%d.%d: commit = %d, want %d", i, j, sm.commit, tt.wcommit)
+			if sm.log.commit != tt.wcommit {
+				t.Errorf("#%d.%d: commit = %d, want %d", i, j, sm.log.commit, tt.wcommit)
 			}
 
 			ents := sm.nextEnts()
@@ -115,8 +113,8 @@ func TestLogReplication(t *testing.T) {
 }
 
 func TestDualingCandidates(t *testing.T) {
-	a := &nsm{stateMachine{log: defaultLog}, nil}
-	c := &nsm{stateMachine{log: defaultLog}, nil}
+	a := &nsm{stateMachine{log: defaultLog()}, nil}
+	c := &nsm{stateMachine{log: defaultLog()}, nil}
 
 	tt := newNetwork(a, nil, c)
 
@@ -156,7 +154,7 @@ func TestDualingCandidates(t *testing.T) {
 			t.Errorf("#%d: term = %d, want %d", i, g, tt.term)
 		}
 	}
-	if g := diffLogs(defaultLog, tt.logs()); g != nil {
+	if g := diffLogs(defaultLog().ents, tt.logs()); g != nil {
 		for _, diff := range g {
 			t.Errorf("bag log:\n%s", diff)
 		}
@@ -164,7 +162,7 @@ func TestDualingCandidates(t *testing.T) {
 }
 
 func TestCandidateConcede(t *testing.T) {
-	a := &nsm{stateMachine{log: defaultLog}, nil}
+	a := &nsm{stateMachine{log: defaultLog()}, nil}
 
 	tt := newNetwork(a, nil, nil)
 	tt.tee = stepperFunc(func(m Message) {
@@ -205,7 +203,7 @@ func TestOldMessages(t *testing.T) {
 	tt.Step(Message{To: 0, Type: msgHup})
 	// pretend we're an old leader trying to make progress
 	tt.Step(Message{To: 0, Type: msgApp, Term: 1, Entries: []Entry{{Term: 1}}})
-	if g := diffLogs(defaultLog, tt.logs()); g != nil {
+	if g := diffLogs(defaultLog().ents, tt.logs()); g != nil {
 		for _, diff := range g {
 			t.Errorf("bag log:\n%s", diff)
 		}
@@ -255,7 +253,7 @@ func TestProposal(t *testing.T) {
 		if tt.success {
 			wantLog = []Entry{{}, {Term: 1, Data: data}}
 		} else {
-			wantLog = defaultLog
+			wantLog = defaultLog().ents
 		}
 		if g := diffLogs(wantLog, tt.logs()); g != nil {
 			for _, diff := range g {
@@ -327,9 +325,9 @@ func TestCommit(t *testing.T) {
 		for j := 0; j < len(ins); j++ {
 			ins[j] = &index{tt.matches[j], tt.matches[j] + 1}
 		}
-		sm := &stateMachine{log: tt.logs, ins: ins, k: len(ins), term: tt.smTerm}
+		sm := &stateMachine{log: &log{ents: tt.logs}, ins: ins, k: len(ins), term: tt.smTerm}
 		sm.maybeCommit()
-		if g := sm.commit; g != tt.w {
+		if g := sm.log.commit; g != tt.w {
 			t.Errorf("#%d: commit = %d, want %d", i, g, tt.w)
 		}
 	}
@@ -363,7 +361,7 @@ func TestVote(t *testing.T) {
 
 	for i, tt := range tests {
 		called := false
-		sm := &nsm{stateMachine{log: []Entry{{}, {Term: 2}, {Term: 2}}}, nil}
+		sm := &nsm{stateMachine{log: &log{ents: []Entry{{}, {Term: 2}, {Term: 2}}}}, nil}
 		sm.next = stepperFunc(func(m Message) {
 			called = true
 			if m.Index != tt.w {
@@ -410,8 +408,8 @@ func TestAllServerStepdown(t *testing.T) {
 			if sm.term != want.term {
 				t.Errorf("#%d.%d term = %v , want %v", i, j, sm.term, want.term)
 			}
-			if len(sm.log) != want.index {
-				t.Errorf("#%d.%d index = %v , want %v", i, j, len(sm.log), want.index)
+			if len(sm.log.ents) != want.index {
+				t.Errorf("#%d.%d index = %v , want %v", i, j, len(sm.log.ents), want.index)
 			}
 		}
 	}
@@ -474,7 +472,7 @@ func (nt network) logs() [][]Entry {
 	ls := make([][]Entry, len(nt.ss))
 	for i, node := range nt.ss {
 		if sm, ok := node.(*nsm); ok {
-			ls[i] = sm.log
+			ls[i] = sm.log.ents
 		}
 	}
 	return ls
@@ -573,3 +571,7 @@ func (n *nsm) Step(m Message) {
 		n.next.Step(m)
 	}
 }
+
+func defaultLog() *log {
+	return &log{ents: []Entry{{}}}
+}