Browse Source

pkg/transport: generate certs

Anthony Romano 10 years ago
parent
commit
a69c709839
2 changed files with 106 additions and 0 deletions
  1. 85 0
      pkg/transport/listener.go
  2. 21 0
      pkg/transport/listener_test.go

+ 85 - 0
pkg/transport/listener.go

@@ -15,13 +15,21 @@
 package transport
 package transport
 
 
 import (
 import (
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/rand"
 	"crypto/tls"
 	"crypto/tls"
 	"crypto/x509"
 	"crypto/x509"
+	"crypto/x509/pkix"
 	"encoding/pem"
 	"encoding/pem"
 	"fmt"
 	"fmt"
 	"io/ioutil"
 	"io/ioutil"
+	"math/big"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
+	"os"
+	"path"
+	"strings"
 	"time"
 	"time"
 )
 )
 
 
@@ -79,6 +87,8 @@ type TLSInfo struct {
 	TrustedCAFile  string
 	TrustedCAFile  string
 	ClientCertAuth bool
 	ClientCertAuth bool
 
 
+	selfCert bool
+
 	// parseFunc exists to simplify testing. Typically, parseFunc
 	// parseFunc exists to simplify testing. Typically, parseFunc
 	// should be left nil. In that case, tls.X509KeyPair will be used.
 	// should be left nil. In that case, tls.X509KeyPair will be used.
 	parseFunc func([]byte, []byte) (tls.Certificate, error)
 	parseFunc func([]byte, []byte) (tls.Certificate, error)
@@ -92,6 +102,78 @@ func (info TLSInfo) Empty() bool {
 	return info.CertFile == "" && info.KeyFile == ""
 	return info.CertFile == "" && info.KeyFile == ""
 }
 }
 
 
+func SelfCert(dirpath string, hosts []string) (info TLSInfo, err error) {
+	if err = os.MkdirAll(dirpath, 0700); err != nil {
+		return
+	}
+
+	certPath := path.Join(dirpath, "cert.pem")
+	keyPath := path.Join(dirpath, "key.pem")
+	_, errcert := os.Stat(certPath)
+	_, errkey := os.Stat(keyPath)
+	if errcert == nil && errkey == nil {
+		info.CertFile = certPath
+		info.KeyFile = keyPath
+		info.selfCert = true
+		return
+	}
+
+	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
+	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+	if err != nil {
+		return
+	}
+
+	tmpl := x509.Certificate{
+		SerialNumber: serialNumber,
+		Subject:      pkix.Name{Organization: []string{"etcd"}},
+		NotBefore:    time.Now(),
+		NotAfter:     time.Now().Add(365 * (24 * time.Hour)),
+
+		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
+		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+		BasicConstraintsValid: true,
+	}
+
+	for _, host := range hosts {
+		if ip := net.ParseIP(host); ip != nil {
+			tmpl.IPAddresses = append(tmpl.IPAddresses, ip)
+		} else {
+			tmpl.DNSNames = append(tmpl.DNSNames, strings.Split(host, ":")[0])
+		}
+	}
+
+	priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
+	if err != nil {
+		return
+	}
+
+	derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv)
+	if err != nil {
+		return
+	}
+
+	certOut, err := os.Create(certPath)
+	if err != nil {
+		return
+	}
+	pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+	certOut.Close()
+
+	b, err := x509.MarshalECPrivateKey(priv)
+	if err != nil {
+		return
+	}
+	keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
+	if err != nil {
+		return
+	}
+	pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
+	keyOut.Close()
+
+	return SelfCert(dirpath, hosts)
+}
+
 func (info TLSInfo) baseConfig() (*tls.Config, error) {
 func (info TLSInfo) baseConfig() (*tls.Config, error) {
 	if info.KeyFile == "" || info.CertFile == "" {
 	if info.KeyFile == "" || info.CertFile == "" {
 		return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile)
 		return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile)
@@ -182,6 +264,9 @@ func (info TLSInfo) ClientConfig() (*tls.Config, error) {
 		}
 		}
 	}
 	}
 
 
+	if info.selfCert {
+		cfg.InsecureSkipVerify = true
+	}
 	return cfg, nil
 	return cfg, nil
 }
 }
 
 

+ 21 - 0
pkg/transport/listener_test.go

@@ -54,6 +54,10 @@ func TestNewListenerTLSInfo(t *testing.T) {
 	defer os.Remove(tmp)
 	defer os.Remove(tmp)
 	tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
 	tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
 	tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
 	tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
+	testNewListenerTLSInfoAccept(t, tlsInfo)
+}
+
+func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
 	ln, err := NewListener("127.0.0.1:0", "https", tlsInfo)
 	ln, err := NewListener("127.0.0.1:0", "https", tlsInfo)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("unexpected NewListener error: %v", err)
 		t.Fatalf("unexpected NewListener error: %v", err)
@@ -249,3 +253,20 @@ func TestNewListenerUnixSocket(t *testing.T) {
 	}
 	}
 	l.Close()
 	l.Close()
 }
 }
+
+// TestNewListenerTLSInfoSelfCert tests that a new certificate accepts connections.
+func TestNewListenerTLSInfoSelfCert(t *testing.T) {
+	tmpdir, err := ioutil.TempDir(os.TempDir(), "tlsdir")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer os.RemoveAll(tmpdir)
+	tlsinfo, err := SelfCert(tmpdir, []string{"127.0.0.1"})
+	if err != nil {
+		t.Fatal(err)
+	}
+	if tlsinfo.Empty() {
+		t.Fatalf("tlsinfo should have certs (%+v)", tlsinfo)
+	}
+	testNewListenerTLSInfoAccept(t, tlsinfo)
+}