Browse Source

raft: introduce top-level context in message struct

swingbach@gmail.com 9 years ago
parent
commit
c36a40ca15
3 changed files with 47 additions and 4 deletions
  1. 4 4
      raft/raft.go
  2. 42 0
      raft/raftpb/raft.pb.go
  3. 1 0
      raft/raftpb/raft.proto

+ 4 - 4
raft/raft.go

@@ -547,11 +547,11 @@ func (r *raft) campaign(t CampaignType) {
 		r.logger.Infof("%x [logterm: %d, index: %d] sent vote request to %x at term %d",
 			r.id, r.raftLog.lastTerm(), r.raftLog.lastIndex(), id, r.Term)
 
-		var entries []pb.Entry
+		var ctx []byte
 		if t == campaignTransfer {
-			entries = []pb.Entry{{Data: []byte(t)}}
+			ctx = []byte(t)
 		}
-		r.send(pb.Message{To: id, Type: pb.MsgVote, Index: r.raftLog.lastIndex(), LogTerm: r.raftLog.lastTerm(), Entries: entries})
+		r.send(pb.Message{To: id, Type: pb.MsgVote, Index: r.raftLog.lastIndex(), LogTerm: r.raftLog.lastTerm(), Context: ctx})
 	}
 }
 
@@ -594,7 +594,7 @@ func (r *raft) Step(m pb.Message) error {
 	case m.Term > r.Term:
 		lead := m.From
 		if m.Type == pb.MsgVote {
-			force := len(m.Entries) == 1 && bytes.Equal(m.Entries[0].Data, []byte(campaignTransfer))
+			force := bytes.Equal(m.Context, []byte(campaignTransfer))
 			inLease := r.checkQuorum && r.state != StateCandidate && r.electionElapsed < r.electionTimeout
 			if !force && inLease {
 				// If a server receives a RequestVote request within the minimum election timeout

+ 42 - 0
raft/raftpb/raft.pb.go

@@ -236,6 +236,7 @@ type Message struct {
 	Snapshot         Snapshot    `protobuf:"bytes,9,opt,name=snapshot" json:"snapshot"`
 	Reject           bool        `protobuf:"varint,10,opt,name=reject" json:"reject"`
 	RejectHint       uint64      `protobuf:"varint,11,opt,name=rejectHint" json:"rejectHint"`
+	Context          []byte      `protobuf:"bytes,12,opt,name=context" json:"context,omitempty"`
 	XXX_unrecognized []byte      `json:"-"`
 }
 
@@ -464,6 +465,12 @@ func (m *Message) MarshalTo(data []byte) (int, error) {
 	data[i] = 0x58
 	i++
 	i = encodeVarintRaft(data, i, uint64(m.RejectHint))
+	if m.Context != nil {
+		data[i] = 0x62
+		i++
+		i = encodeVarintRaft(data, i, uint64(len(m.Context)))
+		i += copy(data[i:], m.Context)
+	}
 	if m.XXX_unrecognized != nil {
 		i += copy(data[i:], m.XXX_unrecognized)
 	}
@@ -655,6 +662,10 @@ func (m *Message) Size() (n int) {
 	n += 1 + l + sovRaft(uint64(l))
 	n += 2
 	n += 1 + sovRaft(uint64(m.RejectHint))
+	if m.Context != nil {
+		l = len(m.Context)
+		n += 1 + l + sovRaft(uint64(l))
+	}
 	if m.XXX_unrecognized != nil {
 		n += len(m.XXX_unrecognized)
 	}
@@ -1348,6 +1359,37 @@ func (m *Message) Unmarshal(data []byte) error {
 					break
 				}
 			}
+		case 12:
+			if wireType != 2 {
+				return fmt.Errorf("proto: wrong wireType = %d for field Context", wireType)
+			}
+			var byteLen int
+			for shift := uint(0); ; shift += 7 {
+				if shift >= 64 {
+					return ErrIntOverflowRaft
+				}
+				if iNdEx >= l {
+					return io.ErrUnexpectedEOF
+				}
+				b := data[iNdEx]
+				iNdEx++
+				byteLen |= (int(b) & 0x7F) << shift
+				if b < 0x80 {
+					break
+				}
+			}
+			if byteLen < 0 {
+				return ErrInvalidLengthRaft
+			}
+			postIndex := iNdEx + byteLen
+			if postIndex > l {
+				return io.ErrUnexpectedEOF
+			}
+			m.Context = append(m.Context[:0], data[iNdEx:postIndex]...)
+			if m.Context == nil {
+				m.Context = []byte{}
+			}
+			iNdEx = postIndex
 		default:
 			iNdEx = preIndex
 			skippy, err := skipRaft(data[iNdEx:])

+ 1 - 0
raft/raftpb/raft.proto

@@ -64,6 +64,7 @@ message Message {
 	optional Snapshot    snapshot    = 9  [(gogoproto.nullable) = false];
 	optional bool        reject      = 10 [(gogoproto.nullable) = false];
 	optional uint64      rejectHint  = 11 [(gogoproto.nullable) = false];
+	optional bytes       context     = 12;
 }
 
 message HardState {