Browse Source

Add timeouts to dials

Brad Fitzpatrick 14 years ago
parent
commit
1a17ac6979
1 changed files with 47 additions and 6 deletions
  1. 47 6
      memcache.go

+ 47 - 6
memcache.go

@@ -28,6 +28,7 @@ import (
 	"os"
 	"strings"
 	"sync"
+	"time"
 )
 
 var _ = log.Printf
@@ -209,20 +210,60 @@ func (c *Client) getFreeConn(addr net.Addr) (cn *conn, ok bool) {
 	return cn, true
 }
 
+func (c *Client) netTimeoutNs() int64 {
+	if c.TimeoutNanos != 0 {
+		return c.TimeoutNanos
+	}
+	return DefaultTimeoutNanos
+}
+
+// ConnectTimeoutError is the error type used when it takes
+// too long to connect to the desired host. This level of
+// detail can generally be ignored.
+type ConnectTimeoutError struct {
+	Addr net.Addr
+}
+
+func (cte *ConnectTimeoutError) String() string {
+	return "memcache: connect timeout to " + cte.Addr.String()
+}
+
+func (c *Client) dial(addr net.Addr) (net.Conn, os.Error) {
+	type connError struct {
+		cn  net.Conn
+		err os.Error
+	}
+	ch := make(chan connError)
+	go func() {
+		nc, err := net.Dial(addr.Network(), addr.String())
+		ch <- connError{nc, err}
+	}()
+	select {
+	case ce := <-ch:
+		return ce.cn, ce.err
+	case <-time.After(c.netTimeoutNs()):
+		// Too slow. Fall through.
+	}
+	// Close the conn if it does end up finally coming in
+	go func() {
+		ce := <-ch
+		if ce.err == nil {
+			ce.cn.Close()
+		}
+	}()
+	return nil, &ConnectTimeoutError{addr}
+}
+
 func (c *Client) getConn(addr net.Addr) (*conn, os.Error) {
 	cn, ok := c.getFreeConn(addr)
 	if ok {
 		return cn, nil
 	}
-	nc, err := net.Dial(addr.Network(), addr.String())
+	nc, err := c.dial(addr)
 	if err != nil {
 		return nil, err
 	}
-	if c.TimeoutNanos != 0 {
-		nc.SetTimeout(c.TimeoutNanos)
-	} else {
-		nc.SetTimeout(DefaultTimeoutNanos)
-	}
+	nc.SetTimeout(c.netTimeoutNs())
 	return &conn{
 		nc:   nc,
 		addr: addr,