Browse Source

acme: stop using ctxhttp

The ctxhttp package used to be big and gross before net/http supported
contexts natively. Nowadays it barely does anything. Stop using it,
because it just pulls in the old context package anyway. (We can't
really clean up the ctxhttp package until Go 1.9)

Change-Id: I48b11f2f483783a32cbaa75e244301148a304c08
Reviewed-on: https://go-review.googlesource.com/40110
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Alex Vaghin <ddos@google.com>
Brad Fitzpatrick 8 năm trước cách đây
mục cha
commit
6022e334c1
2 tập tin đã thay đổi với 70 bổ sung20 xóa
  1. 66 18
      acme/acme.go
  2. 4 2
      acme/acme_test.go

+ 66 - 18
acme/acme.go

@@ -37,8 +37,6 @@ import (
 	"strings"
 	"sync"
 	"time"
-
-	"golang.org/x/net/context/ctxhttp"
 )
 
 // LetsEncryptURL is the Directory endpoint of Let's Encrypt CA.
@@ -133,7 +131,7 @@ func (c *Client) Discover(ctx context.Context) (Directory, error) {
 	if dirURL == "" {
 		dirURL = LetsEncryptURL
 	}
-	res, err := ctxhttp.Get(ctx, c.HTTPClient, dirURL)
+	res, err := c.get(ctx, dirURL)
 	if err != nil {
 		return Directory{}, err
 	}
@@ -216,7 +214,7 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
 		return cert, curl, err
 	}
 	// slurp issued cert and CA chain, if requested
-	cert, err := responseCert(ctx, c.HTTPClient, res, bundle)
+	cert, err := c.responseCert(ctx, res, bundle)
 	return cert, curl, err
 }
 
@@ -231,13 +229,13 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
 // and has expected features.
 func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) {
 	for {
-		res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+		res, err := c.get(ctx, url)
 		if err != nil {
 			return nil, err
 		}
 		defer res.Body.Close()
 		if res.StatusCode == http.StatusOK {
-			return responseCert(ctx, c.HTTPClient, res, bundle)
+			return c.responseCert(ctx, res, bundle)
 		}
 		if res.StatusCode > 299 {
 			return nil, responseError(res)
@@ -387,7 +385,7 @@ func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization,
 // If a caller needs to poll an authorization until its status is final,
 // see the WaitAuthorization method.
 func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) {
-	res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+	res, err := c.get(ctx, url)
 	if err != nil {
 		return nil, err
 	}
@@ -456,7 +454,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
 	}
 
 	for {
-		res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+		res, err := c.get(ctx, url)
 		if err != nil {
 			return nil, err
 		}
@@ -493,7 +491,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
 //
 // A client typically polls a challenge status using this method.
 func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) {
-	res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+	res, err := c.get(ctx, url)
 	if err != nil {
 		return nil, err
 	}
@@ -708,7 +706,7 @@ func (c *Client) postJWS(ctx context.Context, key crypto.Signer, url string, bod
 	if err != nil {
 		return nil, err
 	}
-	res, err := ctxhttp.Post(ctx, c.HTTPClient, url, "application/jose+json", bytes.NewReader(b))
+	res, err := c.post(ctx, url, "application/jose+json", bytes.NewReader(b))
 	if err != nil {
 		return nil, err
 	}
@@ -722,7 +720,7 @@ func (c *Client) popNonce(ctx context.Context, url string) (string, error) {
 	c.noncesMu.Lock()
 	defer c.noncesMu.Unlock()
 	if len(c.nonces) == 0 {
-		return fetchNonce(ctx, c.HTTPClient, url)
+		return c.fetchNonce(ctx, url)
 	}
 	var nonce string
 	for nonce = range c.nonces {
@@ -749,8 +747,58 @@ func (c *Client) addNonce(h http.Header) {
 	c.nonces[v] = struct{}{}
 }
 
-func fetchNonce(ctx context.Context, client *http.Client, url string) (string, error) {
-	resp, err := ctxhttp.Head(ctx, client, url)
+func (c *Client) httpClient() *http.Client {
+	if c.HTTPClient != nil {
+		return c.HTTPClient
+	}
+	return http.DefaultClient
+}
+
+func (c *Client) get(ctx context.Context, urlStr string) (*http.Response, error) {
+	req, err := http.NewRequest("GET", urlStr, nil)
+	if err != nil {
+		return nil, err
+	}
+	return c.do(ctx, req)
+}
+
+func (c *Client) head(ctx context.Context, urlStr string) (*http.Response, error) {
+	req, err := http.NewRequest("HEAD", urlStr, nil)
+	if err != nil {
+		return nil, err
+	}
+	return c.do(ctx, req)
+}
+
+func (c *Client) post(ctx context.Context, urlStr, contentType string, body io.Reader) (*http.Response, error) {
+	req, err := http.NewRequest("POST", urlStr, body)
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Set("Content-Type", contentType)
+	return c.do(ctx, req)
+}
+
+func (c *Client) do(ctx context.Context, req *http.Request) (*http.Response, error) {
+	res, err := c.httpClient().Do(req.WithContext(ctx))
+	if err != nil {
+		select {
+		case <-ctx.Done():
+			// Prefer the unadorned context error.
+			// (The acme package had tests assuming this, previously from ctxhttp's
+			// behavior, predating net/http supporting contexts natively)
+			// TODO(bradfitz): reconsider this in the future. But for now this
+			// requires no test updates.
+			return nil, ctx.Err()
+		default:
+			return nil, err
+		}
+	}
+	return res, nil
+}
+
+func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) {
+	resp, err := c.head(ctx, url)
 	if err != nil {
 		return "", err
 	}
@@ -769,7 +817,7 @@ func nonceFromHeader(h http.Header) string {
 	return h.Get("Replay-Nonce")
 }
 
-func responseCert(ctx context.Context, client *http.Client, res *http.Response, bundle bool) ([][]byte, error) {
+func (c *Client) responseCert(ctx context.Context, res *http.Response, bundle bool) ([][]byte, error) {
 	b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1))
 	if err != nil {
 		return nil, fmt.Errorf("acme: response stream: %v", err)
@@ -793,7 +841,7 @@ func responseCert(ctx context.Context, client *http.Client, res *http.Response,
 		return nil, errors.New("acme: rel=up link is too large")
 	}
 	for _, url := range up {
-		cc, err := chainCert(ctx, client, url, 0)
+		cc, err := c.chainCert(ctx, url, 0)
 		if err != nil {
 			return nil, err
 		}
@@ -836,12 +884,12 @@ func responseError(resp *http.Response) error {
 // if the recursion level reaches maxChainLen.
 //
 // First chainCert call starts with depth of 0.
-func chainCert(ctx context.Context, client *http.Client, url string, depth int) ([][]byte, error) {
+func (c *Client) chainCert(ctx context.Context, url string, depth int) ([][]byte, error) {
 	if depth >= maxChainLen {
 		return nil, errors.New("acme: certificate chain is too deep")
 	}
 
-	res, err := ctxhttp.Get(ctx, client, url)
+	res, err := c.get(ctx, url)
 	if err != nil {
 		return nil, err
 	}
@@ -863,7 +911,7 @@ func chainCert(ctx context.Context, client *http.Client, url string, depth int)
 		return nil, errors.New("acme: certificate chain is too large")
 	}
 	for _, up := range uplink {
-		cc, err := chainCert(ctx, client, up, depth+1)
+		cc, err := c.chainCert(ctx, up, depth+1)
 		if err != nil {
 			return nil, err
 		}

+ 4 - 2
acme/acme_test.go

@@ -980,7 +980,8 @@ func TestNonce_fetch(t *testing.T) {
 	defer ts.Close()
 	for ; i < len(tests); i++ {
 		test := tests[i]
-		n, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL)
+		c := &Client{}
+		n, err := c.fetchNonce(context.Background(), ts.URL)
 		if n != test.nonce {
 			t.Errorf("%d: n=%q; want %q", i, n, test.nonce)
 		}
@@ -998,7 +999,8 @@ func TestNonce_fetchError(t *testing.T) {
 		w.WriteHeader(http.StatusTooManyRequests)
 	}))
 	defer ts.Close()
-	_, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL)
+	c := &Client{}
+	_, err := c.fetchNonce(context.Background(), ts.URL)
 	e, ok := err.(*Error)
 	if !ok {
 		t.Fatalf("err is %T; want *Error", err)