Browse Source

raft: add remove node

Yicheng Qin 11 years ago
parent
commit
193756fa38
2 changed files with 39 additions and 13 deletions
  1. 20 13
      raft/node.go
  2. 19 0
      raft/node_test.go

+ 20 - 13
raft/node.go

@@ -13,7 +13,7 @@ type tick int
 
 type ConfigCmd struct {
 	Type string
-	Id   int
+	Addr   int
 }
 
 type Node struct {
@@ -24,6 +24,8 @@ type Node struct {
 	// elapsed ticks after the last reset
 	elapsed tick
 	sm      *stateMachine
+
+	addr int
 }
 
 func New(addr int, peers []int, heartbeat, election tick) *Node {
@@ -35,6 +37,7 @@ func New(addr int, peers []int, heartbeat, election tick) *Node {
 		sm:        newStateMachine(addr, peers),
 		heartbeat: heartbeat,
 		election:  election,
+		addr: addr,
 	}
 
 	return n
@@ -46,18 +49,12 @@ func (n *Node) Propose(data []byte) {
 	n.Step(m)
 }
 
-func (n *Node) Add(id int) {
-	c := &ConfigCmd{
-		Type: "add",
-		Id:   id,
-	}
+func (n *Node) Add(addr int) {
+	n.Step(n.confMessage(&ConfigCmd{Type: "add", Addr: addr}))
+}
 
-	data, err := json.Marshal(c)
-	if err != nil {
-		panic(err)
-	}
-	m := Message{Type: msgProp, Entries: []Entry{Entry{Type: config, Data: data}}}
-	n.Step(m)
+func (n *Node) Remove(addr int) {
+	n.Step(n.confMessage(&ConfigCmd{Type: "remove", Addr: addr}))
 }
 
 func (n *Node) Msgs() []Message {
@@ -119,10 +116,20 @@ func (n *Node) Tick() {
 	}
 }
 
+func (n *Node) confMessage(c *ConfigCmd) Message {
+	data, err := json.Marshal(c)
+	if err != nil {
+		panic(err)
+	}
+	return Message{Type: msgProp, Entries: []Entry{Entry{Type: config, Data: data}}}
+}
+
 func (n *Node) updateConf(c *ConfigCmd) {
 	switch c.Type {
 	case "add":
-		n.sm.Add(c.Id)
+		n.sm.Add(c.Addr)
+	case "remove":
+		n.sm.Remove(c.Addr)
 	default:
 		// warn
 	}

+ 19 - 0
raft/node_test.go

@@ -97,3 +97,22 @@ func TestAdd(t *testing.T) {
 		t.Errorf("k = %d, want 2", len(n.sm.ins))
 	}
 }
+
+func TestRemove(t *testing.T) {
+	n := New(0, []int{0}, defaultHeartbeat, defaultElection)
+
+	n.sm.becomeCandidate()
+	n.sm.becomeLeader()
+	n.Add(1)
+	n.Next()
+	n.Remove(0)
+	n.Step(Message{Type: msgAppResp, From: 1, Term: 1, Index: 3})
+	n.Next()
+
+	if len(n.sm.ins) != 1 {
+		t.Errorf("k = %d, want 1", len(n.sm.ins))
+	}
+	if n.sm.addr != 0 {
+		t.Errorf("addr = %d, want 0", n.sm.addr)
+	}
+}