Browse Source

Writer tests pass

Pierre.Curto 5 years ago
parent
commit
23f1199d93
9 changed files with 118 additions and 126 deletions
  1. 10 8
      frame.go
  2. 12 16
      lz4.go
  3. 8 2
      options.go
  4. 2 4
      reader.go
  5. 5 4
      reader_test.go
  6. 0 1
      state.go
  7. 5 6
      state_gen.go
  8. 16 21
      writer.go
  9. 60 64
      writer_test.go

+ 10 - 8
frame.go

@@ -31,11 +31,11 @@ func (f *Frame) closeW(w *Writer) error {
 		return err
 	}
 	buf := w.buf[:0]
+	// End mark (data block size of uint32(0)).
+	buf = append(buf, 0, 0, 0, 0)
 	if f.Descriptor.Flags.ContentChecksum() {
 		buf = f.checksum.Sum(buf)
 	}
-	// End mark (data block size of uint32(0)).
-	buf = append(buf, 0, 0, 0, 0)
 	_, err := w.src.Write(buf)
 	return err
 }
@@ -62,7 +62,7 @@ newFrame:
 		}
 		goto newFrame
 	default:
-		return ErrInvalid
+		return ErrInvalidFrame
 	}
 	if err := f.Descriptor.initR(r); err != nil {
 		return err
@@ -103,14 +103,16 @@ func (fd *FrameDescriptor) write(w *Writer) error {
 		return nil
 	}
 
-	buf := w.buf[:2]
-	binary.LittleEndian.PutUint16(buf, uint16(fd.Flags))
+	buf := w.buf[:4+2]
+	// Write the magic number here even though it belongs to the Frame.
+	binary.LittleEndian.PutUint32(buf, w.frame.Magic)
+	binary.LittleEndian.PutUint16(buf[4:], uint16(fd.Flags))
 
 	if fd.Flags.Size() {
-		buf = buf[:10]
-		binary.LittleEndian.PutUint64(buf[2:], fd.ContentSize)
+		buf = buf[:4+2+8]
+		binary.LittleEndian.PutUint64(buf[4+2:], fd.ContentSize)
 	}
-	fd.Checksum = descriptorChecksum(buf)
+	fd.Checksum = descriptorChecksum(buf[4:])
 	buf = append(buf, fd.Checksum)
 
 	_, err := w.src.Write(buf)

+ 12 - 16
lz4.go

@@ -28,28 +28,24 @@ const (
 	// ErrInvalidSourceShortBuffer is returned by UncompressBlock or CompressBLock when a compressed
 	// block is corrupted or the destination buffer is not large enough for the uncompressed data.
 	ErrInvalidSourceShortBuffer _error = "lz4: invalid source or destination buffer too short"
-	// ErrClosed is returned when calling Write/Read or Close on an already closed Writer/Reader.
-	ErrClosed _error = "lz4: closed Writer"
-	// ErrInvalid is returned when reading an invalid LZ4 archive.
-	ErrInvalid _error = "lz4: bad magic number"
-	// ErrBlockDependency is returned when attempting to decompress an archive created with block dependency.
-	ErrBlockDependency _error = "lz4: block dependency not supported"
+	// ErrInvalidFrame is returned when reading an invalid LZ4 archive.
+	ErrInvalidFrame _error = "lz4: bad magic number"
 	// ErrUnsupportedSeek is returned when attempting to Seek any way but forward from the current position.
 	ErrUnsupportedSeek _error = "lz4: can only seek forward from io.SeekCurrent"
 	// ErrInternalUnhandledState is an internal error.
 	ErrInternalUnhandledState _error = "lz4: unhandled state"
-	// ErrInvalidHeaderChecksum
+	// ErrInvalidHeaderChecksum is returned when reading a frame.
 	ErrInvalidHeaderChecksum _error = "lz4: invalid header checksum"
-	// ErrInvalidBlockChecksum
+	// ErrInvalidBlockChecksum is returned when reading a frame.
 	ErrInvalidBlockChecksum _error = "lz4: invalid block checksum"
-	// ErrInvalidFrameChecksum
+	// ErrInvalidFrameChecksum is returned when reading a frame.
 	ErrInvalidFrameChecksum _error = "lz4: invalid frame checksum"
-	// ErrInvalidCompressionLevel
-	ErrInvalidCompressionLevel _error = "lz4: invalid compression level"
-	// ErrCannotApplyOptions
-	ErrCannotApplyOptions _error = "lz4: cannot apply options"
-	// ErrInvalidBlockSize
-	ErrInvalidBlockSize _error = "lz4: invalid block size"
-	// ErrOptionNotApplicable
+	// ErrOptionInvalidCompressionLevel is returned when the supplied compression level is invalid.
+	ErrOptionInvalidCompressionLevel _error = "lz4: invalid compression level"
+	// ErrOptionClosedOrError is returned when an option is applied to a closed or in error object.
+	ErrOptionClosedOrError _error = "lz4: cannot apply options on closed or in error object"
+	// ErrOptionInvalidBlockSize is returned when
+	ErrOptionInvalidBlockSize _error = "lz4: invalid block size"
+	// ErrOptionNotApplicable is returned when trying to apply an option to an object not supporting it.
 	ErrOptionNotApplicable _error = "lz4: option not applicable"
 )

+ 8 - 2
options.go

@@ -2,6 +2,7 @@ package lz4
 
 import (
 	"fmt"
+	"reflect"
 	"runtime"
 	"sync"
 )
@@ -17,6 +18,11 @@ type (
 	Option func(Applier) error
 )
 
+func (o Option) String() string {
+	//TODO proper naming of options
+	return reflect.TypeOf(o).String()
+}
+
 // Default options.
 var (
 	defaultBlockSizeOption = BlockSizeOption(Block4Mb)
@@ -98,7 +104,7 @@ func BlockSizeOption(size BlockSize) Option {
 			return ErrOptionNotApplicable
 		}
 		if !size.isValid() {
-			return fmt.Errorf("%w: %d", ErrInvalidBlockSize, size)
+			return fmt.Errorf("%w: %d", ErrOptionInvalidBlockSize, size)
 		}
 		w.frame.Descriptor.Flags.BlockSizeIndexSet(size.index())
 		return nil
@@ -188,7 +194,7 @@ func CompressionLevelOption(level CompressionLevel) Option {
 		switch level {
 		case Fast, Level1, Level2, Level3, Level4, Level5, Level6, Level7, Level8, Level9:
 		default:
-			return fmt.Errorf("%w: %d", ErrInvalidCompressionLevel, level)
+			return fmt.Errorf("%w: %d", ErrOptionInvalidCompressionLevel, level)
 		}
 		w.level = level
 		return nil

+ 2 - 4
reader.go

@@ -7,8 +7,7 @@ import (
 var readerStates = []aState{
 	noState:     newState,
 	errorState:  newState,
-	newState:    headerState,
-	headerState: readState,
+	newState:    readState,
 	readState:   closedState,
 	closedState: newState,
 }
@@ -40,7 +39,7 @@ func (r *Reader) Apply(options ...Option) (err error) {
 	case errorState:
 		return r.state.err
 	default:
-		return ErrCannotApplyOptions
+		return ErrOptionClosedOrError
 	}
 	for _, o := range options {
 		if err = o(r); err != nil {
@@ -69,7 +68,6 @@ func (r *Reader) Read(buf []byte) (n int, err error) {
 		return 0, r.state.err
 	case newState:
 		// First initialization.
-		r.state.next(nil)
 		if err = r.frame.initR(r); r.state.next(err) {
 			return
 		}

+ 5 - 4
reader_test.go

@@ -42,9 +42,9 @@ func TestReader(t *testing.T) {
 				t.Fatal(err)
 			}
 
-			var out bytes.Buffer
+			out := new(bytes.Buffer)
 			zr := lz4.NewReader(f)
-			n, err := io.Copy(&out, zr)
+			n, err := io.Copy(out, zr)
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -69,7 +69,7 @@ func TestReader(t *testing.T) {
 
 			out.Reset()
 			zr = lz4.NewReader(f2)
-			_, err = io.CopyN(&out, zr, 10)
+			_, err = io.CopyN(out, zr, 10)
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -78,6 +78,7 @@ func TestReader(t *testing.T) {
 			}
 			return
 
+			//TODO add Reader.Seek
 			pos, err := zr.Seek(-1, io.SeekCurrent)
 			if err == nil {
 				t.Fatal("expected error from invalid seek")
@@ -109,7 +110,7 @@ func TestReader(t *testing.T) {
 			}
 
 			out.Reset()
-			_, err = io.CopyN(&out, zr, 10)
+			_, err = io.CopyN(out, zr, 10)
 			if err != nil {
 				t.Fatal(err)
 			}

+ 0 - 1
state.go

@@ -12,7 +12,6 @@ const (
 	noState     aState = iota // uninitialized reader
 	errorState                // unrecoverable error encountered
 	newState                  // instantiated object
-	headerState               // processing header
 	readState                 // reading data
 	writeState                // writing data
 	closedState               // all done

+ 5 - 6
state_gen.go

@@ -11,15 +11,14 @@ func _() {
 	_ = x[noState-0]
 	_ = x[errorState-1]
 	_ = x[newState-2]
-	_ = x[headerState-3]
-	_ = x[readState-4]
-	_ = x[writeState-5]
-	_ = x[closedState-6]
+	_ = x[readState-3]
+	_ = x[writeState-4]
+	_ = x[closedState-5]
 }
 
-const _aState_name = "noStateerrorStatenewStateheaderStatereadStatewriteStateclosedState"
+const _aState_name = "noStateerrorStatenewStatereadStatewriteStateclosedState"
 
-var _aState_index = [...]uint8{0, 7, 17, 25, 36, 45, 55, 66}
+var _aState_index = [...]uint8{0, 7, 17, 25, 34, 44, 55}
 
 func (i aState) String() string {
 	if i >= aState(len(_aState_index)-1) {

+ 16 - 21
writer.go

@@ -4,8 +4,7 @@ import "io"
 
 var writerStates = []aState{
 	noState:     newState,
-	newState:    headerState,
-	headerState: writeState,
+	newState:    writeState,
 	writeState:  closedState,
 	closedState: newState,
 	errorState:  newState,
@@ -21,7 +20,7 @@ func NewWriter(w io.Writer) *Writer {
 
 type Writer struct {
 	state   _State
-	buf     [11]byte         // frame descriptor needs at most 4+8+1=11 bytes
+	buf     [15]byte         // frame descriptor needs at most 4(magic)+4+8+1=11 bytes
 	src     io.Writer        // destination writer
 	level   CompressionLevel // how hard to try
 	num     int              // concurrency level
@@ -41,7 +40,7 @@ func (w *Writer) Apply(options ...Option) (err error) {
 	case errorState:
 		return w.state.err
 	default:
-		return ErrCannotApplyOptions
+		return ErrOptionClosedOrError
 	}
 	for _, o := range options {
 		if err = o(w); err != nil {
@@ -62,7 +61,6 @@ func (w *Writer) Write(buf []byte) (n int, err error) {
 	case closedState, errorState:
 		return 0, w.state.err
 	case newState:
-		w.state.next(nil)
 		if err = w.frame.Descriptor.write(w); w.state.next(err) {
 			return
 		}
@@ -74,7 +72,7 @@ func (w *Writer) Write(buf []byte) (n int, err error) {
 	for len(buf) > 0 {
 		if w.idx == 0 && len(buf) >= zn {
 			// Avoid a copy as there is enough data for a block.
-			if err = w.write(); err != nil {
+			if err = w.write(buf[:zn], false); err != nil {
 				return
 			}
 			n += zn
@@ -93,7 +91,7 @@ func (w *Writer) Write(buf []byte) (n int, err error) {
 		}
 
 		// Buffer full.
-		if err = w.write(); err != nil {
+		if err = w.write(w.data, true); err != nil {
 			return
 		}
 		w.idx = 0
@@ -101,10 +99,11 @@ func (w *Writer) Write(buf []byte) (n int, err error) {
 	return
 }
 
-func (w *Writer) write() error {
+func (w *Writer) write(data []byte, direct bool) error {
 	if w.isNotConcurrent() {
-		defer w.handler(len(w.data))
-		return w.frame.Blocks.Block.compress(w, w.data, w.ht).write(w)
+		defer w.handler(len(data))
+		block := w.frame.Blocks.Block
+		return block.compress(w, data, w.ht).write(w)
 	}
 	size := w.frame.Descriptor.Flags.BlockSizeIndex()
 	c := make(chan *FrameDataBlock)
@@ -122,20 +121,18 @@ func (w *Writer) write() error {
 		size.put(data)
 		<-c
 		size.put(zdata)
-	}(c, w.data, size)
+	}(c, data, size)
 
-	if w.idx > 0 {
-		// Not closed.
+	if direct {
 		w.data = size.get()
 	}
-	w.idx = 0
 
 	return nil
 }
 
 // Close closes the Writer, flushing any unwritten data to the underlying io.Writer,
 // but does not close the underlying io.Writer.
-func (w *Writer) Close() error {
+func (w *Writer) Close() (err error) {
 	switch w.state.state {
 	case writeState:
 	case errorState:
@@ -143,21 +140,19 @@ func (w *Writer) Close() error {
 	default:
 		return nil
 	}
-	var err error
 	defer func() { w.state.next(err) }()
-	if idx := w.idx; idx > 0 {
+	if w.idx > 0 {
 		// Flush pending data.
-		w.data = w.data[:idx]
-		w.idx = 0
-		if err = w.write(); err != nil {
+		if err = w.write(w.data[:w.idx], false); err != nil {
 			return err
 		}
-		w.data = nil
+		w.idx = 0
 	}
 	if w.isNotConcurrent() {
 		htPool.Put(w.ht)
 		size := w.frame.Descriptor.Flags.BlockSizeIndex()
 		size.put(w.data)
+		w.data = nil
 	}
 	return w.frame.closeW(w)
 }

+ 60 - 64
_writer_test.go → writer_test.go

@@ -8,6 +8,8 @@ import (
 	"os"
 	"reflect"
 	"testing"
+
+	"github.com/pierrec/lz4"
 )
 
 func TestWriter(t *testing.T) {
@@ -23,72 +25,67 @@ func TestWriter(t *testing.T) {
 	}
 
 	for _, fname := range goldenFiles {
-		for _, size := range []int{0, 4} {
-			for _, header := range []Header{
-				{}, // Default header.
-				{BlockChecksum: true},
-				{NoChecksum: true},
-				{BlockMaxSize: 64 << 10}, // 64Kb
-				{CompressionLevel: 10},
-				{Size: 123},
-			} {
-				label := fmt.Sprintf("%s/%s", fname, header)
-				t.Run(label, func(t *testing.T) {
-					fname := fname
-					header := header
-					t.Parallel()
-
-					raw, err := ioutil.ReadFile(fname)
-					if err != nil {
-						t.Fatal(err)
-					}
-					r := bytes.NewReader(raw)
-
-					// Compress.
-					var zout bytes.Buffer
-					zw := NewWriter(&zout)
-					zw.Header = header
-					zw.WithConcurrency(size)
-					_, err = io.Copy(zw, r)
-					if err != nil {
-						t.Fatal(err)
-					}
-					err = zw.Close()
-					if err != nil {
-						t.Fatal(err)
-					}
-
-					// Uncompress.
-					var out bytes.Buffer
-					zr := NewReader(&zout)
-					n, err := io.Copy(&out, zr)
-					if err != nil {
-						t.Fatal(err)
-					}
-
-					// The uncompressed data must be the same as the initial input.
-					if got, want := int(n), len(raw); got != want {
-						t.Errorf("invalid sizes: got %d; want %d", got, want)
-					}
-
-					if got, want := out.Bytes(), raw; !reflect.DeepEqual(got, want) {
-						t.Fatal("uncompressed data does not match original")
-					}
-				})
-			}
+		for _, option := range []lz4.Option{
+			lz4.ConcurrencyOption(1),
+			//lz4.BlockChecksumOption(true),
+			//lz4.SizeOption(123),
+			//lz4.ConcurrencyOption(2),
+		} {
+			label := fmt.Sprintf("%s/%s", fname, option)
+			t.Run(label, func(t *testing.T) {
+				fname := fname
+				t.Parallel()
+
+				raw, err := ioutil.ReadFile(fname)
+				if err != nil {
+					t.Fatal(err)
+				}
+				r := bytes.NewReader(raw)
+
+				// Compress.
+				zout := new(bytes.Buffer)
+				zw := lz4.NewWriter(zout)
+				if err := zw.Apply(option); err != nil {
+					t.Fatal(err)
+				}
+				_, err = io.Copy(zw, r)
+				if err != nil {
+					t.Fatal(err)
+				}
+				err = zw.Close()
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				// Uncompress.
+				out := new(bytes.Buffer)
+				zr := lz4.NewReader(zout)
+				n, err := io.Copy(out, zr)
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				// The uncompressed data must be the same as the initial input.
+				if got, want := int(n), len(raw); got != want {
+					t.Errorf("invalid sizes: got %d; want %d", got, want)
+				}
+
+				if got, want := out.Bytes(), raw; !reflect.DeepEqual(got, want) {
+					t.Fatal("uncompressed data does not match original")
+				}
+			})
 		}
 	}
 }
 
 func TestIssue41(t *testing.T) {
 	r, w := io.Pipe()
-	zw := NewWriter(w)
-	zr := NewReader(r)
+	zw := lz4.NewWriter(w)
+	zr := lz4.NewReader(r)
 
 	data := "x"
 	go func() {
 		_, _ = fmt.Fprint(zw, data)
-		_ = zw.Flush()
 		_ = zw.Close()
 		_ = w.Close()
 	}()
@@ -110,7 +107,7 @@ func TestIssue43(t *testing.T) {
 		}
 		defer f.Close()
 
-		zw := NewWriter(w)
+		zw := lz4.NewWriter(w)
 		defer zw.Close()
 
 		_, err = io.Copy(zw, f)
@@ -118,7 +115,7 @@ func TestIssue43(t *testing.T) {
 			t.Fatal(err)
 		}
 	}()
-	_, err := io.Copy(ioutil.Discard, NewReader(r))
+	_, err := io.Copy(ioutil.Discard, lz4.NewReader(r))
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -131,16 +128,15 @@ func TestIssue51(t *testing.T) {
 	}
 
 	zbuf := make([]byte, 8192)
-	ht := make([]int, htSize)
 
-	n, err := CompressBlock(data, zbuf, ht)
+	n, err := lz4.CompressBlock(data, zbuf, nil)
 	if err != nil {
 		t.Fatal(err)
 	}
 	zbuf = zbuf[:n]
 
 	buf := make([]byte, 8192)
-	n, err = UncompressBlock(zbuf, buf)
+	n, err = lz4.UncompressBlock(zbuf, buf)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -157,11 +153,11 @@ func TestIssue71(t *testing.T) {
 	} {
 		t.Run(tc, func(t *testing.T) {
 			src := []byte(tc)
-			bound := CompressBlockBound(len(tc))
+			bound := lz4.CompressBlockBound(len(tc))
 
 			// Small buffer.
 			zSmall := make([]byte, bound-1)
-			n, err := CompressBlock(src, zSmall, nil)
+			n, err := lz4.CompressBlock(src, zSmall, nil)
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -171,7 +167,7 @@ func TestIssue71(t *testing.T) {
 
 			// Large enough buffer.
 			zLarge := make([]byte, bound)
-			n, err = CompressBlock(src, zLarge, nil)
+			n, err = lz4.CompressBlock(src, zLarge, nil)
 			if err != nil {
 				t.Fatal(err)
 			}