Browse Source

Frame is now independent of Writer and Reader

Pierre.Curto 5 years ago
parent
commit
406cae5c2e
5 changed files with 98 additions and 90 deletions
  1. 67 63
      frame.go
  2. 9 9
      frame_test.go
  3. 2 2
      lz4.go
  4. 6 7
      reader.go
  5. 14 9
      writer.go

+ 67 - 63
frame.go

@@ -11,7 +11,12 @@ import (
 
 //go:generate go run gen.go
 
+func NewFrame() *Frame {
+	return &Frame{}
+}
+
 type Frame struct {
+	buf        [15]byte // frame descriptor needs at most 4(magic)+4+8+1=11 bytes
 	Magic      uint32
 	Descriptor FrameDescriptor
 	Blocks     Blocks
@@ -19,34 +24,34 @@ type Frame struct {
 	checksum   xxh32.XXHZero
 }
 
-func (f *Frame) initW(w *Writer) {
+func (f *Frame) initW(dst io.Writer, num int) {
 	f.Magic = frameMagic
-	f.Descriptor.initW(w)
-	f.Blocks.initW(w)
+	f.Descriptor.initW()
+	f.Blocks.initW(f, dst, num)
 	f.checksum.Reset()
 }
 
-func (f *Frame) closeW(w *Writer) error {
-	if err := f.Blocks.closeW(w); err != nil {
+func (f *Frame) closeW(dst io.Writer, num int) error {
+	if err := f.Blocks.closeW(f, num); err != nil {
 		return err
 	}
-	buf := w.buf[:0]
+	buf := f.buf[:0]
 	// End mark (data block size of uint32(0)).
 	buf = append(buf, 0, 0, 0, 0)
 	if f.Descriptor.Flags.ContentChecksum() {
 		buf = f.checksum.Sum(buf)
 	}
-	_, err := w.src.Write(buf)
+	_, err := dst.Write(buf)
 	return err
 }
 
-func (f *Frame) initR(r *Reader) error {
+func (f *Frame) initR(src io.Reader) error {
 	if f.Magic > 0 {
 		// Header already read.
 		return nil
 	}
 newFrame:
-	if err := readUint32(r.src, r.buf[:], &f.Magic); err != nil {
+	if err := readUint32(src, f.buf[:], &f.Magic); err != nil {
 		return err
 	}
 	switch m := f.Magic; {
@@ -54,30 +59,30 @@ newFrame:
 	// All 16 values of frameSkipMagic are valid.
 	case m>>8 == frameSkipMagic>>8:
 		var skip uint32
-		if err := binary.Read(r.src, binary.LittleEndian, &skip); err != nil {
+		if err := binary.Read(src, binary.LittleEndian, &skip); err != nil {
 			return err
 		}
-		if _, err := io.CopyN(ioutil.Discard, r.src, int64(skip)); err != nil {
+		if _, err := io.CopyN(ioutil.Discard, src, int64(skip)); err != nil {
 			return err
 		}
 		goto newFrame
 	default:
 		return ErrInvalidFrame
 	}
-	if err := f.Descriptor.initR(r); err != nil {
+	if err := f.Descriptor.initR(f, src); err != nil {
 		return err
 	}
-	f.Blocks.initR(r)
+	f.Blocks.initR(f)
 	f.checksum.Reset()
 	return nil
 }
 
-func (f *Frame) closeR(r *Reader) error {
+func (f *Frame) closeR(src io.Reader) error {
 	f.Magic = 0
 	if !f.Descriptor.Flags.ContentChecksum() {
 		return nil
 	}
-	if err := readUint32(r.src, r.buf[:], &f.Checksum); err != nil {
+	if err := readUint32(src, f.buf[:], &f.Checksum); err != nil {
 		return err
 	}
 	if c := f.checksum.Sum32(); c != f.Checksum {
@@ -92,20 +97,20 @@ type FrameDescriptor struct {
 	Checksum    uint8
 }
 
-func (fd *FrameDescriptor) initW(_ *Writer) {
+func (fd *FrameDescriptor) initW() {
 	fd.Flags.VersionSet(1)
 	fd.Flags.BlockIndependenceSet(true)
 }
 
-func (fd *FrameDescriptor) write(w *Writer) error {
+func (fd *FrameDescriptor) write(f *Frame, dst io.Writer) error {
 	if fd.Checksum > 0 {
 		// Header already written.
 		return nil
 	}
 
-	buf := w.buf[:4+2]
+	buf := f.buf[:4+2]
 	// Write the magic number here even though it belongs to the Frame.
-	binary.LittleEndian.PutUint32(buf, w.frame.Magic)
+	binary.LittleEndian.PutUint32(buf, f.Magic)
 	binary.LittleEndian.PutUint16(buf[4:], uint16(fd.Flags))
 
 	if fd.Flags.Size() {
@@ -115,14 +120,14 @@ func (fd *FrameDescriptor) write(w *Writer) error {
 	fd.Checksum = descriptorChecksum(buf[4:])
 	buf = append(buf, fd.Checksum)
 
-	_, err := w.src.Write(buf)
+	_, err := dst.Write(buf)
 	return err
 }
 
-func (fd *FrameDescriptor) initR(r *Reader) error {
+func (fd *FrameDescriptor) initR(f *Frame, src io.Reader) error {
 	// Read the flags and the checksum, hoping that there is not content size.
-	buf := r.buf[:3]
-	if _, err := io.ReadFull(r.src, buf); err != nil {
+	buf := f.buf[:3]
+	if _, err := io.ReadFull(src, buf); err != nil {
 		return err
 	}
 	descr := binary.LittleEndian.Uint16(buf)
@@ -130,7 +135,7 @@ func (fd *FrameDescriptor) initR(r *Reader) error {
 	if fd.Flags.Size() {
 		// Append the 8 missing bytes.
 		buf = buf[:3+8]
-		if _, err := io.ReadFull(r.src, buf[3:]); err != nil {
+		if _, err := io.ReadFull(src, buf[3:]); err != nil {
 			return err
 		}
 		fd.ContentSize = binary.LittleEndian.Uint64(buf[2:])
@@ -157,15 +162,15 @@ type Blocks struct {
 	err    error
 }
 
-func (b *Blocks) initW(w *Writer) {
-	size := w.frame.Descriptor.Flags.BlockSizeIndex()
-	if w.isNotConcurrent() {
+func (b *Blocks) initW(f *Frame, dst io.Writer, num int) {
+	size := f.Descriptor.Flags.BlockSizeIndex()
+	if num == 1 {
 		b.Blocks = nil
 		b.Block = newFrameDataBlock(size)
 		return
 	}
-	if cap(b.Blocks) != w.num {
-		b.Blocks = make(chan chan *FrameDataBlock, w.num)
+	if cap(b.Blocks) != num {
+		b.Blocks = make(chan chan *FrameDataBlock, num)
 	}
 	// goroutine managing concurrent block compression goroutines.
 	go func() {
@@ -185,7 +190,7 @@ func (b *Blocks) initW(w *Writer) {
 			// Do not attempt to write the block upon any previous failure.
 			if b.err == nil {
 				// Write the block.
-				if err := block.write(w); err != nil && b.err == nil {
+				if err := block.write(f, dst); err != nil && b.err == nil {
 					// Keep the first error.
 					b.err = err
 					// All pending compression goroutines need to shut down, so we need to keep going.
@@ -196,9 +201,9 @@ func (b *Blocks) initW(w *Writer) {
 	}()
 }
 
-func (b *Blocks) closeW(w *Writer) error {
-	if w.isNotConcurrent() {
-		b.Block.closeW(w)
+func (b *Blocks) closeW(f *Frame, num int) error {
+	if num == 1 {
+		b.Block.closeW(f)
 		b.Block = nil
 		return nil
 	}
@@ -211,8 +216,8 @@ func (b *Blocks) closeW(w *Writer) error {
 	return err
 }
 
-func (b *Blocks) initR(r *Reader) {
-	size := r.frame.Descriptor.Flags.BlockSizeIndex()
+func (b *Blocks) initR(f *Frame) {
+	size := f.Descriptor.Flags.BlockSizeIndex()
 	b.Block = newFrameDataBlock(size)
 }
 
@@ -226,50 +231,48 @@ type FrameDataBlock struct {
 	Checksum uint32
 }
 
-func (b *FrameDataBlock) closeW(w *Writer) {
-	size := w.frame.Descriptor.Flags.BlockSizeIndex()
+func (b *FrameDataBlock) closeW(f *Frame) {
+	size := f.Descriptor.Flags.BlockSizeIndex()
 	size.put(b.Data)
 }
 
 // Block compression errors are ignored since the buffer is sized appropriately.
-func (b *FrameDataBlock) compress(w *Writer, src []byte, ht []int) *FrameDataBlock {
-	dst := b.Data[:len(src)] // trigger the incompressible flag in CompressBlock
+func (b *FrameDataBlock) compress(f *Frame, src []byte, ht []int, level CompressionLevel) *FrameDataBlock {
+	data := b.Data[:len(src)] // trigger the incompressible flag in CompressBlock
 	var n int
-	switch w.level {
+	switch level {
 	case Fast:
-		n, _ = CompressBlock(src, dst, ht)
+		n, _ = CompressBlock(src, data, ht)
 	default:
-		n, _ = CompressBlockHC(src, dst, w.level, ht)
+		n, _ = CompressBlockHC(src, data, level, ht)
 	}
 	if n == 0 {
 		b.Size.uncompressedSet(true)
-		dst = src
+		data = src
 	} else {
 		b.Size.uncompressedSet(false)
-		dst = dst[:n]
+		data = data[:n]
 	}
-	b.Data = dst
-	b.Size.sizeSet(len(dst))
+	b.Data = data
+	b.Size.sizeSet(len(data))
 
-	if w.frame.Descriptor.Flags.BlockChecksum() {
+	if f.Descriptor.Flags.BlockChecksum() {
 		b.Checksum = xxh32.ChecksumZero(src)
 	}
-	if w.frame.Descriptor.Flags.ContentChecksum() {
-		_, _ = w.frame.checksum.Write(src)
+	if f.Descriptor.Flags.ContentChecksum() {
+		_, _ = f.checksum.Write(src)
 	}
 	return b
 }
 
-func (b *FrameDataBlock) write(w *Writer) error {
-	buf := w.buf[:]
-	out := w.src
-
+func (b *FrameDataBlock) write(f *Frame, dst io.Writer) error {
+	buf := f.buf[:]
 	binary.LittleEndian.PutUint32(buf, uint32(b.Size))
-	if _, err := out.Write(buf[:4]); err != nil {
+	if _, err := dst.Write(buf[:4]); err != nil {
 		return err
 	}
 
-	if _, err := out.Write(b.Data); err != nil {
+	if _, err := dst.Write(b.Data); err != nil {
 		return err
 	}
 
@@ -277,13 +280,14 @@ func (b *FrameDataBlock) write(w *Writer) error {
 		return nil
 	}
 	binary.LittleEndian.PutUint32(buf, b.Checksum)
-	_, err := out.Write(buf[:4])
+	_, err := dst.Write(buf[:4])
 	return err
 }
 
-func (b *FrameDataBlock) uncompress(r *Reader, dst []byte) (int, error) {
+func (b *FrameDataBlock) uncompress(f *Frame, src io.Reader, dst []byte) (int, error) {
+	buf := f.buf[:]
 	var x uint32
-	if err := readUint32(r.src, r.buf[:], &x); err != nil {
+	if err := readUint32(src, buf, &x); err != nil {
 		return 0, err
 	}
 	b.Size = DataBlockSize(x)
@@ -301,7 +305,7 @@ func (b *FrameDataBlock) uncompress(r *Reader, dst []byte) (int, error) {
 		data = dst
 	}
 	data = data[:size]
-	if _, err := io.ReadFull(r.src, data); err != nil {
+	if _, err := io.ReadFull(src, data); err != nil {
 		return 0, err
 	}
 	if isCompressed {
@@ -312,16 +316,16 @@ func (b *FrameDataBlock) uncompress(r *Reader, dst []byte) (int, error) {
 		data = dst[:n]
 	}
 
-	if r.frame.Descriptor.Flags.BlockChecksum() {
-		if err := readUint32(r.src, r.buf[:], &b.Checksum); err != nil {
+	if f.Descriptor.Flags.BlockChecksum() {
+		if err := readUint32(src, buf, &b.Checksum); err != nil {
 			return 0, err
 		}
 		if c := xxh32.ChecksumZero(data); c != b.Checksum {
 			return 0, fmt.Errorf("%w: got %x; expected %x", ErrInvalidBlockChecksum, c, b.Checksum)
 		}
 	}
-	if r.frame.Descriptor.Flags.ContentChecksum() {
-		_, _ = r.frame.checksum.Write(data)
+	if f.Descriptor.Flags.ContentChecksum() {
+		_, _ = f.checksum.Write(data)
 	}
 	return len(data), nil
 }

+ 9 - 9
frame_test.go

@@ -22,9 +22,10 @@ func TestFrameDescriptor(t *testing.T) {
 		s := tc.flags
 		label := fmt.Sprintf("%02x %02x %02x", s[0], s[1], s[2])
 		t.Run(label, func(t *testing.T) {
-			r := &Reader{src: strings.NewReader(tc.flags)}
+			r := strings.NewReader(tc.flags)
+			f := NewFrame()
 			var fd FrameDescriptor
-			if err := fd.initR(r); err != nil {
+			if err := fd.initR(f, r); err != nil {
 				t.Fatal(err)
 			}
 
@@ -46,9 +47,9 @@ func TestFrameDescriptor(t *testing.T) {
 
 			buf := new(bytes.Buffer)
 			w := &Writer{src: buf}
-			fd.initW(nil)
+			fd.initW()
 			fd.Checksum = 0
-			if err := fd.write(w); err != nil {
+			if err := fd.write(f, w); err != nil {
 				t.Fatal(err)
 			}
 			if got, want := buf.String(), tc.flags; got != want {
@@ -83,17 +84,16 @@ func TestFrameDataBlock(t *testing.T) {
 			data := tc.data
 			size := tc.size
 			zbuf := new(bytes.Buffer)
-			w := &Writer{src: zbuf, level: Fast}
+			f := NewFrame()
 
 			block := newFrameDataBlock(size.index())
-			block.compress(w, []byte(data), nil)
-			if err := block.write(w); err != nil {
+			block.compress(f, []byte(data), nil, Fast)
+			if err := block.write(f, zbuf); err != nil {
 				t.Fatal(err)
 			}
 
 			buf := make([]byte, size)
-			r := &Reader{src: zbuf}
-			n, err := block.uncompress(r, buf)
+			n, err := block.uncompress(f, zbuf, buf)
 			if err != nil {
 				t.Fatal(err)
 			}

+ 2 - 2
lz4.go

@@ -30,8 +30,6 @@ const (
 	ErrInvalidSourceShortBuffer _error = "lz4: invalid source or destination buffer too short"
 	// ErrInvalidFrame is returned when reading an invalid LZ4 archive.
 	ErrInvalidFrame _error = "lz4: bad magic number"
-	// ErrUnsupportedSeek is returned when attempting to Seek any way but forward from the current position.
-	ErrUnsupportedSeek _error = "lz4: can only seek forward from io.SeekCurrent"
 	// ErrInternalUnhandledState is an internal error.
 	ErrInternalUnhandledState _error = "lz4: unhandled state"
 	// ErrInvalidHeaderChecksum is returned when reading a frame.
@@ -48,4 +46,6 @@ const (
 	ErrOptionInvalidBlockSize _error = "lz4: invalid block size"
 	// ErrOptionNotApplicable is returned when trying to apply an option to an object not supporting it.
 	ErrOptionNotApplicable _error = "lz4: option not applicable"
+	// ErrWriterNotClosed is returned when attempting to reset an unclosed writer.
+	ErrWriterNotClosed _error = "lz4: writer not closed"
 )

+ 6 - 7
reader.go

@@ -14,7 +14,7 @@ var readerStates = []aState{
 
 // NewReader returns a new LZ4 frame decoder.
 func NewReader(r io.Reader) *Reader {
-	zr := new(Reader)
+	zr := &Reader{frame: NewFrame()}
 	zr.state.init(readerStates)
 	_ = zr.Apply(defaultOnBlockDone)
 	return zr.Reset(r)
@@ -22,9 +22,8 @@ func NewReader(r io.Reader) *Reader {
 
 type Reader struct {
 	state   _State
-	buf     [11]byte  // frame descriptor needs at most 2+8+1=11 bytes
 	src     io.Reader // source reader
-	frame   Frame     // frame being read
+	frame   *Frame    // frame being read
 	data    []byte    // pending data
 	idx     int       // size of pending data
 	handler func(int)
@@ -68,7 +67,7 @@ func (r *Reader) Read(buf []byte) (n int, err error) {
 		return 0, r.state.err
 	case newState:
 		// First initialization.
-		if err = r.frame.initR(r); r.state.next(err) {
+		if err = r.frame.initR(r.src); r.state.next(err) {
 			return
 		}
 		r.data = r.frame.Descriptor.Flags.BlockSizeIndex().get()
@@ -88,7 +87,7 @@ func (r *Reader) Read(buf []byte) (n int, err error) {
 	r.data = r.data[:cap(r.data)]
 	for len(buf) >= len(r.data) {
 		// Input buffer large enough and no pending data: uncompress directly into it.
-		switch bn, err = r.frame.Blocks.Block.uncompress(r, buf); err {
+		switch bn, err = r.frame.Blocks.Block.uncompress(r.frame, r.src, buf); err {
 		case nil:
 			r.handler(bn)
 			n += bn
@@ -104,7 +103,7 @@ func (r *Reader) Read(buf []byte) (n int, err error) {
 		return
 	}
 	// Read the next block.
-	switch bn, err = r.frame.Blocks.Block.uncompress(r, r.data); err {
+	switch bn, err = r.frame.Blocks.Block.uncompress(r.frame, r.src, r.data); err {
 	case nil:
 		r.handler(bn)
 		r.data = r.data[:bn]
@@ -116,7 +115,7 @@ func (r *Reader) Read(buf []byte) (n int, err error) {
 close:
 	r.handler(bn)
 	n += bn
-	if er := r.frame.closeR(r); er != nil {
+	if er := r.frame.closeR(r.src); er != nil {
 		err = er
 	}
 	r.frame.Descriptor.Flags.BlockSizeIndex().put(r.data)

+ 14 - 9
writer.go

@@ -12,7 +12,7 @@ var writerStates = []aState{
 
 // NewWriter returns a new LZ4 frame encoder.
 func NewWriter(w io.Writer) *Writer {
-	zw := new(Writer)
+	zw := &Writer{frame: NewFrame()}
 	zw.state.init(writerStates)
 	_ = zw.Apply(DefaultBlockSizeOption, DefaultChecksumOption, DefaultConcurrency, defaultOnBlockDone)
 	return zw.Reset(w)
@@ -20,11 +20,10 @@ func NewWriter(w io.Writer) *Writer {
 
 type Writer struct {
 	state   _State
-	buf     [15]byte         // frame descriptor needs at most 4(magic)+4+8+1=11 bytes
 	src     io.Writer        // destination writer
 	level   CompressionLevel // how hard to try
 	num     int              // concurrency level
-	frame   Frame            // frame being built
+	frame   *Frame           // frame being built
 	ht      []int            // hash table (set if no concurrency)
 	data    []byte           // pending data
 	idx     int              // size of pending data
@@ -47,6 +46,7 @@ func (w *Writer) Apply(options ...Option) (err error) {
 			return
 		}
 	}
+	w.Reset(w.src)
 	return
 }
 
@@ -61,7 +61,7 @@ func (w *Writer) Write(buf []byte) (n int, err error) {
 	case closedState, errorState:
 		return 0, w.state.err
 	case newState:
-		if err = w.frame.Descriptor.write(w); w.state.next(err) {
+		if err = w.frame.Descriptor.write(w.frame, w.src); w.state.next(err) {
 			return
 		}
 	default:
@@ -103,7 +103,7 @@ func (w *Writer) write(data []byte, direct bool) error {
 	if w.isNotConcurrent() {
 		defer w.handler(len(data))
 		block := w.frame.Blocks.Block
-		return block.compress(w, data, w.ht).write(w)
+		return block.compress(w.frame, data, w.ht, w.level).write(w.frame, w.src)
 	}
 	size := w.frame.Descriptor.Flags.BlockSizeIndex()
 	c := make(chan *FrameDataBlock)
@@ -112,7 +112,7 @@ func (w *Writer) write(data []byte, direct bool) error {
 		defer w.handler(len(data))
 		b := newFrameDataBlock(size)
 		zdata := b.Data
-		c <- b.compress(w, data, nil)
+		c <- b.compress(w.frame, data, nil, w.level)
 		// Wait for the compressed or uncompressed data to no longer be in use
 		// and free the allocated buffers
 		if b.Size.uncompressed() {
@@ -154,7 +154,7 @@ func (w *Writer) Close() (err error) {
 		size.put(w.data)
 		w.data = nil
 	}
-	return w.frame.closeW(w)
+	return w.frame.closeW(w.src, w.num)
 }
 
 // Reset clears the state of the Writer w such that it is equivalent to its
@@ -162,12 +162,17 @@ func (w *Writer) Close() (err error) {
 // Reset keeps the previous options unless overwritten by the supplied ones.
 // No access to writer is performed.
 //
-// w.Close must be called before Reset.
+// w.Close must be called before Reset or it will panic.
 func (w *Writer) Reset(writer io.Writer) *Writer {
+	switch w.state.state {
+	case newState, closedState, errorState:
+	default:
+		panic(ErrWriterNotClosed)
+	}
 	w.state.state = noState
 	w.state.next(nil)
 	w.src = writer
-	w.frame.initW(w)
+	w.frame.initW(w.src, w.num)
 	size := w.frame.Descriptor.Flags.BlockSizeIndex()
 	w.data = size.get()
 	w.idx = 0