Browse Source

netutil: release semaphore on error

Also rewrite it a bit for clarity (IMO).

LGTM=pzm, r
R=pzm, adg, r
CC=golang-codereviews
https://golang.org/cl/96560043
Brad Fitzpatrick 11 năm trước cách đây
mục cha
commit
a479876f52
2 tập tin đã thay đổi với 43 bổ sung13 xóa
  1. 11 13
      netutil/listen.go
  2. 32 0
      netutil/listen_test.go

+ 11 - 13
netutil/listen.go

@@ -14,37 +14,35 @@ import (
 // LimitListener returns a Listener that accepts at most n simultaneous
 // connections from the provided Listener.
 func LimitListener(l net.Listener, n int) net.Listener {
-	ch := make(chan struct{}, n)
-	for i := 0; i < n; i++ {
-		ch <- struct{}{}
-	}
-	return &limitListener{l, ch}
+	return &limitListener{l, make(chan struct{}, n)}
 }
 
 type limitListener struct {
 	net.Listener
-	ch chan struct{}
+	sem chan struct{}
 }
 
+func (l *limitListener) acquire() { l.sem <- struct{}{} }
+func (l *limitListener) release() { <-l.sem }
+
 func (l *limitListener) Accept() (net.Conn, error) {
-	<-l.ch
+	l.acquire()
 	c, err := l.Listener.Accept()
 	if err != nil {
+		l.release()
 		return nil, err
 	}
-	return &limitListenerConn{Conn: c, ch: l.ch}, nil
+	return &limitListenerConn{Conn: c, release: l.release}, nil
 }
 
 type limitListenerConn struct {
 	net.Conn
-	ch    chan<- struct{}
-	close sync.Once
+	releaseOnce sync.Once
+	release     func()
 }
 
 func (l *limitListenerConn) Close() error {
 	err := l.Conn.Close()
-	l.close.Do(func() {
-		l.ch <- struct{}{}
-	})
+	l.releaseOnce.Do(l.release)
 	return err
 }

+ 32 - 0
netutil/listen_test.go

@@ -10,6 +10,7 @@
 package netutil
 
 import (
+	"errors"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -69,3 +70,34 @@ func TestLimitListener(t *testing.T) {
 		t.Errorf("too many Gets failed: %v", failed)
 	}
 }
+
+type errorListener struct {
+	net.Listener
+}
+
+func (errorListener) Accept() (net.Conn, error) {
+	return nil, errFake
+}
+
+var errFake = errors.New("fake error from errorListener")
+
+// This used to hang.
+func TestLimitListenerError(t *testing.T) {
+	donec := make(chan bool, 1)
+	go func() {
+		const n = 2
+		ll := LimitListener(errorListener{}, n)
+		for i := 0; i < n+1; i++ {
+			_, err := ll.Accept()
+			if err != errFake {
+				t.Fatalf("Accept error = %v; want errFake", err)
+			}
+		}
+		donec <- true
+	}()
+	select {
+	case <-donec:
+	case <-time.After(5 * time.Second):
+		t.Fatal("timeout. deadlock?")
+	}
+}