Browse Source

wal: cleanup

Xiang Li 11 years ago
parent
commit
5baefcce26
4 changed files with 177 additions and 176 deletions
  1. 43 0
      wal/block.go
  2. 50 0
      wal/block_test.go
  3. 47 97
      wal/wal.go
  4. 37 79
      wal/wal_test.go

+ 43 - 0
wal/block.go

@@ -0,0 +1,43 @@
+package wal
+
+import (
+	"fmt"
+	"io"
+)
+
+type block struct {
+	t int64
+	l int64
+	d []byte
+}
+
+func writeBlock(w io.Writer, t int64, d []byte) error {
+	if err := writeInt64(w, t); err != nil {
+		return err
+	}
+	if err := writeInt64(w, int64(len(d))); err != nil {
+		return err
+	}
+	_, err := w.Write(d)
+	return err
+}
+
+func readBlock(r io.Reader) (*block, error) {
+	t, err := readInt64(r)
+	if err != nil {
+		return nil, err
+	}
+	l, err := readInt64(r)
+	if err != nil {
+		return nil, unexpectedEOF(err)
+	}
+	d := make([]byte, l)
+	n, err := r.Read(d)
+	if err != nil {
+		return nil, unexpectedEOF(err)
+	}
+	if n != int(l) {
+		return nil, fmt.Errorf("len(data) = %d, want %d", n, l)
+	}
+	return &block{t, l, d}, nil
+}

+ 50 - 0
wal/block_test.go

@@ -0,0 +1,50 @@
+package wal
+
+import (
+	"bytes"
+	"io"
+	"reflect"
+	"testing"
+)
+
+func TestReadBlock(t *testing.T) {
+	tests := []struct {
+		data []byte
+		wb   *block
+		we   error
+	}{
+		{infoBlock, &block{1, 8, infoData}, nil},
+		{[]byte(""), nil, io.EOF},
+		{infoBlock[:len(infoBlock)-len(infoData)-8], nil, io.ErrUnexpectedEOF},
+		{infoBlock[:len(infoBlock)-len(infoData)], nil, io.ErrUnexpectedEOF},
+		{infoBlock[:len(infoBlock)-8], nil, io.ErrUnexpectedEOF},
+	}
+
+	for i, tt := range tests {
+		buf := bytes.NewBuffer(tt.data)
+		b, e := readBlock(buf)
+		if !reflect.DeepEqual(b, tt.wb) {
+			t.Errorf("#%d: block = %v, want %v", i, b, tt.wb)
+		}
+		if !reflect.DeepEqual(e, tt.we) {
+			t.Errorf("#%d: err = %v, want %v", i, e, tt.we)
+		}
+	}
+}
+
+func TestWriteBlock(t *testing.T) {
+	typ := int64(0xABCD)
+	d := []byte("Hello world!")
+	buf := new(bytes.Buffer)
+	writeBlock(buf, typ, d)
+	b, err := readBlock(buf)
+	if err != nil {
+		t.Errorf("err = %v, want nil", err)
+	}
+	if b.t != typ {
+		t.Errorf("type = %d, want %d", b.t, typ)
+	}
+	if !reflect.DeepEqual(b.d, d) {
+		t.Errorf("data = %v, want %v", b.d, d)
+	}
+}

+ 47 - 97
wal/wal.go

@@ -49,63 +49,44 @@ func Open(path string) (*WAL, error) {
 
 
 func (w *WAL) Close() {
 func (w *WAL) Close() {
 	if w.f != nil {
 	if w.f != nil {
-		w.flush()
+		w.Flush()
 		w.f.Close()
 		w.f.Close()
 	}
 	}
 }
 }
 
 
-func (w *WAL) writeInfo(id int64) error {
-	// | 8 bytes | 8 bytes |  8 bytes |
-	// | type    |   len   |   nodeid |
+func (w *WAL) SaveInfo(id int64) error {
 	if err := w.checkAtHead(); err != nil {
 	if err := w.checkAtHead(); err != nil {
 		return err
 		return err
 	}
 	}
-	if err := w.writeInt64(infoType); err != nil {
-		return err
-	}
-	if err := w.writeInt64(8); err != nil {
-		return err
+	// cache the buffer?
+	buf := new(bytes.Buffer)
+	err := binary.Write(buf, binary.LittleEndian, id)
+	if err != nil {
+		panic(err)
 	}
 	}
-	return w.writeInt64(id)
+	return writeBlock(w.bw, infoType, buf.Bytes())
 }
 }
 
 
-func (w *WAL) writeEntry(e *raft.Entry) error {
-	// | 8 bytes | 8 bytes |  variable length |
-	// | type    |   len   |   entry data     |
-	if err := w.writeInt64(entryType); err != nil {
-		return err
-	}
+func (w *WAL) SaveEntry(e *raft.Entry) error {
+	// protobuf?
 	b, err := json.Marshal(e)
 	b, err := json.Marshal(e)
 	if err != nil {
 	if err != nil {
-		return err
-	}
-	n := len(b)
-	if err := w.writeInt64(int64(n)); err != nil {
-		return err
-	}
-	if _, err := w.bw.Write(b); err != nil {
-		return err
+		panic(err)
 	}
 	}
-	return nil
+	return writeBlock(w.bw, entryType, b)
 }
 }
 
 
-func (w *WAL) writeState(s *raft.State) error {
-	// | 8 bytes | 8 bytes |  24 bytes |
-	// | type    |   len   |   state   |
-	if err := w.writeInt64(stateType); err != nil {
-		return err
-	}
-	if err := w.writeInt64(24); err != nil {
-		return err
+func (w *WAL) SaveState(s *raft.State) error {
+	// cache the buffer?
+	buf := new(bytes.Buffer)
+	err := binary.Write(buf, binary.LittleEndian, s)
+	if err != nil {
+		panic(err)
 	}
 	}
-	return binary.Write(w.bw, binary.LittleEndian, s)
+	return writeBlock(w.bw, stateType, buf.Bytes())
 }
 }
 
 
-func (w *WAL) writeInt64(n int64) error {
-	return binary.Write(w.bw, binary.LittleEndian, n)
-}
-
-func (w *WAL) flush() error {
+func (w *WAL) Flush() error {
 	return w.bw.Flush()
 	return w.bw.Flush()
 }
 }
 
 
@@ -126,61 +107,51 @@ type Node struct {
 	State raft.State
 	State raft.State
 }
 }
 
 
-func (w *WAL) ReadNode() (*Node, error) {
+func (w *WAL) LoadNode() (*Node, error) {
 	if err := w.checkAtHead(); err != nil {
 	if err := w.checkAtHead(); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 	br := bufio.NewReader(w.f)
 	br := bufio.NewReader(w.f)
-	n := new(Node)
 
 
 	b, err := readBlock(br)
 	b, err := readBlock(br)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	switch b.t {
-	case infoType:
-		id, err := parseInfo(b.d)
-		if err != nil {
-			return nil, err
-		}
-		n.Id = id
-	default:
-		return nil, fmt.Errorf("type = %d, want %d", b.t, infoType)
+	if b.t != infoType {
+		return nil, fmt.Errorf("the first block of wal is not infoType but %d", b.t)
+	}
+	id, err := loadInfo(b.d)
+	if err != nil {
+		return nil, err
 	}
 	}
 
 
 	ents := make([]raft.Entry, 0)
 	ents := make([]raft.Entry, 0)
 	var state raft.State
 	var state raft.State
-	for {
-		b, err := readBlock(br)
-		if err == io.EOF {
-			break
-		}
-		if err != nil {
-			return nil, err
-		}
+	for b, err = readBlock(br); err == nil; b, err = readBlock(br) {
 		switch b.t {
 		switch b.t {
 		case entryType:
 		case entryType:
-			e, err := parseEntry(b.d)
+			e, err := loadEntry(b.d)
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
 			ents = append(ents, e)
 			ents = append(ents, e)
 		case stateType:
 		case stateType:
-			s, err := parseState(b.d)
+			s, err := loadState(b.d)
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
 			state = s
 			state = s
 		default:
 		default:
-			return nil, fmt.Errorf("cannot handle type %d", b.t)
+			return nil, fmt.Errorf("unexpected block type %d", b.t)
 		}
 		}
 	}
 	}
-	n.Ents = ents
-	n.State = state
-	return n, nil
+	if err != io.EOF {
+		return nil, err
+	}
+	return &Node{id, ents, state}, nil
 }
 }
 
 
-func parseInfo(d []byte) (int64, error) {
+func loadInfo(d []byte) (int64, error) {
 	if len(d) != 8 {
 	if len(d) != 8 {
 		return 0, fmt.Errorf("len = %d, want 8", len(d))
 		return 0, fmt.Errorf("len = %d, want 8", len(d))
 	}
 	}
@@ -188,49 +159,21 @@ func parseInfo(d []byte) (int64, error) {
 	return readInt64(buf)
 	return readInt64(buf)
 }
 }
 
 
-func parseEntry(d []byte) (raft.Entry, error) {
+func loadEntry(d []byte) (raft.Entry, error) {
 	var e raft.Entry
 	var e raft.Entry
 	err := json.Unmarshal(d, &e)
 	err := json.Unmarshal(d, &e)
 	return e, err
 	return e, err
 }
 }
 
 
-func parseState(d []byte) (raft.State, error) {
+func loadState(d []byte) (raft.State, error) {
 	var s raft.State
 	var s raft.State
 	buf := bytes.NewBuffer(d)
 	buf := bytes.NewBuffer(d)
 	err := binary.Read(buf, binary.LittleEndian, &s)
 	err := binary.Read(buf, binary.LittleEndian, &s)
 	return s, err
 	return s, err
 }
 }
 
 
-type block struct {
-	t int64
-	l int64
-	d []byte
-}
-
-func readBlock(r io.Reader) (*block, error) {
-	typ, err := readInt64(r)
-	if err != nil {
-		return nil, err
-	}
-	l, err := readInt64(r)
-	if err != nil {
-		if err == io.EOF {
-			err = io.ErrUnexpectedEOF
-		}
-		return nil, err
-	}
-	data := make([]byte, l)
-	n, err := r.Read(data)
-	if err != nil {
-		if err == io.EOF {
-			err = io.ErrUnexpectedEOF
-		}
-		return nil, err
-	}
-	if n != int(l) {
-		return nil, fmt.Errorf("len(data) = %d, want %d", n, l)
-	}
-	return &block{typ, l, data}, nil
+func writeInt64(w io.Writer, n int64) error {
+	return binary.Write(w, binary.LittleEndian, n)
 }
 }
 
 
 func readInt64(r io.Reader) (int64, error) {
 func readInt64(r io.Reader) (int64, error) {
@@ -239,6 +182,13 @@ func readInt64(r io.Reader) (int64, error) {
 	return n, err
 	return n, err
 }
 }
 
 
+func unexpectedEOF(err error) error {
+	if err == io.EOF {
+		return io.ErrUnexpectedEOF
+	}
+	return err
+}
+
 func max(a, b int64) int64 {
 func max(a, b int64) int64 {
 	if a > b {
 	if a > b {
 		return a
 		return a

+ 37 - 79
wal/wal_test.go

@@ -1,8 +1,6 @@
 package wal
 package wal
 
 
 import (
 import (
-	"bytes"
-	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"os"
 	"os"
 	"path"
 	"path"
@@ -12,6 +10,17 @@ import (
 	"github.com/coreos/etcd/raft"
 	"github.com/coreos/etcd/raft"
 )
 )
 
 
+var (
+	infoData  = []byte("\xef\xbe\x00\x00\x00\x00\x00\x00")
+	infoBlock = append([]byte("\x01\x00\x00\x00\x00\x00\x00\x00\b\x00\x00\x00\x00\x00\x00\x00"), infoData...)
+
+	stateData  = []byte("\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00")
+	stateBlock = append([]byte("\x03\x00\x00\x00\x00\x00\x00\x00\x18\x00\x00\x00\x00\x00\x00\x00"), stateData...)
+
+	entryJsonData = []byte("{\"Type\":1,\"Term\":1,\"Data\":\"AQ==\"}")
+	entryBlock    = append([]byte("\x02\x00\x00\x00\x00\x00\x00\x00\x21\x00\x00\x00\x00\x00\x00\x00"), entryJsonData...)
+)
+
 func TestNew(t *testing.T) {
 func TestNew(t *testing.T) {
 	f, err := ioutil.TempFile(os.TempDir(), "waltest")
 	f, err := ioutil.TempFile(os.TempDir(), "waltest")
 	if err != nil {
 	if err != nil {
@@ -37,14 +46,14 @@ func TestNew(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func TestWriteEntry(t *testing.T) {
+func TestSaveEntry(t *testing.T) {
 	p := path.Join(os.TempDir(), "waltest")
 	p := path.Join(os.TempDir(), "waltest")
 	w, err := New(p)
 	w, err := New(p)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 	e := &raft.Entry{1, 1, []byte{1}}
 	e := &raft.Entry{1, 1, []byte{1}}
-	err = w.writeEntry(e)
+	err = w.SaveEntry(e)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -54,9 +63,8 @@ func TestWriteEntry(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	wb := []byte("\x02\x00\x00\x00\x00\x00\x00\x00!\x00\x00\x00\x00\x00\x00\x00{\"Type\":1,\"Term\":1,\"Data\":\"AQ==\"}")
-	if !reflect.DeepEqual(b, wb) {
-		t.Errorf("ent = %q, want %q", b, wb)
+	if !reflect.DeepEqual(b, entryBlock) {
+		t.Errorf("ent = %q, want %q", b, entryBlock)
 	}
 	}
 
 
 	err = os.Remove(p)
 	err = os.Remove(p)
@@ -65,28 +73,28 @@ func TestWriteEntry(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func TestWriteInfo(t *testing.T) {
+func TestSaveInfo(t *testing.T) {
 	p := path.Join(os.TempDir(), "waltest")
 	p := path.Join(os.TempDir(), "waltest")
 	w, err := New(p)
 	w, err := New(p)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 	id := int64(0xBEEF)
 	id := int64(0xBEEF)
-	err = w.writeInfo(id)
+	err = w.SaveInfo(id)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
 	// make sure we can only write info at the head of the wal file
 	// make sure we can only write info at the head of the wal file
 	// still in buffer
 	// still in buffer
-	err = w.writeInfo(id)
+	err = w.SaveInfo(id)
 	if err == nil || err.Error() != "cannot write info at 24, expect 0" {
 	if err == nil || err.Error() != "cannot write info at 24, expect 0" {
 		t.Errorf("err = %v, want cannot write info at 8, expect 0", err)
 		t.Errorf("err = %v, want cannot write info at 8, expect 0", err)
 	}
 	}
 
 
 	// flush to disk
 	// flush to disk
-	w.flush()
-	err = w.writeInfo(id)
+	w.Flush()
+	err = w.SaveInfo(id)
 	if err == nil || err.Error() != "cannot write info at 24, expect 0" {
 	if err == nil || err.Error() != "cannot write info at 24, expect 0" {
 		t.Errorf("err = %v, want cannot write info at 8, expect 0", err)
 		t.Errorf("err = %v, want cannot write info at 8, expect 0", err)
 	}
 	}
@@ -96,9 +104,8 @@ func TestWriteInfo(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	wb := []byte("\x01\x00\x00\x00\x00\x00\x00\x00\b\x00\x00\x00\x00\x00\x00\x00\xef\xbe\x00\x00\x00\x00\x00\x00")
-	if !reflect.DeepEqual(b, wb) {
-		t.Errorf("ent = %q, want %q", b, wb)
+	if !reflect.DeepEqual(b, infoBlock) {
+		t.Errorf("ent = %q, want %q", b, infoBlock)
 	}
 	}
 
 
 	err = os.Remove(p)
 	err = os.Remove(p)
@@ -107,14 +114,14 @@ func TestWriteInfo(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func TestWriteState(t *testing.T) {
+func TestSaveState(t *testing.T) {
 	p := path.Join(os.TempDir(), "waltest")
 	p := path.Join(os.TempDir(), "waltest")
 	w, err := New(p)
 	w, err := New(p)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 	st := &raft.State{1, 1, 1}
 	st := &raft.State{1, 1, 1}
-	err = w.writeState(st)
+	err = w.SaveState(st)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -124,9 +131,8 @@ func TestWriteState(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	wb := []byte("\x03\x00\x00\x00\x00\x00\x00\x00\x18\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00")
-	if !reflect.DeepEqual(b, wb) {
-		t.Errorf("ent = %q, want %q", b, wb)
+	if !reflect.DeepEqual(b, stateBlock) {
+		t.Errorf("ent = %q, want %q", b, stateBlock)
 	}
 	}
 
 
 	err = os.Remove(p)
 	err = os.Remove(p)
@@ -135,9 +141,8 @@ func TestWriteState(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func TestParseInfo(t *testing.T) {
-	data := []byte("\xef\xbe\x00\x00\x00\x00\x00\x00")
-	id, err := parseInfo(data)
+func TestLoadInfo(t *testing.T) {
+	id, err := loadInfo(infoData)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -146,9 +151,8 @@ func TestParseInfo(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func TestParseEntry(t *testing.T) {
-	data := []byte("{\"Type\":1,\"Term\":1,\"Data\":\"AQ==\"}")
-	e, err := parseEntry(data)
+func TestLoadEntry(t *testing.T) {
+	e, err := loadEntry(entryJsonData)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -158,9 +162,8 @@ func TestParseEntry(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func TestParseState(t *testing.T) {
-	data := []byte("\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00")
-	s, err := parseState(data)
+func TestLoadState(t *testing.T) {
+	s, err := loadState(stateData)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -170,70 +173,25 @@ func TestParseState(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func TestReadBlock(t *testing.T) {
-	tests := []struct {
-		data []byte
-		wb   *block
-		we   error
-	}{
-		{
-			[]byte("\x01\x00\x00\x00\x00\x00\x00\x00\b\x00\x00\x00\x00\x00\x00\x00\xef\xbe\x00\x00\x00\x00\x00\x00"),
-			&block{1, 8, []byte("\xef\xbe\x00\x00\x00\x00\x00\x00")},
-			nil,
-		},
-		{
-			[]byte(""),
-			nil,
-			io.EOF,
-		},
-		{
-			[]byte("\x01\x00\x00\x00"),
-			nil,
-			io.ErrUnexpectedEOF,
-		},
-		{
-			[]byte("\x01\x00\x00\x00\x00\x00\x00\x00"),
-			nil,
-			io.ErrUnexpectedEOF,
-		},
-		{
-			[]byte("\x01\x00\x00\x00\x00\x00\x00\x00\b\x00\x00\x00\x00\x00\x00\x00"),
-			nil,
-			io.ErrUnexpectedEOF,
-		},
-	}
-
-	for i, tt := range tests {
-		buf := bytes.NewBuffer(tt.data)
-		b, e := readBlock(buf)
-		if !reflect.DeepEqual(b, tt.wb) {
-			t.Errorf("#%d: block = %v, want %v", i, b, tt.wb)
-		}
-		if !reflect.DeepEqual(e, tt.we) {
-			t.Errorf("#%d: err = %v, want %v", i, e, tt.we)
-		}
-	}
-}
-
-func TestReadNode(t *testing.T) {
+func TestLoadNode(t *testing.T) {
 	p := path.Join(os.TempDir(), "waltest")
 	p := path.Join(os.TempDir(), "waltest")
 	w, err := New(p)
 	w, err := New(p)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 	id := int64(0xBEEF)
 	id := int64(0xBEEF)
-	if err = w.writeInfo(id); err != nil {
+	if err = w.SaveInfo(id); err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 	ents := []raft.Entry{{1, 1, []byte{1}}, {2, 2, []byte{2}}}
 	ents := []raft.Entry{{1, 1, []byte{1}}, {2, 2, []byte{2}}}
 	for _, e := range ents {
 	for _, e := range ents {
-		if err = w.writeEntry(&e); err != nil {
+		if err = w.SaveEntry(&e); err != nil {
 			t.Fatal(err)
 			t.Fatal(err)
 		}
 		}
 	}
 	}
 	sts := []raft.State{{1, 1, 1}, {2, 2, 2}}
 	sts := []raft.State{{1, 1, 1}, {2, 2, 2}}
 	for _, s := range sts {
 	for _, s := range sts {
-		if err = w.writeState(&s); err != nil {
+		if err = w.SaveState(&s); err != nil {
 			t.Fatal(err)
 			t.Fatal(err)
 		}
 		}
 	}
 	}
@@ -243,7 +201,7 @@ func TestReadNode(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	n, err := w.ReadNode()
+	n, err := w.LoadNode()
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}