Browse Source

raft: raft learners should be returned after applyConfChange

Vincent Lee 8 years ago
parent
commit
11fa4f0275
5 changed files with 80 additions and 10 deletions
  1. 6 2
      raft/node.go
  2. 52 0
      raft/node_test.go
  3. 7 1
      raft/raft.go
  4. 13 5
      raft/raft_test.go
  5. 2 2
      raft/rawnode.go

+ 6 - 2
raft/node.go

@@ -325,7 +325,9 @@ func (n *node) run(r *raft) {
 		case cc := <-n.confc:
 		case cc := <-n.confc:
 			if cc.NodeID == None {
 			if cc.NodeID == None {
 				select {
 				select {
-				case n.confstatec <- pb.ConfState{Nodes: r.nodes()}:
+				case n.confstatec <- pb.ConfState{
+					Nodes:    r.nodes(),
+					Learners: r.learnerNodes()}:
 				case <-n.done:
 				case <-n.done:
 				}
 				}
 				break
 				break
@@ -347,7 +349,9 @@ func (n *node) run(r *raft) {
 				panic("unexpected conf type")
 				panic("unexpected conf type")
 			}
 			}
 			select {
 			select {
-			case n.confstatec <- pb.ConfState{Nodes: r.nodes()}:
+			case n.confstatec <- pb.ConfState{
+				Nodes:    r.nodes(),
+				Learners: r.learnerNodes()}:
 			case <-n.done:
 			case <-n.done:
 			}
 			}
 		case <-n.tickc:
 		case <-n.tickc:

+ 52 - 0
raft/node_test.go

@@ -732,3 +732,55 @@ func TestIsHardStateEqual(t *testing.T) {
 		}
 		}
 	}
 	}
 }
 }
+
+func TestNodeProposeAddLearnerNode(t *testing.T) {
+	ticker := time.NewTicker(time.Millisecond * 100)
+	defer ticker.Stop()
+	n := newNode()
+	s := NewMemoryStorage()
+	r := newTestRaft(1, []uint64{1}, 10, 1, s)
+	go n.run(r)
+	n.Campaign(context.TODO())
+	stop := make(chan struct{})
+	done := make(chan struct{})
+	applyConfChan := make(chan struct{})
+	go func() {
+		defer close(done)
+		for {
+			select {
+			case <-stop:
+				return
+			case <-ticker.C:
+				n.Tick()
+			case rd := <-n.Ready():
+				s.Append(rd.Entries)
+				t.Logf("raft: %v", rd.Entries)
+				for _, ent := range rd.Entries {
+					if ent.Type != raftpb.EntryConfChange {
+						continue
+					}
+					var cc raftpb.ConfChange
+					cc.Unmarshal(ent.Data)
+					state := n.ApplyConfChange(cc)
+					if len(state.Learners) == 0 ||
+						state.Learners[0] != cc.NodeID ||
+						cc.NodeID != 2 {
+						t.Errorf("apply conf change should return new added learner: %v", state.String())
+					}
+
+					if len(state.Nodes) != 1 {
+						t.Errorf("add learner should not change the nodes: %v", state.String())
+					}
+					t.Logf("apply raft conf %v changed to: %v", cc, state.String())
+					applyConfChan <- struct{}{}
+				}
+				n.Advance()
+			}
+		}
+	}()
+	cc := raftpb.ConfChange{Type: raftpb.ConfChangeAddLearnerNode, NodeID: 2}
+	n.ProposeConfChange(context.TODO(), cc)
+	<-applyConfChan
+	close(stop)
+	<-done
+}

+ 7 - 1
raft/raft.go

@@ -377,10 +377,16 @@ func (r *raft) hardState() pb.HardState {
 func (r *raft) quorum() int { return len(r.prs)/2 + 1 }
 func (r *raft) quorum() int { return len(r.prs)/2 + 1 }
 
 
 func (r *raft) nodes() []uint64 {
 func (r *raft) nodes() []uint64 {
-	nodes := make([]uint64, 0, len(r.prs)+len(r.learnerPrs))
+	nodes := make([]uint64, 0, len(r.prs))
 	for id := range r.prs {
 	for id := range r.prs {
 		nodes = append(nodes, id)
 		nodes = append(nodes, id)
 	}
 	}
+	sort.Sort(uint64Slice(nodes))
+	return nodes
+}
+
+func (r *raft) learnerNodes() []uint64 {
+	nodes := make([]uint64, 0, len(r.learnerPrs))
 	for id := range r.learnerPrs {
 	for id := range r.learnerPrs {
 		nodes = append(nodes, id)
 		nodes = append(nodes, id)
 	}
 	}

+ 13 - 5
raft/raft_test.go

@@ -2475,8 +2475,12 @@ func TestRestoreWithLearner(t *testing.T) {
 		t.Errorf("log.lastTerm = %d, want %d", mustTerm(sm.raftLog.term(s.Metadata.Index)), s.Metadata.Term)
 		t.Errorf("log.lastTerm = %d, want %d", mustTerm(sm.raftLog.term(s.Metadata.Index)), s.Metadata.Term)
 	}
 	}
 	sg := sm.nodes()
 	sg := sm.nodes()
-	if len(sg) != len(s.Metadata.ConfState.Nodes)+len(s.Metadata.ConfState.Learners) {
-		t.Errorf("sm.Nodes = %+v, length not equal with %+v", sg, s.Metadata.ConfState)
+	if len(sg) != len(s.Metadata.ConfState.Nodes) {
+		t.Errorf("sm.Nodes = %+v, length not equal with %+v", sg, s.Metadata.ConfState.Nodes)
+	}
+	lns := sm.learnerNodes()
+	if len(lns) != len(s.Metadata.ConfState.Learners) {
+		t.Errorf("sm.LearnerNodes = %+v, length not equal with %+v", sg, s.Metadata.ConfState.Learners)
 	}
 	}
 	for _, n := range s.Metadata.ConfState.Nodes {
 	for _, n := range s.Metadata.ConfState.Nodes {
 		if sm.prs[n].IsLearner {
 		if sm.prs[n].IsLearner {
@@ -2805,8 +2809,8 @@ func TestAddNode(t *testing.T) {
 func TestAddLearner(t *testing.T) {
 func TestAddLearner(t *testing.T) {
 	r := newTestRaft(1, []uint64{1}, 10, 1, NewMemoryStorage())
 	r := newTestRaft(1, []uint64{1}, 10, 1, NewMemoryStorage())
 	r.addLearner(2)
 	r.addLearner(2)
-	nodes := r.nodes()
-	wnodes := []uint64{1, 2}
+	nodes := r.learnerNodes()
+	wnodes := []uint64{2}
 	if !reflect.DeepEqual(nodes, wnodes) {
 	if !reflect.DeepEqual(nodes, wnodes) {
 		t.Errorf("nodes = %v, want %v", nodes, wnodes)
 		t.Errorf("nodes = %v, want %v", nodes, wnodes)
 	}
 	}
@@ -2877,9 +2881,13 @@ func TestRemoveLearner(t *testing.T) {
 		t.Errorf("nodes = %v, want %v", g, w)
 		t.Errorf("nodes = %v, want %v", g, w)
 	}
 	}
 
 
+	w = []uint64{}
+	if g := r.learnerNodes(); !reflect.DeepEqual(g, w) {
+		t.Errorf("nodes = %v, want %v", g, w)
+	}
+
 	// remove all nodes from cluster
 	// remove all nodes from cluster
 	r.removeNode(1)
 	r.removeNode(1)
-	w = []uint64{}
 	if g := r.nodes(); !reflect.DeepEqual(g, w) {
 	if g := r.nodes(); !reflect.DeepEqual(g, w) {
 		t.Errorf("nodes = %v, want %v", g, w)
 		t.Errorf("nodes = %v, want %v", g, w)
 	}
 	}

+ 2 - 2
raft/rawnode.go

@@ -169,7 +169,7 @@ func (rn *RawNode) ProposeConfChange(cc pb.ConfChange) error {
 // ApplyConfChange applies a config change to the local node.
 // ApplyConfChange applies a config change to the local node.
 func (rn *RawNode) ApplyConfChange(cc pb.ConfChange) *pb.ConfState {
 func (rn *RawNode) ApplyConfChange(cc pb.ConfChange) *pb.ConfState {
 	if cc.NodeID == None {
 	if cc.NodeID == None {
-		return &pb.ConfState{Nodes: rn.raft.nodes()}
+		return &pb.ConfState{Nodes: rn.raft.nodes(), Learners: rn.raft.learnerNodes()}
 	}
 	}
 	switch cc.Type {
 	switch cc.Type {
 	case pb.ConfChangeAddNode:
 	case pb.ConfChangeAddNode:
@@ -182,7 +182,7 @@ func (rn *RawNode) ApplyConfChange(cc pb.ConfChange) *pb.ConfState {
 	default:
 	default:
 		panic("unexpected conf type")
 		panic("unexpected conf type")
 	}
 	}
-	return &pb.ConfState{Nodes: rn.raft.nodes()}
+	return &pb.ConfState{Nodes: rn.raft.nodes(), Learners: rn.raft.learnerNodes()}
 }
 }
 
 
 // Step advances the state machine using the given message.
 // Step advances the state machine using the given message.