Browse Source

*: fix leaky context creation with cancel

Signed-off-by: Gyu-Ho Lee <gyuhox@gmail.com>
Gyu-Ho Lee 8 years ago
parent
commit
9a726b424d

+ 4 - 2
clientv3/balancer_test.go

@@ -84,8 +84,9 @@ func TestBalancerGetBlocking(t *testing.T) {
 	}
 	}
 	blockingOpts := grpc.BalancerGetOptions{BlockingWait: true}
 	blockingOpts := grpc.BalancerGetOptions{BlockingWait: true}
 
 
-	ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*100)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
 	_, _, err := sb.Get(ctx, blockingOpts)
 	_, _, err := sb.Get(ctx, blockingOpts)
+	cancel()
 	if err != context.DeadlineExceeded {
 	if err != context.DeadlineExceeded {
 		t.Errorf("Get() with no up endpoints should timeout, got %v", err)
 		t.Errorf("Get() with no up endpoints should timeout, got %v", err)
 	}
 	}
@@ -124,8 +125,9 @@ func TestBalancerGetBlocking(t *testing.T) {
 		t.Errorf("closing the only connection should triggered balancer to send the all endpoints via Notify chan so that we can establish a connection")
 		t.Errorf("closing the only connection should triggered balancer to send the all endpoints via Notify chan so that we can establish a connection")
 	}
 	}
 	down2(errors.New("error"))
 	down2(errors.New("error"))
-	ctx, _ = context.WithTimeout(context.Background(), time.Millisecond*100)
+	ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond*100)
 	_, _, err = sb.Get(ctx, blockingOpts)
 	_, _, err = sb.Get(ctx, blockingOpts)
+	cancel()
 	if err != context.DeadlineExceeded {
 	if err != context.DeadlineExceeded {
 		t.Errorf("Get() with no up endpoints should timeout, got %v", err)
 		t.Errorf("Get() with no up endpoints should timeout, got %v", err)
 	}
 	}

+ 5 - 3
clientv3/client.go

@@ -143,8 +143,10 @@ func (c *Client) autoSync() {
 		case <-c.ctx.Done():
 		case <-c.ctx.Done():
 			return
 			return
 		case <-time.After(c.cfg.AutoSyncInterval):
 		case <-time.After(c.cfg.AutoSyncInterval):
-			ctx, _ := context.WithTimeout(c.ctx, 5*time.Second)
-			if err := c.Sync(ctx); err != nil && err != c.ctx.Err() {
+			ctx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
+			err := c.Sync(ctx)
+			cancel()
+			if err != nil && err != c.ctx.Err() {
 				logger.Println("Auto sync endpoints failed:", err)
 				logger.Println("Auto sync endpoints failed:", err)
 			}
 			}
 		}
 		}
@@ -429,7 +431,7 @@ func (c *Client) checkVersion() (err error) {
 	errc := make(chan error, len(c.cfg.Endpoints))
 	errc := make(chan error, len(c.cfg.Endpoints))
 	ctx, cancel := context.WithCancel(c.ctx)
 	ctx, cancel := context.WithCancel(c.ctx)
 	if c.cfg.DialTimeout > 0 {
 	if c.cfg.DialTimeout > 0 {
-		ctx, _ = context.WithTimeout(ctx, c.cfg.DialTimeout)
+		ctx, cancel = context.WithTimeout(ctx, c.cfg.DialTimeout)
 	}
 	}
 	wg.Add(len(c.cfg.Endpoints))
 	wg.Add(len(c.cfg.Endpoints))
 	for _, ep := range c.cfg.Endpoints {
 	for _, ep := range c.cfg.Endpoints {

+ 2 - 0
clientv3/concurrency/election.go

@@ -213,6 +213,7 @@ func (e *Election) observe(ctx context.Context, ch chan<- v3.GetResponse) {
 		for !keyDeleted {
 		for !keyDeleted {
 			wr, ok := <-wch
 			wr, ok := <-wch
 			if !ok {
 			if !ok {
+				cancel()
 				return
 				return
 			}
 			}
 			for _, ev := range wr.Events {
 			for _, ev := range wr.Events {
@@ -225,6 +226,7 @@ func (e *Election) observe(ctx context.Context, ch chan<- v3.GetResponse) {
 				select {
 				select {
 				case ch <- *resp:
 				case ch <- *resp:
 				case <-cctx.Done():
 				case <-cctx.Done():
+					cancel()
 					return
 					return
 				}
 				}
 			}
 			}

+ 1 - 0
clientv3/concurrency/session.go

@@ -53,6 +53,7 @@ func NewSession(client *v3.Client, opts ...SessionOption) (*Session, error) {
 	ctx, cancel := context.WithCancel(ops.ctx)
 	ctx, cancel := context.WithCancel(ops.ctx)
 	keepAlive, err := client.KeepAlive(ctx, id)
 	keepAlive, err := client.KeepAlive(ctx, id)
 	if err != nil || keepAlive == nil {
 	if err != nil || keepAlive == nil {
+		cancel()
 		return nil, err
 		return nil, err
 	}
 	}
 
 

+ 2 - 1
etcdserver/server_test.go

@@ -741,8 +741,9 @@ func TestDoProposalTimeout(t *testing.T) {
 	}
 	}
 	srv.applyV2 = &applierV2store{store: srv.store, cluster: srv.cluster}
 	srv.applyV2 = &applierV2store{store: srv.store, cluster: srv.cluster}
 
 
-	ctx, _ := context.WithTimeout(context.Background(), 0)
+	ctx, cancel := context.WithTimeout(context.Background(), 0)
 	_, err := srv.Do(ctx, pb.Request{Method: "PUT"})
 	_, err := srv.Do(ctx, pb.Request{Method: "PUT"})
+	cancel()
 	if err != ErrTimeout {
 	if err != ErrTimeout {
 		t.Fatalf("err = %v, want %v", err, ErrTimeout)
 		t.Fatalf("err = %v, want %v", err, ErrTimeout)
 	}
 	}

+ 3 - 2
integration/cluster.go

@@ -277,10 +277,11 @@ func (c *cluster) addMemberByURL(t *testing.T, clientURL, peerURL string) error
 	cc := MustNewHTTPClient(t, []string{clientURL}, c.cfg.ClientTLS)
 	cc := MustNewHTTPClient(t, []string{clientURL}, c.cfg.ClientTLS)
 	ma := client.NewMembersAPI(cc)
 	ma := client.NewMembersAPI(cc)
 	ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
 	ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
-	if _, err := ma.Add(ctx, peerURL); err != nil {
+	_, err := ma.Add(ctx, peerURL)
+	cancel()
+	if err != nil {
 		return err
 		return err
 	}
 	}
-	cancel()
 
 
 	// wait for the add node entry applied in the cluster
 	// wait for the add node entry applied in the cluster
 	members := append(c.HTTPMembers(), client.Member{PeerURLs: []string{peerURL}, ClientURLs: []string{}})
 	members := append(c.HTTPMembers(), client.Member{PeerURLs: []string{peerURL}, ClientURLs: []string{}})

+ 1 - 0
tools/functional-tester/etcd-tester/lease_stresser.go

@@ -290,6 +290,7 @@ func (ls *leaseStresser) keepLeaseAlive(leaseID int64) {
 			cancel()
 			cancel()
 			ctx, cancel = context.WithCancel(ls.ctx)
 			ctx, cancel = context.WithCancel(ls.ctx)
 			stream, err = ls.lc.LeaseKeepAlive(ctx)
 			stream, err = ls.lc.LeaseKeepAlive(ctx)
+			cancel()
 			continue
 			continue
 		}
 		}
 		err = stream.Send(&pb.LeaseKeepAliveRequest{ID: leaseID})
 		err = stream.Send(&pb.LeaseKeepAliveRequest{ID: leaseID})