Browse Source

start new raft implementation

Blake Mizerany 11 years ago
parent
commit
9f8ede7b03
2 changed files with 559 additions and 0 deletions
  1. 328 0
      raft.go
  2. 231 0
      raft_test.go

+ 328 - 0
raft.go

@@ -0,0 +1,328 @@
+package raft
+
+import (
+	"errors"
+	"sort"
+)
+
+const none = -1
+
+type messageType int
+
+const (
+	msgHup messageType = iota
+	msgProp
+	msgApp
+	msgAppResp
+	msgVote
+	msgVoteResp
+)
+
+var mtmap = [...]string{
+	msgHup:      "msgHup",
+	msgProp:     "msgProp",
+	msgApp:      "msgApp",
+	msgAppResp:  "msgAppResp",
+	msgVote:     "msgVote",
+	msgVoteResp: "msgVoteResp",
+}
+
+func (mt messageType) String() string {
+	return mtmap[int(mt)]
+}
+
+var errNoLeader = errors.New("no leader")
+
+const (
+	stateFollower = iota
+	stateCandidate
+	stateLeader
+)
+
+type stateType int
+
+var stmap = [...]string{
+	stateFollower:  "stateFollower",
+	stateCandidate: "stateCandidate",
+	stateLeader:    "stateLeader",
+}
+
+func (st stateType) String() string {
+	return stmap[int(st)]
+}
+
+type Entry struct {
+	Term int
+	Data []byte
+}
+
+type Message struct {
+	Type     messageType
+	To       int
+	From     int
+	Term     int
+	LogTerm  int
+	Index    int
+	PrevTerm int
+	Entries  []Entry
+	Commit   int
+	Data     []byte
+}
+
+type stepper interface {
+	step(m Message)
+}
+
+type index struct {
+	match, next int
+}
+
+func (in *index) update(n int) {
+	in.match = n
+	in.next = n + 1
+}
+
+func (in *index) decr() {
+	if in.next--; in.next < 1 {
+		in.next = 1
+	}
+}
+
+type stateMachine struct {
+	// k is the number of peers
+	k int
+
+	// addr is an integer representation of our address amoungst our peers. It is 0 <= addr < k.
+	addr int
+
+	// the term we are participating in at any time
+	term int
+
+	// who we voted for in term
+	vote int
+
+	// the log
+	log []Entry
+
+	ins []*index
+
+	state stateType
+
+	commit int
+
+	votes map[int]bool
+
+	next stepper
+
+	// the leader addr
+	lead int
+}
+
+func newStateMachine(k, addr int, next stepper) *stateMachine {
+	log := make([]Entry, 1, 1024)
+	sm := &stateMachine{k: k, addr: addr, next: next, log: log}
+	sm.reset()
+	return sm
+}
+
+func (sm *stateMachine) canStep(m Message) bool {
+	if m.Type == msgProp {
+		return sm.lead != none
+	}
+	return true
+}
+
+func (sm *stateMachine) poll(addr int, v bool) (granted int) {
+	if _, ok := sm.votes[addr]; !ok {
+		sm.votes[addr] = v
+	}
+	for _, vv := range sm.votes {
+		if vv {
+			granted++
+		}
+	}
+	return granted
+}
+
+var empty = Entry{}
+
+func (sm *stateMachine) append(after int, ents ...Entry) int {
+	sm.log = append(sm.log[:after+1], ents...)
+	return len(sm.log) - 1
+}
+
+func (sm *stateMachine) isLogOk(i, term int) bool {
+	if i > sm.li() {
+		return false
+	}
+	return sm.log[i].Term == term
+}
+
+// send persists state to stable storage and then sends m over the network to m.To
+func (sm *stateMachine) send(m Message) {
+	m.From = sm.addr
+	m.Term = sm.term
+	sm.next.step(m)
+}
+
+// sendAppend sends RRPC, with entries to all peers that are not up-to-date according to sm.mis.
+func (sm *stateMachine) sendAppend() {
+	for i := 0; i < sm.k; i++ {
+		if i == sm.addr {
+			continue
+		}
+		in := sm.ins[i]
+		m := Message{}
+		m.Type = msgApp
+		m.To = i
+		m.Index = in.next - 1
+		m.LogTerm = sm.log[in.next-1].Term
+		m.Entries = sm.log[in.next:]
+		sm.send(m)
+	}
+}
+
+func (sm *stateMachine) theN() int {
+	// TODO(bmizerany): optimize.. Currently naive
+	mis := make([]int, len(sm.ins))
+	for i := range mis {
+		mis[i] = sm.ins[i].match
+	}
+	sort.Ints(mis)
+	for _, mi := range mis[sm.k/2+1:] {
+		if sm.log[mi].Term == sm.term {
+			return mi
+		}
+	}
+	return -1
+}
+
+func (sm *stateMachine) maybeAdvanceCommit() int {
+	ci := sm.theN()
+	if ci > sm.commit {
+		sm.commit = ci
+	}
+	return sm.commit
+}
+
+func (sm *stateMachine) reset() {
+	sm.lead = none
+	sm.vote = none
+	sm.votes = make(map[int]bool)
+	sm.ins = make([]*index, sm.k)
+	for i := range sm.ins {
+		sm.ins[i] = &index{next: len(sm.log)}
+	}
+}
+
+func (sm *stateMachine) q() int {
+	return sm.k/2 + 1
+}
+
+func (sm *stateMachine) voteWorthy(i, term int) bool {
+	// LET logOk == \/ m.mlastLogTerm > LastTerm(log[i])
+	//              \/ /\ m.mlastLogTerm = LastTerm(log[i])
+	//                 /\ m.mlastLogIndex >= Len(log[i])
+	e := sm.log[sm.li()]
+	return term >= e.Term || (term == e.Term && i >= sm.li())
+}
+
+func (sm *stateMachine) li() int {
+	return len(sm.log) - 1
+}
+
+func (sm *stateMachine) becomeFollower(term, lead int) {
+	sm.reset()
+	sm.term = term
+	sm.lead = lead
+	sm.state = stateFollower
+}
+
+func (sm *stateMachine) step(m Message) {
+	switch m.Type {
+	case msgHup:
+		sm.term++
+		sm.reset()
+		sm.state = stateCandidate
+		sm.poll(sm.addr, true)
+		for i := 0; i < sm.k; i++ {
+			if i == sm.addr {
+				continue
+			}
+			lasti := sm.li()
+			sm.send(Message{To: i, Type: msgVote, Index: lasti, LogTerm: sm.log[lasti].Term})
+		}
+		return
+	case msgProp:
+		switch sm.lead {
+		case sm.addr:
+			sm.append(sm.li(), Entry{Term: sm.term, Data: m.Data})
+			sm.sendAppend()
+		case none:
+			panic("msgProp given without leader")
+		default:
+			m.To = sm.lead
+			sm.send(m)
+		}
+		return
+	}
+
+	switch {
+	case m.Term > sm.term:
+		sm.becomeFollower(m.Term, m.From)
+	case m.Term < sm.term:
+		// ignore
+		return
+	}
+
+	handleAppendEntries := func() {
+		if sm.isLogOk(m.Index, m.LogTerm) {
+			sm.append(m.Index, m.Entries...)
+			sm.send(Message{To: m.From, Type: msgAppResp, Index: sm.li()})
+		} else {
+			sm.send(Message{To: m.From, Type: msgAppResp, Index: -1})
+		}
+	}
+
+	switch sm.state {
+	case stateLeader:
+		switch m.Type {
+		case msgAppResp:
+			in := sm.ins[m.From]
+			if m.Index < 0 {
+				in.decr()
+				sm.sendAppend()
+			} else {
+				in.update(m.Index)
+			}
+		}
+	case stateCandidate:
+		switch m.Type {
+		case msgApp:
+			println("lost to appendEnts")
+			sm.becomeFollower(sm.term, m.From)
+			handleAppendEntries()
+		case msgVoteResp:
+			gr := sm.poll(m.From, m.Index >= 0)
+			switch sm.q() {
+			case gr:
+				sm.state = stateLeader
+				sm.lead = sm.addr
+				sm.sendAppend()
+			case len(sm.votes) - gr:
+				sm.state = stateFollower
+			}
+		}
+	case stateFollower:
+		switch m.Type {
+		case msgApp:
+			handleAppendEntries()
+		case msgVote:
+			if sm.voteWorthy(m.Index, m.LogTerm) {
+				sm.send(Message{To: m.From, Type: msgVoteResp, Index: sm.li()})
+			} else {
+				sm.send(Message{To: m.From, Type: msgVoteResp, Index: -1})
+			}
+		}
+	}
+}

+ 231 - 0
raft_test.go

@@ -0,0 +1,231 @@
+package raft
+
+import (
+	"fmt"
+	"reflect"
+	"testing"
+)
+
+func TestLeaderElection(t *testing.T) {
+	tests := []struct {
+		network
+		state stateType
+	}{
+		{newNetwork(nil, nil, nil), stateLeader},
+		{newNetwork(nil, nil, nopStepper), stateLeader},
+		{newNetwork(nil, nopStepper, nopStepper), stateCandidate},
+		{newNetwork(nil, nopStepper, nopStepper, nil), stateCandidate},
+		{newNetwork(nil, nopStepper, nopStepper, nil, nil), stateLeader},
+	}
+
+	for i, tt := range tests {
+		tt.step(Message{To: 0, Type: msgHup})
+		sm := tt.network[0].(*stateMachine)
+		if sm.state != tt.state {
+			t.Errorf("#%d: state = %s, want %s", i, sm.state, tt.state)
+		}
+		if g := sm.term; g != 1 {
+			t.Errorf("#%d: term = %d, want %d", i, g, 1)
+		}
+	}
+}
+
+func TestProposal(t *testing.T) {
+	tests := []struct {
+		network
+		success bool
+	}{
+		{newNetwork(nil, nil, nil), true},
+		{newNetwork(nil, nil, nopStepper), true},
+		{newNetwork(nil, nopStepper, nopStepper), false},
+		{newNetwork(nil, nopStepper, nopStepper, nil), false},
+		{newNetwork(nil, nopStepper, nopStepper, nil, nil), true},
+	}
+
+	for i, tt := range tests {
+		step := stepperFunc(func(m Message) {
+			defer func() {
+				if !tt.success {
+					// not expected success implies there
+					// will be no known leader which will
+					// cause step to panic - swallow it.
+					e := recover()
+					if e != nil {
+						t.Logf("#%d: err: %s", i, e)
+					}
+				}
+			}()
+			t.Logf("#%d: m = %+v", i, m)
+			tt.step(m)
+		})
+
+		var data = []byte("somedata")
+
+		// promote 0 the leader
+		step(Message{To: 0, Type: msgHup})
+		step(Message{To: 0, Type: msgProp, Data: data})
+
+		w := []Entry{{}}
+		if tt.success {
+			w = append(w, Entry{Term: 1, Data: data})
+		}
+		ls := append([][]Entry{w}, tt.logs()...)
+
+		if g := diffLogs(ls); g != nil {
+			for _, diff := range g {
+				t.Errorf("#%d: bag log:\n%s", i, diff)
+			}
+		}
+		sm := tt.network[0].(*stateMachine)
+		if g := sm.term; g != 1 {
+			t.Errorf("#%d: term = %d, want %d", i, g, 1)
+		}
+	}
+}
+
+func TestProposalByProxy(t *testing.T) {
+	tests := []struct {
+		network
+		success bool
+	}{
+		{newNetwork(nil, nil, nil), true},
+		{newNetwork(nil, nil, nopStepper), true},
+	}
+
+	for i, tt := range tests {
+		step := stepperFunc(func(m Message) {
+			t.Logf("#%d: m = %+v", i, m)
+			tt.step(m)
+		})
+
+		// promote 0 the leader
+		step(Message{To: 0, Type: msgHup})
+
+		// propose via follower
+		step(Message{To: 1, Type: msgProp, Data: []byte("somedata")})
+
+		if g := diffLogs(tt.logs()); g != nil {
+			for _, diff := range g {
+				t.Errorf("#%d: bag log:\n%s", i, diff)
+			}
+		}
+		sm := tt.network[0].(*stateMachine)
+		if g := sm.term; g != 1 {
+			t.Errorf("#%d: term = %d, want %d", i, g, 1)
+		}
+	}
+}
+
+type network []stepper
+
+// newNetwork initializes a network from nodes. A nil node will be replaced
+// with a new *stateMachine. A *stateMachine will get its k, addr, and next
+// fields set.
+func newNetwork(nodes ...stepper) network {
+	nt := network(nodes)
+	for i, n := range nodes {
+		switch v := n.(type) {
+		case nil:
+			nt[i] = newStateMachine(len(nodes), i, &nt)
+		case *stateMachine:
+			v.k = len(nodes)
+			v.addr = i
+			v.next = &nt
+		}
+	}
+	return nt
+}
+
+func (nt network) step(m Message) {
+	nt[m.To].step(m)
+}
+
+func (nt network) logs() [][]Entry {
+	ls := make([][]Entry, len(nt))
+	for i, node := range nt {
+		if sm, ok := node.(*stateMachine); ok {
+			ls[i] = sm.log
+		}
+	}
+	return ls
+}
+
+type diff struct {
+	i    int
+	ents []*Entry // pointers so they can be nil for N/A
+}
+
+var naEntry = &Entry{}
+var nologEntry = &Entry{}
+
+func (d diff) String() string {
+	s := fmt.Sprintf("[%d] ", d.i)
+	for i, e := range d.ents {
+		switch e {
+		case nologEntry:
+			s += fmt.Sprintf("<NL>")
+		case naEntry:
+			s += fmt.Sprintf("<N/A>")
+		case nil:
+			s += fmt.Sprintf("<nil>")
+		default:
+			s += fmt.Sprintf("<%d:%q>", e.Term, string(e.Data))
+		}
+		if i != len(d.ents)-1 {
+			s += "\t\t"
+		}
+	}
+	return s
+}
+
+func diffLogs(logs [][]Entry) []diff {
+	var (
+		d   []diff
+		max int
+	)
+	for _, log := range logs {
+		if l := len(log); l > max {
+			max = l
+		}
+	}
+	ediff := func(i int) (result []*Entry) {
+		e := make([]*Entry, len(logs))
+		found := false
+		for j, log := range logs {
+			if log == nil {
+				e[j] = nologEntry
+				continue
+			}
+			if len(log) <= i {
+				e[j] = naEntry
+				found = true
+				continue
+			}
+			e[j] = &log[i]
+			if j > 0 {
+				switch prev := e[j-1]; {
+				case prev == nologEntry:
+				case prev == naEntry:
+				case !reflect.DeepEqual(prev, e[j]):
+					found = true
+				}
+			}
+		}
+		if found {
+			return e
+		}
+		return nil
+	}
+	for i := 0; i < max; i++ {
+		if e := ediff(i); e != nil {
+			d = append(d, diff{i, e})
+		}
+	}
+	return d
+}
+
+type stepperFunc func(Message)
+
+func (f stepperFunc) step(m Message) { f(m) }
+
+var nopStepper = stepperFunc(func(Message) {})