Browse Source

raft: make State a protobuf type

Xiang Li 11 years ago
parent
commit
8e06333d45
7 changed files with 238 additions and 18 deletions
  1. 1 1
      etcd/participant.go
  2. 1 1
      raft/node.go
  3. 4 6
      raft/raft.go
  4. 210 0
      raft/state.pb.go
  5. 14 0
      raft/state.proto
  6. 3 5
      wal/wal.go
  7. 5 5
      wal/wal_test.go

+ 1 - 1
etcd/participant.go

@@ -391,7 +391,7 @@ func (p *participant) save(ents []raft.Entry, state raft.State) {
 			log.Panicf("id=%x participant.save saveEntryErr=%q", p.id, err)
 		}
 	}
-	if state != raft.EmptyState {
+	if !state.IsEmpty() {
 		if err := p.w.SaveState(&state); err != nil {
 			log.Panicf("id=%x participant.save saveStateErr=%q", p.id, err)
 		}

+ 1 - 1
raft/node.go

@@ -238,7 +238,7 @@ func (n *Node) UnstableEnts() []Entry {
 }
 
 func (n *Node) UnstableState() State {
-	if n.sm.unstableState == EmptyState {
+	if n.sm.unstableState.IsEmpty() {
 		return EmptyState
 	}
 	s := n.sm.unstableState

+ 4 - 6
raft/raft.go

@@ -66,12 +66,6 @@ func (st stateType) String() string {
 	return stmap[int64(st)]
 }
 
-type State struct {
-	Term   int64
-	Vote   int64
-	Commit int64
-}
-
 var EmptyState = State{}
 
 type Message struct {
@@ -594,3 +588,7 @@ func (sm *stateMachine) loadState(state State) {
 	sm.setTerm(state.Term)
 	sm.setVote(state.Vote)
 }
+
+func (s *State) IsEmpty() bool {
+	return s.Term == 0
+}

+ 210 - 0
raft/state.pb.go

@@ -0,0 +1,210 @@
+// Code generated by protoc-gen-gogo.
+// source: state.proto
+// DO NOT EDIT!
+
+/*
+	Package raft is a generated protocol buffer package.
+
+	It is generated from these files:
+		state.proto
+
+	It has these top-level messages:
+		State
+*/
+package raft
+
+import proto "code.google.com/p/gogoprotobuf/proto"
+import json "encoding/json"
+import math "math"
+
+// discarding unused import gogoproto "code.google.com/p/gogoprotobuf/gogoproto/gogo.pb"
+
+import io "io"
+import code_google_com_p_gogoprotobuf_proto "code.google.com/p/gogoprotobuf/proto"
+
+// Reference proto, json, and math imports to suppress error if they are not otherwise used.
+var _ = proto.Marshal
+var _ = &json.SyntaxError{}
+var _ = math.Inf
+
+type State struct {
+	Term             int64  `protobuf:"varint,1,req,name=term" json:"term"`
+	Vote             int64  `protobuf:"varint,2,req,name=vote" json:"vote"`
+	Commit           int64  `protobuf:"varint,3,req,name=commit" json:"commit"`
+	XXX_unrecognized []byte `json:"-"`
+}
+
+func (m *State) Reset()         { *m = State{} }
+func (m *State) String() string { return proto.CompactTextString(m) }
+func (*State) ProtoMessage()    {}
+
+func init() {
+}
+func (m *State) Unmarshal(data []byte) error {
+	l := len(data)
+	index := 0
+	for index < l {
+		var wire uint64
+		for shift := uint(0); ; shift += 7 {
+			if index >= l {
+				return io.ErrUnexpectedEOF
+			}
+			b := data[index]
+			index++
+			wire |= (uint64(b) & 0x7F) << shift
+			if b < 0x80 {
+				break
+			}
+		}
+		fieldNum := int32(wire >> 3)
+		wireType := int(wire & 0x7)
+		switch fieldNum {
+		case 1:
+			if wireType != 0 {
+				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+			}
+			for shift := uint(0); ; shift += 7 {
+				if index >= l {
+					return io.ErrUnexpectedEOF
+				}
+				b := data[index]
+				index++
+				m.Term |= (int64(b) & 0x7F) << shift
+				if b < 0x80 {
+					break
+				}
+			}
+		case 2:
+			if wireType != 0 {
+				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+			}
+			for shift := uint(0); ; shift += 7 {
+				if index >= l {
+					return io.ErrUnexpectedEOF
+				}
+				b := data[index]
+				index++
+				m.Vote |= (int64(b) & 0x7F) << shift
+				if b < 0x80 {
+					break
+				}
+			}
+		case 3:
+			if wireType != 0 {
+				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+			}
+			for shift := uint(0); ; shift += 7 {
+				if index >= l {
+					return io.ErrUnexpectedEOF
+				}
+				b := data[index]
+				index++
+				m.Commit |= (int64(b) & 0x7F) << shift
+				if b < 0x80 {
+					break
+				}
+			}
+		default:
+			var sizeOfWire int
+			for {
+				sizeOfWire++
+				wire >>= 7
+				if wire == 0 {
+					break
+				}
+			}
+			index -= sizeOfWire
+			skippy, err := code_google_com_p_gogoprotobuf_proto.Skip(data[index:])
+			if err != nil {
+				return err
+			}
+			if (index + skippy) > l {
+				return io.ErrUnexpectedEOF
+			}
+			m.XXX_unrecognized = append(m.XXX_unrecognized, data[index:index+skippy]...)
+			index += skippy
+		}
+	}
+	return nil
+}
+func (m *State) Size() (n int) {
+	var l int
+	_ = l
+	n += 1 + sovState(uint64(m.Term))
+	n += 1 + sovState(uint64(m.Vote))
+	n += 1 + sovState(uint64(m.Commit))
+	if m.XXX_unrecognized != nil {
+		n += len(m.XXX_unrecognized)
+	}
+	return n
+}
+
+func sovState(x uint64) (n int) {
+	for {
+		n++
+		x >>= 7
+		if x == 0 {
+			break
+		}
+	}
+	return n
+}
+func sozState(x uint64) (n int) {
+	return sovState(uint64((x << 1) ^ uint64((int64(x) >> 63))))
+}
+func (m *State) Marshal() (data []byte, err error) {
+	size := m.Size()
+	data = make([]byte, size)
+	n, err := m.MarshalTo(data)
+	if err != nil {
+		return nil, err
+	}
+	return data[:n], nil
+}
+
+func (m *State) MarshalTo(data []byte) (n int, err error) {
+	var i int
+	_ = i
+	var l int
+	_ = l
+	data[i] = 0x8
+	i++
+	i = encodeVarintState(data, i, uint64(m.Term))
+	data[i] = 0x10
+	i++
+	i = encodeVarintState(data, i, uint64(m.Vote))
+	data[i] = 0x18
+	i++
+	i = encodeVarintState(data, i, uint64(m.Commit))
+	if m.XXX_unrecognized != nil {
+		i += copy(data[i:], m.XXX_unrecognized)
+	}
+	return i, nil
+}
+func encodeFixed64State(data []byte, offset int, v uint64) int {
+	data[offset] = uint8(v)
+	data[offset+1] = uint8(v >> 8)
+	data[offset+2] = uint8(v >> 16)
+	data[offset+3] = uint8(v >> 24)
+	data[offset+4] = uint8(v >> 32)
+	data[offset+5] = uint8(v >> 40)
+	data[offset+6] = uint8(v >> 48)
+	data[offset+7] = uint8(v >> 56)
+	return offset + 8
+}
+func encodeFixed32State(data []byte, offset int, v uint32) int {
+	data[offset] = uint8(v)
+	data[offset+1] = uint8(v >> 8)
+	data[offset+2] = uint8(v >> 16)
+	data[offset+3] = uint8(v >> 24)
+	return offset + 4
+}
+func encodeVarintState(data []byte, offset int, v uint64) int {
+	for v >= 1<<7 {
+		data[offset] = uint8(v&0x7f | 0x80)
+		v >>= 7
+		offset++
+	}
+	data[offset] = uint8(v)
+	return offset + 1
+}

+ 14 - 0
raft/state.proto

@@ -0,0 +1,14 @@
+package raft;
+
+import "code.google.com/p/gogoprotobuf/gogoproto/gogo.proto";
+
+option (gogoproto.marshaler_all) = true;
+option (gogoproto.sizer_all) = true;
+option (gogoproto.unmarshaler_all) = true;
+option (gogoproto.goproto_getters_all) = false;
+
+message State {
+	required int64 term   = 1 [(gogoproto.nullable) = false];
+	required int64 vote   = 2 [(gogoproto.nullable) = false];
+	required int64 commit = 3 [(gogoproto.nullable) = false];
+}

+ 3 - 5
wal/wal.go

@@ -106,12 +106,11 @@ func (w *WAL) SaveEntry(e *raft.Entry) error {
 
 func (w *WAL) SaveState(s *raft.State) error {
 	log.Printf("path=%s wal.saveState state=\"%+v\"", w.f.Name(), s)
-	w.buf.Reset()
-	err := binary.Write(w.buf, binary.LittleEndian, s)
+	b, err := s.Marshal()
 	if err != nil {
 		panic(err)
 	}
-	return writeBlock(w.bw, stateType, w.buf.Bytes())
+	return writeBlock(w.bw, stateType, b)
 }
 
 func (w *WAL) checkAtHead() error {
@@ -197,8 +196,7 @@ func loadEntry(d []byte) (raft.Entry, error) {
 
 func loadState(d []byte) (raft.State, error) {
 	var s raft.State
-	buf := bytes.NewBuffer(d)
-	err := binary.Read(buf, binary.LittleEndian, &s)
+	err := s.Unmarshal(d)
 	return s, err
 }
 

+ 5 - 5
wal/wal_test.go

@@ -30,8 +30,8 @@ var (
 	infoData  = []byte("\b\xef\xfd\x02")
 	infoBlock = append([]byte("\x01\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00"), infoData...)
 
-	stateData  = []byte("\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00")
-	stateBlock = append([]byte("\x03\x00\x00\x00\x00\x00\x00\x00\x18\x00\x00\x00\x00\x00\x00\x00"), stateData...)
+	stateData  = []byte("\b\x01\x10\x01\x18\x01")
+	stateBlock = append([]byte("\x03\x00\x00\x00\x00\x00\x00\x00\x06\x00\x00\x00\x00\x00\x00\x00"), stateData...)
 
 	entryData  = []byte("\b\x01\x10\x01\x18\x01\x22\x01\x01")
 	entryBlock = append([]byte("\x02\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00\x00\x00\x00\x00"), entryData...)
@@ -136,7 +136,7 @@ func TestSaveState(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	st := &raft.State{1, 1, 1}
+	st := &raft.State{Term: 1, Vote: 1, Commit: 1}
 	err = w.SaveState(st)
 	if err != nil {
 		t.Fatal(err)
@@ -183,7 +183,7 @@ func TestLoadState(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	ws := raft.State{1, 1, 1}
+	ws := raft.State{Term: 1, Vote: 1, Commit: 1}
 	if !reflect.DeepEqual(s, ws) {
 		t.Errorf("state = %v, want %v", s, ws)
 	}
@@ -205,7 +205,7 @@ func TestLoadNode(t *testing.T) {
 			t.Fatal(err)
 		}
 	}
-	sts := []raft.State{{1, 1, 1}, {2, 2, 2}}
+	sts := []raft.State{{Term: 1, Vote: 1, Commit: 1}, {Term: 2, Vote: 2, Commit: 2}}
 	for _, s := range sts {
 		if err = w.SaveState(&s); err != nil {
 			t.Fatal(err)