Browse Source

etcdserver/api/v3rpc: support watch fragmentation with max request bytes

Signed-off-by: Gyuho Lee <gyuhox@gmail.com>
Gyuho Lee 7 years ago
parent
commit
294b5745d6
2 changed files with 168 additions and 7 deletions
  1. 73 7
      etcdserver/api/v3rpc/watch.go
  2. 95 0
      etcdserver/api/v3rpc/watch_test.go

+ 73 - 7
etcdserver/api/v3rpc/watch.go

@@ -37,6 +37,8 @@ type watchServer struct {
 	clusterID int64
 	memberID  int64
 
+	maxRequestBytes int
+
 	sg        etcdserver.RaftStatusGetter
 	watchable mvcc.WatchableKV
 	ag        AuthGetter
@@ -50,6 +52,8 @@ func NewWatchServer(s *etcdserver.EtcdServer) pb.WatchServer {
 		clusterID: int64(s.Cluster().ID()),
 		memberID:  int64(s.ID()),
 
+		maxRequestBytes: int(s.Cfg.MaxRequestBytes + grpcOverheadBytes),
+
 		sg:        s,
 		watchable: s.Watchable(),
 		ag:        s,
@@ -102,6 +106,8 @@ type serverWatchStream struct {
 	clusterID int64
 	memberID  int64
 
+	maxRequestBytes int
+
 	sg        etcdserver.RaftStatusGetter
 	watchable mvcc.WatchableKV
 	ag        AuthGetter
@@ -110,13 +116,15 @@ type serverWatchStream struct {
 	watchStream mvcc.WatchStream
 	ctrlStream  chan *pb.WatchResponse
 
-	// mu protects progress, prevKV
+	// mu protects progress, prevKV, fragment
 	mu sync.RWMutex
 	// tracks the watchID that stream might need to send progress to
 	// TODO: combine progress and prevKV into a single struct?
 	progress map[mvcc.WatchID]bool
 	// record watch IDs that need return previous key-value pair
 	prevKV map[mvcc.WatchID]bool
+	// records fragmented watch IDs
+	fragment map[mvcc.WatchID]bool
 
 	// closec indicates the stream is closed.
 	closec chan struct{}
@@ -132,6 +140,8 @@ func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) {
 		clusterID: ws.clusterID,
 		memberID:  ws.memberID,
 
+		maxRequestBytes: ws.maxRequestBytes,
+
 		sg:        ws.sg,
 		watchable: ws.watchable,
 		ag:        ws.ag,
@@ -143,6 +153,7 @@ func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) {
 
 		progress: make(map[mvcc.WatchID]bool),
 		prevKV:   make(map[mvcc.WatchID]bool),
+		fragment: make(map[mvcc.WatchID]bool),
 
 		closec: make(chan struct{}),
 	}
@@ -268,6 +279,9 @@ func (sws *serverWatchStream) recvLoop() error {
 				if creq.PrevKv {
 					sws.prevKV[id] = true
 				}
+				if creq.Fragment {
+					sws.fragment[id] = true
+				}
 				sws.mu.Unlock()
 			}
 			wr := &pb.WatchResponse{
@@ -298,6 +312,7 @@ func (sws *serverWatchStream) recvLoop() error {
 					sws.mu.Lock()
 					delete(sws.progress, mvcc.WatchID(id))
 					delete(sws.prevKV, mvcc.WatchID(id))
+					delete(sws.fragment, mvcc.WatchID(id))
 					sws.mu.Unlock()
 				}
 			}
@@ -376,18 +391,30 @@ func (sws *serverWatchStream) sendLoop() {
 			}
 
 			mvcc.ReportEventReceived(len(evs))
-			if err := sws.gRPCStream.Send(wr); err != nil {
-				if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
+
+			sws.mu.RLock()
+			fragmented, ok := sws.fragment[wresp.WatchID]
+			sws.mu.RUnlock()
+
+			var serr error
+			if !fragmented && !ok {
+				serr = sws.gRPCStream.Send(wr)
+			} else {
+				serr = sendFragments(wr, sws.maxRequestBytes, sws.gRPCStream.Send)
+			}
+
+			if serr != nil {
+				if isClientCtxErr(sws.gRPCStream.Context().Err(), serr) {
 					if sws.lg != nil {
-						sws.lg.Debug("failed to send watch response to gRPC stream", zap.Error(err))
+						sws.lg.Debug("failed to send watch response to gRPC stream", zap.Error(serr))
 					} else {
-						plog.Debugf("failed to send watch response to gRPC stream (%q)", err.Error())
+						plog.Debugf("failed to send watch response to gRPC stream (%q)", serr.Error())
 					}
 				} else {
 					if sws.lg != nil {
-						sws.lg.Warn("failed to send watch response to gRPC stream", zap.Error(err))
+						sws.lg.Warn("failed to send watch response to gRPC stream", zap.Error(serr))
 					} else {
-						plog.Warningf("failed to send watch response to gRPC stream (%q)", err.Error())
+						plog.Warningf("failed to send watch response to gRPC stream (%q)", serr.Error())
 					}
 				}
 				return
@@ -469,6 +496,45 @@ func (sws *serverWatchStream) sendLoop() {
 	}
 }
 
+func sendFragments(
+	wr *pb.WatchResponse,
+	maxRequestBytes int,
+	sendFunc func(*pb.WatchResponse) error) error {
+	// no need to fragment if total request size is smaller
+	// than max request limit or response contains only one event
+	if wr.Size() < maxRequestBytes || len(wr.Events) < 2 {
+		return sendFunc(wr)
+	}
+
+	ow := *wr
+	ow.Events = make([]*mvccpb.Event, 0)
+	ow.Fragment = true
+
+	var idx int
+	for {
+		cur := ow
+		for _, ev := range wr.Events[idx:] {
+			cur.Events = append(cur.Events, ev)
+			if len(cur.Events) > 1 && cur.Size() >= maxRequestBytes {
+				cur.Events = cur.Events[:len(cur.Events)-1]
+				break
+			}
+			idx++
+		}
+		if idx == len(wr.Events) {
+			// last response has no more fragment
+			cur.Fragment = false
+		}
+		if err := sendFunc(&cur); err != nil {
+			return err
+		}
+		if !cur.Fragment {
+			break
+		}
+	}
+	return nil
+}
+
 func (sws *serverWatchStream) close() {
 	sws.watchStream.Close()
 	close(sws.closec)

+ 95 - 0
etcdserver/api/v3rpc/watch_test.go

@@ -0,0 +1,95 @@
+// Copyright 2018 The etcd Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package v3rpc
+
+import (
+	"bytes"
+	"math"
+	"testing"
+
+	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
+	"github.com/coreos/etcd/mvcc/mvccpb"
+)
+
+func TestSendFragment(t *testing.T) {
+	tt := []struct {
+		wr              *pb.WatchResponse
+		maxRequestBytes int
+		fragments       int
+		werr            error
+	}{
+		{ // large limit should not fragment
+			wr:              createResponse(100, 1),
+			maxRequestBytes: math.MaxInt32,
+			fragments:       1,
+		},
+		{ // large limit for two messages, expect no fragment
+			wr:              createResponse(10, 2),
+			maxRequestBytes: 50,
+			fragments:       1,
+		},
+		{ // limit is small but only one message, expect no fragment
+			wr:              createResponse(1024, 1),
+			maxRequestBytes: 1,
+			fragments:       1,
+		},
+		{ // exceed limit only when combined, expect fragments
+			wr:              createResponse(11, 5),
+			maxRequestBytes: 20,
+			fragments:       5,
+		},
+		{ // 5 events with each event exceeding limits, expect fragments
+			wr:              createResponse(15, 5),
+			maxRequestBytes: 10,
+			fragments:       5,
+		},
+		{ // 4 events with some combined events exceeding limits
+			wr:              createResponse(10, 4),
+			maxRequestBytes: 35,
+			fragments:       2,
+		},
+	}
+
+	for i := range tt {
+		fragmentedResp := make([]*pb.WatchResponse, 0)
+		testSend := func(wr *pb.WatchResponse) error {
+			fragmentedResp = append(fragmentedResp, wr)
+			return nil
+		}
+		err := sendFragments(tt[i].wr, tt[i].maxRequestBytes, testSend)
+		if err != tt[i].werr {
+			t.Errorf("#%d: expected error %v, got %v", i, tt[i].werr, err)
+		}
+		got := len(fragmentedResp)
+		if got != tt[i].fragments {
+			t.Errorf("#%d: expected response number %d, got %d", i, tt[i].fragments, got)
+		}
+		if got > 0 && fragmentedResp[got-1].Fragment {
+			t.Errorf("#%d: expected fragment=false in last response, got %+v", i, fragmentedResp[got-1])
+		}
+	}
+}
+
+func createResponse(dataSize, events int) (resp *pb.WatchResponse) {
+	resp = &pb.WatchResponse{Events: make([]*mvccpb.Event, events)}
+	for i := range resp.Events {
+		resp.Events[i] = &mvccpb.Event{
+			Kv: &mvccpb.KeyValue{
+				Key: bytes.Repeat([]byte("a"), dataSize),
+			},
+		}
+	}
+	return resp
+}