node_test.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. package raft
  2. import (
  3. "reflect"
  4. "testing"
  5. )
  6. const (
  7. defaultHeartbeat = 1
  8. defaultElection = 5
  9. )
  10. func TestTickMsgHup(t *testing.T) {
  11. n := New(0, defaultHeartbeat, defaultElection)
  12. n.sm = newStateMachine(0, []int64{0, 1, 2})
  13. // simulate to patch the join log
  14. n.Step(Message{From: 1, Type: msgApp, Commit: 1, Entries: []Entry{Entry{}}})
  15. for i := 0; i < defaultElection*2; i++ {
  16. n.Tick()
  17. }
  18. called := false
  19. for _, m := range n.Msgs() {
  20. if m.Type == msgVote {
  21. called = true
  22. }
  23. }
  24. if !called {
  25. t.Errorf("called = %v, want true", called)
  26. }
  27. }
  28. func TestTickMsgBeat(t *testing.T) {
  29. k := 3
  30. n := dictate(New(0, defaultHeartbeat, defaultElection))
  31. n.Next()
  32. for i := 1; i < k; i++ {
  33. n.Add(int64(i), "", nil)
  34. for _, m := range n.Msgs() {
  35. if m.Type == msgApp {
  36. n.Step(Message{From: m.To, ClusterId: m.ClusterId, Type: msgAppResp, Index: m.Index + int64(len(m.Entries))})
  37. }
  38. }
  39. // ignore commit index update messages
  40. n.Msgs()
  41. n.Next()
  42. }
  43. for i := 0; i < defaultHeartbeat+1; i++ {
  44. n.Tick()
  45. }
  46. called := 0
  47. for _, m := range n.Msgs() {
  48. if m.Type == msgApp && len(m.Entries) == 0 {
  49. called++
  50. }
  51. }
  52. // msgBeat -> k-1 append
  53. w := k - 1
  54. if called != w {
  55. t.Errorf("called = %v, want %v", called, w)
  56. }
  57. }
  58. func TestResetElapse(t *testing.T) {
  59. tests := []struct {
  60. msg Message
  61. welapsed tick
  62. }{
  63. {Message{From: 0, To: 1, Type: msgApp, Term: 2, Entries: []Entry{{Term: 1}}}, 0},
  64. {Message{From: 0, To: 1, Type: msgApp, Term: 1, Entries: []Entry{{Term: 1}}}, 1},
  65. {Message{From: 0, To: 1, Type: msgVote, Term: 2, Index: 1, LogTerm: 1}, 0},
  66. {Message{From: 0, To: 1, Type: msgVote, Term: 1}, 1},
  67. }
  68. for i, tt := range tests {
  69. n := New(0, defaultHeartbeat, defaultElection)
  70. n.sm = newStateMachine(0, []int64{0, 1, 2})
  71. n.sm.raftLog.append(0, Entry{Type: Normal, Term: 1})
  72. n.sm.term = 2
  73. n.sm.raftLog.committed = 1
  74. n.Tick()
  75. if n.elapsed != 1 {
  76. t.Errorf("%d: elpased = %d, want %d", i, n.elapsed, 1)
  77. }
  78. n.Step(tt.msg)
  79. if n.elapsed != tt.welapsed {
  80. t.Errorf("%d: elpased = %d, want %d", i, n.elapsed, tt.welapsed)
  81. }
  82. }
  83. }
  84. func TestStartCluster(t *testing.T) {
  85. n := dictate(New(0, defaultHeartbeat, defaultElection))
  86. n.Next()
  87. if len(n.sm.ins) != 1 {
  88. t.Errorf("k = %d, want 1", len(n.sm.ins))
  89. }
  90. if n.sm.id != 0 {
  91. t.Errorf("id = %d, want 0", n.sm.id)
  92. }
  93. if n.sm.state != stateLeader {
  94. t.Errorf("state = %s, want %s", n.sm.state, stateLeader)
  95. }
  96. }
  97. func TestAdd(t *testing.T) {
  98. n := dictate(New(0, defaultHeartbeat, defaultElection))
  99. n.Next()
  100. n.Add(1, "", nil)
  101. n.Next()
  102. if len(n.sm.ins) != 2 {
  103. t.Errorf("k = %d, want 2", len(n.sm.ins))
  104. }
  105. if n.sm.id != 0 {
  106. t.Errorf("id = %d, want 0", n.sm.id)
  107. }
  108. }
  109. func TestRemove(t *testing.T) {
  110. n := dictate(New(0, defaultHeartbeat, defaultElection))
  111. n.Next()
  112. n.Add(1, "", nil)
  113. n.Next()
  114. n.Remove(0)
  115. n.Step(Message{Type: msgAppResp, From: 1, ClusterId: n.ClusterId(), Term: 1, Index: 5})
  116. n.Next()
  117. if len(n.sm.ins) != 1 {
  118. t.Errorf("k = %d, want 1", len(n.sm.ins))
  119. }
  120. if n.sm.id != 0 {
  121. t.Errorf("id = %d, want 0", n.sm.id)
  122. }
  123. }
  124. func TestDenial(t *testing.T) {
  125. logents := []Entry{
  126. {Type: AddNode, Term: 1, Data: []byte(`{"NodeId":1}`)},
  127. {Type: AddNode, Term: 1, Data: []byte(`{"NodeId":2}`)},
  128. {Type: RemoveNode, Term: 1, Data: []byte(`{"NodeId":2}`)},
  129. }
  130. tests := []struct {
  131. ent Entry
  132. wdenied map[int64]bool
  133. }{
  134. {
  135. Entry{Type: AddNode, Term: 1, Data: []byte(`{"NodeId":2}`)},
  136. map[int64]bool{1: false, 2: false},
  137. },
  138. {
  139. Entry{Type: RemoveNode, Term: 1, Data: []byte(`{"NodeId":1}`)},
  140. map[int64]bool{1: true, 2: true},
  141. },
  142. {
  143. Entry{Type: RemoveNode, Term: 1, Data: []byte(`{"NodeId":0}`)},
  144. map[int64]bool{1: false, 2: true},
  145. },
  146. }
  147. for i, tt := range tests {
  148. n := dictate(New(0, defaultHeartbeat, defaultElection))
  149. n.Next()
  150. n.Msgs()
  151. n.sm.raftLog.append(n.sm.raftLog.committed, append(logents, tt.ent)...)
  152. n.sm.raftLog.committed += int64(len(logents) + 1)
  153. n.Next()
  154. for id, denied := range tt.wdenied {
  155. n.Step(Message{From: id, To: 0, ClusterId: n.ClusterId(), Type: msgApp, Term: 1})
  156. w := []Message{}
  157. if denied {
  158. w = []Message{{From: 0, To: id, ClusterId: n.ClusterId(), Term: 1, Type: msgDenied}}
  159. }
  160. if g := n.Msgs(); !reflect.DeepEqual(g, w) {
  161. t.Errorf("#%d: msgs for %d = %+v, want %+v", i, id, g, w)
  162. }
  163. }
  164. }
  165. }
  166. func TestRecover(t *testing.T) {
  167. ents := []Entry{{Term: 1}, {Term: 2}, {Term: 3}}
  168. state := State{Term: 500, Vote: 1, Commit: 3}
  169. n := Recover(0, ents, state, defaultHeartbeat, defaultElection)
  170. if g := n.Next(); !reflect.DeepEqual(g, ents) {
  171. t.Errorf("ents = %+v, want %+v", g, ents)
  172. }
  173. if g := n.sm.term; g.Get() != state.Term {
  174. t.Errorf("term = %d, want %d", g, state.Term)
  175. }
  176. if g := n.sm.vote; g != state.Vote {
  177. t.Errorf("vote = %d, want %d", g, state.Vote)
  178. }
  179. if g := n.sm.raftLog.committed; g != state.Commit {
  180. t.Errorf("committed = %d, want %d", g, state.Commit)
  181. }
  182. if g := n.UnstableEnts(); g != nil {
  183. t.Errorf("unstableEnts = %+v, want nil", g)
  184. }
  185. if g := n.UnstableState(); !reflect.DeepEqual(g, state) {
  186. t.Errorf("unstableState = %+v, want %+v", g, state)
  187. }
  188. if g := n.Msgs(); len(g) != 0 {
  189. t.Errorf("#%d: len(msgs) = %d, want 0", len(g))
  190. }
  191. }
  192. func dictate(n *Node) *Node {
  193. n.Step(Message{From: n.Id(), Type: msgHup})
  194. n.InitCluster(0xBEEF)
  195. n.Add(n.Id(), "", nil)
  196. return n
  197. }