retrier_test.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. package retrier
  2. import (
  3. "context"
  4. "errors"
  5. "testing"
  6. "time"
  7. )
  8. var i int
  9. func genWork(returns []error) func() error {
  10. i = 0
  11. return func() error {
  12. i++
  13. if i > len(returns) {
  14. return nil
  15. }
  16. return returns[i-1]
  17. }
  18. }
  19. func genWorkWithCtx() func(ctx context.Context) error {
  20. i = 0
  21. return func(ctx context.Context) error {
  22. select {
  23. case <-ctx.Done():
  24. return errFoo
  25. default:
  26. i++
  27. }
  28. return nil
  29. }
  30. }
  31. func TestRetrier(t *testing.T) {
  32. r := New([]time.Duration{0, 10 * time.Millisecond}, WhitelistClassifier{errFoo})
  33. err := r.Run(genWork([]error{errFoo, errFoo}))
  34. if err != nil {
  35. t.Error(err)
  36. }
  37. if i != 3 {
  38. t.Error("run wrong number of times")
  39. }
  40. err = r.Run(genWork([]error{errFoo, errBar}))
  41. if err != errBar {
  42. t.Error(err)
  43. }
  44. if i != 2 {
  45. t.Error("run wrong number of times")
  46. }
  47. err = r.Run(genWork([]error{errBar, errBaz}))
  48. if err != errBar {
  49. t.Error(err)
  50. }
  51. if i != 1 {
  52. t.Error("run wrong number of times")
  53. }
  54. }
  55. func TestRetrierCtx(t *testing.T) {
  56. ctx, cancel := context.WithCancel(context.Background())
  57. r := New([]time.Duration{0, 10 * time.Millisecond}, WhitelistClassifier{})
  58. err := r.RunCtx(ctx, genWorkWithCtx())
  59. if err != nil {
  60. t.Error(err)
  61. }
  62. if i != 1 {
  63. t.Error("run wrong number of times")
  64. }
  65. cancel()
  66. err = r.RunCtx(ctx, genWorkWithCtx())
  67. if err != errFoo {
  68. t.Error("context must be cancelled")
  69. }
  70. if i != 0 {
  71. t.Error("run wrong number of times")
  72. }
  73. }
  74. func TestRetrierNone(t *testing.T) {
  75. r := New(nil, nil)
  76. i = 0
  77. err := r.Run(func() error {
  78. i++
  79. return errFoo
  80. })
  81. if err != errFoo {
  82. t.Error(err)
  83. }
  84. if i != 1 {
  85. t.Error("run wrong number of times")
  86. }
  87. i = 0
  88. err = r.Run(func() error {
  89. i++
  90. return nil
  91. })
  92. if err != nil {
  93. t.Error(err)
  94. }
  95. if i != 1 {
  96. t.Error("run wrong number of times")
  97. }
  98. }
  99. func TestRetrierJitter(t *testing.T) {
  100. r := New([]time.Duration{0, 10 * time.Millisecond, 4 * time.Hour}, nil)
  101. if r.calcSleep(0) != 0 {
  102. t.Error("Incorrect sleep calculated")
  103. }
  104. if r.calcSleep(1) != 10*time.Millisecond {
  105. t.Error("Incorrect sleep calculated")
  106. }
  107. if r.calcSleep(2) != 4*time.Hour {
  108. t.Error("Incorrect sleep calculated")
  109. }
  110. r.SetJitter(0.25)
  111. for i := 0; i < 20; i++ {
  112. if r.calcSleep(0) != 0 {
  113. t.Error("Incorrect sleep calculated")
  114. }
  115. slp := r.calcSleep(1)
  116. if slp < 7500*time.Microsecond || slp > 12500*time.Microsecond {
  117. t.Error("Incorrect sleep calculated")
  118. }
  119. slp = r.calcSleep(2)
  120. if slp < 3*time.Hour || slp > 5*time.Hour {
  121. t.Error("Incorrect sleep calculated")
  122. }
  123. }
  124. r.SetJitter(-1)
  125. if r.jitter != 0.25 {
  126. t.Error("Invalid jitter value accepted")
  127. }
  128. r.SetJitter(2)
  129. if r.jitter != 0.25 {
  130. t.Error("Invalid jitter value accepted")
  131. }
  132. }
  133. func TestRetrierThreadSafety(t *testing.T) {
  134. r := New([]time.Duration{0}, nil)
  135. for i := 0; i < 2; i++ {
  136. go func() {
  137. r.Run(func() error {
  138. return errors.New("error")
  139. })
  140. }()
  141. }
  142. }
  143. func ExampleRetrier() {
  144. r := New(ConstantBackoff(3, 100*time.Millisecond), nil)
  145. err := r.Run(func() error {
  146. // do some work
  147. return nil
  148. })
  149. if err != nil {
  150. // handle the case where the work failed three times
  151. }
  152. }