소스 검색

Merge pull request #1140 from bcwaldon/TLS

client server TLS
Brian Waldon 11 년 전
부모
커밋
b754406f10
3개의 변경된 파일118개의 추가작업 그리고 4개의 파일을 삭제
  1. 25 3
      main.go
  2. 1 1
      test
  3. 92 0
      transport/listener.go

+ 25 - 3
main.go

@@ -19,6 +19,7 @@ import (
 	"github.com/coreos/etcd/raft"
 	"github.com/coreos/etcd/snap"
 	"github.com/coreos/etcd/store"
+	"github.com/coreos/etcd/transport"
 	"github.com/coreos/etcd/wal"
 )
 
@@ -48,6 +49,8 @@ var (
 		proxyFlagValueReadonly,
 		proxyFlagValueOn,
 	}
+
+	clientTLSInfo = transport.TLSInfo{}
 )
 
 func init() {
@@ -58,6 +61,10 @@ func init() {
 	peers.Set("0x1=localhost:8080")
 	addrs.Set("127.0.0.1:4001")
 	proxyFlag.Set(proxyFlagValueOff)
+
+	flag.StringVar(&clientTLSInfo.CAFile, "ca-file", "", "Path to the client server TLS CA file.")
+	flag.StringVar(&clientTLSInfo.CertFile, "cert-file", "", "Path to the client server TLS cert file.")
+	flag.StringVar(&clientTLSInfo.KeyFile, "key-file", "", "Path to the client server TLS key file.")
 }
 
 func main() {
@@ -167,18 +174,28 @@ func startEtcd() {
 		Info:    cors,
 	}
 
+	l, err := transport.NewListener(*paddr, transport.TLSInfo{})
+	if err != nil {
+		log.Fatal(err)
+	}
+
 	// Start the peer server in a goroutine
 	go func() {
 		log.Print("Listening for peers on ", *paddr)
-		log.Fatal(http.ListenAndServe(*paddr, ph))
+		log.Fatal(http.Serve(l, ph))
 	}()
 
 	// Start a client server goroutine for each listen address
 	for _, addr := range *addrs {
 		addr := addr
+		l, err := transport.NewListener(addr, clientTLSInfo)
+		if err != nil {
+			log.Fatal(err)
+		}
+
 		go func() {
 			log.Print("Listening for client requests on ", addr)
-			log.Fatal(http.ListenAndServe(addr, ch))
+			log.Fatal(http.Serve(l, ch))
 		}()
 	}
 }
@@ -201,9 +218,14 @@ func startProxy() {
 	// Start a proxy server goroutine for each listen address
 	for _, addr := range *addrs {
 		addr := addr
+		l, err := transport.NewListener(addr, clientTLSInfo)
+		if err != nil {
+			log.Fatal(err)
+		}
+
 		go func() {
 			log.Print("Listening for client requests on ", addr)
-			log.Fatal(http.ListenAndServe(addr, ph))
+			log.Fatal(http.Serve(l, ph))
 		}()
 	}
 }

+ 1 - 1
test

@@ -15,7 +15,7 @@ COVER=${COVER:-"-cover"}
 source ./build
 
 # Hack: gofmt ./ will recursively check the .git directory. So use *.go for gofmt.
-TESTABLE_AND_FORMATTABLE="client etcdserver etcdserver/etcdhttp etcdserver/etcdserverpb functional proxy raft snap store wait wal"
+TESTABLE_AND_FORMATTABLE="client etcdserver etcdserver/etcdhttp etcdserver/etcdserverpb functional proxy raft snap store wait wal transport"
 TESTABLE="$TESTABLE_AND_FORMATTABLE ./"
 FORMATTABLE="$TESTABLE_AND_FORMATTABLE *.go"
 

+ 92 - 0
transport/listener.go

@@ -0,0 +1,92 @@
+package transport
+
+import (
+	"crypto/tls"
+	"crypto/x509"
+	"encoding/pem"
+	"fmt"
+	"io/ioutil"
+	"net"
+)
+
+func NewListener(addr string, info TLSInfo) (net.Listener, error) {
+	l, err := net.Listen("tcp", addr)
+	if err != nil {
+		return nil, err
+	}
+
+	if !info.Empty() {
+		cfg, err := info.ServerConfig()
+		if err != nil {
+			return nil, err
+		}
+
+		l = tls.NewListener(l, cfg)
+	}
+
+	return l, nil
+}
+
+type TLSInfo struct {
+	CertFile string
+	KeyFile  string
+	CAFile   string
+}
+
+func (info TLSInfo) Empty() bool {
+	return info.CertFile == "" && info.KeyFile == ""
+}
+
+// Generates a tls.Config object for a server from the given files.
+func (info TLSInfo) ServerConfig() (*tls.Config, error) {
+	// Both the key and cert must be present.
+	if info.KeyFile == "" || info.CertFile == "" {
+		return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile)
+	}
+
+	var cfg tls.Config
+
+	tlsCert, err := tls.LoadX509KeyPair(info.CertFile, info.KeyFile)
+	if err != nil {
+		return nil, err
+	}
+
+	cfg.Certificates = []tls.Certificate{tlsCert}
+
+	if info.CAFile != "" {
+		cfg.ClientAuth = tls.RequireAndVerifyClientCert
+		cp, err := newCertPool(info.CAFile)
+		if err != nil {
+			return nil, err
+		}
+
+		cfg.RootCAs = cp
+		cfg.ClientCAs = cp
+	} else {
+		cfg.ClientAuth = tls.NoClientCert
+	}
+
+	return &cfg, nil
+}
+
+// newCertPool creates x509 certPool with provided CA file
+func newCertPool(CAFile string) (*x509.CertPool, error) {
+	certPool := x509.NewCertPool()
+	pemByte, err := ioutil.ReadFile(CAFile)
+	if err != nil {
+		return nil, err
+	}
+
+	for {
+		var block *pem.Block
+		block, pemByte = pem.Decode(pemByte)
+		if block == nil {
+			return certPool, nil
+		}
+		cert, err := x509.ParseCertificate(block.Bytes)
+		if err != nil {
+			return nil, err
+		}
+		certPool.AddCert(cert)
+	}
+}