Browse Source

Merge pull request #1260 from coreos/snap_rm

raft: save removed nodes in snapshot
Xiang Li 11 years ago
parent
commit
38af14b0f4
6 changed files with 95 additions and 27 deletions
  1. 6 5
      raft/log.go
  2. 5 4
      raft/node_test.go
  3. 16 1
      raft/raft.go
  4. 27 13
      raft/raft_test.go
  5. 36 0
      raft/raftpb/raft.pb.go
  6. 5 4
      raft/raftpb/raft.proto

+ 6 - 5
raft/log.go

@@ -168,12 +168,13 @@ func (l *raftLog) compact(i int64) int64 {
 	return int64(len(l.ents))
 }
 
-func (l *raftLog) snap(d []byte, index, term int64, nodes []int64) {
+func (l *raftLog) snap(d []byte, index, term int64, nodes []int64, removed []int64) {
 	l.snapshot = pb.Snapshot{
-		Data:  d,
-		Nodes: nodes,
-		Index: index,
-		Term:  term,
+		Data:         d,
+		Nodes:        nodes,
+		Index:        index,
+		Term:         term,
+		RemovedNodes: removed,
 	}
 }
 

+ 5 - 4
raft/node_test.go

@@ -231,10 +231,11 @@ func TestNodeCompact(t *testing.T) {
 	n.Propose(ctx, []byte("foo"))
 
 	w := raftpb.Snapshot{
-		Term:  1,
-		Index: 2, // one nop + one proposal
-		Data:  []byte("a snapshot"),
-		Nodes: []int64{1},
+		Term:         1,
+		Index:        2, // one nop + one proposal
+		Data:         []byte("a snapshot"),
+		Nodes:        []int64{1},
+		RemovedNodes: []int64{},
 	}
 
 	pkg.ForceGosched()

+ 16 - 1
raft/raft.go

@@ -523,7 +523,10 @@ func (r *raft) compact(index int64, nodes []int64, d []byte) {
 	if index > r.raftLog.applied {
 		panic(fmt.Sprintf("raft: compact index (%d) exceeds applied index (%d)", index, r.raftLog.applied))
 	}
-	r.raftLog.snap(d, index, r.raftLog.term(index), nodes)
+	// We do not get the removed nodes at the given index.
+	// We get the removed nodes at current index. So a state machine might
+	// have a newer verison of removed nodes after recovery. It is OK.
+	r.raftLog.snap(d, index, r.raftLog.term(index), nodes, r.removedNodes())
 	r.raftLog.compact(index)
 }
 
@@ -543,6 +546,10 @@ func (r *raft) restore(s pb.Snapshot) bool {
 			r.setProgress(n, 0, r.raftLog.lastIndex()+1)
 		}
 	}
+	r.removed = make(map[int64]bool)
+	for _, n := range s.RemovedNodes {
+		r.removed[n] = true
+	}
 	return true
 }
 
@@ -564,6 +571,14 @@ func (r *raft) nodes() []int64 {
 	return nodes
 }
 
+func (r *raft) removedNodes() []int64 {
+	removed := make([]int64, 0, len(r.removed))
+	for k := range r.removed {
+		removed = append(removed, k)
+	}
+	return removed
+}
+
 func (r *raft) setProgress(id, match, next int64) {
 	r.prs[id] = &progress{next: next, match: match}
 }

+ 27 - 13
raft/raft_test.go

@@ -413,12 +413,13 @@ func TestCompact(t *testing.T) {
 	tests := []struct {
 		compacti int64
 		nodes    []int64
+		removed  []int64
 		snapd    []byte
 		wpanic   bool
 	}{
-		{1, []int64{1, 2, 3}, []byte("some data"), false},
-		{2, []int64{1, 2, 3}, []byte("some data"), false},
-		{4, []int64{1, 2, 3}, []byte("some data"), true}, // compact out of range
+		{1, []int64{1, 2, 3}, []int64{4, 5}, []byte("some data"), false},
+		{2, []int64{1, 2, 3}, []int64{4, 5}, []byte("some data"), false},
+		{4, []int64{1, 2, 3}, []int64{4, 5}, []byte("some data"), true}, // compact out of range
 	}
 
 	for i, tt := range tests {
@@ -426,7 +427,7 @@ func TestCompact(t *testing.T) {
 			defer func() {
 				if r := recover(); r != nil {
 					if tt.wpanic != true {
-						t.Errorf("%d: panic = %v, want %v", i, false, true)
+						t.Errorf("%d: panic = %v, want %v", i, true, tt.wpanic)
 					}
 				}
 			}()
@@ -437,8 +438,14 @@ func TestCompact(t *testing.T) {
 					applied:   2,
 					ents:      []pb.Entry{{}, {Term: 1}, {Term: 1}, {Term: 1}},
 				},
+				removed: make(map[int64]bool),
+			}
+			for _, r := range tt.removed {
+				sm.removeNode(r)
 			}
 			sm.compact(tt.compacti, tt.nodes, tt.snapd)
+			sort.Sort(int64Slice(sm.raftLog.snapshot.Nodes))
+			sort.Sort(int64Slice(sm.raftLog.snapshot.RemovedNodes))
 			if sm.raftLog.offset != tt.compacti {
 				t.Errorf("%d: log.offset = %d, want %d", i, sm.raftLog.offset, tt.compacti)
 			}
@@ -448,6 +455,9 @@ func TestCompact(t *testing.T) {
 			if !reflect.DeepEqual(sm.raftLog.snapshot.Data, tt.snapd) {
 				t.Errorf("%d: snap.data = %v, want %v", i, sm.raftLog.snapshot.Data, tt.snapd)
 			}
+			if !reflect.DeepEqual(sm.raftLog.snapshot.RemovedNodes, tt.removed) {
+				t.Errorf("%d: snap.removedNodes = %v, want %v", i, sm.raftLog.snapshot.RemovedNodes, tt.removed)
+			}
 		}()
 	}
 }
@@ -886,9 +896,10 @@ func TestRecvMsgBeat(t *testing.T) {
 
 func TestRestore(t *testing.T) {
 	s := pb.Snapshot{
-		Index: defaultCompactThreshold + 1,
-		Term:  defaultCompactThreshold + 1,
-		Nodes: []int64{1, 2, 3},
+		Index:        defaultCompactThreshold + 1,
+		Term:         defaultCompactThreshold + 1,
+		Nodes:        []int64{1, 2, 3},
+		RemovedNodes: []int64{4, 5},
 	}
 
 	sm := newRaft(1, []int64{1, 2}, 10, 1)
@@ -902,12 +913,15 @@ func TestRestore(t *testing.T) {
 	if sm.raftLog.term(s.Index) != s.Term {
 		t.Errorf("log.lastTerm = %d, want %d", sm.raftLog.term(s.Index), s.Term)
 	}
-	sg := int64Slice(sm.nodes())
-	sw := int64Slice(s.Nodes)
-	sort.Sort(sg)
-	sort.Sort(sw)
-	if !reflect.DeepEqual(sg, sw) {
-		t.Errorf("sm.Nodes = %+v, want %+v", sg, sw)
+	sg := sm.nodes()
+	srn := sm.removedNodes()
+	sort.Sort(int64Slice(sg))
+	sort.Sort(int64Slice(srn))
+	if !reflect.DeepEqual(sg, s.Nodes) {
+		t.Errorf("sm.Nodes = %+v, want %+v", sg, s.Nodes)
+	}
+	if !reflect.DeepEqual(s.RemovedNodes, srn) {
+		t.Errorf("sm.RemovedNodes = %+v, want %+v", s.RemovedNodes, srn)
 	}
 	if !reflect.DeepEqual(sm.raftLog.snapshot, s) {
 		t.Errorf("snapshot = %+v, want %+v", sm.raftLog.snapshot, s)

+ 36 - 0
raft/raftpb/raft.pb.go

@@ -124,6 +124,7 @@ type Snapshot struct {
 	Nodes            []int64 `protobuf:"varint,2,rep,name=nodes" json:"nodes"`
 	Index            int64   `protobuf:"varint,3,req,name=index" json:"index"`
 	Term             int64   `protobuf:"varint,4,req,name=term" json:"term"`
+	RemovedNodes     []int64 `protobuf:"varint,5,rep,name=removed_nodes" json:"removed_nodes"`
 	XXX_unrecognized []byte  `json:"-"`
 }
 
@@ -430,6 +431,23 @@ func (m *Snapshot) Unmarshal(data []byte) error {
 					break
 				}
 			}
+		case 5:
+			if wireType != 0 {
+				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+			}
+			var v int64
+			for shift := uint(0); ; shift += 7 {
+				if index >= l {
+					return io.ErrUnexpectedEOF
+				}
+				b := data[index]
+				index++
+				v |= (int64(b) & 0x7F) << shift
+				if b < 0x80 {
+					break
+				}
+			}
+			m.RemovedNodes = append(m.RemovedNodes, v)
 		default:
 			var sizeOfWire int
 			for {
@@ -894,6 +912,11 @@ func (m *Snapshot) Size() (n int) {
 	}
 	n += 1 + sovRaft(uint64(m.Index))
 	n += 1 + sovRaft(uint64(m.Term))
+	if len(m.RemovedNodes) > 0 {
+		for _, e := range m.RemovedNodes {
+			n += 1 + sovRaft(uint64(e))
+		}
+	}
 	if m.XXX_unrecognized != nil {
 		n += len(m.XXX_unrecognized)
 	}
@@ -1055,6 +1078,19 @@ func (m *Snapshot) MarshalTo(data []byte) (n int, err error) {
 	data[i] = 0x20
 	i++
 	i = encodeVarintRaft(data, i, uint64(m.Term))
+	if len(m.RemovedNodes) > 0 {
+		for _, num := range m.RemovedNodes {
+			data[i] = 0x28
+			i++
+			for num >= 1<<7 {
+				data[i] = uint8(uint64(num)&0x7f | 0x80)
+				num >>= 7
+				i++
+			}
+			data[i] = uint8(num)
+			i++
+		}
+	}
 	if m.XXX_unrecognized != nil {
 		i += copy(data[i:], m.XXX_unrecognized)
 	}

+ 5 - 4
raft/raftpb/raft.proto

@@ -25,10 +25,11 @@ message Entry {
 }
 
 message Snapshot {
-	required bytes data  = 1 [(gogoproto.nullable) = false];
-	repeated int64 nodes = 2 [(gogoproto.nullable) = false];
-	required int64 index = 3 [(gogoproto.nullable) = false];
-	required int64 term  = 4 [(gogoproto.nullable) = false];
+	required bytes data          = 1 [(gogoproto.nullable) = false];
+	repeated int64 nodes         = 2 [(gogoproto.nullable) = false];
+	required int64 index         = 3 [(gogoproto.nullable) = false];
+	required int64 term          = 4 [(gogoproto.nullable) = false];
+	repeated int64 removed_nodes = 5 [(gogoproto.nullable) = false];
 }
 
 message Message {