فهرست منبع

Merge pull request #54 from ianwilkes/master

add Seek() to Reader
Pierre Curto 6 سال پیش
والد
کامیت
645f9b948e
4فایلهای تغییر یافته به همراه116 افزوده شده و 0 حذف شده
  1. 27 0
      bench_test.go
  2. 2 0
      errors.go
  3. 28 0
      reader.go
  4. 59 0
      reader_test.go

+ 27 - 0
bench_test.go

@@ -101,6 +101,33 @@ func BenchmarkUncompressDigits(b *testing.B) { benchmarkUncompress(b, digitsLZ4)
 func BenchmarkUncompressTwain(b *testing.B)  { benchmarkUncompress(b, twainLZ4) }
 func BenchmarkUncompressRand(b *testing.B)   { benchmarkUncompress(b, randomLZ4) }
 
+func benchmarkSkipBytes(b *testing.B, compressed []byte) {
+	r := bytes.NewReader(compressed)
+	zr := lz4.NewReader(r)
+
+	// Determine the uncompressed size of testfile.
+	uncompressedSize, err := io.Copy(ioutil.Discard, zr)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	b.SetBytes(uncompressedSize)
+	b.ReportAllocs()
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		r.Reset(compressed)
+		zr.Reset(r)
+		zr.Seek(uncompressedSize, io.SeekCurrent)
+		_, _ = io.Copy(ioutil.Discard, zr)
+	}
+}
+
+func BenchmarkSkipBytesPg1661(b *testing.B) { benchmarkSkipBytes(b, pg1661LZ4) }
+func BenchmarkSkipBytesDigits(b *testing.B) { benchmarkSkipBytes(b, digitsLZ4) }
+func BenchmarkSkipBytesTwain(b *testing.B)  { benchmarkSkipBytes(b, twainLZ4) }
+func BenchmarkSkipBytesRand(b *testing.B)   { benchmarkSkipBytes(b, randomLZ4) }
+
 func benchmarkCompress(b *testing.B, uncompressed []byte) {
 	w := bytes.NewBuffer(nil)
 	zw := lz4.NewWriter(w)

+ 2 - 0
errors.go

@@ -15,6 +15,8 @@ var (
 	ErrInvalid = errors.New("lz4: bad magic number")
 	// ErrBlockDependency is returned when attempting to decompress an archive created with block dependency.
 	ErrBlockDependency = errors.New("lz4: block dependency not supported")
+	// ErrUnsupportedSeek is returned when attempting to Seek any way but forward from the current position.
+	ErrUnsupportedSeek = errors.New("lz4: can only seek forward from io.SeekCurrent")
 )
 
 func recoverBlock(e *error) {

+ 28 - 0
reader.go

@@ -25,6 +25,8 @@ type Reader struct {
 	data     []byte        // Uncompressed data.
 	idx      int           // Index of unread bytes into data.
 	checksum xxh32.XXHZero // Frame hash.
+	skip     int64         // Bytes to skip before next read.
+	dpos     int64         // Position in dest
 }
 
 // NewReader returns a new LZ4 frame decoder.
@@ -275,8 +277,20 @@ func (z *Reader) Read(buf []byte) (int, error) {
 		z.idx = 0
 	}
 
+	if z.skip > int64(len(z.data[z.idx:])) {
+		z.skip -= int64(len(z.data[z.idx:]))
+		z.dpos += int64(len(z.data[z.idx:]))
+		z.idx = len(z.data)
+		return 0, nil
+	}
+
+	z.idx += int(z.skip)
+	z.dpos += z.skip
+	z.skip = 0
+
 	n := copy(buf, z.data[z.idx:])
 	z.idx += n
+	z.dpos += int64(n)
 	if debugFlag {
 		debug("copied %d bytes to input", n)
 	}
@@ -284,6 +298,20 @@ func (z *Reader) Read(buf []byte) (int, error) {
 	return n, nil
 }
 
+// Seek implements io.Seeker, but supports seeking forward from the current
+// position only. Any other seek will return an error. Allows skipping output
+// bytes which aren't needed, which in some scenarios is faster than reading
+// and discarding them.
+// Note this may cause future calls to Read() to read 0 bytes if all of the
+// data they would have returned is skipped.
+func (z *Reader) Seek(offset int64, whence int) (int64, error) {
+	if offset < 0 || whence != io.SeekCurrent {
+		return z.dpos + z.skip, ErrUnsupportedSeek
+	}
+	z.skip += offset
+	return z.dpos + z.skip, nil
+}
+
 // Reset discards the Reader's state and makes it equivalent to the
 // result of its original state from NewReader, but reading from r instead.
 // This permits reusing a Reader rather than allocating a new one.

+ 59 - 0
reader_test.go

@@ -56,6 +56,65 @@ func TestReader(t *testing.T) {
 			if got, want := out.Bytes(), raw; !reflect.DeepEqual(got, want) {
 				t.Fatal("uncompressed data does not match original")
 			}
+
+			if len(raw) < 20 {
+				return
+			}
+
+			f2, err := os.Open(fname)
+			if err != nil {
+				t.Fatal(err)
+			}
+			defer f2.Close()
+
+			out.Reset()
+			zr = lz4.NewReader(f2)
+			_, err = io.CopyN(&out, zr, 10)
+			if err != nil {
+				t.Fatal(err)
+			}
+			if !reflect.DeepEqual(out.Bytes(), raw[:10]) {
+				t.Fatal("partial read does not match original")
+			}
+
+			pos, err := zr.Seek(-1, io.SeekCurrent)
+			if err == nil {
+				t.Fatal("expected error from invalid seek")
+			}
+			if pos != 10 {
+				t.Fatalf("unexpected position %d", pos)
+			}
+			pos, err = zr.Seek(1, io.SeekStart)
+			if err == nil {
+				t.Fatal("expected error from invalid seek")
+			}
+			if pos != 10 {
+				t.Fatalf("unexpected position %d", pos)
+			}
+			pos, err = zr.Seek(-1, io.SeekEnd)
+			if err == nil {
+				t.Fatal("expected error from invalid seek")
+			}
+			if pos != 10 {
+				t.Fatalf("unexpected position %d", pos)
+			}
+
+			pos, err = zr.Seek(int64(len(raw)-20), io.SeekCurrent)
+			if err != nil {
+				t.Fatal(err)
+			}
+			if pos != int64(len(raw)-10) {
+				t.Fatalf("unexpected position %d", pos)
+			}
+
+			out.Reset()
+			_, err = io.CopyN(&out, zr, 10)
+			if err != nil {
+				t.Fatal(err)
+			}
+			if !reflect.DeepEqual(out.Bytes(), raw[len(raw)-10:]) {
+				t.Fatal("after seek, partial read does not match original")
+			}
 		})
 	}
 }