Browse Source

*: move baisc tls util funcs to tlsutil pkg

Xiang Li 9 years ago
parent
commit
eb3919e8cf
2 changed files with 78 additions and 47 deletions
  1. 72 0
      pkg/tlsutil/tlsutil.go
  2. 6 47
      pkg/transport/listener.go

+ 72 - 0
pkg/tlsutil/tlsutil.go

@@ -0,0 +1,72 @@
+// Copyright 2016 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tlsutil
+
+import (
+	"crypto/tls"
+	"crypto/x509"
+	"encoding/pem"
+	"io/ioutil"
+)
+
+// NewCertPool creates x509 certPool with provided CA files.
+func NewCertPool(CAFiles []string) (*x509.CertPool, error) {
+	certPool := x509.NewCertPool()
+
+	for _, CAFile := range CAFiles {
+		pemByte, err := ioutil.ReadFile(CAFile)
+		if err != nil {
+			return nil, err
+		}
+
+		for {
+			var block *pem.Block
+			block, pemByte = pem.Decode(pemByte)
+			if block == nil {
+				break
+			}
+			cert, err := x509.ParseCertificate(block.Bytes)
+			if err != nil {
+				return nil, err
+			}
+			certPool.AddCert(cert)
+		}
+	}
+
+	return certPool, nil
+}
+
+// NewCert generates TLS cert by using the given cert,key and parse function.
+func NewCert(certfile, keyfile string, parseFunc func([]byte, []byte) (tls.Certificate, error)) (*tls.Certificate, error) {
+	cert, err := ioutil.ReadFile(certfile)
+	if err != nil {
+		return nil, err
+	}
+
+	key, err := ioutil.ReadFile(keyfile)
+	if err != nil {
+		return nil, err
+	}
+
+	if parseFunc == nil {
+		parseFunc = tls.X509KeyPair
+	}
+
+	tlsCert, err := parseFunc(cert, key)
+	if err != nil {
+		return nil, err
+	}
+	return &tlsCert, nil
+}

+ 6 - 47
pkg/transport/listener.go

@@ -23,7 +23,6 @@ import (
 	"crypto/x509/pkix"
 	"encoding/pem"
 	"fmt"
-	"io/ioutil"
 	"math/big"
 	"net"
 	"net/http"
@@ -31,6 +30,8 @@ import (
 	"path"
 	"strings"
 	"time"
+
+	"github.com/coreos/etcd/pkg/tlsutil"
 )
 
 func NewListener(addr string, scheme string, tlscfg *tls.Config) (net.Listener, error) {
@@ -176,28 +177,13 @@ func (info TLSInfo) baseConfig() (*tls.Config, error) {
 		return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile)
 	}
 
-	cert, err := ioutil.ReadFile(info.CertFile)
-	if err != nil {
-		return nil, err
-	}
-
-	key, err := ioutil.ReadFile(info.KeyFile)
-	if err != nil {
-		return nil, err
-	}
-
-	parseFunc := info.parseFunc
-	if parseFunc == nil {
-		parseFunc = tls.X509KeyPair
-	}
-
-	tlsCert, err := parseFunc(cert, key)
+	tlsCert, err := tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
 	if err != nil {
 		return nil, err
 	}
 
 	cfg := &tls.Config{
-		Certificates: []tls.Certificate{tlsCert},
+		Certificates: []tls.Certificate{*tlsCert},
 		MinVersion:   tls.VersionTLS10,
 	}
 	return cfg, nil
@@ -229,7 +215,7 @@ func (info TLSInfo) ServerConfig() (*tls.Config, error) {
 
 	CAFiles := info.cafiles()
 	if len(CAFiles) > 0 {
-		cp, err := newCertPool(CAFiles)
+		cp, err := tlsutil.NewCertPool(CAFiles)
 		if err != nil {
 			return nil, err
 		}
@@ -255,7 +241,7 @@ func (info TLSInfo) ClientConfig() (*tls.Config, error) {
 
 	CAFiles := info.cafiles()
 	if len(CAFiles) > 0 {
-		cfg.RootCAs, err = newCertPool(CAFiles)
+		cfg.RootCAs, err = tlsutil.NewCertPool(CAFiles)
 		if err != nil {
 			return nil, err
 		}
@@ -266,30 +252,3 @@ func (info TLSInfo) ClientConfig() (*tls.Config, error) {
 	}
 	return cfg, nil
 }
-
-// newCertPool creates x509 certPool with provided CA files.
-func newCertPool(CAFiles []string) (*x509.CertPool, error) {
-	certPool := x509.NewCertPool()
-
-	for _, CAFile := range CAFiles {
-		pemByte, err := ioutil.ReadFile(CAFile)
-		if err != nil {
-			return nil, err
-		}
-
-		for {
-			var block *pem.Block
-			block, pemByte = pem.Decode(pemByte)
-			if block == nil {
-				break
-			}
-			cert, err := x509.ParseCertificate(block.Bytes)
-			if err != nil {
-				return nil, err
-			}
-			certPool.AddCert(cert)
-		}
-	}
-
-	return certPool, nil
-}