Browse Source

Fixed Reader.Reset and Writer.Reset

Pierre.Curto 5 years ago
parent
commit
6441bd6429
6 changed files with 80 additions and 19 deletions
  1. 18 1
      internal/lz4stream/frame.go
  2. 6 10
      reader.go
  3. 20 0
      reader_test.go
  4. 5 0
      state.go
  5. 3 8
      writer.go
  6. 28 0
      writer_test.go

+ 18 - 1
internal/lz4stream/frame.go

@@ -32,6 +32,16 @@ type Frame struct {
 	checksum   xxh32.XXHZero
 }
 
+// Reset allows reusing the Frame.
+// The Descriptor configuration is not modified.
+func (f *Frame) Reset(num int) {
+	f.Magic = 0
+	f.Descriptor.Checksum = 0
+	f.Descriptor.ContentSize = 0
+	_ = f.Blocks.closeW(f, num)
+	f.Checksum = 0
+}
+
 func (f *Frame) InitW(dst io.Writer, num int) {
 	f.Magic = frameMagic
 	f.Descriptor.initW()
@@ -86,7 +96,6 @@ newFrame:
 }
 
 func (f *Frame) CloseR(src io.Reader) error {
-	f.Magic = 0
 	if !f.Descriptor.Flags.ContentChecksum() {
 		return nil
 	}
@@ -212,9 +221,17 @@ func (b *Blocks) initW(f *Frame, dst io.Writer, num int) {
 
 func (b *Blocks) closeW(f *Frame, num int) error {
 	if num == 1 {
+		if b.Block == nil {
+			// Not initialized yet.
+			return nil
+		}
 		b.Block.CloseW(f)
 		return nil
 	}
+	if b.Blocks == nil {
+		// Not initialized yet.
+		return nil
+	}
 	c := make(chan *FrameDataBlock)
 	b.Blocks <- c
 	c <- nil

+ 6 - 10
reader.go

@@ -127,7 +127,7 @@ close:
 		err = er
 	}
 	r.frame.Descriptor.Flags.BlockSizeIndex().Put(r.data)
-	r.reset(nil)
+	r.Reset(nil)
 	return
 fillbuf:
 	bn = copy(buf, r.data[r.idx:])
@@ -140,21 +140,17 @@ fillbuf:
 	return
 }
 
-func (r *Reader) reset(reader io.Reader) {
-	r.src = reader
-	r.data = nil
-	r.idx = 0
-}
-
 // Reset clears the state of the Reader r such that it is equivalent to its
 // initial state from NewReader, but instead writing to writer.
 // No access to reader is performed.
 //
 // w.Close must be called before Reset.
 func (r *Reader) Reset(reader io.Reader) {
-	r.reset(reader)
-	r.state.state = noState
-	r.state.next(nil)
+	r.frame.Reset(1)
+	r.src = reader
+	r.data = nil
+	r.idx = 0
+	r.state.reset()
 }
 
 // WriteTo efficiently uncompresses the data from the Reader underlying source to w.

+ 20 - 0
reader_test.go

@@ -79,3 +79,23 @@ func TestReader(t *testing.T) {
 		})
 	}
 }
+
+func TestReader_Reset(t *testing.T) {
+	data := pg1661LZ4
+	buf := new(bytes.Buffer)
+	src := bytes.NewReader(data)
+	zr := lz4.NewReader(src)
+
+	// Partial read.
+	_, _ = io.CopyN(buf, zr, int64(len(data))/2)
+
+	buf.Reset()
+	src.Reset(data)
+	zr.Reset(src)
+	if _, err := io.Copy(buf, zr); err != nil {
+		t.Fatal(err)
+	}
+	if !reflect.DeepEqual(buf.Bytes(), pg1661) {
+		t.Fatal("result does not match original")
+	}
+}

+ 5 - 0
state.go

@@ -33,6 +33,11 @@ func (s *_State) init(states []aState) {
 	s.state = states[0]
 }
 
+func (s *_State) reset() {
+	s.state = s.states[0]
+	s.err = nil
+}
+
 // next sets the state to the next one unless it is passed a non nil error.
 // It returns whether or not it is in error.
 func (s *_State) next(err error) bool {

+ 3 - 8
writer.go

@@ -177,15 +177,10 @@ func (w *Writer) Close() (err error) {
 // Reset keeps the previous options unless overwritten by the supplied ones.
 // No access to writer is performed.
 //
-// w.Close must be called before Reset or it will panic.
+// w.Close must be called before Reset or pending data may be dropped.
 func (w *Writer) Reset(writer io.Writer) {
-	switch w.state.state {
-	case newState, closedState, errorState:
-	default:
-		panic(lz4errors.ErrWriterNotClosed)
-	}
-	w.state.state = noState
-	w.state.next(nil)
+	w.frame.Reset(w.num)
+	w.state.reset()
 	w.src = writer
 }
 

+ 28 - 0
writer_test.go

@@ -86,6 +86,34 @@ func TestWriter(t *testing.T) {
 	}
 }
 
+func TestWriter_Reset(t *testing.T) {
+	data := pg1661
+	buf := new(bytes.Buffer)
+	src := bytes.NewReader(data)
+	zw := lz4.NewWriter(buf)
+
+	// Partial write.
+	_, _ = io.CopyN(zw, src, int64(len(data))/2)
+
+	buf.Reset()
+	src.Reset(data)
+	zw.Reset(buf)
+	if _, err := io.Copy(zw, src); err != nil {
+		t.Fatal(err)
+	}
+	if err := zw.Close(); err != nil {
+		t.Fatal(err)
+	}
+	// Cannot compare compressed outputs directly, so compare the uncompressed output.
+	out := new(bytes.Buffer)
+	if _, err := io.Copy(out, lz4.NewReader(buf)); err != nil {
+		t.Fatal(err)
+	}
+	if !reflect.DeepEqual(out.Bytes(), data) {
+		t.Fatal("result does not match original")
+	}
+}
+
 func TestIssue41(t *testing.T) {
 	r, w := io.Pipe()
 	zw := lz4.NewWriter(w)