Browse Source

Merge pull request #1771 from xiang90/listener

pkg/transport: add timeout dailer and timeout listener
Xiang Li 11 years ago
parent
commit
c6cbea451a

+ 30 - 0
pkg/transport/timeout_conn.go

@@ -0,0 +1,30 @@
+package transport
+
+import (
+	"net"
+	"time"
+)
+
+type timeoutConn struct {
+	net.Conn
+	wtimeoutd  time.Duration
+	rdtimeoutd time.Duration
+}
+
+func (c timeoutConn) Write(b []byte) (n int, err error) {
+	if c.wtimeoutd > 0 {
+		if err := c.SetWriteDeadline(time.Now().Add(c.wtimeoutd)); err != nil {
+			return 0, err
+		}
+	}
+	return c.Conn.Write(b)
+}
+
+func (c timeoutConn) Read(b []byte) (n int, err error) {
+	if c.rdtimeoutd > 0 {
+		if err := c.SetReadDeadline(time.Now().Add(c.rdtimeoutd)); err != nil {
+			return 0, err
+		}
+	}
+	return c.Conn.Read(b)
+}

+ 22 - 0
pkg/transport/timeout_dailer.go

@@ -0,0 +1,22 @@
+package transport
+
+import (
+	"net"
+	"time"
+)
+
+type rwTimeoutDialer struct {
+	wtimeoutd  time.Duration
+	rdtimeoutd time.Duration
+	net.Dialer
+}
+
+func (d *rwTimeoutDialer) Dial(network, address string) (net.Conn, error) {
+	conn, err := d.Dialer.Dial(network, address)
+	tconn := &timeoutConn{
+		rdtimeoutd: d.rdtimeoutd,
+		wtimeoutd:  d.wtimeoutd,
+		Conn:       conn,
+	}
+	return tconn, err
+}

+ 73 - 0
pkg/transport/timeout_dailer_test.go

@@ -0,0 +1,73 @@
+package transport
+
+import (
+	"net"
+	"testing"
+	"time"
+)
+
+func TestReadWriteTimeoutDialer(t *testing.T) {
+	stop := make(chan struct{})
+
+	ln, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("unexpected listen error: %v", err)
+	}
+	ts := testBlockingServer{ln, 2, stop}
+	go ts.Start(t)
+
+	d := rwTimeoutDialer{
+		wtimeoutd:  10 * time.Millisecond,
+		rdtimeoutd: 10 * time.Millisecond,
+	}
+	conn, err := d.Dial("tcp", ln.Addr().String())
+	if err != nil {
+		t.Fatalf("unexpected dial error: %v", err)
+	}
+	defer conn.Close()
+
+	// fill the socket buffer
+	data := make([]byte, 1024*1024)
+	timer := time.AfterFunc(d.wtimeoutd*5, func() {
+		t.Fatal("wait timeout")
+	})
+	defer timer.Stop()
+
+	_, err = conn.Write(data)
+	if operr, ok := err.(*net.OpError); !ok || operr.Op != "write" || !operr.Timeout() {
+		t.Errorf("err = %v, want write i/o timeout error", err)
+	}
+
+	timer.Reset(d.rdtimeoutd * 5)
+
+	conn, err = d.Dial("tcp", ln.Addr().String())
+	if err != nil {
+		t.Fatalf("unexpected dial error: %v", err)
+	}
+	defer conn.Close()
+
+	buf := make([]byte, 10)
+	_, err = conn.Read(buf)
+	if operr, ok := err.(*net.OpError); !ok || operr.Op != "read" || !operr.Timeout() {
+		t.Errorf("err = %v, want write i/o timeout error", err)
+	}
+
+	stop <- struct{}{}
+}
+
+type testBlockingServer struct {
+	ln   net.Listener
+	n    int
+	stop chan struct{}
+}
+
+func (ts *testBlockingServer) Start(t *testing.T) {
+	for i := 0; i < ts.n; i++ {
+		conn, err := ts.ln.Accept()
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer conn.Close()
+	}
+	<-ts.stop
+}

+ 24 - 0
pkg/transport/timeout_listener.go

@@ -0,0 +1,24 @@
+package transport
+
+import (
+	"net"
+	"time"
+)
+
+type rwTimeoutListener struct {
+	net.Listener
+	wtimeoutd  time.Duration
+	rdtimeoutd time.Duration
+}
+
+func (rwln *rwTimeoutListener) Accept() (net.Conn, error) {
+	c, err := rwln.Listener.Accept()
+	if err != nil {
+		return nil, err
+	}
+	return timeoutConn{
+		Conn:       c,
+		wtimeoutd:  rwln.wtimeoutd,
+		rdtimeoutd: rwln.rdtimeoutd,
+	}, nil
+}

+ 64 - 0
pkg/transport/timeout_listener_test.go

@@ -0,0 +1,64 @@
+package transport
+
+import (
+	"net"
+	"testing"
+	"time"
+)
+
+func TestWriteReadTimeoutListener(t *testing.T) {
+	ln, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("unexpected listen error: %v", err)
+	}
+	wln := rwTimeoutListener{
+		Listener:   ln,
+		wtimeoutd:  10 * time.Millisecond,
+		rdtimeoutd: 10 * time.Millisecond,
+	}
+	stop := make(chan struct{})
+
+	blocker := func() {
+		conn, err := net.Dial("tcp", ln.Addr().String())
+		if err != nil {
+			t.Fatalf("unexpected dail error: %v", err)
+		}
+		defer conn.Close()
+		// block the receiver until the writer timeout
+		<-stop
+	}
+	go blocker()
+
+	conn, err := wln.Accept()
+	if err != nil {
+		t.Fatalf("unexpected accept error: %v", err)
+	}
+	defer conn.Close()
+
+	// fill the socket buffer
+	data := make([]byte, 1024*1024)
+	timer := time.AfterFunc(wln.wtimeoutd*5, func() {
+		t.Fatal("wait timeout")
+	})
+	defer timer.Stop()
+
+	_, err = conn.Write(data)
+	if operr, ok := err.(*net.OpError); !ok || operr.Op != "write" || !operr.Timeout() {
+		t.Errorf("err = %v, want write i/o timeout error", err)
+	}
+	stop <- struct{}{}
+
+	timer.Reset(wln.rdtimeoutd * 5)
+	go blocker()
+
+	conn, err = wln.Accept()
+	if err != nil {
+		t.Fatalf("unexpected accept error: %v", err)
+	}
+	buf := make([]byte, 10)
+	_, err = conn.Read(buf)
+	if operr, ok := err.(*net.OpError); !ok || operr.Op != "read" || !operr.Timeout() {
+		t.Errorf("err = %v, want write i/o timeout error", err)
+	}
+	stop <- struct{}{}
+}