Quellcode durchsuchen

Add DialWithExplicitTLS

Julien Laffaye vor 5 Jahren
Ursprung
Commit
ac1574d383
2 geänderte Dateien mit 37 neuen und 7 gelöschten Zeilen
  1. 35 7
      ftp.go
  2. 2 0
      status.go

+ 35 - 7
ftp.go

@@ -50,6 +50,7 @@ type dialOptions struct {
 	context     context.Context
 	dialer      net.Dialer
 	tlsConfig   *tls.Config
+	explicitTLS bool
 	conn        net.Conn
 	disableEPSV bool
 	location    *time.Location
@@ -90,7 +91,7 @@ func Dial(addr string, options ...DialOption) (*ServerConn, error) {
 
 		if do.dialFunc != nil {
 			tconn, err = do.dialFunc("tcp", addr)
-		} else if do.tlsConfig != nil {
+		} else if do.tlsConfig != nil && !do.explicitTLS {
 			tconn, err = tls.DialWithDialer(&do.dialer, "tcp", addr, do.tlsConfig)
 		} else {
 			ctx := do.context
@@ -111,15 +112,10 @@ func Dial(addr string, options ...DialOption) (*ServerConn, error) {
 	// If we use the domain name, we might not resolve to the same IP.
 	remoteAddr := tconn.RemoteAddr().(*net.TCPAddr)
 
-	var sourceConn io.ReadWriteCloser = tconn
-	if do.debugOutput != nil {
-		sourceConn = newDebugWrapper(tconn, do.debugOutput)
-	}
-
 	c := &ServerConn{
 		options:  do,
 		features: make(map[string]string),
-		conn:     textproto.NewConn(sourceConn),
+		conn:     textproto.NewConn(do.wrapConn(tconn)),
 		host:     remoteAddr.IP.String(),
 	}
 
@@ -129,6 +125,15 @@ func Dial(addr string, options ...DialOption) (*ServerConn, error) {
 		return nil, err
 	}
 
+	if do.explicitTLS {
+		if err := c.authTLS(); err != nil {
+			_ = c.Quit()
+			return nil, err
+		}
+		tconn = tls.Client(tconn, do.tlsConfig)
+		c.conn = textproto.NewConn(do.wrapConn(tconn))
+	}
+
 	err = c.feat()
 	if err != nil {
 		c.Quit()
@@ -198,6 +203,15 @@ func DialWithTLS(tlsConfig *tls.Config) DialOption {
 	}}
 }
 
+// DialWithExplicitTLS returns a DialOption that configures the ServerConn to be upgraded to TLS
+// See DialWithTLS for general TLS documentation
+func DialWithExplicitTLS(tlsConfig *tls.Config) DialOption {
+	return DialOption{func(do *dialOptions) {
+		do.explicitTLS = true
+		do.tlsConfig = tlsConfig
+	}}
+}
+
 // DialWithDebugOutput returns a DialOption that configures the ServerConn to write to the Writer
 // everything it reads from the server
 func DialWithDebugOutput(w io.Writer) DialOption {
@@ -218,6 +232,14 @@ func DialWithDialFunc(f func(network, address string) (net.Conn, error)) DialOpt
 	}}
 }
 
+func (o *dialOptions) wrapConn(netConn net.Conn) io.ReadWriteCloser {
+	if o.debugOutput == nil {
+		return netConn
+	}
+
+	return newDebugWrapper(netConn, o.debugOutput)
+}
+
 // Connect is an alias to Dial, for backward compatibility
 func Connect(addr string) (*ServerConn, error) {
 	return Dial(addr)
@@ -269,6 +291,12 @@ func (c *ServerConn) Login(user, password string) error {
 	return err
 }
 
+// authTLS upgrades the connection to use TLS
+func (c *ServerConn) authTLS() error {
+	_, _, err := c.cmd(StatusAuthOK, "AUTH TLS")
+	return err
+}
+
 // feat issues a FEAT FTP command to list the additional commands supported by
 // the remote FTP server.
 // FEAT is described in RFC 2389

+ 2 - 0
status.go

@@ -25,6 +25,7 @@ const (
 	StatusLoggedIn              = 230
 	StatusLoggedOut             = 231
 	StatusLogoutAck             = 232
+	StatusAuthOK                = 234
 	StatusRequestedFileActionOK = 250
 	StatusPathCreated           = 257
 
@@ -73,6 +74,7 @@ var statusText = map[int]string{
 	StatusLoggedIn:              "User logged in, proceed.",
 	StatusLoggedOut:             "User logged out; service terminated.",
 	StatusLogoutAck:             "Logout command noted, will complete when transfer done.",
+	StatusAuthOK:                "AUTH command OK",
 	StatusRequestedFileActionOK: "Requested file action okay, completed.",
 	StatusPathCreated:           "Path created.",