Просмотр исходного кода

acme: stop using ctxhttp

The ctxhttp package used to be big and gross before net/http supported
contexts natively. Nowadays it barely does anything. Stop using it,
because it just pulls in the old context package anyway. (We can't
really clean up the ctxhttp package until Go 1.9)

Change-Id: I48b11f2f483783a32cbaa75e244301148a304c08
Reviewed-on: https://go-review.googlesource.com/40110
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Alex Vaghin <ddos@google.com>
Brad Fitzpatrick 8 лет назад
Родитель
Сommit
6022e334c1
2 измененных файлов с 70 добавлено и 20 удалено
  1. 66 18
      acme/acme.go
  2. 4 2
      acme/acme_test.go

+ 66 - 18
acme/acme.go

@@ -37,8 +37,6 @@ import (
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"time"
 	"time"
-
-	"golang.org/x/net/context/ctxhttp"
 )
 )
 
 
 // LetsEncryptURL is the Directory endpoint of Let's Encrypt CA.
 // LetsEncryptURL is the Directory endpoint of Let's Encrypt CA.
@@ -133,7 +131,7 @@ func (c *Client) Discover(ctx context.Context) (Directory, error) {
 	if dirURL == "" {
 	if dirURL == "" {
 		dirURL = LetsEncryptURL
 		dirURL = LetsEncryptURL
 	}
 	}
-	res, err := ctxhttp.Get(ctx, c.HTTPClient, dirURL)
+	res, err := c.get(ctx, dirURL)
 	if err != nil {
 	if err != nil {
 		return Directory{}, err
 		return Directory{}, err
 	}
 	}
@@ -216,7 +214,7 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
 		return cert, curl, err
 		return cert, curl, err
 	}
 	}
 	// slurp issued cert and CA chain, if requested
 	// slurp issued cert and CA chain, if requested
-	cert, err := responseCert(ctx, c.HTTPClient, res, bundle)
+	cert, err := c.responseCert(ctx, res, bundle)
 	return cert, curl, err
 	return cert, curl, err
 }
 }
 
 
@@ -231,13 +229,13 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
 // and has expected features.
 // and has expected features.
 func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) {
 func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) {
 	for {
 	for {
-		res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+		res, err := c.get(ctx, url)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 		defer res.Body.Close()
 		defer res.Body.Close()
 		if res.StatusCode == http.StatusOK {
 		if res.StatusCode == http.StatusOK {
-			return responseCert(ctx, c.HTTPClient, res, bundle)
+			return c.responseCert(ctx, res, bundle)
 		}
 		}
 		if res.StatusCode > 299 {
 		if res.StatusCode > 299 {
 			return nil, responseError(res)
 			return nil, responseError(res)
@@ -387,7 +385,7 @@ func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization,
 // If a caller needs to poll an authorization until its status is final,
 // If a caller needs to poll an authorization until its status is final,
 // see the WaitAuthorization method.
 // see the WaitAuthorization method.
 func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) {
 func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) {
-	res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+	res, err := c.get(ctx, url)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -456,7 +454,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
 	}
 	}
 
 
 	for {
 	for {
-		res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+		res, err := c.get(ctx, url)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -493,7 +491,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
 //
 //
 // A client typically polls a challenge status using this method.
 // A client typically polls a challenge status using this method.
 func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) {
 func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) {
-	res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+	res, err := c.get(ctx, url)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -708,7 +706,7 @@ func (c *Client) postJWS(ctx context.Context, key crypto.Signer, url string, bod
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	res, err := ctxhttp.Post(ctx, c.HTTPClient, url, "application/jose+json", bytes.NewReader(b))
+	res, err := c.post(ctx, url, "application/jose+json", bytes.NewReader(b))
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -722,7 +720,7 @@ func (c *Client) popNonce(ctx context.Context, url string) (string, error) {
 	c.noncesMu.Lock()
 	c.noncesMu.Lock()
 	defer c.noncesMu.Unlock()
 	defer c.noncesMu.Unlock()
 	if len(c.nonces) == 0 {
 	if len(c.nonces) == 0 {
-		return fetchNonce(ctx, c.HTTPClient, url)
+		return c.fetchNonce(ctx, url)
 	}
 	}
 	var nonce string
 	var nonce string
 	for nonce = range c.nonces {
 	for nonce = range c.nonces {
@@ -749,8 +747,58 @@ func (c *Client) addNonce(h http.Header) {
 	c.nonces[v] = struct{}{}
 	c.nonces[v] = struct{}{}
 }
 }
 
 
-func fetchNonce(ctx context.Context, client *http.Client, url string) (string, error) {
-	resp, err := ctxhttp.Head(ctx, client, url)
+func (c *Client) httpClient() *http.Client {
+	if c.HTTPClient != nil {
+		return c.HTTPClient
+	}
+	return http.DefaultClient
+}
+
+func (c *Client) get(ctx context.Context, urlStr string) (*http.Response, error) {
+	req, err := http.NewRequest("GET", urlStr, nil)
+	if err != nil {
+		return nil, err
+	}
+	return c.do(ctx, req)
+}
+
+func (c *Client) head(ctx context.Context, urlStr string) (*http.Response, error) {
+	req, err := http.NewRequest("HEAD", urlStr, nil)
+	if err != nil {
+		return nil, err
+	}
+	return c.do(ctx, req)
+}
+
+func (c *Client) post(ctx context.Context, urlStr, contentType string, body io.Reader) (*http.Response, error) {
+	req, err := http.NewRequest("POST", urlStr, body)
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Set("Content-Type", contentType)
+	return c.do(ctx, req)
+}
+
+func (c *Client) do(ctx context.Context, req *http.Request) (*http.Response, error) {
+	res, err := c.httpClient().Do(req.WithContext(ctx))
+	if err != nil {
+		select {
+		case <-ctx.Done():
+			// Prefer the unadorned context error.
+			// (The acme package had tests assuming this, previously from ctxhttp's
+			// behavior, predating net/http supporting contexts natively)
+			// TODO(bradfitz): reconsider this in the future. But for now this
+			// requires no test updates.
+			return nil, ctx.Err()
+		default:
+			return nil, err
+		}
+	}
+	return res, nil
+}
+
+func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) {
+	resp, err := c.head(ctx, url)
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
@@ -769,7 +817,7 @@ func nonceFromHeader(h http.Header) string {
 	return h.Get("Replay-Nonce")
 	return h.Get("Replay-Nonce")
 }
 }
 
 
-func responseCert(ctx context.Context, client *http.Client, res *http.Response, bundle bool) ([][]byte, error) {
+func (c *Client) responseCert(ctx context.Context, res *http.Response, bundle bool) ([][]byte, error) {
 	b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1))
 	b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1))
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("acme: response stream: %v", err)
 		return nil, fmt.Errorf("acme: response stream: %v", err)
@@ -793,7 +841,7 @@ func responseCert(ctx context.Context, client *http.Client, res *http.Response,
 		return nil, errors.New("acme: rel=up link is too large")
 		return nil, errors.New("acme: rel=up link is too large")
 	}
 	}
 	for _, url := range up {
 	for _, url := range up {
-		cc, err := chainCert(ctx, client, url, 0)
+		cc, err := c.chainCert(ctx, url, 0)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -836,12 +884,12 @@ func responseError(resp *http.Response) error {
 // if the recursion level reaches maxChainLen.
 // if the recursion level reaches maxChainLen.
 //
 //
 // First chainCert call starts with depth of 0.
 // First chainCert call starts with depth of 0.
-func chainCert(ctx context.Context, client *http.Client, url string, depth int) ([][]byte, error) {
+func (c *Client) chainCert(ctx context.Context, url string, depth int) ([][]byte, error) {
 	if depth >= maxChainLen {
 	if depth >= maxChainLen {
 		return nil, errors.New("acme: certificate chain is too deep")
 		return nil, errors.New("acme: certificate chain is too deep")
 	}
 	}
 
 
-	res, err := ctxhttp.Get(ctx, client, url)
+	res, err := c.get(ctx, url)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -863,7 +911,7 @@ func chainCert(ctx context.Context, client *http.Client, url string, depth int)
 		return nil, errors.New("acme: certificate chain is too large")
 		return nil, errors.New("acme: certificate chain is too large")
 	}
 	}
 	for _, up := range uplink {
 	for _, up := range uplink {
-		cc, err := chainCert(ctx, client, up, depth+1)
+		cc, err := c.chainCert(ctx, up, depth+1)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}

+ 4 - 2
acme/acme_test.go

@@ -980,7 +980,8 @@ func TestNonce_fetch(t *testing.T) {
 	defer ts.Close()
 	defer ts.Close()
 	for ; i < len(tests); i++ {
 	for ; i < len(tests); i++ {
 		test := tests[i]
 		test := tests[i]
-		n, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL)
+		c := &Client{}
+		n, err := c.fetchNonce(context.Background(), ts.URL)
 		if n != test.nonce {
 		if n != test.nonce {
 			t.Errorf("%d: n=%q; want %q", i, n, test.nonce)
 			t.Errorf("%d: n=%q; want %q", i, n, test.nonce)
 		}
 		}
@@ -998,7 +999,8 @@ func TestNonce_fetchError(t *testing.T) {
 		w.WriteHeader(http.StatusTooManyRequests)
 		w.WriteHeader(http.StatusTooManyRequests)
 	}))
 	}))
 	defer ts.Close()
 	defer ts.Close()
-	_, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL)
+	c := &Client{}
+	_, err := c.fetchNonce(context.Background(), ts.URL)
 	e, ok := err.(*Error)
 	e, ok := err.(*Error)
 	if !ok {
 	if !ok {
 		t.Fatalf("err is %T; want *Error", err)
 		t.Fatalf("err is %T; want *Error", err)