raft_test.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. package raft
  2. import (
  3. "fmt"
  4. "reflect"
  5. "testing"
  6. )
  7. var defaultLog = []Entry{{}}
  8. func TestLeaderElection(t *testing.T) {
  9. tests := []struct {
  10. *network
  11. state stateType
  12. }{
  13. {newNetwork(nil, nil, nil), stateLeader},
  14. {newNetwork(nil, nil, nopStepper), stateLeader},
  15. {newNetwork(nil, nopStepper, nopStepper), stateCandidate},
  16. {newNetwork(nil, nopStepper, nopStepper, nil), stateCandidate},
  17. {newNetwork(nil, nopStepper, nopStepper, nil, nil), stateLeader},
  18. // three nodes are have logs further along than 0
  19. {
  20. newNetwork(
  21. nil,
  22. &stateMachine{log: []Entry{{}, {Term: 1}}},
  23. &stateMachine{log: []Entry{{}, {Term: 2}}},
  24. &stateMachine{log: []Entry{{}, {Term: 1}}},
  25. nil,
  26. ),
  27. stateFollower,
  28. },
  29. }
  30. for i, tt := range tests {
  31. tt.step(Message{To: 0, Type: msgHup})
  32. sm := tt.network.ss[0].(*stateMachine)
  33. if sm.state != tt.state {
  34. t.Errorf("#%d: state = %s, want %s", i, sm.state, tt.state)
  35. }
  36. if g := sm.term; g != 1 {
  37. t.Errorf("#%d: term = %d, want %d", i, g, 1)
  38. }
  39. }
  40. }
  41. func TestDualingCandidates(t *testing.T) {
  42. a := &stateMachine{log: defaultLog}
  43. c := &stateMachine{log: defaultLog}
  44. tt := newNetwork(a, nil, c)
  45. heal := false
  46. next := stepperFunc(func(m Message) {
  47. if heal {
  48. tt.step(m)
  49. }
  50. })
  51. a.next = next
  52. c.next = next
  53. tt.tee = stepperFunc(func(m Message) {
  54. t.Logf("m = %+v", m)
  55. })
  56. tt.step(Message{To: 0, Type: msgHup})
  57. tt.step(Message{To: 2, Type: msgHup})
  58. t.Log("healing")
  59. heal = true
  60. tt.step(Message{To: 2, Type: msgHup})
  61. tests := []struct {
  62. sm *stateMachine
  63. state stateType
  64. term int
  65. }{
  66. {a, stateFollower, 2},
  67. {c, stateLeader, 2},
  68. }
  69. for i, tt := range tests {
  70. if g := tt.sm.state; g != tt.state {
  71. t.Errorf("#%d: state = %s, want %s", i, g, tt.state)
  72. }
  73. if g := tt.sm.term; g != tt.term {
  74. t.Errorf("#%d: term = %d, want %d", i, g, tt.term)
  75. }
  76. }
  77. if g := diffLogs(defaultLog, tt.logs()); g != nil {
  78. for _, diff := range g {
  79. t.Errorf("bag log:\n%s", diff)
  80. }
  81. }
  82. }
  83. func TestOldMessages(t *testing.T) {
  84. tt := newNetwork(nil, nil, nil)
  85. // make 0 leader @ term 3
  86. tt.step(Message{To: 0, Type: msgHup})
  87. tt.step(Message{To: 0, Type: msgHup})
  88. tt.step(Message{To: 0, Type: msgHup})
  89. // pretend we're an old leader trying to make progress
  90. tt.step(Message{To: 0, Type: msgApp, Term: 1, Entries: []Entry{{Term: 1}}})
  91. if g := diffLogs(defaultLog, tt.logs()); g != nil {
  92. for _, diff := range g {
  93. t.Errorf("bag log:\n%s", diff)
  94. }
  95. }
  96. }
  97. // TestOldMessagesReply - optimization - reply with new term.
  98. func TestProposal(t *testing.T) {
  99. tests := []struct {
  100. *network
  101. success bool
  102. }{
  103. {newNetwork(nil, nil, nil), true},
  104. {newNetwork(nil, nil, nopStepper), true},
  105. {newNetwork(nil, nopStepper, nopStepper), false},
  106. {newNetwork(nil, nopStepper, nopStepper, nil), false},
  107. {newNetwork(nil, nopStepper, nopStepper, nil, nil), true},
  108. }
  109. for i, tt := range tests {
  110. tt.tee = stepperFunc(func(m Message) {
  111. t.Logf("#%d: m = %+v", i, m)
  112. })
  113. step := stepperFunc(func(m Message) {
  114. defer func() {
  115. // only recover is we expect it to panic so
  116. // panics we don't expect go up.
  117. if !tt.success {
  118. e := recover()
  119. if e != nil {
  120. t.Logf("#%d: err: %s", i, e)
  121. }
  122. }
  123. }()
  124. tt.step(m)
  125. })
  126. data := []byte("somedata")
  127. // promote 0 the leader
  128. step(Message{To: 0, Type: msgHup})
  129. step(Message{To: 0, Type: msgProp, Data: data})
  130. var wantLog []Entry
  131. if tt.success {
  132. wantLog = []Entry{{}, {Term: 1, Data: data}}
  133. } else {
  134. wantLog = defaultLog
  135. }
  136. if g := diffLogs(wantLog, tt.logs()); g != nil {
  137. for _, diff := range g {
  138. t.Errorf("#%d: diff:%s", i, diff)
  139. }
  140. }
  141. sm := tt.network.ss[0].(*stateMachine)
  142. if g := sm.term; g != 1 {
  143. t.Errorf("#%d: term = %d, want %d", i, g, 1)
  144. }
  145. }
  146. }
  147. func TestProposalByProxy(t *testing.T) {
  148. data := []byte("somedata")
  149. tests := []*network{
  150. newNetwork(nil, nil, nil),
  151. newNetwork(nil, nil, nopStepper),
  152. }
  153. for i, tt := range tests {
  154. tt.tee = stepperFunc(func(m Message) {
  155. t.Logf("#%d: m = %+v", i, m)
  156. })
  157. // promote 0 the leader
  158. tt.step(Message{To: 0, Type: msgHup})
  159. // propose via follower
  160. tt.step(Message{To: 1, Type: msgProp, Data: []byte("somedata")})
  161. wantLog := []Entry{{}, {Term: 1, Data: data}}
  162. if g := diffLogs(wantLog, tt.logs()); g != nil {
  163. for _, diff := range g {
  164. t.Errorf("#%d: bad entry: %s", i, diff)
  165. }
  166. }
  167. sm := tt.ss[0].(*stateMachine)
  168. if g := sm.term; g != 1 {
  169. t.Errorf("#%d: term = %d, want %d", i, g, 1)
  170. }
  171. }
  172. }
  173. func TestVote(t *testing.T) {
  174. tests := []struct {
  175. i, term int
  176. w int
  177. }{
  178. {0, 0, -1},
  179. {0, 1, -1},
  180. {0, 2, -1},
  181. {0, 3, 2},
  182. {1, 0, -1},
  183. {1, 1, -1},
  184. {1, 2, -1},
  185. {1, 3, 2},
  186. {2, 0, -1},
  187. {2, 1, -1},
  188. {2, 2, 2},
  189. {2, 3, 2},
  190. {3, 0, -1},
  191. {3, 1, -1},
  192. {3, 2, 2},
  193. {3, 3, 2},
  194. }
  195. for i, tt := range tests {
  196. called := false
  197. sm := &stateMachine{log: []Entry{{}, {Term: 2}, {Term: 2}}}
  198. sm.next = stepperFunc(func(m Message) {
  199. called = true
  200. if m.Index != tt.w {
  201. t.Errorf("#%d, m.Index = %d, want %d", i, m.Index, tt.w)
  202. }
  203. })
  204. sm.step(Message{Type: msgVote, Index: tt.i, LogTerm: tt.term})
  205. if !called {
  206. t.Fatal("#%d: not called", i)
  207. }
  208. }
  209. }
  210. func TestLogDiff(t *testing.T) {
  211. a := []Entry{{}, {Term: 1}, {Term: 2}}
  212. b := []Entry{{}, {Term: 1}, {Term: 2}}
  213. c := []Entry{{}, {Term: 2}}
  214. d := []Entry(nil)
  215. w := []diff{
  216. diff{1, []*Entry{{Term: 1}, {Term: 1}, {Term: 2}, nilLogEntry}},
  217. diff{2, []*Entry{{Term: 2}, {Term: 2}, noEntry, nilLogEntry}},
  218. }
  219. if g := diffLogs(a, [][]Entry{b, c, d}); !reflect.DeepEqual(w, g) {
  220. t.Errorf("g = %s", g)
  221. t.Errorf("want %s", w)
  222. }
  223. }
  224. type network struct {
  225. tee stepper
  226. ss []stepper
  227. }
  228. // newNetwork initializes a network from nodes. A nil node will be replaced
  229. // with a new *stateMachine. A *stateMachine will get its k, addr, and next
  230. // fields set.
  231. func newNetwork(nodes ...stepper) *network {
  232. nt := &network{ss: nodes}
  233. for i, n := range nodes {
  234. switch v := n.(type) {
  235. case nil:
  236. nt.ss[i] = newStateMachine(len(nodes), i, nt)
  237. case *stateMachine:
  238. v.k = len(nodes)
  239. v.addr = i
  240. if v.next == nil {
  241. v.next = nt
  242. }
  243. default:
  244. nt.ss[i] = v
  245. }
  246. }
  247. return nt
  248. }
  249. func (nt network) step(m Message) {
  250. if nt.tee != nil {
  251. nt.tee.step(m)
  252. }
  253. nt.ss[m.To].step(m)
  254. }
  255. // logs returns all logs in nt prepended with want. If a node is not a
  256. // *stateMachine, its log will be nil.
  257. func (nt network) logs() [][]Entry {
  258. ls := make([][]Entry, len(nt.ss))
  259. for i, node := range nt.ss {
  260. if sm, ok := node.(*stateMachine); ok {
  261. ls[i] = sm.log
  262. }
  263. }
  264. return ls
  265. }
  266. type diff struct {
  267. i int
  268. ents []*Entry // pointers so they can be nil for N/A
  269. }
  270. var noEntry = &Entry{}
  271. var nilLogEntry = &Entry{}
  272. func (d diff) String() string {
  273. s := fmt.Sprintf("[%d] ", d.i)
  274. for i, e := range d.ents {
  275. switch e {
  276. case nilLogEntry:
  277. s += fmt.Sprintf("o")
  278. case noEntry:
  279. s += fmt.Sprintf("-")
  280. case nil:
  281. s += fmt.Sprintf("<nil>")
  282. default:
  283. s += fmt.Sprintf("<%d:%q>", e.Term, string(e.Data))
  284. }
  285. if i != len(d.ents)-1 {
  286. s += "\t\t"
  287. }
  288. }
  289. return s
  290. }
  291. func diffLogs(base []Entry, logs [][]Entry) []diff {
  292. var (
  293. d []diff
  294. max int
  295. )
  296. logs = append([][]Entry{base}, logs...)
  297. for _, log := range logs {
  298. if l := len(log); l > max {
  299. max = l
  300. }
  301. }
  302. ediff := func(i int) (result []*Entry) {
  303. e := make([]*Entry, len(logs))
  304. found := false
  305. for j, log := range logs {
  306. if log == nil {
  307. e[j] = nilLogEntry
  308. continue
  309. }
  310. if len(log) <= i {
  311. e[j] = noEntry
  312. found = true
  313. continue
  314. }
  315. e[j] = &log[i]
  316. if j > 0 {
  317. switch prev := e[j-1]; {
  318. case prev == nilLogEntry:
  319. case prev == noEntry:
  320. case !reflect.DeepEqual(prev, e[j]):
  321. found = true
  322. }
  323. }
  324. }
  325. if found {
  326. return e
  327. }
  328. return nil
  329. }
  330. for i := 0; i < max; i++ {
  331. if e := ediff(i); e != nil {
  332. d = append(d, diff{i, e})
  333. }
  334. }
  335. return d
  336. }
  337. type stepperFunc func(Message)
  338. func (f stepperFunc) step(m Message) { f(m) }
  339. var nopStepper = stepperFunc(func(Message) {})
  340. type nextStepperFunc func(Message, stepper)