Browse Source

Merge pull request #1779 from yichengq/233

rafthttp: limit the buffer for every read correctly
Yicheng Qin 11 years ago
parent
commit
d67eea4a7d
4 changed files with 79 additions and 3 deletions
  1. 41 0
      pkg/ioutils/reader.go
  2. 35 0
      pkg/ioutils/reader_test.go
  3. 2 2
      rafthttp/http.go
  4. 1 1
      test

+ 41 - 0
pkg/ioutils/reader.go

@@ -0,0 +1,41 @@
+/*
+   Copyright 2014 CoreOS, Inc.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+*/
+
+package ioutils
+
+import "io"
+
+// NewLimitedBufferReader returns a reader that reads from the given reader
+// but limits the amount of data returned to at most n bytes.
+func NewLimitedBufferReader(r io.Reader, n int) io.Reader {
+	return &limitedBufferReader{
+		r: r,
+		n: n,
+	}
+}
+
+type limitedBufferReader struct {
+	r io.Reader
+	n int
+}
+
+func (r *limitedBufferReader) Read(p []byte) (n int, err error) {
+	np := p
+	if len(np) > r.n {
+		np = np[:r.n]
+	}
+	return r.r.Read(np)
+}

+ 35 - 0
pkg/ioutils/reader_test.go

@@ -0,0 +1,35 @@
+/*
+   Copyright 2014 CoreOS, Inc.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+*/
+
+package ioutils
+
+import (
+	"bytes"
+	"testing"
+)
+
+func TestLimitedBufferReaderRead(t *testing.T) {
+	buf := bytes.NewBuffer(make([]byte, 10))
+	ln := 1
+	lr := NewLimitedBufferReader(buf, ln)
+	n, err := lr.Read(make([]byte, 10))
+	if err != nil {
+		t.Fatalf("unexpected read error: %v", err)
+	}
+	if n != ln {
+		t.Errorf("len(data read) = %d, want %d", n, ln)
+	}
+}

+ 2 - 2
rafthttp/http.go

@@ -17,7 +17,6 @@
 package rafthttp
 
 import (
-	"io"
 	"io/ioutil"
 	"log"
 	"net/http"
@@ -25,6 +24,7 @@ import (
 	"strconv"
 	"strings"
 
+	"github.com/coreos/etcd/pkg/ioutils"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/raft/raftpb"
 
@@ -90,7 +90,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 
 	// Limit the data size that could be read from the request body, which ensures that read from
 	// connection will not time out accidentally due to possible block in underlying implementation.
-	limitedr := io.LimitReader(r.Body, ConnReadLimitByte)
+	limitedr := ioutils.NewLimitedBufferReader(r.Body, ConnReadLimitByte)
 	b, err := ioutil.ReadAll(limitedr)
 	if err != nil {
 		log.Println("rafthttp: error reading raft message:", err)

+ 1 - 1
test

@@ -15,7 +15,7 @@ COVER=${COVER:-"-cover"}
 source ./build
 
 # Hack: gofmt ./ will recursively check the .git directory. So use *.go for gofmt.
-TESTABLE_AND_FORMATTABLE="client discovery error etcdctl/command etcdmain etcdserver etcdserver/etcdhttp etcdserver/etcdhttp/httptypes etcdserver/etcdserverpb integration migrate pkg/flags pkg/types pkg/transport pkg/wait proxy raft rafthttp snap store wal"
+TESTABLE_AND_FORMATTABLE="client discovery error etcdctl/command etcdmain etcdserver etcdserver/etcdhttp etcdserver/etcdhttp/httptypes etcdserver/etcdserverpb integration migrate pkg/flags pkg/ioutils pkg/types pkg/transport pkg/wait proxy raft rafthttp snap store wal"
 FORMATTABLE="$TESTABLE_AND_FORMATTABLE *.go etcdctl/"
 
 # user has not provided PKG override