Browse Source

work started on tests

Pierre.Curto 5 years ago
parent
commit
2b087337a4
11 changed files with 88 additions and 111 deletions
  1. 0 0
      _writer_test.go
  2. 2 2
      bench_test.go
  3. 1 1
      example_test.go
  4. 26 27
      frame.go
  5. 0 35
      internal/xxh32/xxh32zero.go
  6. 6 2
      lz4.go
  7. 8 8
      options.go
  8. 11 10
      reader.go
  9. 2 1
      reader_test.go
  10. 3 4
      state.go
  11. 29 21
      writer.go

+ 0 - 0
writer_test.go → _writer_test.go


+ 2 - 2
bench_test.go

@@ -128,7 +128,7 @@ func BenchmarkSkipBytesRand(b *testing.B)   { benchmarkSkipBytes(b, randomLZ4) }
 
 
 func benchmarkCompress(b *testing.B, uncompressed []byte) {
 func benchmarkCompress(b *testing.B, uncompressed []byte) {
 	w := bytes.NewBuffer(nil)
 	w := bytes.NewBuffer(nil)
-	zw, _ := lz4.NewWriter(w)
+	zw := lz4.NewWriter(w)
 	r := bytes.NewReader(uncompressed)
 	r := bytes.NewReader(uncompressed)
 
 
 	// Determine the compressed size of testfile.
 	// Determine the compressed size of testfile.
@@ -161,7 +161,7 @@ func BenchmarkCompressRand(b *testing.B)   { benchmarkCompress(b, random) }
 func BenchmarkWriterReset(b *testing.B) {
 func BenchmarkWriterReset(b *testing.B) {
 	b.ReportAllocs()
 	b.ReportAllocs()
 
 
-	zw, _ := lz4.NewWriter(nil)
+	zw := lz4.NewWriter(nil)
 	src := mustLoadFile("testdata/gettysburg.txt")
 	src := mustLoadFile("testdata/gettysburg.txt")
 	var buf bytes.Buffer
 	var buf bytes.Buffer
 
 

+ 1 - 1
example_test.go

@@ -16,7 +16,7 @@ func Example() {
 
 
 	// The pipe will uncompress the data from the writer.
 	// The pipe will uncompress the data from the writer.
 	pr, pw := io.Pipe()
 	pr, pw := io.Pipe()
-	zw, _ := lz4.NewWriter(pw)
+	zw := lz4.NewWriter(pw)
 	zr := lz4.NewReader(pr)
 	zr := lz4.NewReader(pr)
 
 
 	go func() {
 	go func() {

+ 26 - 27
frame.go

@@ -19,14 +19,14 @@ type Frame struct {
 	checksum   xxh32.XXHZero
 	checksum   xxh32.XXHZero
 }
 }
 
 
-func (f *Frame) initW(w *_Writer) {
+func (f *Frame) initW(w *Writer) {
 	f.Magic = frameMagic
 	f.Magic = frameMagic
 	f.Descriptor.initW(w)
 	f.Descriptor.initW(w)
 	f.Blocks.initW(w)
 	f.Blocks.initW(w)
 	f.checksum.Reset()
 	f.checksum.Reset()
 }
 }
 
 
-func (f *Frame) closeW(w *_Writer) error {
+func (f *Frame) closeW(w *Writer) error {
 	if err := f.Blocks.closeW(w); err != nil {
 	if err := f.Blocks.closeW(w); err != nil {
 		return err
 		return err
 	}
 	}
@@ -40,7 +40,7 @@ func (f *Frame) closeW(w *_Writer) error {
 	return err
 	return err
 }
 }
 
 
-func (f *Frame) initR(r *_Reader) error {
+func (f *Frame) initR(r *Reader) error {
 	if f.Magic > 0 {
 	if f.Magic > 0 {
 		// Header already read.
 		// Header already read.
 		return nil
 		return nil
@@ -72,7 +72,7 @@ newFrame:
 	return nil
 	return nil
 }
 }
 
 
-func (f *Frame) closeR(r *_Reader) error {
+func (f *Frame) closeR(r *Reader) error {
 	f.Magic = 0
 	f.Magic = 0
 	if !f.Descriptor.Flags.ContentChecksum() {
 	if !f.Descriptor.Flags.ContentChecksum() {
 		return nil
 		return nil
@@ -92,28 +92,27 @@ type FrameDescriptor struct {
 	Checksum    uint8
 	Checksum    uint8
 }
 }
 
 
-func (fd *FrameDescriptor) initW(_ *_Writer) {
+func (fd *FrameDescriptor) initW(_ *Writer) {
 	fd.Flags.VersionSet(1)
 	fd.Flags.VersionSet(1)
 	fd.Flags.BlockIndependenceSet(false)
 	fd.Flags.BlockIndependenceSet(false)
 }
 }
 
 
-func (fd *FrameDescriptor) write(w *_Writer) error {
+func (fd *FrameDescriptor) write(w *Writer) error {
 	if fd.Checksum > 0 {
 	if fd.Checksum > 0 {
 		// Header already written.
 		// Header already written.
 		return nil
 		return nil
 	}
 	}
 
 
-	buf := w.buf[:]
+	buf := w.buf[:2]
 	binary.LittleEndian.PutUint16(buf, uint16(fd.Flags))
 	binary.LittleEndian.PutUint16(buf, uint16(fd.Flags))
 
 
 	var checksum uint32
 	var checksum uint32
 	if fd.Flags.Size() {
 	if fd.Flags.Size() {
-		checksum = xxh32.ChecksumZero10(uint16(fd.Flags), fd.ContentSize)
-		binary.LittleEndian.PutUint64(buf[2:], fd.ContentSize)
 		buf = buf[:10]
 		buf = buf[:10]
+		binary.LittleEndian.PutUint64(buf[2:], fd.ContentSize)
+		checksum = xxh32.ChecksumZero(buf)
 	} else {
 	} else {
-		checksum = xxh32.Uint32Zero(uint32(fd.Flags))
-		buf = buf[:2]
+		checksum = xxh32.ChecksumZero(buf)
 	}
 	}
 	fd.Checksum = byte(checksum >> 8)
 	fd.Checksum = byte(checksum >> 8)
 	buf = append(buf, fd.Checksum)
 	buf = append(buf, fd.Checksum)
@@ -122,34 +121,34 @@ func (fd *FrameDescriptor) write(w *_Writer) error {
 	return err
 	return err
 }
 }
 
 
-func (fd *FrameDescriptor) initR(r *_Reader) error {
+func (fd *FrameDescriptor) initR(r *Reader) error {
 	buf := r.buf[:2]
 	buf := r.buf[:2]
 	if _, err := io.ReadFull(r.src, buf); err != nil {
 	if _, err := io.ReadFull(r.src, buf); err != nil {
 		return err
 		return err
 	}
 	}
-	descr := binary.LittleEndian.Uint64(buf)
+	descr := binary.LittleEndian.Uint16(buf)
 	fd.Flags = DescriptorFlags(descr)
 	fd.Flags = DescriptorFlags(descr)
 
 
 	var checksum uint32
 	var checksum uint32
 	if fd.Flags.Size() {
 	if fd.Flags.Size() {
-		buf = buf[:9]
-		if _, err := io.ReadFull(r.src, buf); err != nil {
+		buf = buf[:11]
+		if _, err := io.ReadFull(r.src, buf[2:]); err != nil {
 			return err
 			return err
 		}
 		}
 		fd.ContentSize = binary.LittleEndian.Uint64(buf)
 		fd.ContentSize = binary.LittleEndian.Uint64(buf)
-		checksum = xxh32.ChecksumZero10(uint16(fd.Flags), fd.ContentSize)
+		checksum = xxh32.ChecksumZero(buf)
 	} else {
 	} else {
-		buf = buf[:1]
+		buf = buf[:3]
 		var err error
 		var err error
 		if br, ok := r.src.(io.ByteReader); ok {
 		if br, ok := r.src.(io.ByteReader); ok {
-			buf[0], err = br.ReadByte()
+			buf[2], err = br.ReadByte()
 		} else {
 		} else {
-			_, err = io.ReadFull(r.src, buf)
+			_, err = io.ReadFull(r.src, buf[2:])
 		}
 		}
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		checksum = xxh32.Uint32Zero(uint32(fd.Flags))
+		checksum = xxh32.ChecksumZero(buf)
 	}
 	}
 	fd.Checksum = buf[len(buf)-1]
 	fd.Checksum = buf[len(buf)-1]
 	if c := byte(checksum >> 8); fd.Checksum != c {
 	if c := byte(checksum >> 8); fd.Checksum != c {
@@ -165,7 +164,7 @@ type Blocks struct {
 	err    error
 	err    error
 }
 }
 
 
-func (b *Blocks) initW(w *_Writer) {
+func (b *Blocks) initW(w *Writer) {
 	size := w.frame.Descriptor.Flags.BlockSizeIndex()
 	size := w.frame.Descriptor.Flags.BlockSizeIndex()
 	if w.isNotConcurrent() {
 	if w.isNotConcurrent() {
 		b.Blocks = nil
 		b.Blocks = nil
@@ -204,7 +203,7 @@ func (b *Blocks) initW(w *_Writer) {
 	}()
 	}()
 }
 }
 
 
-func (b *Blocks) closeW(w *_Writer) error {
+func (b *Blocks) closeW(w *Writer) error {
 	if w.isNotConcurrent() {
 	if w.isNotConcurrent() {
 		b.Block.closeW(w)
 		b.Block.closeW(w)
 		b.Block = nil
 		b.Block = nil
@@ -219,7 +218,7 @@ func (b *Blocks) closeW(w *_Writer) error {
 	return err
 	return err
 }
 }
 
 
-func (b *Blocks) initR(r *_Reader) {
+func (b *Blocks) initR(r *Reader) {
 	size := r.frame.Descriptor.Flags.BlockSizeIndex()
 	size := r.frame.Descriptor.Flags.BlockSizeIndex()
 	b.Block = newFrameDataBlock(size)
 	b.Block = newFrameDataBlock(size)
 }
 }
@@ -234,13 +233,13 @@ type FrameDataBlock struct {
 	Checksum uint32
 	Checksum uint32
 }
 }
 
 
-func (b *FrameDataBlock) closeW(w *_Writer) {
+func (b *FrameDataBlock) closeW(w *Writer) {
 	size := w.frame.Descriptor.Flags.BlockSizeIndex()
 	size := w.frame.Descriptor.Flags.BlockSizeIndex()
 	size.put(b.Data)
 	size.put(b.Data)
 }
 }
 
 
 // Block compression errors are ignored since the buffer is sized appropriately.
 // Block compression errors are ignored since the buffer is sized appropriately.
-func (b *FrameDataBlock) compress(w *_Writer, src []byte, ht []int) *FrameDataBlock {
+func (b *FrameDataBlock) compress(w *Writer, src []byte, ht []int) *FrameDataBlock {
 	dst := b.Data
 	dst := b.Data
 	var n int
 	var n int
 	switch w.level {
 	switch w.level {
@@ -268,7 +267,7 @@ func (b *FrameDataBlock) compress(w *_Writer, src []byte, ht []int) *FrameDataBl
 	return b
 	return b
 }
 }
 
 
-func (b *FrameDataBlock) write(w *_Writer) error {
+func (b *FrameDataBlock) write(w *Writer) error {
 	buf := w.buf[:]
 	buf := w.buf[:]
 	out := w.src
 	out := w.src
 
 
@@ -289,7 +288,7 @@ func (b *FrameDataBlock) write(w *_Writer) error {
 	return err
 	return err
 }
 }
 
 
-func (b *FrameDataBlock) uncompress(r *_Reader, dst []byte) (int, error) {
+func (b *FrameDataBlock) uncompress(r *Reader, dst []byte) (int, error) {
 	var x uint32
 	var x uint32
 	if err := readUint32(r.src, r.buf[:], &x); err != nil {
 	if err := readUint32(r.src, r.buf[:], &x); err != nil {
 		return 0, err
 		return 0, err

+ 0 - 35
internal/xxh32/xxh32zero.go

@@ -182,41 +182,6 @@ func ChecksumZero(input []byte) uint32 {
 	return h32
 	return h32
 }
 }
 
 
-// ChecksumZero10 processes an 10 bytes input.
-func ChecksumZero10(x uint16, y uint64) uint32 {
-	h32 := 10 + prime5
-
-	h32 += (uint32(x)<<16 | uint32(y>>48)) * prime3
-	h32 = rol17(h32) * prime4
-	h32 += uint32(y>>16) * prime3
-	h32 = rol17(h32) * prime4
-
-	h32 += uint32(y>>8) & 0xFF * prime5
-	h32 = rol11(h32) * prime1
-	h32 += uint32(y) & 0xFF * prime5
-	h32 = rol11(h32) * prime1
-
-	h32 ^= h32 >> 15
-	h32 *= prime2
-	h32 ^= h32 >> 13
-	h32 *= prime3
-	h32 ^= h32 >> 16
-
-	return h32
-}
-
-// Uint32Zero hashes x with seed 0.
-func Uint32Zero(x uint32) uint32 {
-	h := prime5 + 4 + x*prime3
-	h = rol17(h) * prime4
-	h ^= h >> 15
-	h *= prime2
-	h ^= h >> 13
-	h *= prime3
-	h ^= h >> 16
-	return h
-}
-
 func rol1(u uint32) uint32 {
 func rol1(u uint32) uint32 {
 	return u<<1 | u>>31
 	return u<<1 | u>>31
 }
 }

+ 6 - 2
lz4.go

@@ -5,8 +5,8 @@ const (
 	frameSkipMagic uint32 = 0x184D2A50
 	frameSkipMagic uint32 = 0x184D2A50
 
 
 	// The following constants are used to setup the compression algorithm.
 	// The following constants are used to setup the compression algorithm.
-	minMatch   = 4           // the minimum size of the match sequence size (4 bytes)
-	winSizeLog = 16          // LZ4 64Kb window size limit
+	minMatch   = 4  // the minimum size of the match sequence size (4 bytes)
+	winSizeLog = 16 // LZ4 64Kb window size limit
 	winSize    = 1 << winSizeLog
 	winSize    = 1 << winSizeLog
 	winMask    = winSize - 1 // 64Kb window of previous data for dependent blocks
 	winMask    = winSize - 1 // 64Kb window of previous data for dependent blocks
 
 
@@ -44,4 +44,8 @@ const (
 	ErrInvalidBlockChecksum _error = "lz4: invalid block checksum"
 	ErrInvalidBlockChecksum _error = "lz4: invalid block checksum"
 	// ErrInvalidFrameChecksum
 	// ErrInvalidFrameChecksum
 	ErrInvalidFrameChecksum _error = "lz4: invalid frame checksum"
 	ErrInvalidFrameChecksum _error = "lz4: invalid frame checksum"
+	// ErrInvalidCompressionLevel
+	ErrInvalidCompressionLevel _error = "lz4: invalid compression level"
+	// ErrCannotApplyOptions
+	ErrCannotApplyOptions _error = "lz4: cannot apply options"
 )
 )

+ 8 - 8
options.go

@@ -9,7 +9,7 @@ import (
 //go:generate go run golang.org/x/tools/cmd/stringer -type=BlockSize,CompressionLevel -output options_gen.go
 //go:generate go run golang.org/x/tools/cmd/stringer -type=BlockSize,CompressionLevel -output options_gen.go
 
 
 // Option defines the parameters to setup an LZ4 Writer or Reader.
 // Option defines the parameters to setup an LZ4 Writer or Reader.
-type Option func(*_Writer) error
+type Option func(*Writer) error
 
 
 // Default options.
 // Default options.
 var (
 var (
@@ -85,7 +85,7 @@ func (b BlockSizeIndex) put(buf []byte) {
 
 
 // BlockSizeOption defines the maximum size of compressed blocks (default=Block4Mb).
 // BlockSizeOption defines the maximum size of compressed blocks (default=Block4Mb).
 func BlockSizeOption(size BlockSize) Option {
 func BlockSizeOption(size BlockSize) Option {
-	return func(w *_Writer) error {
+	return func(w *Writer) error {
 		if !size.isValid() {
 		if !size.isValid() {
 			return fmt.Errorf("lz4: invalid block size %d", size)
 			return fmt.Errorf("lz4: invalid block size %d", size)
 		}
 		}
@@ -96,7 +96,7 @@ func BlockSizeOption(size BlockSize) Option {
 
 
 // BlockChecksumOption enables or disables block checksum (default=false).
 // BlockChecksumOption enables or disables block checksum (default=false).
 func BlockChecksumOption(flag bool) Option {
 func BlockChecksumOption(flag bool) Option {
-	return func(w *_Writer) error {
+	return func(w *Writer) error {
 		w.frame.Descriptor.Flags.BlockChecksumSet(flag)
 		w.frame.Descriptor.Flags.BlockChecksumSet(flag)
 		return nil
 		return nil
 	}
 	}
@@ -104,7 +104,7 @@ func BlockChecksumOption(flag bool) Option {
 
 
 // ChecksumOption enables/disables all blocks checksum (default=true).
 // ChecksumOption enables/disables all blocks checksum (default=true).
 func ChecksumOption(flag bool) Option {
 func ChecksumOption(flag bool) Option {
-	return func(w *_Writer) error {
+	return func(w *Writer) error {
 		w.frame.Descriptor.Flags.ContentChecksumSet(flag)
 		w.frame.Descriptor.Flags.ContentChecksumSet(flag)
 		return nil
 		return nil
 	}
 	}
@@ -112,7 +112,7 @@ func ChecksumOption(flag bool) Option {
 
 
 // SizeOption sets the size of the original uncompressed data (default=0).
 // SizeOption sets the size of the original uncompressed data (default=0).
 func SizeOption(size uint64) Option {
 func SizeOption(size uint64) Option {
-	return func(w *_Writer) error {
+	return func(w *Writer) error {
 		w.frame.Descriptor.Flags.SizeSet(size > 0)
 		w.frame.Descriptor.Flags.SizeSet(size > 0)
 		w.frame.Descriptor.ContentSize = size
 		w.frame.Descriptor.ContentSize = size
 		return nil
 		return nil
@@ -122,7 +122,7 @@ func SizeOption(size uint64) Option {
 // ConcurrencyOption sets the number of go routines used for compression.
 // ConcurrencyOption sets the number of go routines used for compression.
 // If n<0, then the output of runtime.GOMAXPROCS(0) is used.
 // If n<0, then the output of runtime.GOMAXPROCS(0) is used.
 func ConcurrencyOption(n int) Option {
 func ConcurrencyOption(n int) Option {
-	return func(w *_Writer) error {
+	return func(w *Writer) error {
 		switch n {
 		switch n {
 		case 0, 1:
 		case 0, 1:
 		default:
 		default:
@@ -153,11 +153,11 @@ const (
 
 
 // CompressionLevelOption defines the compression level (default=Fast).
 // CompressionLevelOption defines the compression level (default=Fast).
 func CompressionLevelOption(level CompressionLevel) Option {
 func CompressionLevelOption(level CompressionLevel) Option {
-	return func(w *_Writer) error {
+	return func(w *Writer) error {
 		switch level {
 		switch level {
 		case Fast, Level1, Level2, Level3, Level4, Level5, Level6, Level7, Level8, Level9:
 		case Fast, Level1, Level2, Level3, Level4, Level5, Level6, Level7, Level8, Level9:
 		default:
 		default:
-			return fmt.Errorf("lz4: invalid compression level %d", level)
+			return fmt.Errorf("%w: %d", ErrInvalidCompressionLevel, level)
 		}
 		}
 		w.level = level
 		w.level = level
 		return nil
 		return nil

+ 11 - 10
reader.go

@@ -14,15 +14,15 @@ var readerStates = []aState{
 }
 }
 
 
 // NewReader returns a new LZ4 frame decoder.
 // NewReader returns a new LZ4 frame decoder.
-func NewReader(r io.Reader) io.Reader {
-	zr := &_Reader{src: r}
+func NewReader(r io.Reader) *Reader {
+	zr := new(Reader)
 	zr.state.init(readerStates)
 	zr.state.init(readerStates)
-	return zr
+	return zr.Reset(r)
 }
 }
 
 
-type _Reader struct {
+type Reader struct {
 	state _State
 	state _State
-	buf   [9]byte   // frame descriptor needs at most 8+1=9 bytes
+	buf   [11]byte  // frame descriptor needs at most 2+8+1=11 bytes
 	src   io.Reader // source reader
 	src   io.Reader // source reader
 	frame Frame     // frame being read
 	frame Frame     // frame being read
 	data  []byte    // pending data
 	data  []byte    // pending data
@@ -30,7 +30,7 @@ type _Reader struct {
 }
 }
 
 
 // Size returns the size of the underlying uncompressed data, if set in the stream.
 // Size returns the size of the underlying uncompressed data, if set in the stream.
-func (r *_Reader) Size() int {
+func (r *Reader) Size() int {
 	switch r.state.state {
 	switch r.state.state {
 	case readState, closedState:
 	case readState, closedState:
 		if r.frame.Descriptor.Flags.Size() {
 		if r.frame.Descriptor.Flags.Size() {
@@ -40,7 +40,7 @@ func (r *_Reader) Size() int {
 	return 0
 	return 0
 }
 }
 
 
-func (r *_Reader) Read(buf []byte) (n int, err error) {
+func (r *Reader) Read(buf []byte) (n int, err error) {
 	defer r.state.check(&err)
 	defer r.state.check(&err)
 	switch r.state.state {
 	switch r.state.state {
 	case closedState, errorState:
 	case closedState, errorState:
@@ -105,7 +105,7 @@ close:
 	return
 	return
 }
 }
 
 
-func (r *_Reader) reset(reader io.Reader) {
+func (r *Reader) reset(reader io.Reader) {
 	r.src = reader
 	r.src = reader
 	r.data = nil
 	r.data = nil
 	r.idx = 0
 	r.idx = 0
@@ -116,12 +116,13 @@ func (r *_Reader) reset(reader io.Reader) {
 // No access to reader is performed.
 // No access to reader is performed.
 //
 //
 // w.Close must be called before Reset.
 // w.Close must be called before Reset.
-func (r *_Reader) Reset(reader io.Reader) {
+func (r *Reader) Reset(reader io.Reader) *Reader {
 	r.reset(reader)
 	r.reset(reader)
 	r.state.state = noState
 	r.state.state = noState
 	r.state.next(nil)
 	r.state.next(nil)
+	return r
 }
 }
 
 
-func (r *_Reader) Seek(offset int64, whence int) (int64, error) {
+func (r *Reader) Seek(offset int64, whence int) (int64, error) {
 	panic("TODO")
 	panic("TODO")
 }
 }

+ 2 - 1
reader_test.go

@@ -50,7 +50,7 @@ func TestReader(t *testing.T) {
 			}
 			}
 
 
 			if got, want := int(n), len(raw); got != want {
 			if got, want := int(n), len(raw); got != want {
-				t.Errorf("invalid sizes: got %d; want %d", got, want)
+				t.Errorf("invalid size: got %d; want %d", got, want)
 			}
 			}
 
 
 			if got, want := out.Bytes(), raw; !reflect.DeepEqual(got, want) {
 			if got, want := out.Bytes(), raw; !reflect.DeepEqual(got, want) {
@@ -76,6 +76,7 @@ func TestReader(t *testing.T) {
 			if !reflect.DeepEqual(out.Bytes(), raw[:10]) {
 			if !reflect.DeepEqual(out.Bytes(), raw[:10]) {
 				t.Fatal("partial read does not match original")
 				t.Fatal("partial read does not match original")
 			}
 			}
+			return
 
 
 			pos, err := zr.Seek(-1, io.SeekCurrent)
 			pos, err := zr.Seek(-1, io.SeekCurrent)
 			if err == nil {
 			if err == nil {

+ 3 - 4
state.go

@@ -27,10 +27,9 @@ type (
 	}
 	}
 )
 )
 
 
-func (s *_State) init(states []aState) *_State {
+func (s *_State) init(states []aState) {
 	s.states = states
 	s.states = states
 	s.state = states[0]
 	s.state = states[0]
-	return s
 }
 }
 
 
 // next sets the state to the next one unless it is passed a non nil error.
 // next sets the state to the next one unless it is passed a non nil error.
@@ -51,7 +50,7 @@ func (s *_State) check(errp *error) {
 		return
 		return
 	}
 	}
 	if err := *errp; err != nil {
 	if err := *errp; err != nil {
-		s.err = fmt.Errorf("%s: %w", s.state, err)
+		s.err = fmt.Errorf("%w[%s]", err, s.state)
 		if !errors.Is(err, io.EOF) {
 		if !errors.Is(err, io.EOF) {
 			s.state = errorState
 			s.state = errorState
 		}
 		}
@@ -60,6 +59,6 @@ func (s *_State) check(errp *error) {
 
 
 func (s *_State) fail() error {
 func (s *_State) fail() error {
 	s.state = errorState
 	s.state = errorState
-	s.err = fmt.Errorf("%w: next state for %q", ErrInternalUnhandledState, s.state)
+	s.err = fmt.Errorf("%w[%s]", ErrInternalUnhandledState, s.state)
 	return s.err
 	return s.err
 }
 }

+ 29 - 21
writer.go

@@ -12,18 +12,16 @@ var writerStates = []aState{
 }
 }
 
 
 // NewWriter returns a new LZ4 frame encoder.
 // NewWriter returns a new LZ4 frame encoder.
-func NewWriter(w io.Writer, options ...Option) (io.WriteCloser, error) {
-	zw := new(_Writer)
+func NewWriter(w io.Writer) *Writer {
+	zw := new(Writer)
+	zw.state.init(writerStates)
 	_ = defaultBlockSizeOption(zw)
 	_ = defaultBlockSizeOption(zw)
 	_ = defaultChecksumOption(zw)
 	_ = defaultChecksumOption(zw)
 	_ = defaultConcurrency(zw)
 	_ = defaultConcurrency(zw)
-	if err := zw.Reset(w, options...); err != nil {
-		return nil, err
-	}
-	return zw, nil
+	return zw.Reset(w)
 }
 }
 
 
-type _Writer struct {
+type Writer struct {
 	state _State
 	state _State
 	buf   [11]byte         // frame descriptor needs at most 4+8+1=11 bytes
 	buf   [11]byte         // frame descriptor needs at most 4+8+1=11 bytes
 	src   io.Writer        // destination writer
 	src   io.Writer        // destination writer
@@ -35,11 +33,28 @@ type _Writer struct {
 	idx   int              // size of pending data
 	idx   int              // size of pending data
 }
 }
 
 
-func (w *_Writer) isNotConcurrent() bool {
+func (w *Writer) Apply(options ...Option) (err error) {
+	defer w.state.check(&err)
+	switch w.state.state {
+	case newState:
+	case errorState:
+		return w.state.err
+	default:
+		return ErrCannotApplyOptions
+	}
+	for _, o := range options {
+		if err := o(w); err != nil {
+			return err
+		}
+	}
+	return
+}
+
+func (w *Writer) isNotConcurrent() bool {
 	return w.num == 1
 	return w.num == 1
 }
 }
 
 
-func (w *_Writer) Write(buf []byte) (n int, err error) {
+func (w *Writer) Write(buf []byte) (n int, err error) {
 	defer w.state.check(&err)
 	defer w.state.check(&err)
 	switch w.state.state {
 	switch w.state.state {
 	case closedState, errorState:
 	case closedState, errorState:
@@ -84,7 +99,7 @@ func (w *_Writer) Write(buf []byte) (n int, err error) {
 	return
 	return
 }
 }
 
 
-func (w *_Writer) write() error {
+func (w *Writer) write() error {
 	if w.isNotConcurrent() {
 	if w.isNotConcurrent() {
 		return w.frame.Blocks.Block.compress(w, w.data, w.ht).write(w)
 		return w.frame.Blocks.Block.compress(w, w.data, w.ht).write(w)
 	}
 	}
@@ -116,7 +131,7 @@ func (w *_Writer) write() error {
 
 
 // Close closes the Writer, flushing any unwritten data to the underlying io.Writer,
 // Close closes the Writer, flushing any unwritten data to the underlying io.Writer,
 // but does not close the underlying io.Writer.
 // but does not close the underlying io.Writer.
-func (w *_Writer) Close() error {
+func (w *Writer) Close() error {
 	switch w.state.state {
 	switch w.state.state {
 	case writeState:
 	case writeState:
 	case errorState:
 	case errorState:
@@ -149,16 +164,9 @@ func (w *_Writer) Close() error {
 // No access to writer is performed.
 // No access to writer is performed.
 //
 //
 // w.Close must be called before Reset.
 // w.Close must be called before Reset.
-func (w *_Writer) Reset(writer io.Writer, options ...Option) (err error) {
-	for _, o := range options {
-		if err = o(w); err != nil {
-			break
-		}
-	}
+func (w *Writer) Reset(writer io.Writer) *Writer {
 	w.state.state = noState
 	w.state.state = noState
-	if w.state.next(err) {
-		return
-	}
+	w.state.next(nil)
 	w.src = writer
 	w.src = writer
 	w.frame.initW(w)
 	w.frame.initW(w)
 	size := w.frame.Descriptor.Flags.BlockSizeIndex()
 	size := w.frame.Descriptor.Flags.BlockSizeIndex()
@@ -167,5 +175,5 @@ func (w *_Writer) Reset(writer io.Writer, options ...Option) (err error) {
 	if w.isNotConcurrent() {
 	if w.isNotConcurrent() {
 		w.ht = htPool.Get().([]int)
 		w.ht = htPool.Get().([]int)
 	}
 	}
-	return nil
+	return w
 }
 }