Browse Source

integration: test TLS reload

Signed-off-by: Gyu-Ho Lee <gyuhox@gmail.com>
Gyu-Ho Lee 8 years ago
parent
commit
22943e7e06
3 changed files with 274 additions and 0 deletions
  1. 7 0
      integration/cluster.go
  2. 62 0
      integration/util_test.go
  3. 205 0
      integration/v3_grpc_test.go

+ 7 - 0
integration/cluster.go

@@ -76,6 +76,13 @@ var (
 		ClientCertAuth: true,
 	}
 
+	testTLSInfoExpired = transport.TLSInfo{
+		KeyFile:        "./fixtures-expired/server-key.pem",
+		CertFile:       "./fixtures-expired/server.pem",
+		TrustedCAFile:  "./fixtures-expired/etcd-root-ca.pem",
+		ClientCertAuth: true,
+	}
+
 	plog = capnslog.NewPackageLogger("github.com/coreos/etcd", "integration")
 )
 

+ 62 - 0
integration/util_test.go

@@ -0,0 +1,62 @@
+// Copyright 2017 The etcd Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration
+
+import (
+	"io"
+	"os"
+	"path/filepath"
+
+	"github.com/coreos/etcd/pkg/transport"
+)
+
+// copyTLSFiles clones certs files to dst directory.
+func copyTLSFiles(ti transport.TLSInfo, dst string) (transport.TLSInfo, error) {
+	ci := transport.TLSInfo{
+		KeyFile:        filepath.Join(dst, "server-key.pem"),
+		CertFile:       filepath.Join(dst, "server.pem"),
+		TrustedCAFile:  filepath.Join(dst, "etcd-root-ca.pem"),
+		ClientCertAuth: ti.ClientCertAuth,
+	}
+	if err := copyFile(ti.KeyFile, ci.KeyFile); err != nil {
+		return transport.TLSInfo{}, err
+	}
+	if err := copyFile(ti.CertFile, ci.CertFile); err != nil {
+		return transport.TLSInfo{}, err
+	}
+	if err := copyFile(ti.TrustedCAFile, ci.TrustedCAFile); err != nil {
+		return transport.TLSInfo{}, err
+	}
+	return ci, nil
+}
+
+func copyFile(src, dst string) error {
+	f, err := os.Open(src)
+	if err != nil {
+		return err
+	}
+	defer f.Close()
+
+	w, err := os.Create(dst)
+	if err != nil {
+		return err
+	}
+	defer w.Close()
+
+	if _, err = io.Copy(w, f); err != nil {
+		return err
+	}
+	return w.Sync()
+}

+ 205 - 0
integration/v3_grpc_test.go

@@ -16,17 +16,21 @@ package integration
 
 import (
 	"bytes"
+	"crypto/tls"
 	"fmt"
+	"io/ioutil"
 	"math/rand"
 	"os"
 	"reflect"
 	"testing"
 	"time"
 
+	"github.com/coreos/etcd/clientv3"
 	"github.com/coreos/etcd/etcdserver/api/v3rpc"
 	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
 	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
 	"github.com/coreos/etcd/pkg/testutil"
+
 	"golang.org/x/net/context"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/metadata"
@@ -1374,6 +1378,207 @@ func TestTLSGRPCAcceptSecureAll(t *testing.T) {
 	}
 }
 
+// TestTLSReloadAtomicReplace ensures server reloads expired/valid certs
+// when all certs are atomically replaced by directory renaming.
+// And expects server to reject client requests, and vice versa.
+func TestTLSReloadAtomicReplace(t *testing.T) {
+	defer testutil.AfterTest(t)
+
+	// clone valid,expired certs to separate directories for atomic renaming
+	vDir, err := ioutil.TempDir(os.TempDir(), "fixtures-valid")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer os.RemoveAll(vDir)
+	ts, err := copyTLSFiles(testTLSInfo, vDir)
+	if err != nil {
+		t.Fatal(err)
+	}
+	eDir, err := ioutil.TempDir(os.TempDir(), "fixtures-expired")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer os.RemoveAll(eDir)
+	if _, err = copyTLSFiles(testTLSInfoExpired, eDir); err != nil {
+		t.Fatal(err)
+	}
+
+	tDir, err := ioutil.TempDir(os.TempDir(), "fixtures")
+	if err != nil {
+		t.Fatal(err)
+	}
+	os.RemoveAll(tDir)
+	defer os.RemoveAll(tDir)
+
+	// start with valid certs
+	clus := NewClusterV3(t, &ClusterConfig{Size: 1, PeerTLS: &ts, ClientTLS: &ts})
+	defer clus.Terminate(t)
+
+	// concurrent client dialing while certs transition from valid to expired
+	errc := make(chan error, 1)
+	go func() {
+		for {
+			cc, err := ts.ClientConfig()
+			if err != nil {
+				if os.IsNotExist(err) {
+					// from concurrent renaming
+					continue
+				}
+				t.Fatal(err)
+			}
+			cli, cerr := clientv3.New(clientv3.Config{
+				Endpoints:   []string{clus.Members[0].GRPCAddr()},
+				DialTimeout: time.Second,
+				TLS:         cc,
+			})
+			if cerr != nil {
+				errc <- cerr
+				return
+			}
+			cli.Close()
+		}
+	}()
+
+	// replace certs directory with expired ones
+	if err = os.Rename(vDir, tDir); err != nil {
+		t.Fatal(err)
+	}
+	if err = os.Rename(eDir, vDir); err != nil {
+		t.Fatal(err)
+	}
+
+	// after rename,
+	// 'vDir' contains expired certs
+	// 'tDir' contains valid certs
+	// 'eDir' does not exist
+
+	select {
+	case err = <-errc:
+		if err != grpc.ErrClientConnTimeout {
+			t.Fatalf("expected %v, got %v", grpc.ErrClientConnTimeout, err)
+		}
+	case <-time.After(5 * time.Second):
+		t.Fatal("failed to receive dial timeout error")
+	}
+
+	// now, replace expired certs back with valid ones
+	if err = os.Rename(tDir, eDir); err != nil {
+		t.Fatal(err)
+	}
+	if err = os.Rename(vDir, tDir); err != nil {
+		t.Fatal(err)
+	}
+	if err = os.Rename(eDir, vDir); err != nil {
+		t.Fatal(err)
+	}
+
+	// new incoming client request should trigger
+	// listener to reload valid certs
+	var tls *tls.Config
+	tls, err = ts.ClientConfig()
+	if err != nil {
+		t.Fatal(err)
+	}
+	var cl *clientv3.Client
+	cl, err = clientv3.New(clientv3.Config{
+		Endpoints:   []string{clus.Members[0].GRPCAddr()},
+		DialTimeout: time.Second,
+		TLS:         tls,
+	})
+	if err != nil {
+		t.Fatalf("expected no error, got %v", err)
+	}
+	cl.Close()
+}
+
+// TestTLSReloadCopy ensures server reloads expired/valid certs
+// when new certs are copied over, one by one. And expects server
+// to reject client requests, and vice versa.
+func TestTLSReloadCopy(t *testing.T) {
+	defer testutil.AfterTest(t)
+
+	// clone certs directory, free to overwrite
+	cDir, err := ioutil.TempDir(os.TempDir(), "fixtures-test")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer os.RemoveAll(cDir)
+	ts, err := copyTLSFiles(testTLSInfo, cDir)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// start with valid certs
+	clus := NewClusterV3(t, &ClusterConfig{Size: 1, PeerTLS: &ts, ClientTLS: &ts})
+	defer clus.Terminate(t)
+
+	// concurrent client dialing while certs transition from valid to expired
+	errc := make(chan error, 1)
+	go func() {
+		for {
+			cc, err := ts.ClientConfig()
+			if err != nil {
+				// from concurrent certs overwriting
+				switch err.Error() {
+				case "tls: private key does not match public key":
+					fallthrough
+				case "tls: failed to find any PEM data in key input":
+					continue
+				}
+				t.Fatal(err)
+			}
+			cli, cerr := clientv3.New(clientv3.Config{
+				Endpoints:   []string{clus.Members[0].GRPCAddr()},
+				DialTimeout: time.Second,
+				TLS:         cc,
+			})
+			if cerr != nil {
+				errc <- cerr
+				return
+			}
+			cli.Close()
+		}
+	}()
+
+	// overwrite valid certs with expired ones
+	// (e.g. simulate cert expiration in practice)
+	if _, err = copyTLSFiles(testTLSInfoExpired, cDir); err != nil {
+		t.Fatal(err)
+	}
+
+	select {
+	case gerr := <-errc:
+		if gerr != grpc.ErrClientConnTimeout {
+			t.Fatalf("expected %v, got %v", grpc.ErrClientConnTimeout, gerr)
+		}
+	case <-time.After(5 * time.Second):
+		t.Fatal("failed to receive dial timeout error")
+	}
+
+	// now, replace expired certs back with valid ones
+	if _, err = copyTLSFiles(testTLSInfo, cDir); err != nil {
+		t.Fatal(err)
+	}
+
+	// new incoming client request should trigger
+	// listener to reload valid certs
+	var tls *tls.Config
+	tls, err = ts.ClientConfig()
+	if err != nil {
+		t.Fatal(err)
+	}
+	var cl *clientv3.Client
+	cl, err = clientv3.New(clientv3.Config{
+		Endpoints:   []string{clus.Members[0].GRPCAddr()},
+		DialTimeout: time.Second,
+		TLS:         tls,
+	})
+	if err != nil {
+		t.Fatalf("expected no error, got %v", err)
+	}
+	cl.Close()
+}
+
 func TestGRPCRequireLeader(t *testing.T) {
 	defer testutil.AfterTest(t)