Browse Source

transport: wrap net.Listener with TLSInfo

Brian Waldon 11 years ago
parent
commit
17459c7bfc
3 changed files with 89 additions and 6 deletions
  1. 3 3
      main.go
  2. 1 1
      test
  3. 85 2
      transport/listener.go

+ 3 - 3
main.go

@@ -168,7 +168,7 @@ func startEtcd() {
 		Info:    cors,
 	}
 
-	l, err := transport.NewListener(*paddr)
+	l, err := transport.NewListener(*paddr, transport.TLSInfo{})
 	if err != nil {
 		log.Fatal(err)
 	}
@@ -182,7 +182,7 @@ func startEtcd() {
 	// Start a client server goroutine for each listen address
 	for _, addr := range *addrs {
 		addr := addr
-		l, err := transport.NewListener(addr)
+		l, err := transport.NewListener(addr, transport.TLSInfo{})
 		if err != nil {
 			log.Fatal(err)
 		}
@@ -212,7 +212,7 @@ func startProxy() {
 	// Start a proxy server goroutine for each listen address
 	for _, addr := range *addrs {
 		addr := addr
-		l, err := transport.NewListener(addr)
+		l, err := transport.NewListener(addr, transport.TLSInfo{})
 		if err != nil {
 			log.Fatal(err)
 		}

+ 1 - 1
test

@@ -15,7 +15,7 @@ COVER=${COVER:-"-cover"}
 source ./build
 
 # Hack: gofmt ./ will recursively check the .git directory. So use *.go for gofmt.
-TESTABLE_AND_FORMATTABLE="client etcdserver etcdserver/etcdhttp etcdserver/etcdserverpb functional proxy raft snap store wait wal"
+TESTABLE_AND_FORMATTABLE="client etcdserver etcdserver/etcdhttp etcdserver/etcdserverpb functional proxy raft snap store wait wal transport"
 TESTABLE="$TESTABLE_AND_FORMATTABLE ./"
 FORMATTABLE="$TESTABLE_AND_FORMATTABLE *.go"
 

+ 85 - 2
transport/listener.go

@@ -1,9 +1,92 @@
 package transport
 
 import (
+	"crypto/tls"
+	"crypto/x509"
+	"encoding/pem"
+	"fmt"
+	"io/ioutil"
 	"net"
 )
 
-func NewListener(addr string) (net.Listener, error) {
-	return net.Listen("tcp", addr)
+func NewListener(addr string, info TLSInfo) (net.Listener, error) {
+	l, err := net.Listen("tcp", addr)
+	if err != nil {
+		return nil, err
+	}
+
+	if !info.Empty() {
+		cfg, err := info.ServerConfig()
+		if err != nil {
+			return nil, err
+		}
+
+		l = tls.NewListener(l, cfg)
+	}
+
+	return l, nil
+}
+
+type TLSInfo struct {
+	CertFile string
+	KeyFile  string
+	CAFile   string
+}
+
+func (info TLSInfo) Empty() bool {
+	return info.CertFile == "" && info.KeyFile == ""
+}
+
+// Generates a tls.Config object for a server from the given files.
+func (info TLSInfo) ServerConfig() (*tls.Config, error) {
+	// Both the key and cert must be present.
+	if info.KeyFile == "" || info.CertFile == "" {
+		return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile)
+	}
+
+	var cfg tls.Config
+
+	tlsCert, err := tls.LoadX509KeyPair(info.CertFile, info.KeyFile)
+	if err != nil {
+		return nil, err
+	}
+
+	cfg.Certificates = []tls.Certificate{tlsCert}
+
+	if info.CAFile != "" {
+		cfg.ClientAuth = tls.RequireAndVerifyClientCert
+		cp, err := newCertPool(info.CAFile)
+		if err != nil {
+			return nil, err
+		}
+
+		cfg.RootCAs = cp
+		cfg.ClientCAs = cp
+	} else {
+		cfg.ClientAuth = tls.NoClientCert
+	}
+
+	return &cfg, nil
+}
+
+// newCertPool creates x509 certPool with provided CA file
+func newCertPool(CAFile string) (*x509.CertPool, error) {
+	certPool := x509.NewCertPool()
+	pemByte, err := ioutil.ReadFile(CAFile)
+	if err != nil {
+		return nil, err
+	}
+
+	for {
+		var block *pem.Block
+		block, pemByte = pem.Decode(pemByte)
+		if block == nil {
+			return certPool, nil
+		}
+		cert, err := x509.ParseCertificate(block.Bytes)
+		if err != nil {
+			return nil, err
+		}
+		certPool.AddCert(cert)
+	}
 }