Browse Source

client: elevate context to caller of KeysAPI

Brian Waldon 11 years ago
parent
commit
17c6f21d68
3 changed files with 42 additions and 42 deletions
  1. 16 23
      client/keys.go
  2. 13 6
      discovery/discovery.go
  3. 13 13
      discovery/discovery_test.go

+ 16 - 23
client/keys.go

@@ -41,31 +41,30 @@ var (
 	ErrKeyExists   = errors.New("client: key already exists")
 	ErrKeyExists   = errors.New("client: key already exists")
 )
 )
 
 
-func NewKeysAPI(c httpActionDo, to time.Duration) KeysAPI {
+func NewKeysAPI(c httpActionDo) KeysAPI {
 	return &httpKeysAPI{
 	return &httpKeysAPI{
-		client:  c,
-		prefix:  DefaultV2KeysPrefix,
-		timeout: to,
+		client: c,
+		prefix: DefaultV2KeysPrefix,
 	}
 	}
 }
 }
 
 
-func NewDiscoveryKeysAPI(c httpActionDo, to time.Duration) KeysAPI {
+func NewDiscoveryKeysAPI(c httpActionDo) KeysAPI {
 	return &httpKeysAPI{
 	return &httpKeysAPI{
-		client:  c,
-		prefix:  "",
-		timeout: to,
+		client: c,
+		prefix: "",
 	}
 	}
 }
 }
 
 
 type KeysAPI interface {
 type KeysAPI interface {
-	Create(key, value string, ttl time.Duration) (*Response, error)
-	Get(key string) (*Response, error)
+	Create(ctx context.Context, key, value string, ttl time.Duration) (*Response, error)
+	Get(ctx context.Context, key string) (*Response, error)
+
 	Watch(key string, idx uint64) Watcher
 	Watch(key string, idx uint64) Watcher
 	RecursiveWatch(key string, idx uint64) Watcher
 	RecursiveWatch(key string, idx uint64) Watcher
 }
 }
 
 
 type Watcher interface {
 type Watcher interface {
-	Next() (*Response, error)
+	Next(context.Context) (*Response, error)
 }
 }
 
 
 type Response struct {
 type Response struct {
@@ -88,12 +87,11 @@ func (n *Node) String() string {
 }
 }
 
 
 type httpKeysAPI struct {
 type httpKeysAPI struct {
-	client  httpActionDo
-	prefix  string
-	timeout time.Duration
+	client httpActionDo
+	prefix string
 }
 }
 
 
-func (k *httpKeysAPI) Create(key, val string, ttl time.Duration) (*Response, error) {
+func (k *httpKeysAPI) Create(ctx context.Context, key, val string, ttl time.Duration) (*Response, error) {
 	create := &createAction{
 	create := &createAction{
 		Prefix: k.prefix,
 		Prefix: k.prefix,
 		Key:    key,
 		Key:    key,
@@ -104,9 +102,7 @@ func (k *httpKeysAPI) Create(key, val string, ttl time.Duration) (*Response, err
 		create.TTL = &uttl
 		create.TTL = &uttl
 	}
 	}
 
 
-	ctx, cancel := context.WithTimeout(context.Background(), k.timeout)
 	resp, body, err := k.client.Do(ctx, create)
 	resp, body, err := k.client.Do(ctx, create)
-	cancel()
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -114,16 +110,14 @@ func (k *httpKeysAPI) Create(key, val string, ttl time.Duration) (*Response, err
 	return unmarshalHTTPResponse(resp.StatusCode, body)
 	return unmarshalHTTPResponse(resp.StatusCode, body)
 }
 }
 
 
-func (k *httpKeysAPI) Get(key string) (*Response, error) {
+func (k *httpKeysAPI) Get(ctx context.Context, key string) (*Response, error) {
 	get := &getAction{
 	get := &getAction{
 		Prefix:    k.prefix,
 		Prefix:    k.prefix,
 		Key:       key,
 		Key:       key,
 		Recursive: false,
 		Recursive: false,
 	}
 	}
 
 
-	ctx, cancel := context.WithTimeout(context.Background(), k.timeout)
 	resp, body, err := k.client.Do(ctx, get)
 	resp, body, err := k.client.Do(ctx, get)
-	cancel()
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -160,9 +154,8 @@ type httpWatcher struct {
 	nextWait waitAction
 	nextWait waitAction
 }
 }
 
 
-func (hw *httpWatcher) Next() (*Response, error) {
-	//TODO(bcwaldon): This needs to be cancellable by the calling user
-	httpresp, body, err := hw.client.Do(context.Background(), &hw.nextWait)
+func (hw *httpWatcher) Next(ctx context.Context) (*Response, error) {
+	httpresp, body, err := hw.client.Do(ctx, &hw.nextWait)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}

+ 13 - 6
discovery/discovery.go

@@ -29,6 +29,7 @@ import (
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
+	"github.com/coreos/etcd/Godeps/_workspace/src/code.google.com/p/go.net/context"
 	"github.com/coreos/etcd/Godeps/_workspace/src/github.com/jonboulle/clockwork"
 	"github.com/coreos/etcd/Godeps/_workspace/src/github.com/jonboulle/clockwork"
 	"github.com/coreos/etcd/client"
 	"github.com/coreos/etcd/client"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/pkg/types"
@@ -110,7 +111,7 @@ func New(durl string, id types.ID, config string) (Discoverer, error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	dc := client.NewDiscoveryKeysAPI(c, client.DefaultRequestTimeout)
+	dc := client.NewDiscoveryKeysAPI(c)
 	return &discovery{
 	return &discovery{
 		cluster: token,
 		cluster: token,
 		id:      id,
 		id:      id,
@@ -150,21 +151,25 @@ func (d *discovery) Discover() (string, error) {
 }
 }
 
 
 func (d *discovery) createSelf() error {
 func (d *discovery) createSelf() error {
-	resp, err := d.c.Create(d.selfKey(), d.config, -1)
+	ctx, cancel := context.WithTimeout(context.Background(), client.DefaultRequestTimeout)
+	resp, err := d.c.Create(ctx, d.selfKey(), d.config, -1)
+	cancel()
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
 	// ensure self appears on the server we connected to
 	// ensure self appears on the server we connected to
 	w := d.c.Watch(d.selfKey(), resp.Node.CreatedIndex)
 	w := d.c.Watch(d.selfKey(), resp.Node.CreatedIndex)
-	_, err = w.Next()
+	_, err = w.Next(context.Background())
 	return err
 	return err
 }
 }
 
 
 func (d *discovery) checkCluster() (client.Nodes, int, error) {
 func (d *discovery) checkCluster() (client.Nodes, int, error) {
 	configKey := path.Join("/", d.cluster, "_config")
 	configKey := path.Join("/", d.cluster, "_config")
+	ctx, cancel := context.WithTimeout(context.Background(), client.DefaultRequestTimeout)
 	// find cluster size
 	// find cluster size
-	resp, err := d.c.Get(path.Join(configKey, "size"))
+	resp, err := d.c.Get(ctx, path.Join(configKey, "size"))
+	cancel()
 	if err != nil {
 	if err != nil {
 		if err == client.ErrKeyNoExist {
 		if err == client.ErrKeyNoExist {
 			return nil, 0, ErrSizeNotFound
 			return nil, 0, ErrSizeNotFound
@@ -179,7 +184,9 @@ func (d *discovery) checkCluster() (client.Nodes, int, error) {
 		return nil, 0, ErrBadSizeKey
 		return nil, 0, ErrBadSizeKey
 	}
 	}
 
 
-	resp, err = d.c.Get(d.cluster)
+	ctx, cancel = context.WithTimeout(context.Background(), client.DefaultRequestTimeout)
+	resp, err = d.c.Get(ctx, d.cluster)
+	cancel()
 	if err != nil {
 	if err != nil {
 		if err == client.ErrTimeout {
 		if err == client.ErrTimeout {
 			return d.checkClusterRetry()
 			return d.checkClusterRetry()
@@ -254,7 +261,7 @@ func (d *discovery) waitNodes(nodes client.Nodes, size int) (client.Nodes, error
 	// wait for others
 	// wait for others
 	for len(all) < size {
 	for len(all) < size {
 		log.Printf("discovery: found %d peer(s), waiting for %d more", len(all), size-len(all))
 		log.Printf("discovery: found %d peer(s), waiting for %d more", len(all), size-len(all))
-		resp, err := w.Next()
+		resp, err := w.Next(context.Background())
 		if err != nil {
 		if err != nil {
 			if err == client.ErrTimeout {
 			if err == client.ErrTimeout {
 				return d.waitNodesRetry()
 				return d.waitNodesRetry()

+ 13 - 13
discovery/discovery_test.go

@@ -21,13 +21,13 @@ import (
 	"math/rand"
 	"math/rand"
 	"net/http"
 	"net/http"
 	"os"
 	"os"
+	"reflect"
 	"sort"
 	"sort"
 	"strconv"
 	"strconv"
-
-	"reflect"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
+	"github.com/coreos/etcd/Godeps/_workspace/src/code.google.com/p/go.net/context"
 	"github.com/coreos/etcd/Godeps/_workspace/src/github.com/jonboulle/clockwork"
 	"github.com/coreos/etcd/Godeps/_workspace/src/github.com/jonboulle/clockwork"
 	"github.com/coreos/etcd/client"
 	"github.com/coreos/etcd/client"
 )
 )
@@ -397,7 +397,7 @@ type clientWithResp struct {
 	w  client.Watcher
 	w  client.Watcher
 }
 }
 
 
-func (c *clientWithResp) Create(key string, value string, ttl time.Duration) (*client.Response, error) {
+func (c *clientWithResp) Create(ctx context.Context, key string, value string, ttl time.Duration) (*client.Response, error) {
 	if len(c.rs) == 0 {
 	if len(c.rs) == 0 {
 		return &client.Response{}, nil
 		return &client.Response{}, nil
 	}
 	}
@@ -406,7 +406,7 @@ func (c *clientWithResp) Create(key string, value string, ttl time.Duration) (*c
 	return r, nil
 	return r, nil
 }
 }
 
 
-func (c *clientWithResp) Get(key string) (*client.Response, error) {
+func (c *clientWithResp) Get(ctx context.Context, key string) (*client.Response, error) {
 	if len(c.rs) == 0 {
 	if len(c.rs) == 0 {
 		return &client.Response{}, client.ErrKeyNoExist
 		return &client.Response{}, client.ErrKeyNoExist
 	}
 	}
@@ -428,11 +428,11 @@ type clientWithErr struct {
 	w   client.Watcher
 	w   client.Watcher
 }
 }
 
 
-func (c *clientWithErr) Create(key string, value string, ttl time.Duration) (*client.Response, error) {
+func (c *clientWithErr) Create(ctx context.Context, key string, value string, ttl time.Duration) (*client.Response, error) {
 	return &client.Response{}, c.err
 	return &client.Response{}, c.err
 }
 }
 
 
-func (c *clientWithErr) Get(key string) (*client.Response, error) {
+func (c *clientWithErr) Get(ctx context.Context, key string) (*client.Response, error) {
 	return &client.Response{}, c.err
 	return &client.Response{}, c.err
 }
 }
 
 
@@ -448,7 +448,7 @@ type watcherWithResp struct {
 	rs []*client.Response
 	rs []*client.Response
 }
 }
 
 
-func (w *watcherWithResp) Next() (*client.Response, error) {
+func (w *watcherWithResp) Next(context.Context) (*client.Response, error) {
 	if len(w.rs) == 0 {
 	if len(w.rs) == 0 {
 		return &client.Response{}, nil
 		return &client.Response{}, nil
 	}
 	}
@@ -461,7 +461,7 @@ type watcherWithErr struct {
 	err error
 	err error
 }
 }
 
 
-func (w *watcherWithErr) Next() (*client.Response, error) {
+func (w *watcherWithErr) Next(context.Context) (*client.Response, error) {
 	return &client.Response{}, w.err
 	return &client.Response{}, w.err
 }
 }
 
 
@@ -472,20 +472,20 @@ type clientWithRetry struct {
 	failTimes int
 	failTimes int
 }
 }
 
 
-func (c *clientWithRetry) Create(key string, value string, ttl time.Duration) (*client.Response, error) {
+func (c *clientWithRetry) Create(ctx context.Context, key string, value string, ttl time.Duration) (*client.Response, error) {
 	if c.failCount < c.failTimes {
 	if c.failCount < c.failTimes {
 		c.failCount++
 		c.failCount++
 		return nil, client.ErrTimeout
 		return nil, client.ErrTimeout
 	}
 	}
-	return c.clientWithResp.Create(key, value, ttl)
+	return c.clientWithResp.Create(ctx, key, value, ttl)
 }
 }
 
 
-func (c *clientWithRetry) Get(key string) (*client.Response, error) {
+func (c *clientWithRetry) Get(ctx context.Context, key string) (*client.Response, error) {
 	if c.failCount < c.failTimes {
 	if c.failCount < c.failTimes {
 		c.failCount++
 		c.failCount++
 		return nil, client.ErrTimeout
 		return nil, client.ErrTimeout
 	}
 	}
-	return c.clientWithResp.Get(key)
+	return c.clientWithResp.Get(ctx, key)
 }
 }
 
 
 // watcherWithRetry will timeout all requests up to failTimes
 // watcherWithRetry will timeout all requests up to failTimes
@@ -495,7 +495,7 @@ type watcherWithRetry struct {
 	failTimes int
 	failTimes int
 }
 }
 
 
-func (w *watcherWithRetry) Next() (*client.Response, error) {
+func (w *watcherWithRetry) Next(context.Context) (*client.Response, error) {
 	if w.failCount < w.failTimes {
 	if w.failCount < w.failTimes {
 		w.failCount++
 		w.failCount++
 		return nil, client.ErrTimeout
 		return nil, client.ErrTimeout