snapshot_test.go 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. package raft
  2. import (
  3. "testing"
  4. "github.com/stretchr/testify/assert"
  5. "github.com/stretchr/testify/mock"
  6. )
  7. // Ensure that a snapshot occurs when there are existing logs.
  8. func TestSnapshot(t *testing.T) {
  9. runServerWithMockStateMachine(Leader, func(s Server, m *mock.Mock) {
  10. m.On("Save").Return([]byte("foo"), nil)
  11. m.On("Recovery", []byte("foo")).Return(nil)
  12. s.Do(&testCommand1{})
  13. err := s.TakeSnapshot()
  14. assert.NoError(t, err)
  15. assert.Equal(t, s.(*server).snapshot.LastIndex, uint64(2))
  16. // Repeat to make sure new snapshot gets created.
  17. s.Do(&testCommand1{})
  18. err = s.TakeSnapshot()
  19. assert.NoError(t, err)
  20. assert.Equal(t, s.(*server).snapshot.LastIndex, uint64(4))
  21. // Restart server.
  22. s.Stop()
  23. s.Start()
  24. // Recover from snapshot.
  25. err = s.LoadSnapshot()
  26. assert.NoError(t, err)
  27. })
  28. }
  29. // Ensure that a snapshot request can be sent and received.
  30. func TestSnapshotRequest(t *testing.T) {
  31. runServerWithMockStateMachine(Follower, func(s Server, m *mock.Mock) {
  32. m.On("Recovery", []byte("bar")).Return(nil)
  33. // Send snapshot request.
  34. resp := s.RequestSnapshot(&SnapshotRequest{LastIndex: 5, LastTerm: 1})
  35. assert.Equal(t, resp.Success, true)
  36. assert.Equal(t, s.State(), Snapshotting)
  37. // Send recovery request.
  38. resp2 := s.SnapshotRecoveryRequest(&SnapshotRecoveryRequest{
  39. LeaderName: "1",
  40. LastIndex: 5,
  41. LastTerm: 2,
  42. Peers: make([]*Peer, 0),
  43. State: []byte("bar"),
  44. })
  45. assert.Equal(t, resp2.Success, true)
  46. })
  47. }
  48. func runServerWithMockStateMachine(state string, fn func(s Server, m *mock.Mock)) {
  49. var m mockStateMachine
  50. s := newTestServer("1", &testTransporter{})
  51. s.(*server).stateMachine = &m
  52. if err := s.Start(); err != nil {
  53. panic("server start error: " + err.Error())
  54. }
  55. if state == Leader {
  56. if _, err := s.Do(&DefaultJoinCommand{Name: s.Name()}); err != nil {
  57. panic("unable to join server to self: " + err.Error())
  58. }
  59. }
  60. defer s.Stop()
  61. fn(s, &m.Mock)
  62. }