Browse Source

Merge pull request #2628 from yichengq/improve-msgappv2

rafthttp: reduce allocs in msgappv2
Yicheng Qin 10 years ago
parent
commit
fc4543a3fd
3 changed files with 84 additions and 28 deletions
  1. 80 24
      rafthttp/msgappv2.go
  2. 2 2
      rafthttp/msgappv2_test.go
  3. 2 2
      rafthttp/stream.go

+ 80 - 24
rafthttp/msgappv2.go

@@ -30,6 +30,8 @@ const (
 	msgTypeLinkHeartbeat uint8 = 0
 	msgTypeAppEntries    uint8 = 1
 	msgTypeApp           uint8 = 2
+
+	msgAppV2BufSize = 1024 * 1024
 )
 
 // msgappv2 stream sends three types of message: linkHeartbeatMessage,
@@ -63,38 +65,67 @@ type msgAppV2Encoder struct {
 	w  io.Writer
 	fs *stats.FollowerStats
 
-	term  uint64
-	index uint64
+	term      uint64
+	index     uint64
+	buf       []byte
+	uint64buf []byte
+	uint8buf  []byte
+}
+
+func newMsgAppV2Encoder(w io.Writer, fs *stats.FollowerStats) *msgAppV2Encoder {
+	return &msgAppV2Encoder{
+		w:         w,
+		fs:        fs,
+		buf:       make([]byte, msgAppV2BufSize),
+		uint64buf: make([]byte, 8),
+		uint8buf:  make([]byte, 1),
+	}
 }
 
 func (enc *msgAppV2Encoder) encode(m raftpb.Message) error {
 	start := time.Now()
 	switch {
 	case isLinkHeartbeatMessage(m):
-		return binary.Write(enc.w, binary.BigEndian, msgTypeLinkHeartbeat)
+		enc.uint8buf[0] = byte(msgTypeLinkHeartbeat)
+		if _, err := enc.w.Write(enc.uint8buf); err != nil {
+			return err
+		}
 	case enc.index == m.Index && enc.term == m.LogTerm && m.LogTerm == m.Term:
-		if err := binary.Write(enc.w, binary.BigEndian, msgTypeAppEntries); err != nil {
+		enc.uint8buf[0] = byte(msgTypeAppEntries)
+		if _, err := enc.w.Write(enc.uint8buf); err != nil {
 			return err
 		}
 		// write length of entries
-		l := len(m.Entries)
-		if err := binary.Write(enc.w, binary.BigEndian, uint64(l)); err != nil {
+		binary.BigEndian.PutUint64(enc.uint64buf, uint64(len(m.Entries)))
+		if _, err := enc.w.Write(enc.uint64buf); err != nil {
 			return err
 		}
-		for i := 0; i < l; i++ {
-			size := m.Entries[i].Size()
-			if err := binary.Write(enc.w, binary.BigEndian, uint64(size)); err != nil {
+		for i := 0; i < len(m.Entries); i++ {
+			// write length of entry
+			binary.BigEndian.PutUint64(enc.uint64buf, uint64(m.Entries[i].Size()))
+			if _, err := enc.w.Write(enc.uint64buf); err != nil {
 				return err
 			}
-			if _, err := enc.w.Write(pbutil.MustMarshal(&m.Entries[i])); err != nil {
-				return err
+			if n := m.Entries[i].Size(); n < msgAppV2BufSize {
+				if _, err := m.Entries[i].MarshalTo(enc.buf); err != nil {
+					return err
+				}
+				if _, err := enc.w.Write(enc.buf[:n]); err != nil {
+					return err
+				}
+			} else {
+				if _, err := enc.w.Write(pbutil.MustMarshal(&m.Entries[i])); err != nil {
+					return err
+				}
 			}
 			enc.index++
 		}
 		// write commit index
-		if err := binary.Write(enc.w, binary.BigEndian, m.Commit); err != nil {
+		binary.BigEndian.PutUint64(enc.uint64buf, m.Commit)
+		if _, err := enc.w.Write(enc.uint64buf); err != nil {
 			return err
 		}
+		enc.fs.Succ(time.Since(start))
 	default:
 		if err := binary.Write(enc.w, binary.BigEndian, msgTypeApp); err != nil {
 			return err
@@ -113,8 +144,8 @@ func (enc *msgAppV2Encoder) encode(m raftpb.Message) error {
 		if l := len(m.Entries); l > 0 {
 			enc.index = m.Entries[l-1].Index
 		}
+		enc.fs.Succ(time.Since(start))
 	}
-	enc.fs.Succ(time.Since(start))
 	return nil
 }
 
@@ -122,8 +153,22 @@ type msgAppV2Decoder struct {
 	r             io.Reader
 	local, remote types.ID
 
-	term  uint64
-	index uint64
+	term      uint64
+	index     uint64
+	buf       []byte
+	uint64buf []byte
+	uint8buf  []byte
+}
+
+func newMsgAppV2Decoder(r io.Reader, local, remote types.ID) *msgAppV2Decoder {
+	return &msgAppV2Decoder{
+		r:         r,
+		local:     local,
+		remote:    remote,
+		buf:       make([]byte, msgAppV2BufSize),
+		uint64buf: make([]byte, 8),
+		uint8buf:  make([]byte, 1),
+	}
 }
 
 func (dec *msgAppV2Decoder) decode() (raftpb.Message, error) {
@@ -131,9 +176,10 @@ func (dec *msgAppV2Decoder) decode() (raftpb.Message, error) {
 		m   raftpb.Message
 		typ uint8
 	)
-	if err := binary.Read(dec.r, binary.BigEndian, &typ); err != nil {
+	if _, err := io.ReadFull(dec.r, dec.uint8buf); err != nil {
 		return m, err
 	}
+	typ = uint8(dec.uint8buf[0])
 	switch typ {
 	case msgTypeLinkHeartbeat:
 		return linkHeartbeatMessage, nil
@@ -148,27 +194,37 @@ func (dec *msgAppV2Decoder) decode() (raftpb.Message, error) {
 		}
 
 		// decode entries
-		var l uint64
-		if err := binary.Read(dec.r, binary.BigEndian, &l); err != nil {
+		if _, err := io.ReadFull(dec.r, dec.uint64buf); err != nil {
 			return m, err
 		}
+		l := binary.BigEndian.Uint64(dec.uint64buf)
 		m.Entries = make([]raftpb.Entry, int(l))
 		for i := 0; i < int(l); i++ {
-			var size uint64
-			if err := binary.Read(dec.r, binary.BigEndian, &size); err != nil {
+			if _, err := io.ReadFull(dec.r, dec.uint64buf); err != nil {
 				return m, err
 			}
-			buf := make([]byte, int(size))
-			if _, err := io.ReadFull(dec.r, buf); err != nil {
-				return m, err
+			size := binary.BigEndian.Uint64(dec.uint64buf)
+			var buf []byte
+			if size < msgAppV2BufSize {
+				buf = dec.buf[:size]
+				if _, err := io.ReadFull(dec.r, buf); err != nil {
+					return m, err
+				}
+			} else {
+				buf = make([]byte, int(size))
+				if _, err := io.ReadFull(dec.r, buf); err != nil {
+					return m, err
+				}
 			}
 			dec.index++
+			// 1 alloc
 			pbutil.MustUnmarshal(&m.Entries[i], buf)
 		}
 		// decode commit index
-		if err := binary.Read(dec.r, binary.BigEndian, &m.Commit); err != nil {
+		if _, err := io.ReadFull(dec.r, dec.uint64buf); err != nil {
 			return m, err
 		}
+		m.Commit = binary.BigEndian.Uint64(dec.uint64buf)
 	case msgTypeApp:
 		var size uint64
 		if err := binary.Read(dec.r, binary.BigEndian, &size); err != nil {

+ 2 - 2
rafthttp/msgappv2_test.go

@@ -103,8 +103,8 @@ func TestMsgAppV2(t *testing.T) {
 		linkHeartbeatMessage,
 	}
 	b := &bytes.Buffer{}
-	enc := &msgAppV2Encoder{w: b, fs: &stats.FollowerStats{}}
-	dec := &msgAppV2Decoder{r: b, local: types.ID(2), remote: types.ID(1)}
+	enc := newMsgAppV2Encoder(b, &stats.FollowerStats{})
+	dec := newMsgAppV2Decoder(b, types.ID(2), types.ID(1))
 
 	for i, tt := range tests {
 		if err := enc.encode(tt); err != nil {

+ 2 - 2
rafthttp/stream.go

@@ -162,7 +162,7 @@ func (cw *streamWriter) run() {
 				}
 				enc = &msgAppEncoder{w: conn.Writer, fs: cw.fs}
 			case streamTypeMsgAppV2:
-				enc = &msgAppV2Encoder{w: conn.Writer, fs: cw.fs}
+				enc = newMsgAppV2Encoder(conn.Writer, cw.fs)
 			case streamTypeMessage:
 				enc = &messageEncoder{w: conn.Writer}
 			default:
@@ -283,7 +283,7 @@ func (cr *streamReader) decodeLoop(rc io.ReadCloser) error {
 	case streamTypeMsgApp:
 		dec = &msgAppDecoder{r: rc, local: cr.from, remote: cr.to, term: cr.msgAppTerm}
 	case streamTypeMsgAppV2:
-		dec = &msgAppV2Decoder{r: rc, local: cr.from, remote: cr.to}
+		dec = newMsgAppV2Decoder(rc, cr.from, cr.to)
 	case streamTypeMessage:
 		dec = &messageDecoder{r: rc}
 	default: