Browse Source

raft: add clusterId

Xiang Li 11 years ago
parent
commit
060de128a7
5 changed files with 102 additions and 23 deletions
  1. 47 1
      raft/cluster_test.go
  2. 1 0
      raft/log.go
  3. 33 5
      raft/node.go
  4. 6 5
      raft/node_test.go
  5. 15 12
      raft/raft.go

+ 47 - 1
raft/cluster_test.go

@@ -63,6 +63,52 @@ func TestBuildCluster(t *testing.T) {
 	}
 }
 
+func TestInitCluster(t *testing.T) {
+	node := New(1, defaultHeartbeat, defaultElection)
+	dictate(node)
+	node.Next()
+
+	if node.ClusterId() != 0xBEEF {
+		t.Errorf("clusterId = %x, want %x", node.ClusterId(), 0xBEEF)
+	}
+
+	func() {
+		defer func() {
+			e := recover()
+			if e != "cannot init a started cluster" {
+				t.Errorf("err = %v, want cannot init a started cluster", e)
+			}
+		}()
+		node.InitCluster(0xFBEE)
+		node.Next()
+	}()
+}
+
+func TestMessageFromDifferentCluster(t *testing.T) {
+	tests := []struct {
+		clusterId int64
+		wType     messageType
+	}{
+		{0xBEEF, msgVoteResp},
+		{0xFBEE, msgDenied},
+	}
+
+	for i, tt := range tests {
+		node := New(1, defaultHeartbeat, defaultElection)
+		dictate(node)
+		node.Next()
+
+		node.Step(Message{From: 1, ClusterId: tt.clusterId, Type: msgVote, Term: 2, LogTerm: 2, Index: 2})
+		msgs := node.Msgs()
+		if len(msgs) != 1 {
+			t.Errorf("#%d: len(msgs) = %d, want 1", i, len(msgs))
+		}
+		if msgs[0].Type != tt.wType {
+			t.Errorf("#%d: msg.Type = %v, want %d", i, msgs[0].Type, tt.wType)
+		}
+	}
+}
+
 // TestBasicCluster ensures all nodes can send proposal to the cluster.
 // And all the proposals will get committed.
 func TestBasicCluster(t *testing.T) {
@@ -83,7 +129,7 @@ func TestBasicCluster(t *testing.T) {
 		for j := 0; j < tt.round; j++ {
 			for _, n := range nodes {
 				data := []byte{byte(n.Id())}
-				nt.send(Message{From: n.Id(), To: n.Id(), Type: msgProp, Entries: []Entry{{Data: data}}})
+				nt.send(Message{From: n.Id(), To: n.Id(), ClusterId: n.ClusterId(), Type: msgProp, Entries: []Entry{{Data: data}}})
 
 				base := nodes[0].Next()
 				if len(base) != 1 {

+ 1 - 0
raft/log.go

@@ -5,6 +5,7 @@ import "fmt"
 const (
 	Normal int64 = iota
 
+	ClusterInit
 	AddNode
 	RemoveNode
 )

+ 33 - 5
raft/node.go

@@ -1,6 +1,7 @@
 package raft
 
 import (
+	"encoding/binary"
 	"encoding/json"
 	golog "log"
 	"math/rand"
@@ -52,6 +53,8 @@ func New(id int64, heartbeat, election tick) *Node {
 
 func (n *Node) Id() int64 { return n.sm.id }
 
+func (n *Node) ClusterId() int64 { return n.sm.clusterId }
+
 func (n *Node) Index() int64 { return n.sm.index.Get() }
 
 func (n *Node) Term() int64 { return n.sm.term.Get() }
@@ -70,16 +73,24 @@ func (n *Node) IsRemoved() bool { return n.removed }
 func (n *Node) Propose(data []byte) { n.propose(Normal, data) }
 
 func (n *Node) propose(t int64, data []byte) {
-	n.Step(Message{From: n.sm.id, Type: msgProp, Entries: []Entry{{Type: t, Data: data}}})
+	n.Step(Message{From: n.sm.id, ClusterId: n.ClusterId(), Type: msgProp, Entries: []Entry{{Type: t, Data: data}}})
 }
 
-func (n *Node) Campaign() { n.Step(Message{From: n.sm.id, Type: msgHup}) }
+func (n *Node) Campaign() { n.Step(Message{From: n.sm.id, ClusterId: n.ClusterId(), Type: msgHup}) }
+
+func (n *Node) InitCluster(clusterId int64) {
+	d := make([]byte, 8)
+	wn := binary.PutVarint(d, clusterId)
+	n.propose(ClusterInit, d[:wn])
+}
 
 func (n *Node) Add(id int64, addr string, context []byte) {
 	n.UpdateConf(AddNode, &Config{NodeId: id, Addr: addr, Context: context})
 }
 
-func (n *Node) Remove(id int64) { n.UpdateConf(RemoveNode, &Config{NodeId: id}) }
+func (n *Node) Remove(id int64) {
+	n.UpdateConf(RemoveNode, &Config{NodeId: id})
+}
 
 func (n *Node) Msgs() []Message { return n.sm.Msgs() }
 
@@ -88,17 +99,25 @@ func (n *Node) Step(m Message) bool {
 		n.removed = true
 		return false
 	}
+	if n.ClusterId() != none && m.ClusterId != none && m.ClusterId != n.ClusterId() {
+		golog.Printf("denied a message from node %d, cluster %d. accept cluster: %d\n", m.From, m.ClusterId, n.ClusterId())
+		n.sm.send(Message{To: m.From, ClusterId: n.ClusterId(), Type: msgDenied})
+		return true
+	}
+
 	if _, ok := n.rmNodes[m.From]; ok {
 		if m.From != n.sm.id {
-			n.sm.send(Message{From: n.sm.id, To: m.From, Type: msgDenied})
+			n.sm.send(Message{To: m.From, ClusterId: n.ClusterId(), Type: msgDenied})
 		}
 		return true
 	}
 
 	l := len(n.sm.msgs)
+
 	if !n.sm.Step(m) {
 		return false
 	}
+
 	for _, m := range n.sm.msgs[l:] {
 		switch m.Type {
 		case msgAppResp:
@@ -120,6 +139,15 @@ func (n *Node) Next() []Entry {
 	for i := range ents {
 		switch ents[i].Type {
 		case Normal:
+		case ClusterInit:
+			cid, nr := binary.Varint(ents[i].Data)
+			if nr <= 0 {
+				panic("init cluster failed: cannot read clusterId")
+			}
+			if n.ClusterId() != -1 {
+				panic("cannot init a started cluster")
+			}
+			n.sm.clusterId = cid
 		case AddNode:
 			c := new(Config)
 			if err := json.Unmarshal(ents[i].Data, c); err != nil {
@@ -159,7 +187,7 @@ func (n *Node) Tick() {
 		timeout, msgType = n.heartbeat, msgBeat
 	}
 	if n.elapsed >= timeout {
-		n.Step(Message{From: n.sm.id, Type: msgType})
+		n.Step(Message{From: n.sm.id, ClusterId: n.ClusterId(), Type: msgType})
 		n.elapsed = 0
 		if n.sm.state != stateLeader {
 			n.electionRand = n.election + tick(rand.Int31())%n.election

+ 6 - 5
raft/node_test.go

@@ -40,7 +40,7 @@ func TestTickMsgBeat(t *testing.T) {
 		n.Add(int64(i), "", nil)
 		for _, m := range n.Msgs() {
 			if m.Type == msgApp {
-				n.Step(Message{From: m.To, Type: msgAppResp, Index: m.Index + int64(len(m.Entries))})
+				n.Step(Message{From: m.To, ClusterId: m.ClusterId, Type: msgAppResp, Index: m.Index + int64(len(m.Entries))})
 			}
 		}
 		// ignore commit index update messages
@@ -131,7 +131,7 @@ func TestRemove(t *testing.T) {
 	n.Add(1, "", nil)
 	n.Next()
 	n.Remove(0)
-	n.Step(Message{Type: msgAppResp, From: 1, Term: 1, Index: 4})
+	n.Step(Message{Type: msgAppResp, From: 1, ClusterId: n.ClusterId(), Term: 1, Index: 5})
 	n.Next()
 
 	if len(n.sm.ins) != 1 {
@@ -176,10 +176,10 @@ func TestDenial(t *testing.T) {
 		n.Next()
 
 		for id, denied := range tt.wdenied {
-			n.Step(Message{From: id, To: 0, Type: msgApp, Term: 1})
+			n.Step(Message{From: id, To: 0, ClusterId: n.ClusterId(), Type: msgApp, Term: 1})
 			w := []Message{}
 			if denied {
-				w = []Message{{From: 0, To: id, Term: 1, Type: msgDenied}}
+				w = []Message{{From: 0, To: id, ClusterId: n.ClusterId(), Term: 1, Type: msgDenied}}
 			}
 			if g := n.Msgs(); !reflect.DeepEqual(g, w) {
 				t.Errorf("#%d: msgs for %d = %+v, want %+v", i, id, g, w)
@@ -189,7 +189,8 @@ func TestDenial(t *testing.T) {
 }
 
 func dictate(n *Node) *Node {
-	n.Step(Message{Type: msgHup})
+	n.Step(Message{From: n.Id(), Type: msgHup})
+	n.InitCluster(0xBEEF)
 	n.Add(n.Id(), "", nil)
 	return n
 }

+ 15 - 12
raft/raft.go

@@ -65,16 +65,17 @@ func (st stateType) String() string {
 }
 
 type Message struct {
-	Type     messageType
-	To       int64
-	From     int64
-	Term     int64
-	LogTerm  int64
-	Index    int64
-	PrevTerm int64
-	Entries  []Entry
-	Commit   int64
-	Snapshot Snapshot
+	Type      messageType
+	ClusterId int64
+	To        int64
+	From      int64
+	Term      int64
+	LogTerm   int64
+	Index     int64
+	PrevTerm  int64
+	Entries   []Entry
+	Commit    int64
+	Snapshot  Snapshot
 }
 
 type index struct {
@@ -111,7 +112,8 @@ func (p int64Slice) Less(i, j int) bool { return p[i] < p[j] }
 func (p int64Slice) Swap(i, j int)      { p[i], p[j] = p[j], p[i] }
 
 type stateMachine struct {
-	id int64
+	clusterId int64
+	id        int64
 
 	// the term we are participating in at any time
 	term  atomicInt
@@ -144,7 +146,7 @@ func newStateMachine(id int64, peers []int64) *stateMachine {
 	if id == none {
 		panic("cannot use none id")
 	}
-	sm := &stateMachine{id: id, lead: none, log: newLog(), ins: make(map[int64]*index)}
+	sm := &stateMachine{id: id, clusterId: none, lead: none, log: newLog(), ins: make(map[int64]*index)}
 	for _, p := range peers {
 		sm.ins[p] = &index{}
 	}
@@ -170,6 +172,7 @@ func (sm *stateMachine) poll(id int64, v bool) (granted int) {
 
 // send persists state to stable storage and then sends to its mailbox.
 func (sm *stateMachine) send(m Message) {
+	m.ClusterId = sm.clusterId
 	m.From = sm.id
 	m.Term = sm.term.Get()
 	sm.msgs = append(sm.msgs, m)