Browse Source

raft: add atomicInt

Xiang Li 11 years ago
parent
commit
e11c7f35b4
4 changed files with 23 additions and 9 deletions
  1. 1 1
      raft/cluster_test.go
  2. 1 1
      raft/node.go
  3. 20 6
      raft/raft.go
  4. 1 1
      raft/raft_test.go

+ 1 - 1
raft/cluster_test.go

@@ -39,7 +39,7 @@ func TestBuildCluster(t *testing.T) {
 			if tt.ids != nil {
 			if tt.ids != nil {
 				w = tt.ids[0]
 				w = tt.ids[0]
 			}
 			}
-			if g := n.sm.lead; g != w {
+			if g := n.sm.lead.Get(); g != w {
 				t.Errorf("#%d.%d: lead = %d, want %d", i, j, g, w)
 				t.Errorf("#%d.%d: lead = %d, want %d", i, j, g, w)
 			}
 			}
 
 

+ 1 - 1
raft/node.go

@@ -55,7 +55,7 @@ func (n *Node) HasLeader() bool { return n.Leader() != none }
 
 
 func (n *Node) IsLeader() bool { return n.Leader() == n.Id() }
 func (n *Node) IsLeader() bool { return n.Leader() == n.Id() }
 
 
-func (n *Node) Leader() int64 { return atomic.LoadInt64(&n.sm.lead) }
+func (n *Node) Leader() int64 { return n.sm.lead.Get() }
 
 
 // Propose asynchronously proposes data be applied to the underlying state machine.
 // Propose asynchronously proposes data be applied to the underlying state machine.
 func (n *Node) Propose(data []byte) { n.propose(Normal, data) }
 func (n *Node) Propose(data []byte) { n.propose(Normal, data) }

+ 20 - 6
raft/raft.go

@@ -3,6 +3,7 @@ package raft
 import (
 import (
 	"errors"
 	"errors"
 	"sort"
 	"sort"
+	"sync/atomic"
 )
 )
 
 
 const none = -1
 const none = -1
@@ -89,6 +90,19 @@ func (in *index) decr() {
 	}
 	}
 }
 }
 
 
+// An AtomicInt is an int64 to be accessed atomically.
+type atomicInt int64
+
+// Add atomically adds n to i.
+func (i *atomicInt) Set(n int64) {
+	atomic.StoreInt64((*int64)(i), n)
+}
+
+// Get atomically gets the value of i.
+func (i *atomicInt) Get() int64 {
+	return atomic.LoadInt64((*int64)(i))
+}
+
 type stateMachine struct {
 type stateMachine struct {
 	id int64
 	id int64
 
 
@@ -110,7 +124,7 @@ type stateMachine struct {
 	msgs []Message
 	msgs []Message
 
 
 	// the leader id
 	// the leader id
-	lead int64
+	lead atomicInt
 
 
 	// pending reconfiguration
 	// pending reconfiguration
 	pendingConf bool
 	pendingConf bool
@@ -197,7 +211,7 @@ func (sm *stateMachine) nextEnts() (ents []Entry) {
 
 
 func (sm *stateMachine) reset(term int) {
 func (sm *stateMachine) reset(term int) {
 	sm.term = term
 	sm.term = term
-	sm.lead = none
+	sm.lead.Set(none)
 	sm.vote = none
 	sm.vote = none
 	sm.votes = make(map[int64]bool)
 	sm.votes = make(map[int64]bool)
 	for i := range sm.ins {
 	for i := range sm.ins {
@@ -228,7 +242,7 @@ func (sm *stateMachine) promotable() bool {
 
 
 func (sm *stateMachine) becomeFollower(term int, lead int64) {
 func (sm *stateMachine) becomeFollower(term int, lead int64) {
 	sm.reset(term)
 	sm.reset(term)
-	sm.lead = lead
+	sm.lead.Set(lead)
 	sm.state = stateFollower
 	sm.state = stateFollower
 	sm.pendingConf = false
 	sm.pendingConf = false
 }
 }
@@ -249,7 +263,7 @@ func (sm *stateMachine) becomeLeader() {
 		panic("invalid transition [follower -> leader]")
 		panic("invalid transition [follower -> leader]")
 	}
 	}
 	sm.reset(sm.term)
 	sm.reset(sm.term)
-	sm.lead = sm.id
+	sm.lead.Set(sm.id)
 	sm.state = stateLeader
 	sm.state = stateLeader
 
 
 	for _, e := range sm.log.entries(sm.log.committed + 1) {
 	for _, e := range sm.log.entries(sm.log.committed + 1) {
@@ -384,10 +398,10 @@ func stepCandidate(sm *stateMachine, m Message) bool {
 func stepFollower(sm *stateMachine, m Message) bool {
 func stepFollower(sm *stateMachine, m Message) bool {
 	switch m.Type {
 	switch m.Type {
 	case msgProp:
 	case msgProp:
-		if sm.lead == none {
+		if sm.lead.Get() == none {
 			return false
 			return false
 		}
 		}
-		m.To = sm.lead
+		m.To = sm.lead.Get()
 		sm.send(m)
 		sm.send(m)
 	case msgApp:
 	case msgApp:
 		sm.handleAppendEntries(m)
 		sm.handleAppendEntries(m)

+ 1 - 1
raft/raft_test.go

@@ -545,7 +545,7 @@ func TestStateTransition(t *testing.T) {
 			if sm.term != tt.wterm {
 			if sm.term != tt.wterm {
 				t.Errorf("%d: term = %d, want %d", i, sm.term, tt.wterm)
 				t.Errorf("%d: term = %d, want %d", i, sm.term, tt.wterm)
 			}
 			}
-			if sm.lead != tt.wlead {
+			if sm.lead.Get() != tt.wlead {
 				t.Errorf("%d: lead = %d, want %d", i, sm.lead, tt.wlead)
 				t.Errorf("%d: lead = %d, want %d", i, sm.lead, tt.wlead)
 			}
 			}
 		}()
 		}()