浏览代码

acme: build up full chain certs when requested

The latest ACME spec (v3) changed the wording to:

    ... the server MUST send one or more link relation header
    fields [RFC5988] with relation "up", each indicating a single
    certificate resource for the issuer of this certificate.  The server
    MAY also include the "up" links from these resources to enable the
    client to build a full certificate chain.

See https://tools.ietf.org/html/draft-ietf-acme-acme-03#section-6.3.1.

Before this change, Client was fetching only the first "up" link, but never
checked to follow the chain further. To my knowledge, Let's Encrypt never
provided a chain longer than 1, this is just to make the Client future proof.

Also fixes google/acme#26.

Change-Id: I35cf5f1997b21a0b2a2d0a732043a7e04b7f1c45
Reviewed-on: https://go-review.googlesource.com/26693
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Alex Vaghin 9 年之前
父节点
当前提交
6575f7ea32
共有 2 个文件被更改,包括 169 次插入39 次删除
  1. 100 28
      acme/internal/acme/acme.go
  2. 69 11
      acme/internal/acme/acme_test.go

+ 100 - 28
acme/internal/acme/acme.go

@@ -23,6 +23,7 @@ import (
 	"encoding/pem"
 	"errors"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"math/big"
 	"net/http"
@@ -32,11 +33,17 @@ import (
 	"time"
 
 	"golang.org/x/net/context"
+	"golang.org/x/net/context/ctxhttp"
 )
 
 // LetsEncryptURL is the Directory endpoint of Let's Encrypt CA.
 const LetsEncryptURL = "https://acme-v01.api.letsencrypt.org/directory"
 
+const (
+	maxChainLen = 5       // max depth and breadth of a certificate chain
+	maxCertSize = 1 << 20 // max size of a certificate, in bytes
+)
+
 // Client is an ACME client.
 // The only required field is Key. An example of creating a client with a new key
 // is as follows:
@@ -117,13 +124,17 @@ func (c *Client) Discover() (Directory, error) {
 	return *c.dir, nil
 }
 
-// CreateCert requests a new certificate.
+// CreateCert requests a new certificate using the Certificate Signing Request csr encoded in DER format.
+// The exp argument indicates the desired certificate validity duration. CA may issue a certificate
+// with a different duration.
+// If the bundle argument is true, the returned value will also contain the CA (issuer) certificate chain.
+//
 // In the case where CA server does not provide the issued certificate in the response,
 // CreateCert will poll certURL using c.FetchCert, which will result in additional round-trips.
 // In such scenario the caller can cancel the polling with ctx.
 //
-// If the bundle is true, the returned value will also contain CA (the issuer) certificate.
-// The csr is a DER encoded certificate signing request.
+// CreateCert returns an error if the CA's response or chain was unreasonably large.
+// Callers are encouraged to parse the returned value to ensure the certificate is valid and has the expected features.
 func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration, bundle bool) (der [][]byte, certURL string, err error) {
 	if _, err := c.Discover(); err != nil {
 		return nil, "", err
@@ -159,8 +170,8 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
 		cert, err := c.FetchCert(ctx, curl, bundle)
 		return cert, curl, err
 	}
-	// slurp issued cert and ca, if requested
-	cert, err := responseCert(c.httpClient(), res, bundle)
+	// slurp issued cert and CA chain, if requested
+	cert, err := responseCert(ctx, c.httpClient(), res, bundle)
 	return cert, curl, err
 }
 
@@ -168,16 +179,20 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
 // It retries the request until the certificate is successfully retrieved,
 // context is cancelled by the caller or an error response is received.
 //
-// The returned value will also contain CA (the issuer) certificate if bundle is true.
+// The returned value will also contain the CA (issuer) certificate if the bundle argument is true.
+//
+// FetchCert returns an error if the CA's response or chain was unreasonably large.
+// Callers are encouraged to parse the returned value to ensure the certificate is valid
+// and has expected features.
 func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) {
 	for {
-		res, err := c.httpClient().Get(url)
+		res, err := ctxhttp.Get(ctx, c.httpClient(), url)
 		if err != nil {
 			return nil, err
 		}
 		defer res.Body.Close()
 		if res.StatusCode == http.StatusOK {
-			return responseCert(c.httpClient(), res, bundle)
+			return responseCert(ctx, c.httpClient(), res, bundle)
 		}
 		if res.StatusCode > 299 {
 			return nil, responseError(res)
@@ -502,45 +517,56 @@ func (c *Client) doReg(url string, typ string, acct *Account) (*Account, error)
 	if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
 		return nil, fmt.Errorf("acme: invalid response: %v", err)
 	}
+	var tos string
+	if v := linkHeader(res.Header, "terms-of-service"); len(v) > 0 {
+		tos = v[0]
+	}
+	var authz string
+	if v := linkHeader(res.Header, "next"); len(v) > 0 {
+		authz = v[0]
+	}
 	return &Account{
 		URI:            res.Header.Get("Location"),
 		Contact:        v.Contact,
 		AgreedTerms:    v.Agreement,
-		CurrentTerms:   linkHeader(res.Header, "terms-of-service"),
-		Authz:          linkHeader(res.Header, "next"),
+		CurrentTerms:   tos,
+		Authz:          authz,
 		Authorizations: v.Authorizations,
 		Certificates:   v.Certificates,
 	}, nil
 }
 
-func responseCert(client *http.Client, res *http.Response, bundle bool) ([][]byte, error) {
-	b, err := ioutil.ReadAll(res.Body)
+func responseCert(ctx context.Context, client *http.Client, res *http.Response, bundle bool) ([][]byte, error) {
+	b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1))
 	if err != nil {
 		return nil, fmt.Errorf("acme: response stream: %v", err)
 	}
+	if len(b) > maxCertSize {
+		return nil, errors.New("acme: certificate is too big")
+	}
 	cert := [][]byte{b}
 	if !bundle {
 		return cert, nil
 	}
 
-	// append ca cert
+	// Append CA chain cert(s).
+	// At least one is required according to the spec:
+	// https://tools.ietf.org/html/draft-ietf-acme-acme-03#section-6.3.1
 	up := linkHeader(res.Header, "up")
-	if up == "" {
+	if len(up) == 0 {
 		return nil, errors.New("acme: rel=up link not found")
 	}
-	res, err = client.Get(up)
-	if err != nil {
-		return nil, err
+	if len(up) > maxChainLen {
+		return nil, errors.New("acme: rel=up link is too large")
 	}
-	defer res.Body.Close()
-	if res.StatusCode != http.StatusOK {
-		return nil, responseError(res)
-	}
-	b, err = ioutil.ReadAll(res.Body)
-	if err != nil {
-		return nil, err
+	for _, url := range up {
+		cc, err := chainCert(ctx, client, url, 0)
+		if err != nil {
+			return nil, err
+		}
+		cert = append(cert, cc...)
 	}
-	return append(cert, b), nil
+	return cert, nil
 }
 
 // responseError creates an error of Error type from resp.
@@ -572,6 +598,48 @@ func responseError(resp *http.Response) error {
 	}
 }
 
+// chainCert fetches CA certificate chain recursively by following "up" links.
+// Each recursive call increments the depth by 1, resulting in an error
+// if the recursion level reaches maxChainLen.
+//
+// First chainCert call starts with depth of 0.
+func chainCert(ctx context.Context, client *http.Client, url string, depth int) ([][]byte, error) {
+	if depth >= maxChainLen {
+		return nil, errors.New("acme: certificate chain is too deep")
+	}
+
+	res, err := ctxhttp.Get(ctx, client, url)
+	if err != nil {
+		return nil, err
+	}
+	defer res.Body.Close()
+	if res.StatusCode != http.StatusOK {
+		return nil, responseError(res)
+	}
+	b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1))
+	if err != nil {
+		return nil, err
+	}
+	if len(b) > maxCertSize {
+		return nil, errors.New("acme: certificate is too big")
+	}
+	chain := [][]byte{b}
+
+	uplink := linkHeader(res.Header, "up")
+	if len(uplink) > maxChainLen {
+		return nil, errors.New("acme: certificate chain is too large")
+	}
+	for _, up := range uplink {
+		cc, err := chainCert(ctx, client, up, depth+1)
+		if err != nil {
+			return nil, err
+		}
+		chain = append(chain, cc...)
+	}
+
+	return chain, nil
+}
+
 func fetchNonce(client *http.Client, url string) (string, error) {
 	resp, err := client.Head(url)
 	if err != nil {
@@ -585,7 +653,11 @@ func fetchNonce(client *http.Client, url string) (string, error) {
 	return enc, nil
 }
 
-func linkHeader(h http.Header, rel string) string {
+// linkHeader returns URI-Reference values of all Link headers
+// with relation-type rel.
+// See https://tools.ietf.org/html/rfc5988#section-5 for details.
+func linkHeader(h http.Header, rel string) []string {
+	var links []string
 	for _, v := range h["Link"] {
 		parts := strings.Split(v, ";")
 		for _, p := range parts {
@@ -594,11 +666,11 @@ func linkHeader(h http.Header, rel string) string {
 				continue
 			}
 			if v := strings.Trim(p[4:], `"`); v == rel {
-				return strings.Trim(parts[0], "<>")
+				links = append(links, strings.Trim(parts[0], "<>"))
 			}
 		}
 	}
-	return ""
+	return links
 }
 
 func retryAfter(v string) (time.Duration, error) {

+ 69 - 11
acme/internal/acme/acme_test.go

@@ -5,6 +5,7 @@
 package acme
 
 import (
+	"bytes"
 	"crypto/rand"
 	"crypto/x509"
 	"crypto/x509/pkix"
@@ -635,15 +636,22 @@ func TestNewCert(t *testing.T) {
 }
 
 func TestFetchCert(t *testing.T) {
-	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		w.Write([]byte{1})
+	var count byte
+	var ts *httptest.Server
+	ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		count++
+		if count < 3 {
+			up := fmt.Sprintf("<%s>;rel=up", ts.URL)
+			w.Header().Set("link", up)
+		}
+		w.Write([]byte{count})
 	}))
 	defer ts.Close()
-	res, err := (&Client{}).FetchCert(context.Background(), ts.URL, false)
+	res, err := (&Client{}).FetchCert(context.Background(), ts.URL, true)
 	if err != nil {
 		t.Fatalf("FetchCert: %v", err)
 	}
-	cert := [][]byte{{1}}
+	cert := [][]byte{{1}, {2}, {3}}
 	if !reflect.DeepEqual(res, cert) {
 		t.Errorf("res = %v; want %v", res, cert)
 	}
@@ -691,6 +699,52 @@ func TestFetchCertCancel(t *testing.T) {
 	}
 }
 
+func TestFetchCertDepth(t *testing.T) {
+	var count byte
+	var ts *httptest.Server
+	ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		count++
+		if count > maxChainLen+1 {
+			t.Errorf("count = %d; want at most %d", count, maxChainLen+1)
+			w.WriteHeader(http.StatusInternalServerError)
+		}
+		w.Header().Set("link", fmt.Sprintf("<%s>;rel=up", ts.URL))
+		w.Write([]byte{count})
+	}))
+	defer ts.Close()
+	_, err := (&Client{}).FetchCert(context.Background(), ts.URL, true)
+	if err == nil {
+		t.Errorf("err is nil")
+	}
+}
+
+func TestFetchCertBreadth(t *testing.T) {
+	var ts *httptest.Server
+	ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		for i := 0; i < maxChainLen+1; i++ {
+			w.Header().Add("link", fmt.Sprintf("<%s>;rel=up", ts.URL))
+		}
+		w.Write([]byte{1})
+	}))
+	defer ts.Close()
+	_, err := (&Client{}).FetchCert(context.Background(), ts.URL, true)
+	if err == nil {
+		t.Errorf("err is nil")
+	}
+}
+
+func TestFetchCertSize(t *testing.T) {
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		b := bytes.Repeat([]byte{1}, maxCertSize+1)
+		w.Write(b)
+	}))
+	defer ts.Close()
+	_, err := (&Client{}).FetchCert(context.Background(), ts.URL, false)
+	if err == nil {
+		t.Errorf("err is nil")
+	}
+}
+
 func TestFetchNonce(t *testing.T) {
 	tests := []struct {
 		code  int
@@ -729,16 +783,20 @@ func TestLinkHeader(t *testing.T) {
 		`<https://example.com/acme/new-authz>;rel="next"`,
 		`<https://example.com/acme/recover-reg>; rel=recover`,
 		`<https://example.com/acme/terms>; foo=bar; rel="terms-of-service"`,
+		`<dup>;rel="next"`,
 	}}
-	tests := []struct{ in, out string }{
-		{"next", "https://example.com/acme/new-authz"},
-		{"recover", "https://example.com/acme/recover-reg"},
-		{"terms-of-service", "https://example.com/acme/terms"},
-		{"empty", ""},
+	tests := []struct {
+		rel string
+		out []string
+	}{
+		{"next", []string{"https://example.com/acme/new-authz", "dup"}},
+		{"recover", []string{"https://example.com/acme/recover-reg"}},
+		{"terms-of-service", []string{"https://example.com/acme/terms"}},
+		{"empty", nil},
 	}
 	for i, test := range tests {
-		if v := linkHeader(h, test.in); v != test.out {
-			t.Errorf("%d: parseLinkHeader(%q): %q; want %q", i, test.in, v, test.out)
+		if v := linkHeader(h, test.rel); !reflect.DeepEqual(v, test.out) {
+			t.Errorf("%d: linkHeader(%q): %v; want %v", i, test.rel, v, test.out)
 		}
 	}
 }