Browse Source

etcdserver/api/v3rpc: support watch fragmentation

Signed-off-by: Gyuho Lee <leegyuho@amazon.com>
Gyuho Lee 6 years ago
parent
commit
5a678bb4e3

+ 20 - 0
etcdserver/api/v3rpc/rpctypes/metadatafields.go

@@ -0,0 +1,20 @@
+// Copyright 2018 The etcd Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rpctypes
+
+var (
+	TokenFieldNameGRPC    = "token"
+	TokenFieldNameSwagger = "authorization"
+)

+ 87 - 12
etcdserver/api/v3rpc/watch.go

@@ -31,6 +31,9 @@ import (
 type watchServer struct {
 	clusterID int64
 	memberID  int64
+
+	maxRequestBytes int
+
 	raftTimer etcdserver.RaftTimer
 	watchable mvcc.WatchableKV
 
@@ -39,11 +42,12 @@ type watchServer struct {
 
 func NewWatchServer(s *etcdserver.EtcdServer) pb.WatchServer {
 	return &watchServer{
-		clusterID: int64(s.Cluster().ID()),
-		memberID:  int64(s.ID()),
-		raftTimer: s,
-		watchable: s.Watchable(),
-		ag:        s,
+		clusterID:       int64(s.Cluster().ID()),
+		memberID:        int64(s.ID()),
+		maxRequestBytes: int(s.Cfg.MaxRequestBytes + grpcOverheadBytes),
+		raftTimer:       s,
+		watchable:       s.Watchable(),
+		ag:              s,
 	}
 }
 
@@ -83,6 +87,9 @@ const (
 type serverWatchStream struct {
 	clusterID int64
 	memberID  int64
+
+	maxRequestBytes int
+
 	raftTimer etcdserver.RaftTimer
 
 	watchable mvcc.WatchableKV
@@ -92,12 +99,14 @@ type serverWatchStream struct {
 	ctrlStream  chan *pb.WatchResponse
 
 	// mu protects progress, prevKV
-	mu sync.Mutex
+	mu sync.RWMutex
 	// progress tracks the watchID that stream might need to send
 	// progress to.
 	// TODO: combine progress and prevKV into a single struct?
 	progress map[mvcc.WatchID]bool
 	prevKV   map[mvcc.WatchID]bool
+	// records fragmented watch IDs
+	fragment map[mvcc.WatchID]bool
 
 	// closec indicates the stream is closed.
 	closec chan struct{}
@@ -112,6 +121,9 @@ func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) {
 	sws := serverWatchStream{
 		clusterID: ws.clusterID,
 		memberID:  ws.memberID,
+
+		maxRequestBytes: ws.maxRequestBytes,
+
 		raftTimer: ws.raftTimer,
 
 		watchable: ws.watchable,
@@ -122,6 +134,7 @@ func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) {
 		ctrlStream: make(chan *pb.WatchResponse, ctrlStreamBufLen),
 		progress:   make(map[mvcc.WatchID]bool),
 		prevKV:     make(map[mvcc.WatchID]bool),
+		fragment:   make(map[mvcc.WatchID]bool),
 		closec:     make(chan struct{}),
 
 		ag: ws.ag,
@@ -238,6 +251,9 @@ func (sws *serverWatchStream) recvLoop() error {
 				if creq.PrevKv {
 					sws.prevKV[id] = true
 				}
+				if creq.Fragment {
+					sws.fragment[id] = true
+				}
 				sws.mu.Unlock()
 			}
 			wr := &pb.WatchResponse{
@@ -264,9 +280,17 @@ func (sws *serverWatchStream) recvLoop() error {
 					sws.mu.Lock()
 					delete(sws.progress, mvcc.WatchID(id))
 					delete(sws.prevKV, mvcc.WatchID(id))
+					delete(sws.fragment, mvcc.WatchID(id))
 					sws.mu.Unlock()
 				}
 			}
+		case *pb.WatchRequest_ProgressRequest:
+			if uv.ProgressRequest != nil {
+				sws.ctrlStream <- &pb.WatchResponse{
+					Header:  sws.newResponseHeader(sws.watchStream.Rev()),
+					WatchId: -1, // response is not associated with any WatchId and will be broadcast to all watch channels
+				}
+			}
 		default:
 			// we probably should not shutdown the entire stream when
 			// receive an valid command.
@@ -310,9 +334,9 @@ func (sws *serverWatchStream) sendLoop() {
 			// or define protocol buffer with []mvccpb.Event.
 			evs := wresp.Events
 			events := make([]*mvccpb.Event, len(evs))
-			sws.mu.Lock()
+			sws.mu.RLock()
 			needPrevKV := sws.prevKV[wresp.WatchID]
-			sws.mu.Unlock()
+			sws.mu.RUnlock()
 			for i := range evs {
 				events[i] = &evs[i]
 
@@ -342,11 +366,23 @@ func (sws *serverWatchStream) sendLoop() {
 			}
 
 			mvcc.ReportEventReceived(len(evs))
-			if err := sws.gRPCStream.Send(wr); err != nil {
-				if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
-					plog.Debugf("failed to send watch response to gRPC stream (%q)", err.Error())
+
+			sws.mu.RLock()
+			fragmented, ok := sws.fragment[wresp.WatchID]
+			sws.mu.RUnlock()
+
+			var serr error
+			if !fragmented && !ok {
+				serr = sws.gRPCStream.Send(wr)
+			} else {
+				serr = sendFragments(wr, sws.maxRequestBytes, sws.gRPCStream.Send)
+			}
+
+			if serr != nil {
+				if isClientCtxErr(sws.gRPCStream.Context().Err(), serr) {
+					plog.Debugf("failed to send watch response to gRPC stream (%q)", serr.Error())
 				} else {
-					plog.Warningf("failed to send watch response to gRPC stream (%q)", err.Error())
+					plog.Warningf("failed to send watch response to gRPC stream (%q)", serr.Error())
 				}
 				return
 			}
@@ -409,6 +445,45 @@ func (sws *serverWatchStream) sendLoop() {
 	}
 }
 
+func sendFragments(
+	wr *pb.WatchResponse,
+	maxRequestBytes int,
+	sendFunc func(*pb.WatchResponse) error) error {
+	// no need to fragment if total request size is smaller
+	// than max request limit or response contains only one event
+	if wr.Size() < maxRequestBytes || len(wr.Events) < 2 {
+		return sendFunc(wr)
+	}
+
+	ow := *wr
+	ow.Events = make([]*mvccpb.Event, 0)
+	ow.Fragment = true
+
+	var idx int
+	for {
+		cur := ow
+		for _, ev := range wr.Events[idx:] {
+			cur.Events = append(cur.Events, ev)
+			if len(cur.Events) > 1 && cur.Size() >= maxRequestBytes {
+				cur.Events = cur.Events[:len(cur.Events)-1]
+				break
+			}
+			idx++
+		}
+		if idx == len(wr.Events) {
+			// last response has no more fragment
+			cur.Fragment = false
+		}
+		if err := sendFunc(&cur); err != nil {
+			return err
+		}
+		if !cur.Fragment {
+			break
+		}
+	}
+	return nil
+}
+
 func (sws *serverWatchStream) close() {
 	sws.watchStream.Close()
 	close(sws.closec)

+ 95 - 0
etcdserver/api/v3rpc/watch_test.go

@@ -0,0 +1,95 @@
+// Copyright 2018 The etcd Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package v3rpc
+
+import (
+	"bytes"
+	"math"
+	"testing"
+
+	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
+	"github.com/coreos/etcd/mvcc/mvccpb"
+)
+
+func TestSendFragment(t *testing.T) {
+	tt := []struct {
+		wr              *pb.WatchResponse
+		maxRequestBytes int
+		fragments       int
+		werr            error
+	}{
+		{ // large limit should not fragment
+			wr:              createResponse(100, 1),
+			maxRequestBytes: math.MaxInt32,
+			fragments:       1,
+		},
+		{ // large limit for two messages, expect no fragment
+			wr:              createResponse(10, 2),
+			maxRequestBytes: 50,
+			fragments:       1,
+		},
+		{ // limit is small but only one message, expect no fragment
+			wr:              createResponse(1024, 1),
+			maxRequestBytes: 1,
+			fragments:       1,
+		},
+		{ // exceed limit only when combined, expect fragments
+			wr:              createResponse(11, 5),
+			maxRequestBytes: 20,
+			fragments:       5,
+		},
+		{ // 5 events with each event exceeding limits, expect fragments
+			wr:              createResponse(15, 5),
+			maxRequestBytes: 10,
+			fragments:       5,
+		},
+		{ // 4 events with some combined events exceeding limits
+			wr:              createResponse(10, 4),
+			maxRequestBytes: 35,
+			fragments:       2,
+		},
+	}
+
+	for i := range tt {
+		fragmentedResp := make([]*pb.WatchResponse, 0)
+		testSend := func(wr *pb.WatchResponse) error {
+			fragmentedResp = append(fragmentedResp, wr)
+			return nil
+		}
+		err := sendFragments(tt[i].wr, tt[i].maxRequestBytes, testSend)
+		if err != tt[i].werr {
+			t.Errorf("#%d: expected error %v, got %v", i, tt[i].werr, err)
+		}
+		got := len(fragmentedResp)
+		if got != tt[i].fragments {
+			t.Errorf("#%d: expected response number %d, got %d", i, tt[i].fragments, got)
+		}
+		if got > 0 && fragmentedResp[got-1].Fragment {
+			t.Errorf("#%d: expected fragment=false in last response, got %+v", i, fragmentedResp[got-1])
+		}
+	}
+}
+
+func createResponse(dataSize, events int) (resp *pb.WatchResponse) {
+	resp = &pb.WatchResponse{Events: make([]*mvccpb.Event, events)}
+	for i := range resp.Events {
+		resp.Events[i] = &mvccpb.Event{
+			Kv: &mvccpb.KeyValue{
+				Key: bytes.Repeat([]byte("a"), dataSize),
+			},
+		}
+	}
+	return resp
+}