|
@@ -66,9 +66,11 @@ func NewTransport(info TLSInfo) (*http.Transport, error) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
type TLSInfo struct {
|
|
type TLSInfo struct {
|
|
|
- CertFile string
|
|
|
|
|
- KeyFile string
|
|
|
|
|
- CAFile string
|
|
|
|
|
|
|
+ CertFile string
|
|
|
|
|
+ KeyFile string
|
|
|
|
|
+ CAFile string
|
|
|
|
|
+ TrustedCAFile string
|
|
|
|
|
+ ClientCertAuth bool
|
|
|
|
|
|
|
|
// parseFunc exists to simplify testing. Typically, parseFunc
|
|
// parseFunc exists to simplify testing. Typically, parseFunc
|
|
|
// should be left nil. In that case, tls.X509KeyPair will be used.
|
|
// should be left nil. In that case, tls.X509KeyPair will be used.
|
|
@@ -115,29 +117,47 @@ func (info TLSInfo) baseConfig() (*tls.Config, error) {
|
|
|
return cfg, nil
|
|
return cfg, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// ServerConfig generates a tls.Config object for use by an HTTP server
|
|
|
|
|
|
|
+// cafiles returns a list of CA file paths.
|
|
|
|
|
+func (info TLSInfo) cafiles() []string {
|
|
|
|
|
+ cs := make([]string, 0)
|
|
|
|
|
+ if info.CAFile != "" {
|
|
|
|
|
+ cs = append(cs, info.CAFile)
|
|
|
|
|
+ }
|
|
|
|
|
+ if info.TrustedCAFile != "" {
|
|
|
|
|
+ cs = append(cs, info.TrustedCAFile)
|
|
|
|
|
+ }
|
|
|
|
|
+ return cs
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// ServerConfig generates a tls.Config object for use by an HTTP server.
|
|
|
func (info TLSInfo) ServerConfig() (*tls.Config, error) {
|
|
func (info TLSInfo) ServerConfig() (*tls.Config, error) {
|
|
|
cfg, err := info.baseConfig()
|
|
cfg, err := info.baseConfig()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if info.CAFile != "" {
|
|
|
|
|
|
|
+ cfg.ClientAuth = tls.NoClientCert
|
|
|
|
|
+ if info.CAFile != "" || info.ClientCertAuth {
|
|
|
cfg.ClientAuth = tls.RequireAndVerifyClientCert
|
|
cfg.ClientAuth = tls.RequireAndVerifyClientCert
|
|
|
- cp, err := newCertPool(info.CAFile)
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ CAFiles := info.cafiles()
|
|
|
|
|
+ if len(CAFiles) > 0 {
|
|
|
|
|
+ cp, err := newCertPool(CAFiles)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
|
cfg.ClientCAs = cp
|
|
cfg.ClientCAs = cp
|
|
|
- } else {
|
|
|
|
|
- cfg.ClientAuth = tls.NoClientCert
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
return cfg, nil
|
|
return cfg, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// ClientConfig generates a tls.Config object for use by an HTTP client
|
|
|
|
|
-func (info TLSInfo) ClientConfig() (cfg *tls.Config, err error) {
|
|
|
|
|
|
|
+// ClientConfig generates a tls.Config object for use by an HTTP client.
|
|
|
|
|
+func (info TLSInfo) ClientConfig() (*tls.Config, error) {
|
|
|
|
|
+ var cfg *tls.Config
|
|
|
|
|
+ var err error
|
|
|
|
|
+
|
|
|
if !info.Empty() {
|
|
if !info.Empty() {
|
|
|
cfg, err = info.baseConfig()
|
|
cfg, err = info.baseConfig()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
@@ -147,34 +167,40 @@ func (info TLSInfo) ClientConfig() (cfg *tls.Config, err error) {
|
|
|
cfg = &tls.Config{}
|
|
cfg = &tls.Config{}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if info.CAFile != "" {
|
|
|
|
|
- cfg.RootCAs, err = newCertPool(info.CAFile)
|
|
|
|
|
|
|
+ CAFiles := info.cafiles()
|
|
|
|
|
+ if len(CAFiles) > 0 {
|
|
|
|
|
+ cfg.RootCAs, err = newCertPool(CAFiles)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- return
|
|
|
|
|
|
|
+ return nil, err
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- return
|
|
|
|
|
|
|
+ return cfg, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// newCertPool creates x509 certPool with provided CA file
|
|
|
|
|
-func newCertPool(CAFile string) (*x509.CertPool, error) {
|
|
|
|
|
|
|
+// newCertPool creates x509 certPool with provided CA files.
|
|
|
|
|
+func newCertPool(CAFiles []string) (*x509.CertPool, error) {
|
|
|
certPool := x509.NewCertPool()
|
|
certPool := x509.NewCertPool()
|
|
|
- pemByte, err := ioutil.ReadFile(CAFile)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return nil, err
|
|
|
|
|
- }
|
|
|
|
|
|
|
|
|
|
- for {
|
|
|
|
|
- var block *pem.Block
|
|
|
|
|
- block, pemByte = pem.Decode(pemByte)
|
|
|
|
|
- if block == nil {
|
|
|
|
|
- return certPool, nil
|
|
|
|
|
- }
|
|
|
|
|
- cert, err := x509.ParseCertificate(block.Bytes)
|
|
|
|
|
|
|
+ for _, CAFile := range CAFiles {
|
|
|
|
|
+ pemByte, err := ioutil.ReadFile(CAFile)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
|
- certPool.AddCert(cert)
|
|
|
|
|
|
|
+
|
|
|
|
|
+ for {
|
|
|
|
|
+ var block *pem.Block
|
|
|
|
|
+ block, pemByte = pem.Decode(pemByte)
|
|
|
|
|
+ if block == nil {
|
|
|
|
|
+ break
|
|
|
|
|
+ }
|
|
|
|
|
+ cert, err := x509.ParseCertificate(block.Bytes)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ certPool.AddCert(cert)
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ return certPool, nil
|
|
|
}
|
|
}
|