Browse Source

wal: fix wal

Xiang Li 11 years ago
parent
commit
bdb954b2f5
14 changed files with 644 additions and 501 deletions
  1. 41 0
      crc/crc.go
  2. 0 0
      raft/raftpb/genproto.sh
  3. 122 24
      raft/raftpb/raft.pb.go
  4. 4 0
      raft/raftpb/raft.proto
  5. 90 0
      wal/decoder.go
  6. 48 0
      wal/encoder.go
  7. 31 0
      wal/multi_readcloser.go
  8. 4 37
      wal/record.go
  9. 4 4
      wal/record.pb.go
  10. 1 1
      wal/record.proto
  11. 15 4
      wal/record_test.go
  12. 91 0
      wal/util.go
  13. 144 261
      wal/wal.go
  14. 49 170
      wal/wal_test.go

+ 41 - 0
crc/crc.go

@@ -0,0 +1,41 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package crc
+
+import (
+	"hash"
+	"hash/crc32"
+)
+
+// The size of a CRC-32 checksum in bytes.
+const Size = 4
+
+type digest struct {
+	crc uint32
+	tab *crc32.Table
+}
+
+// New creates a new hash.Hash32 computing the CRC-32 checksum
+// using the polynomial represented by the Table.
+// Modified by xiangli to take a prevcrc.
+func New(prev uint32, tab *crc32.Table) hash.Hash32 { return &digest{prev, tab} }
+
+func (d *digest) Size() int { return Size }
+
+func (d *digest) BlockSize() int { return 1 }
+
+func (d *digest) Reset() { d.crc = 0 }
+
+func (d *digest) Write(p []byte) (n int, err error) {
+	d.crc = crc32.Update(d.crc, d.tab, p)
+	return len(p), nil
+}
+
+func (d *digest) Sum32() uint32 { return d.crc }
+
+func (d *digest) Sum(in []byte) []byte {
+	s := d.Sum32()
+	return append(in, byte(s>>24), byte(s>>16), byte(s>>8), byte(s))
+}

+ 0 - 0
raft/raftpb/genproto.sh


+ 122 - 24
raft/raftpb/raft.pb.go

@@ -9,6 +9,7 @@
 		raft.proto
 
 	It has these top-level messages:
+		Info
 		Entry
 		Snapshot
 		Message
@@ -17,19 +18,27 @@
 package raftpb
 
 import proto "code.google.com/p/gogoprotobuf/proto"
-import json "encoding/json"
 import math "math"
 
 // discarding unused import gogoproto "code.google.com/p/gogoprotobuf/gogoproto/gogo.pb"
 
 import io "io"
+import fmt "fmt"
 import code_google_com_p_gogoprotobuf_proto "code.google.com/p/gogoprotobuf/proto"
 
-// Reference proto, json, and math imports to suppress error if they are not otherwise used.
+// Reference imports to suppress errors if they are not otherwise used.
 var _ = proto.Marshal
-var _ = &json.SyntaxError{}
 var _ = math.Inf
 
+type Info struct {
+	Id               int64  `protobuf:"varint,1,req,name=id" json:"id"`
+	XXX_unrecognized []byte `json:"-"`
+}
+
+func (m *Info) Reset()         { *m = Info{} }
+func (m *Info) String() string { return proto.CompactTextString(m) }
+func (*Info) ProtoMessage()    {}
+
 type Entry struct {
 	Type             int64  `protobuf:"varint,1,req,name=type" json:"type"`
 	Term             int64  `protobuf:"varint,2,req,name=term" json:"term"`
@@ -85,6 +94,63 @@ func (*State) ProtoMessage()    {}
 
 func init() {
 }
+func (m *Info) Unmarshal(data []byte) error {
+	l := len(data)
+	index := 0
+	for index < l {
+		var wire uint64
+		for shift := uint(0); ; shift += 7 {
+			if index >= l {
+				return io.ErrUnexpectedEOF
+			}
+			b := data[index]
+			index++
+			wire |= (uint64(b) & 0x7F) << shift
+			if b < 0x80 {
+				break
+			}
+		}
+		fieldNum := int32(wire >> 3)
+		wireType := int(wire & 0x7)
+		switch fieldNum {
+		case 1:
+			if wireType != 0 {
+				return fmt.Errorf("proto: wrong wireType = %d for field Id", wireType)
+			}
+			for shift := uint(0); ; shift += 7 {
+				if index >= l {
+					return io.ErrUnexpectedEOF
+				}
+				b := data[index]
+				index++
+				m.Id |= (int64(b) & 0x7F) << shift
+				if b < 0x80 {
+					break
+				}
+			}
+		default:
+			var sizeOfWire int
+			for {
+				sizeOfWire++
+				wire >>= 7
+				if wire == 0 {
+					break
+				}
+			}
+			index -= sizeOfWire
+			skippy, err := code_google_com_p_gogoprotobuf_proto.Skip(data[index:])
+			if err != nil {
+				return err
+			}
+			if (index + skippy) > l {
+				return io.ErrUnexpectedEOF
+			}
+			m.XXX_unrecognized = append(m.XXX_unrecognized, data[index:index+skippy]...)
+			index += skippy
+		}
+	}
+	return nil
+}
 func (m *Entry) Unmarshal(data []byte) error {
 	l := len(data)
 	index := 0
@@ -106,7 +172,7 @@ func (m *Entry) Unmarshal(data []byte) error {
 		switch fieldNum {
 		case 1:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -121,7 +187,7 @@ func (m *Entry) Unmarshal(data []byte) error {
 			}
 		case 2:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Term", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -136,7 +202,7 @@ func (m *Entry) Unmarshal(data []byte) error {
 			}
 		case 3:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Index", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -151,7 +217,7 @@ func (m *Entry) Unmarshal(data []byte) error {
 			}
 		case 4:
 			if wireType != 2 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType)
 			}
 			var byteLen int
 			for shift := uint(0); ; shift += 7 {
@@ -215,7 +281,7 @@ func (m *Snapshot) Unmarshal(data []byte) error {
 		switch fieldNum {
 		case 1:
 			if wireType != 2 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType)
 			}
 			var byteLen int
 			for shift := uint(0); ; shift += 7 {
@@ -237,7 +303,7 @@ func (m *Snapshot) Unmarshal(data []byte) error {
 			index = postIndex
 		case 2:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Nodes", wireType)
 			}
 			var v int64
 			for shift := uint(0); ; shift += 7 {
@@ -254,7 +320,7 @@ func (m *Snapshot) Unmarshal(data []byte) error {
 			m.Nodes = append(m.Nodes, v)
 		case 3:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Index", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -269,7 +335,7 @@ func (m *Snapshot) Unmarshal(data []byte) error {
 			}
 		case 4:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Term", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -326,7 +392,7 @@ func (m *Message) Unmarshal(data []byte) error {
 		switch fieldNum {
 		case 1:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -341,7 +407,7 @@ func (m *Message) Unmarshal(data []byte) error {
 			}
 		case 2:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field To", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -356,7 +422,7 @@ func (m *Message) Unmarshal(data []byte) error {
 			}
 		case 3:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field From", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -371,7 +437,7 @@ func (m *Message) Unmarshal(data []byte) error {
 			}
 		case 4:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Term", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -386,7 +452,7 @@ func (m *Message) Unmarshal(data []byte) error {
 			}
 		case 5:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field LogTerm", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -401,7 +467,7 @@ func (m *Message) Unmarshal(data []byte) error {
 			}
 		case 6:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Index", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -416,7 +482,7 @@ func (m *Message) Unmarshal(data []byte) error {
 			}
 		case 7:
 			if wireType != 2 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Entries", wireType)
 			}
 			var msglen int
 			for shift := uint(0); ; shift += 7 {
@@ -439,7 +505,7 @@ func (m *Message) Unmarshal(data []byte) error {
 			index = postIndex
 		case 8:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Commit", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -454,7 +520,7 @@ func (m *Message) Unmarshal(data []byte) error {
 			}
 		case 9:
 			if wireType != 2 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Snapshot", wireType)
 			}
 			var msglen int
 			for shift := uint(0); ; shift += 7 {
@@ -520,7 +586,7 @@ func (m *State) Unmarshal(data []byte) error {
 		switch fieldNum {
 		case 1:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Term", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -535,7 +601,7 @@ func (m *State) Unmarshal(data []byte) error {
 			}
 		case 2:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Vote", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -550,7 +616,7 @@ func (m *State) Unmarshal(data []byte) error {
 			}
 		case 3:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field Commit", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -565,7 +631,7 @@ func (m *State) Unmarshal(data []byte) error {
 			}
 		case 4:
 			if wireType != 0 {
-				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+				return fmt.Errorf("proto: wrong wireType = %d for field LastIndex", wireType)
 			}
 			for shift := uint(0); ; shift += 7 {
 				if index >= l {
@@ -601,6 +667,15 @@ func (m *State) Unmarshal(data []byte) error {
 	}
 	return nil
 }
+func (m *Info) Size() (n int) {
+	var l int
+	_ = l
+	n += 1 + sovRaft(uint64(m.Id))
+	if m.XXX_unrecognized != nil {
+		n += len(m.XXX_unrecognized)
+	}
+	return n
+}
 func (m *Entry) Size() (n int) {
 	var l int
 	_ = l
@@ -680,6 +755,29 @@ func sovRaft(x uint64) (n int) {
 func sozRaft(x uint64) (n int) {
 	return sovRaft(uint64((x << 1) ^ uint64((int64(x) >> 63))))
 }
+func (m *Info) Marshal() (data []byte, err error) {
+	size := m.Size()
+	data = make([]byte, size)
+	n, err := m.MarshalTo(data)
+	if err != nil {
+		return nil, err
+	}
+	return data[:n], nil
+}
+
+func (m *Info) MarshalTo(data []byte) (n int, err error) {
+	var i int
+	_ = i
+	var l int
+	_ = l
+	data[i] = 0x8
+	i++
+	i = encodeVarintRaft(data, i, uint64(m.Id))
+	if m.XXX_unrecognized != nil {
+		i += copy(data[i:], m.XXX_unrecognized)
+	}
+	return i, nil
+}
 func (m *Entry) Marshal() (data []byte, err error) {
 	size := m.Size()
 	data = make([]byte, size)

+ 4 - 0
raft/raftpb/raft.proto

@@ -7,6 +7,10 @@ option (gogoproto.sizer_all) = true;
 option (gogoproto.unmarshaler_all) = true;
 option (gogoproto.goproto_getters_all) = false;
 
+message Info {
+	required int64 id   = 1 [(gogoproto.nullable) = false];
+}
+
 message Entry {
 	required int64 type  = 1 [(gogoproto.nullable) = false];
 	required int64 term  = 2 [(gogoproto.nullable) = false];

+ 90 - 0
wal/decoder.go

@@ -0,0 +1,90 @@
+package wal
+
+import (
+	"bufio"
+	"encoding/binary"
+	"hash"
+	"io"
+
+	"github.com/coreos/etcd/crc"
+	"github.com/coreos/etcd/raft/raftpb"
+)
+
+type decoder struct {
+	br  *bufio.Reader
+	c   io.Closer
+	crc hash.Hash32
+}
+
+func newDecoder(rc io.ReadCloser) *decoder {
+	return &decoder{
+		br:  bufio.NewReader(rc),
+		c:   rc,
+		crc: crc.New(0, crcTable),
+	}
+}
+
+func (d *decoder) decode(rec *Record) error {
+	rec.Reset()
+	l, err := readInt64(d.br)
+	if err != nil {
+		return err
+	}
+	data := make([]byte, l)
+	if _, err = io.ReadFull(d.br, data); err != nil {
+		return err
+	}
+	if err := rec.Unmarshal(data); err != nil {
+		return err
+	}
+	// skip crc checking if the record type is crcType
+	if rec.Type == crcType {
+		return nil
+	}
+	d.crc.Write(rec.Data)
+	return rec.validate(d.crc.Sum32())
+}
+
+func (d *decoder) updateCRC(prevCrc uint32) {
+	d.crc = crc.New(prevCrc, crcTable)
+}
+
+func (d *decoder) lastCRC() uint32 {
+	return d.crc.Sum32()
+}
+
+func (d *decoder) close() error {
+	return d.c.Close()
+}
+
+func mustUnmarshalInfo(d []byte) raftpb.Info {
+	var i raftpb.Info
+	if err := i.Unmarshal(d); err != nil {
+		// crc matched, but we cannot unmarshal the struct?!
+		// we must be the next winner of the $1B lottery.
+		panic(err)
+	}
+	return i
+}
+
+func mustUnmarshalEntry(d []byte) raftpb.Entry {
+	var e raftpb.Entry
+	if err := e.Unmarshal(d); err != nil {
+		panic(err)
+	}
+	return e
+}
+
+func mustUnmarshalState(d []byte) raftpb.State {
+	var s raftpb.State
+	if err := s.Unmarshal(d); err != nil {
+		panic(err)
+	}
+	return s
+}
+
+func readInt64(r io.Reader) (int64, error) {
+	var n int64
+	err := binary.Read(r, binary.LittleEndian, &n)
+	return n, err
+}

+ 48 - 0
wal/encoder.go

@@ -0,0 +1,48 @@
+package wal
+
+import (
+	"bufio"
+	"encoding/binary"
+	"hash"
+	"io"
+
+	"github.com/coreos/etcd/crc"
+)
+
+type encoder struct {
+	bw  *bufio.Writer
+	crc hash.Hash32
+}
+
+func newEncoder(w io.Writer, prevCrc uint32) *encoder {
+	return &encoder{
+		bw:  bufio.NewWriter(w),
+		crc: crc.New(prevCrc, crcTable),
+	}
+}
+
+func (e *encoder) encode(rec *Record) error {
+	e.crc.Write(rec.Data)
+	rec.Crc = e.crc.Sum32()
+	data, err := rec.Marshal()
+	if err != nil {
+		return err
+	}
+	if err := writeInt64(e.bw, int64(len(data))); err != nil {
+		return err
+	}
+	_, err = e.bw.Write(data)
+	return err
+}
+
+func (e *encoder) flush() error {
+	return e.bw.Flush()
+}
+
+func (e *encoder) buffered() int {
+	return e.bw.Buffered()
+}
+
+func writeInt64(w io.Writer, n int64) error {
+	return binary.Write(w, binary.LittleEndian, n)
+}

+ 31 - 0
wal/multi_readcloser.go

@@ -0,0 +1,31 @@
+package wal
+
+import "io"
+
+type multiReadCloser struct {
+	closers []io.Closer
+	reader  io.Reader
+}
+
+func (mc *multiReadCloser) Close() error {
+	var err error
+	for i := range mc.closers {
+		err = mc.closers[i].Close()
+	}
+	return err
+}
+
+func (mc *multiReadCloser) Read(p []byte) (int, error) {
+	return mc.reader.Read(p)
+}
+
+func MultiReadCloser(readClosers ...io.ReadCloser) io.ReadCloser {
+	cs := make([]io.Closer, len(readClosers))
+	rs := make([]io.Reader, len(readClosers))
+	for i := range readClosers {
+		cs[i] = readClosers[i]
+		rs[i] = readClosers[i]
+	}
+	r := io.MultiReader(rs...)
+	return &multiReadCloser{cs, r}
+}

+ 4 - 37
wal/record.go

@@ -16,43 +16,10 @@ limitations under the License.
 
 package wal
 
-import (
-	"encoding/binary"
-	"io"
-)
-
-func writeRecord(w io.Writer, rec *Record) error {
-	data, err := rec.Marshal()
-	if err != nil {
-		return err
-	}
-
-	if err := writeInt64(w, int64(len(data))); err != nil {
-		return err
+func (rec *Record) validate(crc uint32) error {
+	if rec.Crc == crc {
+		return nil
 	}
-	_, err = w.Write(data)
-	return err
-}
-
-func readRecord(r io.Reader, rec *Record) error {
 	rec.Reset()
-	l, err := readInt64(r)
-	if err != nil {
-		return err
-	}
-	d := make([]byte, l)
-	if _, err = io.ReadFull(r, d); err != nil {
-		return err
-	}
-	return rec.Unmarshal(d)
-}
-
-func writeInt64(w io.Writer, n int64) error {
-	return binary.Write(w, binary.LittleEndian, n)
-}
-
-func readInt64(r io.Reader) (int64, error) {
-	var n int64
-	err := binary.Read(r, binary.LittleEndian, &n)
-	return n, err
+	return ErrCRCMismatch
 }

+ 4 - 4
wal/record.pb.go

@@ -29,7 +29,7 @@ var _ = math.Inf
 
 type Record struct {
 	Type             int64  `protobuf:"varint,1,req,name=type" json:"type"`
-	Crc              int32  `protobuf:"varint,2,req,name=crc" json:"crc"`
+	Crc              uint32 `protobuf:"varint,2,req,name=crc" json:"crc"`
 	Data             []byte `protobuf:"bytes,3,opt,name=data" json:"data,omitempty"`
 	XXX_unrecognized []byte `json:"-"`
 }
@@ -84,7 +84,7 @@ func (m *Record) Unmarshal(data []byte) error {
 				}
 				b := data[index]
 				index++
-				m.Crc |= (int32(b) & 0x7F) << shift
+				m.Crc |= (uint32(b) & 0x7F) << shift
 				if b < 0x80 {
 					break
 				}
@@ -138,7 +138,7 @@ func (m *Record) Size() (n int) {
 	var l int
 	_ = l
 	n += 1 + sovRecord(uint64(m.Type))
-	n += 1 + sovRecord(uint64(uint32(m.Crc)))
+	n += 1 + sovRecord(uint64(m.Crc))
 	if m.Data != nil {
 		l = len(m.Data)
 		n += 1 + l + sovRecord(uint64(l))
@@ -182,7 +182,7 @@ func (m *Record) MarshalTo(data []byte) (n int, err error) {
 	i = encodeVarintRecord(data, i, uint64(m.Type))
 	data[i] = 0x10
 	i++
-	i = encodeVarintRecord(data, i, uint64(uint32(m.Crc)))
+	i = encodeVarintRecord(data, i, uint64(m.Crc))
 	if m.Data != nil {
 		data[i] = 0x1a
 		i++

+ 1 - 1
wal/record.proto

@@ -9,6 +9,6 @@ option (gogoproto.goproto_getters_all) = false;
 
 message Record {
 	required int64 type  = 1 [(gogoproto.nullable) = false];
-	required int32 crc   = 2 [(gogoproto.nullable) = false];
+	required uint32 crc  = 2 [(gogoproto.nullable) = false];
 	optional bytes data  = 3;
 }

+ 15 - 4
wal/record_test.go

@@ -18,28 +18,36 @@ package wal
 
 import (
 	"bytes"
+	"hash/crc32"
 	"io"
+	"io/ioutil"
 	"reflect"
 	"testing"
 )
 
 func TestReadRecord(t *testing.T) {
+	badInfoRecord := make([]byte, len(infoRecord))
+	copy(badInfoRecord, infoRecord)
+	badInfoRecord[len(badInfoRecord)-1] = 'a'
+
 	tests := []struct {
 		data []byte
 		wr   *Record
 		we   error
 	}{
-		{infoRecord, &Record{Type: 1, Crc: 0, Data: infoData}, nil},
+		{infoRecord, &Record{Type: 1, Crc: crc32.Checksum(infoData, crcTable), Data: infoData}, nil},
 		{[]byte(""), &Record{}, io.EOF},
 		{infoRecord[:len(infoRecord)-len(infoData)-8], &Record{}, io.ErrUnexpectedEOF},
 		{infoRecord[:len(infoRecord)-len(infoData)], &Record{}, io.ErrUnexpectedEOF},
 		{infoRecord[:len(infoRecord)-8], &Record{}, io.ErrUnexpectedEOF},
+		{badInfoRecord, &Record{}, ErrCRCMismatch},
 	}
 
 	rec := &Record{}
 	for i, tt := range tests {
 		buf := bytes.NewBuffer(tt.data)
-		e := readRecord(buf, rec)
+		decoder := newDecoder(ioutil.NopCloser(buf))
+		e := decoder.decode(rec)
 		if !reflect.DeepEqual(rec, tt.wr) {
 			t.Errorf("#%d: block = %v, want %v", i, rec, tt.wr)
 		}
@@ -55,8 +63,11 @@ func TestWriteRecord(t *testing.T) {
 	typ := int64(0xABCD)
 	d := []byte("Hello world!")
 	buf := new(bytes.Buffer)
-	writeRecord(buf, &Record{Type: typ, Crc: 0, Data: d})
-	err := readRecord(buf, b)
+	e := newEncoder(buf, 0)
+	e.encode(&Record{Type: typ, Data: d})
+	e.flush()
+	decoder := newDecoder(ioutil.NopCloser(buf))
+	err := decoder.decode(b)
 	if err != nil {
 		t.Errorf("err = %v, want nil", err)
 	}

+ 91 - 0
wal/util.go

@@ -0,0 +1,91 @@
+package wal
+
+import (
+	"fmt"
+	"log"
+	"os"
+)
+
+func Exist(dirpath string) bool {
+	names, err := readDir(dirpath)
+	if err != nil {
+		return false
+	}
+	return len(names) != 0
+}
+
+// The input names should be sorted.
+// serachIndex returns the array index of the last name that has
+// a smaller raft index section than the given raft index.
+func searchIndex(names []string, index int64) (int, bool) {
+	for i := len(names) - 1; i >= 0; i-- {
+		name := names[i]
+		_, curIndex, err := parseWalName(name)
+		if err != nil {
+			panic("parse correct name error")
+		}
+		if index >= curIndex {
+			return i, true
+		}
+	}
+	return -1, false
+}
+
+// names should have been sorted based on sequence number.
+// isValidSeq checks whether seq increases continuously.
+func isValidSeq(names []string) bool {
+	var lastSeq int64
+	for _, name := range names {
+		curSeq, _, err := parseWalName(name)
+		if err != nil {
+			panic("parse correct name error")
+		}
+		if lastSeq != 0 && lastSeq != curSeq-1 {
+			return false
+		}
+		lastSeq = curSeq
+	}
+	return true
+}
+
+// readDir returns the filenames in wal directory.
+func readDir(dirpath string) ([]string, error) {
+	dir, err := os.Open(dirpath)
+	if err != nil {
+		return nil, err
+	}
+	defer dir.Close()
+	names, err := dir.Readdirnames(-1)
+	if err != nil {
+		return nil, err
+	}
+	return names, nil
+}
+
+func checkWalNames(names []string) []string {
+	wnames := make([]string, 0)
+	for _, name := range names {
+		if _, _, err := parseWalName(name); err != nil {
+			log.Printf("parse %s: %v", name, err)
+			continue
+		}
+		wnames = append(wnames, name)
+	}
+	return wnames
+}
+
+func parseWalName(str string) (seq, index int64, err error) {
+	var num int
+	num, err = fmt.Sscanf(str, "%016x-%016x.wal", &seq, &index)
+	if num != 2 && err == nil {
+		err = fmt.Errorf("bad wal name: %s", str)
+	}
+	return
+}
+
+func max(a, b int64) int64 {
+	if a > b {
+		return a
+	}
+	return b
+}

+ 144 - 261
wal/wal.go

@@ -17,47 +17,49 @@ limitations under the License.
 package wal
 
 import (
-	"bufio"
-	"bytes"
+	"errors"
 	"fmt"
+	"hash/crc32"
 	"io"
 	"log"
 	"os"
 	"path"
 	"sort"
 
-	"github.com/coreos/etcd/raft"
+	"github.com/coreos/etcd/raft/raftpb"
 )
 
 const (
 	infoType int64 = iota + 1
 	entryType
 	stateType
+	crcType
 )
 
 var (
-	ErrIdMismatch = fmt.Errorf("unmatch id")
-	ErrNotFound   = fmt.Errorf("wal file is not found")
+	ErrIdMismatch  = fmt.Errorf("wal: unmatch id")
+	ErrNotFound    = fmt.Errorf("wal: file is not found")
+	ErrCRCMismatch = errors.New("wal: crc mismatch")
+	crcTable       = crc32.MakeTable(crc32.Castagnoli)
 )
 
+// WAL is a logical repersentation of the stable storage.
+// WAL is either in read mode or append mode but not both.
+// A newly created WAL is in append mode, and ready for appending records.
+// A just opened WAL is in read mode, and ready for reading records.
+// The WAL will be ready for appending after reading out all the previous records.
 type WAL struct {
-	f   *os.File
-	bw  *bufio.Writer
-	buf *bytes.Buffer
-}
+	dir string // the living directory of the underlay files
 
-func newWAL(f *os.File) *WAL {
-	return &WAL{f, bufio.NewWriter(f), new(bytes.Buffer)}
-}
+	ri      int64    // index of entry to start reading
+	decoder *decoder // decoder to decode records
 
-func Exist(dirpath string) bool {
-	names, err := readDir(dirpath)
-	if err != nil {
-		return false
-	}
-	return len(names) != 0
+	f       *os.File // underlay file opened for appending, sync
+	seq     int64    // current sequence of the wal file
+	encoder *encoder // encoder to encode records
 }
 
+// Create creates a WAL ready for appending records.
 func Create(dirpath string) (*WAL, error) {
 	log.Printf("path=%s wal.create", dirpath)
 	if Exist(dirpath) {
@@ -68,11 +70,24 @@ func Create(dirpath string) (*WAL, error) {
 	if err != nil {
 		return nil, err
 	}
-	return newWAL(f), nil
+	w := &WAL{
+		dir:     dirpath,
+		seq:     0,
+		f:       f,
+		encoder: newEncoder(f, 0),
+	}
+	if err := w.saveCrc(0); err != nil {
+		return nil, err
+	}
+	return w, nil
 }
 
-func Open(dirpath string) (*WAL, error) {
-	log.Printf("path=%s wal.append", dirpath)
+// OpenFromIndex opens the WAL files containing all the entries after
+// the given index.
+// The returned WAL is ready to read. The WAL cannot be appended to before
+// reading out all of its previous records.
+func OpenFromIndex(dirpath string, index int64) (*WAL, error) {
+	log.Printf("path=%s wal.load index=%d", dirpath, index)
 	names, err := readDir(dirpath)
 	if err != nil {
 		return nil, err
@@ -82,298 +97,166 @@ func Open(dirpath string) (*WAL, error) {
 		return nil, ErrNotFound
 	}
 
-	name := names[len(names)-1]
-	p := path.Join(dirpath, name)
-	f, err := os.OpenFile(p, os.O_WRONLY|os.O_APPEND, 0)
-	if err != nil {
-		return nil, err
-	}
-	return newWAL(f), nil
-}
-
-// index should be the index of last log entry currently.
-// Cut closes current file written and creates a new one to append.
-func (w *WAL) Cut(index int64) error {
-	log.Printf("path=%s wal.cut index=%d", w.f.Name(), index)
-	fpath := w.f.Name()
-	seq, _, err := parseWalName(path.Base(fpath))
-	if err != nil {
-		panic("parse correct name error")
-	}
-	fpath = path.Join(path.Dir(fpath), fmt.Sprintf("%016x-%016x.wal", seq+1, index))
-	f, err := os.OpenFile(fpath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
-	if err != nil {
-		return err
-	}
-
-	w.Sync()
-	w.f.Close()
-	w.f = f
-	w.bw = bufio.NewWriter(f)
-	return nil
-}
-
-func (w *WAL) Sync() error {
-	if err := w.bw.Flush(); err != nil {
-		return err
-	}
-	return w.f.Sync()
-}
+	sort.Sort(sort.StringSlice(names))
 
-func (w *WAL) Close() {
-	log.Printf("path=%s wal.close", w.f.Name())
-	if w.f != nil {
-		w.Sync()
-		w.f.Close()
+	nameIndex, ok := searchIndex(names, index)
+	if !ok || !isValidSeq(names[nameIndex:]) {
+		return nil, ErrNotFound
 	}
-}
 
-func (w *WAL) SaveInfo(i *raft.Info) error {
-	log.Printf("path=%s wal.saveInfo id=%d", w.f.Name(), i.Id)
-	if err := w.checkAtHead(); err != nil {
-		return err
-	}
-	b, err := i.Marshal()
-	if err != nil {
-		panic(err)
+	// open the wal files for reading
+	rcs := make([]io.ReadCloser, 0)
+	for _, name := range names[nameIndex:] {
+		f, err := os.Open(path.Join(dirpath, name))
+		if err != nil {
+			return nil, err
+		}
+		rcs = append(rcs, f)
 	}
-	rec := &Record{Type: infoType, Data: b}
-	return writeRecord(w.bw, rec)
-}
+	rc := MultiReadCloser(rcs...)
 
-func (w *WAL) SaveEntry(e *raft.Entry) error {
-	b, err := e.Marshal()
+	// open the lastest wal file for appending
+	last := path.Join(dirpath, names[len(names)-1])
+	f, err := os.OpenFile(last, os.O_WRONLY|os.O_APPEND, 0)
 	if err != nil {
-		panic(err)
+		rc.Close()
+		return nil, err
 	}
-	rec := &Record{Type: entryType, Data: b}
-	return writeRecord(w.bw, rec)
-}
 
-func (w *WAL) SaveState(s *raft.State) error {
-	log.Printf("path=%s wal.saveState state=\"%+v\"", w.f.Name(), s)
-	b, err := s.Marshal()
-	if err != nil {
-		panic(err)
-	}
-	rec := &Record{Type: stateType, Data: b}
-	return writeRecord(w.bw, rec)
-}
+	// create a WAL ready for reading
+	w := &WAL{
+		ri:      index,
+		decoder: newDecoder(rc),
 
-func (w *WAL) checkAtHead() error {
-	o, err := w.f.Seek(0, os.SEEK_CUR)
-	if err != nil {
-		return err
+		f: f,
 	}
-	if o != 0 || w.bw.Buffered() != 0 {
-		return fmt.Errorf("cannot write info at %d, expect 0", max(o, int64(w.bw.Buffered())))
-	}
-	return nil
+	return w, nil
 }
 
-type Node struct {
-	Id    int64
-	Ents  []raft.Entry
-	State raft.State
-
-	// index of the first entry
-	index int64
-}
-
-func newNode(index int64) *Node {
-	return &Node{Ents: make([]raft.Entry, 0), index: index + 1}
-}
+// ReadAll reads out all records of the current WAL.
+// After ReadAll, the WAL will be ready for appending new records.
+func (w *WAL) ReadAll() (int64, raftpb.State, []raftpb.Entry, error) {
+	var id int64
+	var state raftpb.State
+	var entries []raftpb.Entry
 
-func (n *Node) load(path string) error {
-	f, err := os.Open(path)
-	if err != nil {
-		return err
-	}
-	defer f.Close()
-	br := bufio.NewReader(f)
 	rec := &Record{}
-
-	err = readRecord(br, rec)
-	if err != nil {
-		return err
-	}
-	if rec.Type != infoType {
-		return fmt.Errorf("the first block of wal is not infoType but %d", rec.Type)
-	}
-	i, err := loadInfo(rec.Data)
-	if err != nil {
-		return err
-	}
-	if n.Id != 0 && n.Id != i.Id {
-		return ErrIdMismatch
-	}
-	n.Id = i.Id
-
-	for err = readRecord(br, rec); err == nil; err = readRecord(br, rec) {
+	decoder := w.decoder
+	var err error
+	for err = decoder.decode(rec); err == nil; err = decoder.decode(rec) {
 		switch rec.Type {
 		case entryType:
-			e, err := loadEntry(rec.Data)
-			if err != nil {
-				return err
-			}
-			if e.Index >= n.index {
-				n.Ents = append(n.Ents[:e.Index-n.index], e)
+			e := mustUnmarshalEntry(rec.Data)
+			if e.Index > w.ri {
+				entries = append(entries[:e.Index-w.ri-1], e)
 			}
 		case stateType:
-			s, err := loadState(rec.Data)
-			if err != nil {
-				return err
+			state = mustUnmarshalState(rec.Data)
+		case infoType:
+			i := mustUnmarshalInfo(rec.Data)
+			if id != 0 && id != i.Id {
+				state.Reset()
+				return 0, state, nil, ErrIdMismatch
 			}
-			n.State = s
+			id = i.Id
+		case crcType:
+			crc := decoder.crc.Sum32()
+			// current crc of decoder must match the crc of the record.
+			// do no need to match 0 crc, since the decoder is a new one at this case.
+			if crc != 0 && rec.validate(crc) != nil {
+				state.Reset()
+				return 0, state, nil, ErrCRCMismatch
+			}
+			decoder.updateCRC(rec.Crc)
 		default:
-			return fmt.Errorf("unexpected block type %d", rec.Type)
+			state.Reset()
+			return 0, state, nil, fmt.Errorf("unexpected block type %d", rec.Type)
 		}
 	}
 	if err != io.EOF {
-		return err
+		state.Reset()
+		return 0, state, nil, err
 	}
-	return nil
-}
 
-func (n *Node) startFrom(index int64) error {
-	diff := int(index - n.index)
-	if diff > len(n.Ents) {
-		return ErrNotFound
-	}
-	n.Ents = n.Ents[diff:]
-	return nil
+	// close decoder, disable reading
+	w.decoder.close()
+	w.ri = 0
+
+	// create encoder (chain crc with the decoder), enable appending
+	w.encoder = newEncoder(w.f, w.decoder.lastCRC())
+	w.decoder = nil
+	return id, state, entries, nil
 }
 
-// Read loads all entries after index (index is not included).
-func Read(dirpath string, index int64) (*Node, error) {
-	log.Printf("path=%s wal.load index=%d", dirpath, index)
-	names, err := readDir(dirpath)
+// index should be the index of last log entry.
+// Cut closes current file written and creates a new one ready to append.
+func (w *WAL) Cut(index int64) error {
+	log.Printf("wal.cut index=%d", index)
+
+	// create a new wal file with name sequence + 1
+	fpath := path.Join(w.dir, fmt.Sprintf("%016x-%016x.wal", w.seq+1, index))
+	f, err := os.OpenFile(fpath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
 	if err != nil {
-		return nil, err
-	}
-	names = checkWalNames(names)
-	if len(names) == 0 {
-		return nil, ErrNotFound
+		return err
 	}
 
-	sort.Sort(sort.StringSlice(names))
-	nameIndex, ok := searchIndex(names, index)
-	if !ok || !isValidSeq(names[nameIndex:]) {
-		return nil, ErrNotFound
-	}
+	w.Sync()
+	w.f.Close()
 
-	_, initIndex, err := parseWalName(names[nameIndex])
-	if err != nil {
-		panic("parse correct name error")
-	}
-	n := newNode(initIndex)
-	for _, name := range names[nameIndex:] {
-		if err := n.load(path.Join(dirpath, name)); err != nil {
-			return nil, err
-		}
-	}
-	if err := n.startFrom(index + 1); err != nil {
-		return nil, ErrNotFound
-	}
-	return n, nil
+	// update writer and save the previous crc
+	w.f = f
+	w.seq++
+	prevCrc := w.encoder.crc.Sum32()
+	w.encoder = newEncoder(w.f, prevCrc)
+	return w.saveCrc(prevCrc)
 }
 
-// The input names should be sorted.
-// serachIndex returns the array index of the last name that has
-// a smaller raft index section than the given raft index.
-func searchIndex(names []string, index int64) (int, bool) {
-	for i := len(names) - 1; i >= 0; i-- {
-		name := names[i]
-		_, curIndex, err := parseWalName(name)
-		if err != nil {
-			panic("parse correct name error")
-		}
-		if index >= curIndex {
-			return i, true
+func (w *WAL) Sync() error {
+	if w.encoder != nil {
+		if err := w.encoder.flush(); err != nil {
+			return err
 		}
 	}
-	return -1, false
+	return w.f.Sync()
 }
 
-// names should have been sorted based on sequence number.
-// isValidSeq checks whether seq increases continuously.
-func isValidSeq(names []string) bool {
-	var lastSeq int64
-	for _, name := range names {
-		curSeq, _, err := parseWalName(name)
-		if err != nil {
-			panic("parse correct name error")
-		}
-		if lastSeq != 0 && lastSeq != curSeq-1 {
-			return false
-		}
-		lastSeq = curSeq
+func (w *WAL) Close() {
+	log.Printf("path=%s wal.close", w.f.Name())
+	if w.f != nil {
+		w.Sync()
+		w.f.Close()
 	}
-	return true
 }
 
-func loadInfo(d []byte) (raft.Info, error) {
-	var i raft.Info
-	err := i.Unmarshal(d)
+func (w *WAL) SaveInfo(i *raftpb.Info) error {
+	log.Printf("path=%s wal.saveInfo id=%d", w.f.Name(), i.Id)
+	b, err := i.Marshal()
 	if err != nil {
 		panic(err)
 	}
-	return i, err
+	rec := &Record{Type: infoType, Data: b}
+	return w.encoder.encode(rec)
 }
 
-func loadEntry(d []byte) (raft.Entry, error) {
-	var e raft.Entry
-	err := e.Unmarshal(d)
+func (w *WAL) SaveEntry(e *raftpb.Entry) error {
+	b, err := e.Marshal()
 	if err != nil {
 		panic(err)
 	}
-	return e, err
-}
-
-func loadState(d []byte) (raft.State, error) {
-	var s raft.State
-	err := s.Unmarshal(d)
-	return s, err
+	rec := &Record{Type: entryType, Data: b}
+	return w.encoder.encode(rec)
 }
 
-// readDir returns the filenames in wal directory.
-func readDir(dirpath string) ([]string, error) {
-	dir, err := os.Open(dirpath)
-	if err != nil {
-		return nil, err
-	}
-	defer dir.Close()
-	names, err := dir.Readdirnames(-1)
+func (w *WAL) SaveState(s *raftpb.State) error {
+	log.Printf("path=%s wal.saveState state=\"%+v\"", w.f.Name(), s)
+	b, err := s.Marshal()
 	if err != nil {
-		return nil, err
-	}
-	return names, nil
-}
-
-func checkWalNames(names []string) []string {
-	wnames := make([]string, 0)
-	for _, name := range names {
-		if _, _, err := parseWalName(name); err != nil {
-			log.Printf("parse %s: %v", name, err)
-			continue
-		}
-		wnames = append(wnames, name)
-	}
-	return wnames
-}
-
-func parseWalName(str string) (seq, index int64, err error) {
-	var num int
-	num, err = fmt.Sscanf(str, "%016x-%016x.wal", &seq, &index)
-	if num != 2 && err == nil {
-		err = fmt.Errorf("bad wal name: %s", str)
+		panic(err)
 	}
-	return
+	rec := &Record{Type: stateType, Data: b}
+	return w.encoder.encode(rec)
 }
 
-func max(a, b int64) int64 {
-	if a > b {
-		return a
-	}
-	return b
+func (w *WAL) saveCrc(prevCrc uint32) error {
+	return w.encoder.encode(&Record{Type: crcType, Crc: prevCrc})
 }

+ 49 - 170
wal/wal_test.go

@@ -24,18 +24,12 @@ import (
 	"reflect"
 	"testing"
 
-	"github.com/coreos/etcd/raft"
+	"github.com/coreos/etcd/raft/raftpb"
 )
 
 var (
 	infoData   = []byte("\b\xef\xfd\x02")
-	infoRecord = append([]byte("\n\x00\x00\x00\x00\x00\x00\x00\b\x01\x10\x00\x1a\x04"), infoData...)
-
-	stateData   = []byte("\b\x01\x10\x01\x18\x01")
-	stateRecord = append([]byte("\f\x00\x00\x00\x00\x00\x00\x00\b\x03\x10\x00\x1a\x06"), stateData...)
-
-	entryData   = []byte("\b\x01\x10\x01\x18\x01\x22\x01\x01")
-	entryRecord = append([]byte("\x0f\x00\x00\x00\x00\x00\x00\x00\b\x02\x10\x00\x1a\t"), entryData...)
+	infoRecord = append([]byte("\x0e\x00\x00\x00\x00\x00\x00\x00\b\x01\x10\x99\xb5\xe4\xd0\x03\x1a\x04"), infoData...)
 
 	firstWalName = "0000000000000000-0000000000000000.wal"
 )
@@ -70,15 +64,20 @@ func TestNewForInitedDir(t *testing.T) {
 	}
 }
 
-func TestAppend(t *testing.T) {
-	p, err := ioutil.TempDir(os.TempDir(), "waltest")
+func TestOpenFromIndex(t *testing.T) {
+	dir, err := ioutil.TempDir(os.TempDir(), "waltest")
 	if err != nil {
 		t.Fatal(err)
 	}
-	defer os.RemoveAll(p)
+	defer os.RemoveAll(dir)
 
-	os.Create(path.Join(p, firstWalName))
-	w, err := Open(p)
+	f, err := os.Create(path.Join(dir, firstWalName))
+	if err != nil {
+		t.Fatal(err)
+	}
+	f.Close()
+
+	w, err := OpenFromIndex(dir, 0)
 	if err != nil {
 		t.Fatalf("err = %v, want nil", err)
 	}
@@ -88,8 +87,13 @@ func TestAppend(t *testing.T) {
 	w.Close()
 
 	wname := fmt.Sprintf("%016x-%016x.wal", 2, 10)
-	os.Create(path.Join(p, wname))
-	w, err = Open(p)
+	f, err = os.Create(path.Join(dir, wname))
+	if err != nil {
+		t.Fatal(err)
+	}
+	f.Close()
+
+	w, err = OpenFromIndex(dir, 5)
 	if err != nil {
 		t.Fatalf("err = %v, want nil", err)
 	}
@@ -97,16 +101,13 @@ func TestAppend(t *testing.T) {
 		t.Errorf("name = %+v, want %+v", g, wname)
 	}
 	w.Close()
-}
 
-func TestAppendForUninitedDir(t *testing.T) {
-	p, err := ioutil.TempDir(os.TempDir(), "waltest")
+	emptydir, err := ioutil.TempDir(os.TempDir(), "waltestempty")
 	if err != nil {
 		t.Fatal(err)
 	}
-	defer os.RemoveAll(p)
-
-	if _, err = Open(p); err != ErrNotFound {
+	defer os.RemoveAll(emptydir)
+	if _, err = OpenFromIndex(emptydir, 0); err != ErrNotFound {
 		t.Errorf("err = %v, want %v", err, ErrNotFound)
 	}
 }
@@ -132,7 +133,7 @@ func TestCut(t *testing.T) {
 		t.Errorf("name = %s, want %s", g, wname)
 	}
 
-	e := &raft.Entry{Type: 1, Index: 1, Term: 1, Data: []byte{1}}
+	e := &raftpb.Entry{Type: 1, Index: 1, Term: 1, Data: []byte{1}}
 	if err := w.SaveEntry(e); err != nil {
 		t.Fatal(err)
 	}
@@ -145,75 +146,7 @@ func TestCut(t *testing.T) {
 	}
 }
 
-func TestSaveEntry(t *testing.T) {
-	p, err := ioutil.TempDir(os.TempDir(), "waltest")
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer os.RemoveAll(p)
-
-	w, err := Create(p)
-	if err != nil {
-		t.Fatal(err)
-	}
-	e := &raft.Entry{Type: 1, Index: 1, Term: 1, Data: []byte{1}}
-	err = w.SaveEntry(e)
-	if err != nil {
-		t.Fatal(err)
-	}
-	w.Close()
-
-	b, err := ioutil.ReadFile(path.Join(p, firstWalName))
-	if err != nil {
-		t.Fatal(err)
-	}
-	if !reflect.DeepEqual(b, entryRecord) {
-		t.Errorf("ent = %q, want %q", b, entryRecord)
-	}
-}
-
-func TestSaveInfo(t *testing.T) {
-	p, err := ioutil.TempDir(os.TempDir(), "waltest")
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer os.RemoveAll(p)
-
-	w, err := Create(p)
-	if err != nil {
-		t.Fatal(err)
-	}
-	i := &raft.Info{Id: int64(0xBEEF)}
-	err = w.SaveInfo(i)
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	// make sure we can only write info at the head of the wal file
-	// still in buffer
-	err = w.SaveInfo(i)
-	if err == nil || err.Error() != "cannot write info at 18, expect 0" {
-		t.Errorf("err = %v, want cannot write info at 18, expect 0", err)
-	}
-
-	// sync to disk
-	w.Sync()
-	err = w.SaveInfo(i)
-	if err == nil || err.Error() != "cannot write info at 18, expect 0" {
-		t.Errorf("err = %v, want cannot write info at 18, expect 0", err)
-	}
-	w.Close()
-
-	b, err := ioutil.ReadFile(path.Join(p, firstWalName))
-	if err != nil {
-		t.Fatal(err)
-	}
-	if !reflect.DeepEqual(b, infoRecord) {
-		t.Errorf("ent = %q, want %q", b, infoRecord)
-	}
-}
-
-func TestSaveState(t *testing.T) {
+func TestRecover(t *testing.T) {
 	p, err := ioutil.TempDir(os.TempDir(), "waltest")
 	if err != nil {
 		t.Fatal(err)
@@ -224,76 +157,17 @@ func TestSaveState(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	st := &raft.State{Term: 1, Vote: 1, Commit: 1}
-	err = w.SaveState(st)
-	if err != nil {
-		t.Fatal(err)
-	}
-	w.Close()
-
-	b, err := ioutil.ReadFile(path.Join(p, firstWalName))
-	if err != nil {
-		t.Fatal(err)
-	}
-	if !reflect.DeepEqual(b, stateRecord) {
-		t.Errorf("ent = %q, want %q", b, stateRecord)
-	}
-}
-
-func TestLoadInfo(t *testing.T) {
-	i, err := loadInfo(infoData)
-	if err != nil {
-		t.Fatal(err)
-	}
-	if i.Id != 0xBEEF {
-		t.Errorf("id = %x, want 0xBEEF", i.Id)
-	}
-}
-
-func TestLoadEntry(t *testing.T) {
-	e, err := loadEntry(entryData)
-	if err != nil {
-		t.Fatal(err)
-	}
-	we := raft.Entry{Type: 1, Index: 1, Term: 1, Data: []byte{1}}
-	if !reflect.DeepEqual(e, we) {
-		t.Errorf("ent = %v, want %v", e, we)
-	}
-}
-
-func TestLoadState(t *testing.T) {
-	s, err := loadState(stateData)
-	if err != nil {
-		t.Fatal(err)
-	}
-	ws := raft.State{Term: 1, Vote: 1, Commit: 1}
-	if !reflect.DeepEqual(s, ws) {
-		t.Errorf("state = %v, want %v", s, ws)
-	}
-}
-
-func TestNodeLoad(t *testing.T) {
-	p, err := ioutil.TempDir(os.TempDir(), "waltest")
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer os.RemoveAll(p)
-
-	w, err := Create(p)
-	if err != nil {
-		t.Fatal(err)
-	}
-	i := &raft.Info{Id: int64(0xBEEF)}
+	i := &raftpb.Info{Id: int64(0xBEEF)}
 	if err = w.SaveInfo(i); err != nil {
 		t.Fatal(err)
 	}
-	ents := []raft.Entry{{Type: 1, Index: 1, Term: 1, Data: []byte{1}}, {Type: 2, Index: 2, Term: 2, Data: []byte{2}}}
+	ents := []raftpb.Entry{{Type: 1, Index: 1, Term: 1, Data: []byte{1}}, {Type: 2, Index: 2, Term: 2, Data: []byte{2}}}
 	for _, e := range ents {
 		if err = w.SaveEntry(&e); err != nil {
 			t.Fatal(err)
 		}
 	}
-	sts := []raft.State{{Term: 1, Vote: 1, Commit: 1}, {Term: 2, Vote: 2, Commit: 2}}
+	sts := []raftpb.State{{Term: 1, Vote: 1, Commit: 1}, {Term: 2, Vote: 2, Commit: 2}}
 	for _, s := range sts {
 		if err = w.SaveState(&s); err != nil {
 			t.Fatal(err)
@@ -301,20 +175,24 @@ func TestNodeLoad(t *testing.T) {
 	}
 	w.Close()
 
-	n := newNode(0)
-	if err := n.load(path.Join(p, firstWalName)); err != nil {
+	if w, err = OpenFromIndex(p, 0); err != nil {
+		t.Fatal(err)
+	}
+	id, state, entries, err := w.ReadAll()
+	if err != nil {
 		t.Fatal(err)
 	}
-	if n.Id != i.Id {
-		t.Errorf("id = %d, want %d", n.Id, i.Id)
+
+	if id != i.Id {
+		t.Errorf("id = %d, want %d", id, i.Id)
 	}
-	if !reflect.DeepEqual(n.Ents, ents) {
-		t.Errorf("ents = %+v, want %+v", n.Ents, ents)
+	if !reflect.DeepEqual(entries, ents) {
+		t.Errorf("ents = %+v, want %+v", entries, ents)
 	}
 	// only the latest state is recorded
 	s := sts[len(sts)-1]
-	if !reflect.DeepEqual(n.State, s) {
-		t.Errorf("state = %+v, want %+v", n.State, s)
+	if !reflect.DeepEqual(state, s) {
+		t.Errorf("state = %+v, want %+v", state, s)
 	}
 }
 
@@ -385,7 +263,7 @@ func TestScanWalName(t *testing.T) {
 	}
 }
 
-func TestRead(t *testing.T) {
+func TestRecoverAfterCut(t *testing.T) {
 	p, err := ioutil.TempDir(os.TempDir(), "waltest")
 	if err != nil {
 		t.Fatal(err)
@@ -396,7 +274,7 @@ func TestRead(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	info := &raft.Info{Id: int64(0xBEEF)}
+	info := &raftpb.Info{Id: int64(0xBEEF)}
 	if err = w.SaveInfo(info); err != nil {
 		t.Fatal(err)
 	}
@@ -404,7 +282,7 @@ func TestRead(t *testing.T) {
 		t.Fatal(err)
 	}
 	for i := 1; i < 10; i++ {
-		e := raft.Entry{Index: int64(i)}
+		e := raftpb.Entry{Index: int64(i)}
 		if err = w.SaveEntry(&e); err != nil {
 			t.Fatal(err)
 		}
@@ -421,22 +299,23 @@ func TestRead(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	for i := 0; i < 15; i++ {
-		n, err := Read(p, int64(i))
-		if i <= 3 || i >= 10 {
+	for i := 0; i < 10; i++ {
+		w, err := OpenFromIndex(p, int64(i))
+		if i <= 3 {
 			if err != ErrNotFound {
 				t.Errorf("#%d: err = %v, want %v", i, err, ErrNotFound)
 			}
 			continue
 		}
+		id, _, entries, err := w.ReadAll()
 		if err != nil {
 			t.Errorf("#%d: err = %v, want nil", i, err)
 			continue
 		}
-		if n.Id != info.Id {
-			t.Errorf("#%d: id = %d, want %d", n.Id, info.Id)
+		if id != info.Id {
+			t.Errorf("#%d: id = %d, want %d", id, info.Id)
 		}
-		for j, e := range n.Ents {
+		for j, e := range entries {
 			if e.Index != int64(j+i+1) {
 				t.Errorf("#%d: ents[%d].Index = %+v, want %+v", i, j, e.Index, j+i+1)
 			}