Browse Source

internal/socks: add DialWithConn method to Dialer

This change adds DialWithConn method for allowing package users to use
own net.Conn implementations optionally.

Also makes the deprecated Dialer.Dial return a raw transport connection
instead of a forward proxy connection for preserving the backward
compatibility on proxy.Dialer.Dial method.

Fixes golang/go#25104.

Change-Id: I4259cd10e299c1e36406545708e9f6888191705a
Reviewed-on: https://go-review.googlesource.com/110135
Run-TryBot: Mikio Hara <mikioh.mikioh@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Mikio Hara 7 years ago
parent
commit
75944861c7
2 changed files with 104 additions and 41 deletions
  1. 40 28
      internal/socks/dial_test.go
  2. 64 13
      internal/socks/socks.go

+ 40 - 28
internal/socks/dial_test.go

@@ -17,19 +17,11 @@ import (
 	"golang.org/x/net/internal/sockstest"
 	"golang.org/x/net/internal/sockstest"
 )
 )
 
 
-const (
-	targetNetwork  = "tcp6"
-	targetHostname = "fqdn.doesnotexist"
-	targetHostIP   = "2001:db8::1"
-	targetPort     = "5963"
-)
-
 func TestDial(t *testing.T) {
 func TestDial(t *testing.T) {
 	t.Run("Connect", func(t *testing.T) {
 	t.Run("Connect", func(t *testing.T) {
 		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
 		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
 		if err != nil {
 		if err != nil {
-			t.Error(err)
-			return
+			t.Fatal(err)
 		}
 		}
 		defer ss.Close()
 		defer ss.Close()
 		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
 		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
@@ -41,21 +33,45 @@ func TestDial(t *testing.T) {
 			Username: "username",
 			Username: "username",
 			Password: "password",
 			Password: "password",
 		}).Authenticate
 		}).Authenticate
-		c, err := d.Dial(targetNetwork, net.JoinHostPort(targetHostIP, targetPort))
-		if err == nil {
-			c.(*socks.Conn).BoundAddr()
-			c.Close()
+		c, err := d.DialContext(context.Background(), ss.TargetAddr().Network(), ss.TargetAddr().String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		c.(*socks.Conn).BoundAddr()
+		c.Close()
+	})
+	t.Run("ConnectWithConn", func(t *testing.T) {
+		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer ss.Close()
+		c, err := net.Dial(ss.Addr().Network(), ss.Addr().String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer c.Close()
+		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
+		d.AuthMethods = []socks.AuthMethod{
+			socks.AuthMethodNotRequired,
+			socks.AuthMethodUsernamePassword,
 		}
 		}
+		d.Authenticate = (&socks.UsernamePassword{
+			Username: "username",
+			Password: "password",
+		}).Authenticate
+		a, err := d.DialWithConn(context.Background(), c, ss.TargetAddr().Network(), ss.TargetAddr().String())
 		if err != nil {
 		if err != nil {
-			t.Error(err)
-			return
+			t.Fatal(err)
+		}
+		if _, ok := a.(*socks.Addr); !ok {
+			t.Fatalf("got %+v; want socks.Addr", a)
 		}
 		}
 	})
 	})
 	t.Run("Cancel", func(t *testing.T) {
 	t.Run("Cancel", func(t *testing.T) {
 		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
 		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
 		if err != nil {
 		if err != nil {
-			t.Error(err)
-			return
+			t.Fatal(err)
 		}
 		}
 		defer ss.Close()
 		defer ss.Close()
 		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
 		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
@@ -63,7 +79,7 @@ func TestDial(t *testing.T) {
 		defer cancel()
 		defer cancel()
 		dialErr := make(chan error)
 		dialErr := make(chan error)
 		go func() {
 		go func() {
-			c, err := d.DialContext(ctx, ss.TargetAddr().Network(), net.JoinHostPort(targetHostname, targetPort))
+			c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
 			if err == nil {
 			if err == nil {
 				c.Close()
 				c.Close()
 			}
 			}
@@ -73,41 +89,37 @@ func TestDial(t *testing.T) {
 		cancel()
 		cancel()
 		err = <-dialErr
 		err = <-dialErr
 		if perr, nerr := parseDialError(err); perr != context.Canceled && nerr == nil {
 		if perr, nerr := parseDialError(err); perr != context.Canceled && nerr == nil {
-			t.Errorf("got %v; want context.Canceled or equivalent", err)
-			return
+			t.Fatalf("got %v; want context.Canceled or equivalent", err)
 		}
 		}
 	})
 	})
 	t.Run("Deadline", func(t *testing.T) {
 	t.Run("Deadline", func(t *testing.T) {
 		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
 		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
 		if err != nil {
 		if err != nil {
-			t.Error(err)
-			return
+			t.Fatal(err)
 		}
 		}
 		defer ss.Close()
 		defer ss.Close()
 		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
 		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
 		ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
 		ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
 		defer cancel()
 		defer cancel()
-		c, err := d.DialContext(ctx, ss.TargetAddr().Network(), net.JoinHostPort(targetHostname, targetPort))
+		c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
 		if err == nil {
 		if err == nil {
 			c.Close()
 			c.Close()
 		}
 		}
 		if perr, nerr := parseDialError(err); perr != context.DeadlineExceeded && nerr == nil {
 		if perr, nerr := parseDialError(err); perr != context.DeadlineExceeded && nerr == nil {
-			t.Errorf("got %v; want context.DeadlineExceeded or equivalent", err)
-			return
+			t.Fatalf("got %v; want context.DeadlineExceeded or equivalent", err)
 		}
 		}
 	})
 	})
 	t.Run("WithRogueServer", func(t *testing.T) {
 	t.Run("WithRogueServer", func(t *testing.T) {
 		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, rogueCmdFunc)
 		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, rogueCmdFunc)
 		if err != nil {
 		if err != nil {
-			t.Error(err)
-			return
+			t.Fatal(err)
 		}
 		}
 		defer ss.Close()
 		defer ss.Close()
 		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
 		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
 		for i := 0; i < 2*len(rogueCmdList); i++ {
 		for i := 0; i < 2*len(rogueCmdList); i++ {
 			ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
 			ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
 			defer cancel()
 			defer cancel()
-			c, err := d.DialContext(ctx, targetNetwork, net.JoinHostPort(targetHostIP, targetPort))
+			c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
 			if err == nil {
 			if err == nil {
 				t.Log(c.(*socks.Conn).BoundAddr())
 				t.Log(c.(*socks.Conn).BoundAddr())
 				c.Close()
 				c.Close()

+ 64 - 13
internal/socks/socks.go

@@ -149,20 +149,13 @@ type Dialer struct {
 // See func Dial of the net package of standard library for a
 // See func Dial of the net package of standard library for a
 // description of the network and address parameters.
 // description of the network and address parameters.
 func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
 func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
-	switch network {
-	case "tcp", "tcp6", "tcp4":
-	default:
-		proxy, dst, _ := d.pathAddrs(address)
-		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("network not implemented")}
-	}
-	switch d.cmd {
-	case CmdConnect, cmdBind:
-	default:
+	if err := d.validateTarget(network, address); err != nil {
 		proxy, dst, _ := d.pathAddrs(address)
 		proxy, dst, _ := d.pathAddrs(address)
-		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("command not implemented")}
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
 	}
 	}
 	if ctx == nil {
 	if ctx == nil {
-		ctx = context.Background()
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
 	}
 	}
 	var err error
 	var err error
 	var c net.Conn
 	var c net.Conn
@@ -185,11 +178,69 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
 	return &Conn{Conn: c, boundAddr: a}, nil
 	return &Conn{Conn: c, boundAddr: a}, nil
 }
 }
 
 
+// DialWithConn initiates a connection from SOCKS server to the target
+// network and address using the connection c that is already
+// connected to the SOCKS server.
+//
+// It returns the connection's local address assigned by the SOCKS
+// server.
+func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) {
+	if err := d.validateTarget(network, address); err != nil {
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+	}
+	if ctx == nil {
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
+	}
+	a, err := d.connect(ctx, c, address)
+	if err != nil {
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+	}
+	return a, nil
+}
+
 // Dial connects to the provided address on the provided network.
 // Dial connects to the provided address on the provided network.
 //
 //
-// Deprecated: Use DialContext instead.
+// Unlike DialContext, it returns a raw transport connection instead
+// of a forward proxy connection.
+//
+// Deprecated: Use DialContext or DialWithConn instead.
 func (d *Dialer) Dial(network, address string) (net.Conn, error) {
 func (d *Dialer) Dial(network, address string) (net.Conn, error) {
-	return d.DialContext(context.Background(), network, address)
+	if err := d.validateTarget(network, address); err != nil {
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+	}
+	var err error
+	var c net.Conn
+	if d.ProxyDial != nil {
+		c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress)
+	} else {
+		c, err = net.Dial(d.proxyNetwork, d.proxyAddress)
+	}
+	if err != nil {
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+	}
+	if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil {
+		return nil, err
+	}
+	return c, nil
+}
+
+func (d *Dialer) validateTarget(network, address string) error {
+	switch network {
+	case "tcp", "tcp6", "tcp4":
+	default:
+		return errors.New("network not implemented")
+	}
+	switch d.cmd {
+	case CmdConnect, cmdBind:
+	default:
+		return errors.New("command not implemented")
+	}
+	return nil
 }
 }
 
 
 func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {
 func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {