|
|
@@ -21,6 +21,7 @@ import (
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"net/url"
|
|
|
+ "strconv"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
"time"
|
|
|
@@ -149,8 +150,11 @@ type ServerConfig struct {
|
|
|
type server struct {
|
|
|
lg *zap.Logger
|
|
|
|
|
|
- from url.URL
|
|
|
- to url.URL
|
|
|
+ from url.URL
|
|
|
+ fromPort int
|
|
|
+ to url.URL
|
|
|
+ toPort int
|
|
|
+
|
|
|
tlsInfo transport.TLSInfo
|
|
|
dialTimeout time.Duration
|
|
|
|
|
|
@@ -198,8 +202,9 @@ func NewServer(cfg ServerConfig) Server {
|
|
|
s := &server{
|
|
|
lg: cfg.Logger,
|
|
|
|
|
|
- from: cfg.From,
|
|
|
- to: cfg.To,
|
|
|
+ from: cfg.From,
|
|
|
+ to: cfg.To,
|
|
|
+
|
|
|
tlsInfo: cfg.TLSInfo,
|
|
|
dialTimeout: cfg.DialTimeout,
|
|
|
|
|
|
@@ -215,6 +220,16 @@ func NewServer(cfg ServerConfig) Server {
|
|
|
pauseRxc: make(chan struct{}),
|
|
|
}
|
|
|
|
|
|
+ _, fromPort, err := net.SplitHostPort(cfg.From.Host)
|
|
|
+ if err == nil {
|
|
|
+ s.fromPort, err = strconv.Atoi(fromPort)
|
|
|
+ }
|
|
|
+ var toPort string
|
|
|
+ _, toPort, err = net.SplitHostPort(cfg.To.Host)
|
|
|
+ if err == nil {
|
|
|
+ s.toPort, _ = strconv.Atoi(toPort)
|
|
|
+ }
|
|
|
+
|
|
|
if s.dialTimeout == 0 {
|
|
|
s.dialTimeout = defaultDialTimeout
|
|
|
}
|
|
|
@@ -239,12 +254,16 @@ func NewServer(cfg ServerConfig) Server {
|
|
|
s.to.Scheme = "tcp"
|
|
|
}
|
|
|
|
|
|
+ addr := fmt.Sprintf(":%d", s.fromPort)
|
|
|
+ if s.fromPort == 0 { // unix
|
|
|
+ addr = s.from.Host
|
|
|
+ }
|
|
|
+
|
|
|
var ln net.Listener
|
|
|
- var err error
|
|
|
if !s.tlsInfo.Empty() {
|
|
|
- ln, err = transport.NewListener(s.from.Host, s.from.Scheme, &s.tlsInfo)
|
|
|
+ ln, err = transport.NewListener(addr, s.from.Scheme, &s.tlsInfo)
|
|
|
} else {
|
|
|
- ln, err = net.Listen(s.from.Scheme, s.from.Host)
|
|
|
+ ln, err = net.Listen(s.from.Scheme, addr)
|
|
|
}
|
|
|
if err != nil {
|
|
|
s.errc <- err
|