Browse Source

add SkipBytes()

Ian Wilkes 6 years ago
parent
commit
d22e83ee97
3 changed files with 82 additions and 0 deletions
  1. 27 0
      bench_test.go
  2. 22 0
      reader.go
  3. 33 0
      reader_test.go

+ 27 - 0
bench_test.go

@@ -101,6 +101,33 @@ func BenchmarkUncompressDigits(b *testing.B) { benchmarkUncompress(b, digitsLZ4)
 func BenchmarkUncompressTwain(b *testing.B)  { benchmarkUncompress(b, twainLZ4) }
 func BenchmarkUncompressRand(b *testing.B)   { benchmarkUncompress(b, randomLZ4) }
 
+func benchmarkSkipBytes(b *testing.B, compressed []byte) {
+	r := bytes.NewReader(compressed)
+	zr := lz4.NewReader(r)
+
+	// Determine the uncompressed size of testfile.
+	uncompressedSize, err := io.Copy(ioutil.Discard, zr)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	b.SetBytes(uncompressedSize)
+	b.ReportAllocs()
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		r.Reset(compressed)
+		zr.Reset(r)
+		zr.SkipBytes(uncompressedSize)
+		_, _ = io.Copy(ioutil.Discard, zr)
+	}
+}
+
+func BenchmarkSkipBytesPg1661(b *testing.B) { benchmarkSkipBytes(b, pg1661LZ4) }
+func BenchmarkSkipBytesDigits(b *testing.B) { benchmarkSkipBytes(b, digitsLZ4) }
+func BenchmarkSkipBytesTwain(b *testing.B)  { benchmarkSkipBytes(b, twainLZ4) }
+func BenchmarkSkipBytesRand(b *testing.B)   { benchmarkSkipBytes(b, randomLZ4) }
+
 func benchmarkCompress(b *testing.B, uncompressed []byte) {
 	w := bytes.NewBuffer(nil)
 	zw := lz4.NewWriter(w)

+ 22 - 0
reader.go

@@ -2,6 +2,7 @@ package lz4
 
 import (
 	"encoding/binary"
+	"errors"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -25,6 +26,7 @@ type Reader struct {
 	data     []byte        // Uncompressed data.
 	idx      int           // Index of unread bytes into data.
 	checksum xxh32.XXHZero // Frame hash.
+	skip     int64         // Bytes to skip before next read.
 }
 
 // NewReader returns a new LZ4 frame decoder.
@@ -275,6 +277,15 @@ func (z *Reader) Read(buf []byte) (int, error) {
 		z.idx = 0
 	}
 
+	if z.skip > int64(len(z.data[z.idx:])) {
+		z.skip -= int64(len(z.data[z.idx:]))
+		z.idx = len(z.data)
+		return 0, nil
+	}
+
+	z.idx += int(z.skip)
+	z.skip = 0
+
 	n := copy(buf, z.data[z.idx:])
 	z.idx += n
 	if debugFlag {
@@ -284,6 +295,17 @@ func (z *Reader) Read(buf []byte) (int, error) {
 	return n, nil
 }
 
+// Skip n bytes in the output stream. Equivalent to a forward seek.
+// Note this may cause future calls to Read() to read 0 bytes if all of the
+// data they would have returned is skipped.
+func (z *Reader) SkipBytes(n int64) error {
+	if n < 0 {
+		return errors.New("can only skip forward")
+	}
+	z.skip += n
+	return nil
+}
+
 // Reset discards the Reader's state and makes it equivalent to the
 // result of its original state from NewReader, but reading from r instead.
 // This permits reusing a Reader rather than allocating a new one.

+ 33 - 0
reader_test.go

@@ -56,6 +56,39 @@ func TestReader(t *testing.T) {
 			if got, want := out.Bytes(), raw; !reflect.DeepEqual(got, want) {
 				t.Fatal("uncompressed data does not match original")
 			}
+
+			if len(raw) < 20 {
+				return
+			}
+
+			f2, err := os.Open(fname)
+			if err != nil {
+				t.Fatal(err)
+			}
+			defer f2.Close()
+
+			out.Reset()
+			zr = lz4.NewReader(f2)
+			_, err = io.CopyN(&out, zr, 10)
+			if err != nil {
+				t.Fatal(err)
+			}
+			if !reflect.DeepEqual(out.Bytes(), raw[:10]) {
+				t.Fatal("partial read does not match original")
+			}
+			err = zr.SkipBytes(int64(len(raw) - 20))
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			out.Reset()
+			_, err = io.CopyN(&out, zr, 10)
+			if err != nil {
+				t.Fatal(err)
+			}
+			if !reflect.DeepEqual(out.Bytes(), raw[len(raw)-10:]) {
+				t.Fatal("after seek, partial read does not match original")
+			}
 		})
 	}
 }