// Copyright 2013 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package netutil import ( "errors" "fmt" "io" "io/ioutil" "net" "net/http" "sync" "sync/atomic" "testing" "time" ) const defaultMaxOpenFiles = 256 const timeout = 5 * time.Second func TestLimitListener(t *testing.T) { const max = 5 attempts := (maxOpenFiles() - max) / 2 if attempts > 256 { // maximum length of accept queue is 128 by default attempts = 256 } l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } defer l.Close() l = LimitListener(l, max) var open int32 go http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if n := atomic.AddInt32(&open, 1); n > max { t.Errorf("%d open connections, want <= %d", n, max) } defer atomic.AddInt32(&open, -1) time.Sleep(10 * time.Millisecond) fmt.Fprint(w, "some body") })) var wg sync.WaitGroup var failed int32 for i := 0; i < attempts; i++ { wg.Add(1) go func() { defer wg.Done() c := http.Client{Timeout: 3 * time.Second} r, err := c.Get("http://" + l.Addr().String()) if err != nil { t.Log(err) atomic.AddInt32(&failed, 1) return } defer r.Body.Close() io.Copy(ioutil.Discard, r.Body) }() } wg.Wait() // We expect some Gets to fail as the kernel's accept queue is filled, // but most should succeed. if int(failed) >= attempts/2 { t.Errorf("%d requests failed within %d attempts", failed, attempts) } } type errorListener struct { net.Listener } func (errorListener) Accept() (net.Conn, error) { return nil, errFake } var errFake = errors.New("fake error from errorListener") // This used to hang. func TestLimitListenerError(t *testing.T) { errCh := make(chan error, 1) go func() { defer close(errCh) const n = 2 ll := LimitListener(errorListener{}, n) for i := 0; i < n+1; i++ { _, err := ll.Accept() if err != errFake { errCh <- fmt.Errorf("Accept error = %v; want errFake", err) return } } }() select { case err := <-errCh: if err != nil { t.Fatalf("server: %v", err) } case <-time.After(timeout): t.Fatal("timeout. deadlock?") } } func TestLimitListenerClose(t *testing.T) { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } defer ln.Close() ln = LimitListener(ln, 1) errCh := make(chan error) go func() { defer close(errCh) c, err := net.DialTimeout("tcp", ln.Addr().String(), timeout) if err != nil { errCh <- err return } c.Close() }() c, err := ln.Accept() if err != nil { t.Fatal(err) } defer c.Close() err = <-errCh if err != nil { t.Fatalf("DialTimeout: %v", err) } acceptDone := make(chan struct{}) go func() { c, err := ln.Accept() if err == nil { c.Close() t.Errorf("Unexpected successful Accept()") } close(acceptDone) }() // Wait a tiny bit to ensure the Accept() is blocking. time.Sleep(10 * time.Millisecond) ln.Close() select { case <-acceptDone: case <-time.After(timeout): t.Fatalf("Accept() still blocking") } }