|
|
@@ -1,7 +1,9 @@
|
|
|
package etcdhttp
|
|
|
|
|
|
import (
|
|
|
+ "bytes"
|
|
|
"errors"
|
|
|
+ "io"
|
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
|
"net/url"
|
|
|
@@ -13,7 +15,9 @@ import (
|
|
|
"time"
|
|
|
|
|
|
etcdErr "github.com/coreos/etcd/error"
|
|
|
+ "github.com/coreos/etcd/etcdserver"
|
|
|
"github.com/coreos/etcd/etcdserver/etcdserverpb"
|
|
|
+ "github.com/coreos/etcd/raft/raftpb"
|
|
|
"github.com/coreos/etcd/store"
|
|
|
"github.com/coreos/etcd/third_party/code.google.com/p/go.net/context"
|
|
|
)
|
|
|
@@ -747,3 +751,86 @@ func TestAllowMethod(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+// errServer implements the etcd.Server interface for testing.
|
|
|
+// It returns the given error from any Do/Process calls.
|
|
|
+type errServer struct {
|
|
|
+ err error
|
|
|
+}
|
|
|
+
|
|
|
+func (fs *errServer) Do(ctx context.Context, r etcdserverpb.Request) (etcdserver.Response, error) {
|
|
|
+ return etcdserver.Response{}, fs.err
|
|
|
+}
|
|
|
+func (fs *errServer) Process(ctx context.Context, m raftpb.Message) error {
|
|
|
+ return fs.err
|
|
|
+}
|
|
|
+func (fs *errServer) Start() {}
|
|
|
+func (fs *errServer) Stop() {}
|
|
|
+
|
|
|
+// errReader implements io.Reader to facilitate a broken request.
|
|
|
+type errReader struct{}
|
|
|
+
|
|
|
+func (er *errReader) Read(_ []byte) (int, error) { return 0, errors.New("some error") }
|
|
|
+
|
|
|
+func mustMarshalMsg(t *testing.T, m raftpb.Message) []byte {
|
|
|
+ json, err := m.Marshal()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("error marshalling raft Message: %#v", err)
|
|
|
+ }
|
|
|
+ return json
|
|
|
+}
|
|
|
+
|
|
|
+func TestServeRaft(t *testing.T) {
|
|
|
+ testCases := []struct {
|
|
|
+ reqBody io.Reader
|
|
|
+ serverErr error
|
|
|
+ wcode int
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ &errReader{},
|
|
|
+ nil,
|
|
|
+ http.StatusBadRequest,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ strings.NewReader("malformed garbage"),
|
|
|
+ nil,
|
|
|
+ http.StatusBadRequest,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ bytes.NewReader(
|
|
|
+ mustMarshalMsg(
|
|
|
+ t,
|
|
|
+ raftpb.Message{},
|
|
|
+ ),
|
|
|
+ ),
|
|
|
+ errors.New("some error"),
|
|
|
+ http.StatusInternalServerError,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ bytes.NewReader(
|
|
|
+ mustMarshalMsg(
|
|
|
+ t,
|
|
|
+ raftpb.Message{},
|
|
|
+ ),
|
|
|
+ ),
|
|
|
+ nil,
|
|
|
+ http.StatusNoContent,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ for i, tt := range testCases {
|
|
|
+ req, err := http.NewRequest("POST", "foo", tt.reqBody)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("#%d: could not create request: %#v", i, err)
|
|
|
+ }
|
|
|
+ h := &serverHandler{
|
|
|
+ timeout: time.Hour,
|
|
|
+ server: &errServer{tt.serverErr},
|
|
|
+ peers: nil,
|
|
|
+ }
|
|
|
+ rw := httptest.NewRecorder()
|
|
|
+ h.serveRaft(rw, req)
|
|
|
+ if rw.Code != tt.wcode {
|
|
|
+ t.Errorf("#%d: got code=%d, want %d", i, rw.Code, tt.wcode)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|