Explorar o código

Merge pull request #20 from tukeJonny/retrier-runctx

Add RunCtx to retrier
Evan Huus %!s(int64=6) %!d(string=hai) anos
pai
achega
842e16ec2c
Modificáronse 3 ficheiros con 67 adicións e 7 borrados
  1. 2 3
      .travis.yml
  2. 27 4
      retrier/retrier.go
  3. 38 0
      retrier/retrier_test.go

+ 2 - 3
.travis.yml

@@ -1,6 +1,5 @@
 language: go
 
 go:
-  - 1.2
-  - 1.6
-  - 1.10
+  - 1.7
+  - "1.10"

+ 27 - 4
retrier/retrier.go

@@ -2,6 +2,7 @@
 package retrier
 
 import (
+	"context"
 	"math/rand"
 	"sync"
 	"time"
@@ -33,15 +34,23 @@ func New(backoff []time.Duration, class Classifier) *Retrier {
 	}
 }
 
-// Run executes the given work function, then classifies its return value based on the classifier used
+// Run executes the given work function by executing RunCtx without context.Context.
+func (r *Retrier) Run(work func() error) error {
+	return r.RunCtx(context.Background(), func(ctx context.Context) error {
+		// never use ctx
+		return work()
+	})
+}
+
+// RunCtx executes the given work function, then classifies its return value based on the classifier used
 // to construct the Retrier. If the result is Succeed or Fail, the return value of the work function is
 // returned to the caller. If the result is Retry, then Run sleeps according to the its backoff policy
 // before retrying. If the total number of retries is exceeded then the return value of the work function
 // is returned to the caller regardless.
-func (r *Retrier) Run(work func() error) error {
+func (r *Retrier) RunCtx(ctx context.Context, work func(ctx context.Context) error) error {
 	retries := 0
 	for {
-		ret := work()
+		ret := work(ctx)
 
 		switch r.class.Classify(ret) {
 		case Succeed, Fail:
@@ -50,12 +59,26 @@ func (r *Retrier) Run(work func() error) error {
 			if retries >= len(r.backoff) {
 				return ret
 			}
-			time.Sleep(r.calcSleep(retries))
+
+			timeout := time.After(r.calcSleep(retries))
+			if err := r.sleep(ctx, timeout); err != nil {
+				return err
+			}
+
 			retries++
 		}
 	}
 }
 
+func (r *Retrier) sleep(ctx context.Context, t <-chan time.Time) error {
+	select {
+	case <-t:
+		return nil
+	case <-ctx.Done():
+		return ctx.Err()
+	}
+}
+
 func (r *Retrier) calcSleep(i int) time.Duration {
 	// lock unsafe rand prng
 	r.randMu.Lock()

+ 38 - 0
retrier/retrier_test.go

@@ -1,6 +1,7 @@
 package retrier
 
 import (
+	"context"
 	"errors"
 	"testing"
 	"time"
@@ -19,6 +20,19 @@ func genWork(returns []error) func() error {
 	}
 }
 
+func genWorkWithCtx() func(ctx context.Context) error {
+	i = 0
+	return func(ctx context.Context) error {
+		select {
+		case <-ctx.Done():
+			return errFoo
+		default:
+			i++
+		}
+		return nil
+	}
+}
+
 func TestRetrier(t *testing.T) {
 	r := New([]time.Duration{0, 10 * time.Millisecond}, WhitelistClassifier{errFoo})
 
@@ -47,6 +61,30 @@ func TestRetrier(t *testing.T) {
 	}
 }
 
+func TestRetrierCtx(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+
+	r := New([]time.Duration{0, 10 * time.Millisecond}, WhitelistClassifier{})
+
+	err := r.RunCtx(ctx, genWorkWithCtx())
+	if err != nil {
+		t.Error(err)
+	}
+	if i != 1 {
+		t.Error("run wrong number of times")
+	}
+
+	cancel()
+
+	err = r.RunCtx(ctx, genWorkWithCtx())
+	if err != errFoo {
+		t.Error("context must be cancelled")
+	}
+	if i != 0 {
+		t.Error("run wrong number of times")
+	}
+}
+
 func TestRetrierNone(t *testing.T) {
 	r := New(nil, nil)