Browse Source

Merge pull request #1770 from yichengq/230

*: set read/write timeout for raft transport and listener
Yicheng Qin 11 years ago
parent
commit
cfb96de413

+ 3 - 2
etcdmain/etcd.go

@@ -35,6 +35,7 @@ import (
 	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/proxy"
 	"github.com/coreos/etcd/proxy"
+	"github.com/coreos/etcd/rafthttp"
 	"github.com/coreos/etcd/version"
 	"github.com/coreos/etcd/version"
 )
 )
 
 
@@ -209,7 +210,7 @@ func startEtcd() (<-chan struct{}, error) {
 		return nil, fmt.Errorf("cannot write to data directory: %v", err)
 		return nil, fmt.Errorf("cannot write to data directory: %v", err)
 	}
 	}
 
 
-	pt, err := transport.NewTransport(peerTLSInfo)
+	pt, err := transport.NewTimeoutTransport(peerTLSInfo, rafthttp.ConnReadTimeout, rafthttp.ConnWriteTimeout)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -230,7 +231,7 @@ func startEtcd() (<-chan struct{}, error) {
 	plns := make([]net.Listener, 0)
 	plns := make([]net.Listener, 0)
 	for _, u := range lpurls {
 	for _, u := range lpurls {
 		var l net.Listener
 		var l net.Listener
-		l, err = transport.NewListener(u.Host, u.Scheme, peerTLSInfo)
+		l, err = transport.NewTimeoutListener(u.Host, u.Scheme, peerTLSInfo, rafthttp.ConnReadTimeout, rafthttp.ConnWriteTimeout)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}

+ 0 - 0
pkg/transport/timeout_dailer.go → pkg/transport/timeout_dialer.go


+ 0 - 0
pkg/transport/timeout_dailer_test.go → pkg/transport/timeout_dialer_test.go


+ 15 - 0
pkg/transport/timeout_listener.go

@@ -21,6 +21,21 @@ import (
 	"time"
 	"time"
 )
 )
 
 
+// NewTimeoutListener returns a listener that listens on the given address.
+// If read/write on the accepted connection blocks longer than its time limit,
+// it will return timeout error.
+func NewTimeoutListener(addr string, scheme string, info TLSInfo, rdtimeoutd, wtimeoutd time.Duration) (net.Listener, error) {
+	ln, err := NewListener(addr, scheme, info)
+	if err != nil {
+		return nil, err
+	}
+	return &rwTimeoutListener{
+		Listener:   ln,
+		rdtimeoutd: rdtimeoutd,
+		wtimeoutd:  wtimeoutd,
+	}, nil
+}
+
 type rwTimeoutListener struct {
 type rwTimeoutListener struct {
 	net.Listener
 	net.Listener
 	wtimeoutd  time.Duration
 	wtimeoutd  time.Duration

+ 42 - 0
pkg/transport/timeout_transport.go

@@ -0,0 +1,42 @@
+/*
+   Copyright 2014 CoreOS, Inc.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+*/
+
+package transport
+
+import (
+	"net"
+	"net/http"
+	"time"
+)
+
+// NewTimeoutTransport returns a transport created using the given TLS info.
+// If read/write on the created connection blocks longer than its time limit,
+// it will return timeout error.
+func NewTimeoutTransport(info TLSInfo, rdtimeoutd, wtimeoutd time.Duration) (*http.Transport, error) {
+	tr, err := NewTransport(info)
+	if err != nil {
+		return nil, err
+	}
+	tr.Dial = (&rwTimeoutDialer{
+		Dialer: net.Dialer{
+			Timeout:   30 * time.Second,
+			KeepAlive: 30 * time.Second,
+		},
+		rdtimeoutd: rdtimeoutd,
+		wtimeoutd:  wtimeoutd,
+	}).Dial
+	return tr, nil
+}

+ 9 - 1
rafthttp/http.go

@@ -17,6 +17,7 @@
 package rafthttp
 package rafthttp
 
 
 import (
 import (
+	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"log"
 	"log"
 	"net/http"
 	"net/http"
@@ -30,6 +31,10 @@ import (
 	"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
 	"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
 )
 )
 
 
+const (
+	ConnReadLimitByte = 64 * 1024
+)
+
 var (
 var (
 	RaftPrefix       = "/raft"
 	RaftPrefix       = "/raft"
 	RaftStreamPrefix = path.Join(RaftPrefix, "stream")
 	RaftStreamPrefix = path.Join(RaftPrefix, "stream")
@@ -83,7 +88,10 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		return
 		return
 	}
 	}
 
 
-	b, err := ioutil.ReadAll(r.Body)
+	// Limit the data size that could be read from the request body, which ensures that read from
+	// connection will not time out accidentally due to possible block in underlying implementation.
+	limitedr := io.LimitReader(r.Body, ConnReadLimitByte)
+	b, err := ioutil.ReadAll(limitedr)
 	if err != nil {
 	if err != nil {
 		log.Println("rafthttp: error reading raft message:", err)
 		log.Println("rafthttp: error reading raft message:", err)
 		http.Error(w, "error reading raft message", http.StatusBadRequest)
 		http.Error(w, "error reading raft message", http.StatusBadRequest)

+ 3 - 1
rafthttp/sender.go

@@ -35,6 +35,9 @@ const (
 	senderBufSize = connPerSender * 4
 	senderBufSize = connPerSender * 4
 
 
 	appRespBatchMs = 50
 	appRespBatchMs = 50
+
+	ConnReadTimeout  = 5 * time.Second
+	ConnWriteTimeout = 5 * time.Second
 )
 )
 
 
 type Sender interface {
 type Sender interface {
@@ -176,7 +179,6 @@ func (s *sender) initStream(from, to types.ID, term uint64) {
 		return
 		return
 	}
 	}
 	s.strmCln = strmCln
 	s.strmCln = strmCln
-	log.Printf("rafthttp: start stream client with %s in term %d", to, term)
 }
 }
 
 
 func (s *sender) tryStream(m raftpb.Message) bool {
 func (s *sender) tryStream(m raftpb.Message) bool {

+ 10 - 2
rafthttp/streamer.go

@@ -59,6 +59,7 @@ func startStreamServer(w WriteFlusher, to types.ID, term uint64, fs *stats.Follo
 		done: make(chan struct{}),
 		done: make(chan struct{}),
 	}
 	}
 	go s.handle(w)
 	go s.handle(w)
+	log.Printf("rafthttp: stream server to %s at term %d starts", to, term)
 	return s
 	return s
 }
 }
 
 
@@ -85,7 +86,10 @@ func (s *streamServer) stop() {
 func (s *streamServer) stopNotify() <-chan struct{} { return s.done }
 func (s *streamServer) stopNotify() <-chan struct{} { return s.done }
 
 
 func (s *streamServer) handle(w WriteFlusher) {
 func (s *streamServer) handle(w WriteFlusher) {
-	defer close(s.done)
+	defer func() {
+		close(s.done)
+		log.Printf("rafthttp: stream server to %s at term %d is closed", s.to, s.term)
+	}()
 
 
 	ew := &entryWriter{w: w}
 	ew := &entryWriter{w: w}
 	for ents := range s.q {
 	for ents := range s.q {
@@ -145,6 +149,7 @@ func (s *streamClient) start(tr http.RoundTripper, u string, cid types.ID) error
 	}
 	}
 	s.closer = resp.Body
 	s.closer = resp.Body
 	go s.handle(resp.Body)
 	go s.handle(resp.Body)
+	log.Printf("rafthttp: stream client to %s at term %d starts", s.to, s.term)
 	return nil
 	return nil
 }
 }
 
 
@@ -163,7 +168,10 @@ func (s *streamClient) isStopped() bool {
 }
 }
 
 
 func (s *streamClient) handle(r io.Reader) {
 func (s *streamClient) handle(r io.Reader) {
-	defer close(s.done)
+	defer func() {
+		close(s.done)
+		log.Printf("rafthttp: stream client to %s at term %d is closed", s.to, s.term)
+	}()
 
 
 	er := &entryReader{r: r}
 	er := &entryReader{r: r}
 	for {
 	for {