Browse Source

raft: heartbeat should not contain entries

Xiang Li 11 years ago
parent
commit
2665cc1cc8
4 changed files with 34 additions and 3 deletions
  1. 7 0
      raft/log.go
  2. 1 1
      raft/node_test.go
  3. 25 1
      raft/raft.go
  4. 1 1
      raft/raft_test.go

+ 7 - 0
raft/log.go

@@ -183,3 +183,10 @@ func min(a, b int64) int64 {
 	}
 	}
 	return a
 	return a
 }
 }
+
+func max(a, b int64) int64 {
+	if a > b {
+		return a
+	}
+	return b
+}

+ 1 - 1
raft/node_test.go

@@ -54,7 +54,7 @@ func TestTickMsgBeat(t *testing.T) {
 
 
 	called := 0
 	called := 0
 	for _, m := range n.Msgs() {
 	for _, m := range n.Msgs() {
-		if m.Type == msgApp {
+		if m.Type == msgApp && len(m.Entries) == 0 {
 			called++
 			called++
 		}
 		}
 	}
 	}

+ 25 - 1
raft/raft.go

@@ -196,6 +196,20 @@ func (sm *stateMachine) sendAppend(to int64) {
 	sm.send(m)
 	sm.send(m)
 }
 }
 
 
+// sendHeartbeat sends RRPC, without entries to the given peer.
+func (sm *stateMachine) sendHeartbeat(to int64) {
+	in := sm.ins[to]
+	index := max(in.next-1, sm.log.lastIndex())
+	m := Message{
+		To:      to,
+		Type:    msgApp,
+		Index:   index,
+		LogTerm: sm.log.term(index),
+		Commit:  sm.log.committed,
+	}
+	sm.send(m)
+}
+
 // bcastAppend sends RRPC, with entries to all peers that are not up-to-date according to sm.mis.
 // bcastAppend sends RRPC, with entries to all peers that are not up-to-date according to sm.mis.
 func (sm *stateMachine) bcastAppend() {
 func (sm *stateMachine) bcastAppend() {
 	for i := range sm.ins {
 	for i := range sm.ins {
@@ -206,6 +220,16 @@ func (sm *stateMachine) bcastAppend() {
 	}
 	}
 }
 }
 
 
+// bcastHeartbeat sends RRPC, without entries to all the peers.
+func (sm *stateMachine) bcastHeartbeat() {
+	for i := range sm.ins {
+		if i == sm.id {
+			continue
+		}
+		sm.sendHeartbeat(i)
+	}
+}
+
 func (sm *stateMachine) maybeCommit() bool {
 func (sm *stateMachine) maybeCommit() bool {
 	// TODO(bmizerany): optimize.. Currently naive
 	// TODO(bmizerany): optimize.. Currently naive
 	mis := make(int64Slice, 0, len(sm.ins))
 	mis := make(int64Slice, 0, len(sm.ins))
@@ -359,7 +383,7 @@ type stepFunc func(sm *stateMachine, m Message) bool
 func stepLeader(sm *stateMachine, m Message) bool {
 func stepLeader(sm *stateMachine, m Message) bool {
 	switch m.Type {
 	switch m.Type {
 	case msgBeat:
 	case msgBeat:
-		sm.bcastAppend()
+		sm.bcastHeartbeat()
 	case msgProp:
 	case msgProp:
 		if len(m.Entries) != 1 {
 		if len(m.Entries) != 1 {
 			panic("unexpected length(entries) of a msgProp")
 			panic("unexpected length(entries) of a msgProp")

+ 1 - 1
raft/raft_test.go

@@ -897,7 +897,7 @@ func TestProvideSnap(t *testing.T) {
 	// node 1 needs a snapshot
 	// node 1 needs a snapshot
 	sm.ins[1].next = sm.log.offset
 	sm.ins[1].next = sm.log.offset
 
 
-	sm.Step(Message{From: 0, To: 0, Type: msgBeat})
+	sm.Step(Message{From: 1, To: 0, Type: msgAppResp, Index: -1})
 	msgs = sm.Msgs()
 	msgs = sm.Msgs()
 	if len(msgs) != 1 {
 	if len(msgs) != 1 {
 		t.Errorf("len(msgs) = %d, want 1", len(msgs))
 		t.Errorf("len(msgs) = %d, want 1", len(msgs))