Pārlūkot izejas kodu

acme/autocert: use valid certificates from the cache during renewal

Currently, the renewal flow will check the cache before renewing to make
sure it is actually necessary. This change modifies this flow to update
the local state so the cached cert is actually used by the manager.

Fixes golang/go#22960

Change-Id: I16668e8098616190938ee52858294b59bc1a5160
Reviewed-on: https://go-review.googlesource.com/89995
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Brad Morgan 7 gadi atpakaļ
vecāks
revīzija
c3a3ad6d03
2 mainītis faili ar 168 papildinājumiem un 7 dzēšanām
  1. 22 7
      acme/autocert/renewal.go
  2. 146 0
      acme/autocert/renewal_test.go

+ 22 - 7
acme/autocert/renewal.go

@@ -71,12 +71,21 @@ func (dr *domainRenewal) renew() {
 	testDidRenewLoop(next, err)
 }
 
+// updateState locks and replaces the relevant Manager.state item with the given
+// state. It additionally updates dr.key with the given state's key.
+func (dr *domainRenewal) updateState(state *certState) {
+	dr.m.stateMu.Lock()
+	defer dr.m.stateMu.Unlock()
+	dr.key = state.key
+	dr.m.state[dr.domain] = state
+}
+
 // do is similar to Manager.createCert but it doesn't lock a Manager.state item.
 // Instead, it requests a new certificate independently and, upon success,
 // replaces dr.m.state item with a new one and updates cache for the given domain.
 //
-// It may return immediately if the expiration date of the currently cached cert
-// is far enough in the future.
+// It may lock and update the Manager.state if the expiration date of the currently
+// cached cert is far enough in the future.
 //
 // The returned value is a time interval after which the renewal should occur again.
 func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
@@ -85,7 +94,16 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
 	if tlscert, err := dr.m.cacheGet(ctx, dr.domain); err == nil {
 		next := dr.next(tlscert.Leaf.NotAfter)
 		if next > dr.m.renewBefore()+renewJitter {
-			return next, nil
+			signer, ok := tlscert.PrivateKey.(crypto.Signer)
+			if ok {
+				state := &certState{
+					key:  signer,
+					cert: tlscert.Certificate,
+					leaf: tlscert.Leaf,
+				}
+				dr.updateState(state)
+				return next, nil
+			}
 		}
 	}
 
@@ -105,10 +123,7 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
 	if err := dr.m.cachePut(ctx, dr.domain, tlscert); err != nil {
 		return 0, err
 	}
-	dr.m.stateMu.Lock()
-	defer dr.m.stateMu.Unlock()
-	// m.state is guaranteed to be non-nil at this point
-	dr.m.state[dr.domain] = state
+	dr.updateState(state)
 	return dr.next(leaf.NotAfter), nil
 }
 

+ 146 - 0
acme/autocert/renewal_test.go

@@ -189,3 +189,149 @@ func TestRenewFromCache(t *testing.T) {
 	case <-done:
 	}
 }
+
+func TestRenewFromCacheAlreadyRenewed(t *testing.T) {
+	const domain = "example.org"
+
+	// use EC key to run faster on 386
+	key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+	if err != nil {
+		t.Fatal(err)
+	}
+	man := &Manager{
+		Prompt:      AcceptTOS,
+		Cache:       newMemCache(),
+		RenewBefore: 24 * time.Hour,
+		Client: &acme.Client{
+			Key:          key,
+			DirectoryURL: "invalid",
+		},
+	}
+	defer man.stopRenew()
+
+	// cache a recently renewed cert with a different private key
+	newKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+	if err != nil {
+		t.Fatal(err)
+	}
+	now := time.Now()
+	newCert, err := dateDummyCert(newKey.Public(), now.Add(-2*time.Hour), now.Add(time.Hour*24*90), domain)
+	if err != nil {
+		t.Fatal(err)
+	}
+	newLeaf, err := validCert(domain, [][]byte{newCert}, newKey)
+	if err != nil {
+		t.Fatal(err)
+	}
+	newTLSCert := &tls.Certificate{PrivateKey: newKey, Certificate: [][]byte{newCert}, Leaf: newLeaf}
+	if err := man.cachePut(context.Background(), domain, newTLSCert); err != nil {
+		t.Fatal(err)
+	}
+
+	// set internal state to an almost expired cert
+	oldCert, err := dateDummyCert(key.Public(), now.Add(-2*time.Hour), now.Add(time.Minute), domain)
+	if err != nil {
+		t.Fatal(err)
+	}
+	oldLeaf, err := validCert(domain, [][]byte{oldCert}, key)
+	if err != nil {
+		t.Fatal(err)
+	}
+	man.stateMu.Lock()
+	if man.state == nil {
+		man.state = make(map[string]*certState)
+	}
+	s := &certState{
+		key:  key,
+		cert: [][]byte{oldCert},
+		leaf: oldLeaf,
+	}
+	man.state[domain] = s
+	man.stateMu.Unlock()
+
+	// veriy the renewal accepted the newer cached cert
+	defer func() {
+		testDidRenewLoop = func(next time.Duration, err error) {}
+	}()
+	done := make(chan struct{})
+	testDidRenewLoop = func(next time.Duration, err error) {
+		defer close(done)
+		if err != nil {
+			t.Errorf("testDidRenewLoop: %v", err)
+		}
+		// Next should be about 90 days
+		// Previous expiration was within 1 min.
+		future := 88 * 24 * time.Hour
+		if next < future {
+			t.Errorf("testDidRenewLoop: next = %v; want >= %v", next, future)
+		}
+
+		// ensure the cached cert was not modified
+		tlscert, err := man.cacheGet(context.Background(), domain)
+		if err != nil {
+			t.Fatalf("man.cacheGet: %v", err)
+		}
+		if !tlscert.Leaf.NotAfter.Equal(newLeaf.NotAfter) {
+			t.Errorf("cache leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newLeaf.NotAfter)
+		}
+
+		// verify the old cert is also replaced in memory
+		man.stateMu.Lock()
+		defer man.stateMu.Unlock()
+		s := man.state[domain]
+		if s == nil {
+			t.Fatalf("m.state[%q] is nil", domain)
+		}
+		stateKey := s.key.Public().(*ecdsa.PublicKey)
+		if stateKey.X.Cmp(newKey.X) != 0 || stateKey.Y.Cmp(newKey.Y) != 0 {
+			t.Fatalf("state key was not updated from cache x: %v y: %v; want x: %v y: %v", stateKey.X, stateKey.Y, newKey.X, newKey.Y)
+		}
+		tlscert, err = s.tlscert()
+		if err != nil {
+			t.Fatalf("s.tlscert: %v", err)
+		}
+		if !tlscert.Leaf.NotAfter.Equal(newLeaf.NotAfter) {
+			t.Errorf("state leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newLeaf.NotAfter)
+		}
+
+		// verify the private key is replaced in the renewal state
+		r := man.renewal[domain]
+		if r == nil {
+			t.Fatalf("m.renewal[%q] is nil", domain)
+		}
+		renewalKey := r.key.Public().(*ecdsa.PublicKey)
+		if renewalKey.X.Cmp(newKey.X) != 0 || renewalKey.Y.Cmp(newKey.Y) != 0 {
+			t.Fatalf("renewal private key was not updated from cache x: %v y: %v; want x: %v y: %v", renewalKey.X, renewalKey.Y, newKey.X, newKey.Y)
+		}
+
+	}
+
+	// assert the expiring cert is returned from state
+	hello := &tls.ClientHelloInfo{ServerName: domain}
+	tlscert, err := man.GetCertificate(hello)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if !oldLeaf.NotAfter.Equal(tlscert.Leaf.NotAfter) {
+		t.Errorf("state leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, oldLeaf.NotAfter)
+	}
+
+	// trigger renew
+	go man.renew(domain, s.key, s.leaf.NotAfter)
+
+	// wait for renew loop
+	select {
+	case <-time.After(10 * time.Second):
+		t.Fatal("renew took too long to occur")
+	case <-done:
+		// assert the new cert is returned from state after renew
+		hello := &tls.ClientHelloInfo{ServerName: domain}
+		tlscert, err := man.GetCertificate(hello)
+		if err != nil {
+			t.Fatal(err)
+		}
+		if !newTLSCert.Leaf.NotAfter.Equal(tlscert.Leaf.NotAfter) {
+			t.Errorf("state leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newTLSCert.Leaf.NotAfter)
+		}
+	}
+}