Browse Source

implement io.Seeker

Ian Wilkes 6 years ago
parent
commit
aa78dadf73
4 changed files with 43 additions and 9 deletions
  1. 1 1
      bench_test.go
  2. 2 0
      errors.go
  3. 13 7
      reader.go
  4. 27 1
      reader_test.go

+ 1 - 1
bench_test.go

@@ -118,7 +118,7 @@ func benchmarkSkipBytes(b *testing.B, compressed []byte) {
 	for i := 0; i < b.N; i++ {
 	for i := 0; i < b.N; i++ {
 		r.Reset(compressed)
 		r.Reset(compressed)
 		zr.Reset(r)
 		zr.Reset(r)
-		zr.SkipBytes(uncompressedSize)
+		zr.Seek(uncompressedSize, io.SeekCurrent)
 		_, _ = io.Copy(ioutil.Discard, zr)
 		_, _ = io.Copy(ioutil.Discard, zr)
 	}
 	}
 }
 }

+ 2 - 0
errors.go

@@ -15,6 +15,8 @@ var (
 	ErrInvalid = errors.New("lz4: bad magic number")
 	ErrInvalid = errors.New("lz4: bad magic number")
 	// ErrBlockDependency is returned when attempting to decompress an archive created with block dependency.
 	// ErrBlockDependency is returned when attempting to decompress an archive created with block dependency.
 	ErrBlockDependency = errors.New("lz4: block dependency not supported")
 	ErrBlockDependency = errors.New("lz4: block dependency not supported")
+	// ErrUnsupportedSeek is returned when attempting to Seek any way but forward from the current position.
+	ErrUnsupportedSeek = errors.New("lz4: can only seek forward from io.SeekCurrent")
 )
 )
 
 
 func recoverBlock(e *error) {
 func recoverBlock(e *error) {

+ 13 - 7
reader.go

@@ -2,7 +2,6 @@ package lz4
 
 
 import (
 import (
 	"encoding/binary"
 	"encoding/binary"
-	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
@@ -27,6 +26,7 @@ type Reader struct {
 	idx      int           // Index of unread bytes into data.
 	idx      int           // Index of unread bytes into data.
 	checksum xxh32.XXHZero // Frame hash.
 	checksum xxh32.XXHZero // Frame hash.
 	skip     int64         // Bytes to skip before next read.
 	skip     int64         // Bytes to skip before next read.
+	dpos     int64         // Position in dest
 }
 }
 
 
 // NewReader returns a new LZ4 frame decoder.
 // NewReader returns a new LZ4 frame decoder.
@@ -279,15 +279,18 @@ func (z *Reader) Read(buf []byte) (int, error) {
 
 
 	if z.skip > int64(len(z.data[z.idx:])) {
 	if z.skip > int64(len(z.data[z.idx:])) {
 		z.skip -= int64(len(z.data[z.idx:]))
 		z.skip -= int64(len(z.data[z.idx:]))
+		z.dpos += int64(len(z.data[z.idx:]))
 		z.idx = len(z.data)
 		z.idx = len(z.data)
 		return 0, nil
 		return 0, nil
 	}
 	}
 
 
 	z.idx += int(z.skip)
 	z.idx += int(z.skip)
+	z.dpos += z.skip
 	z.skip = 0
 	z.skip = 0
 
 
 	n := copy(buf, z.data[z.idx:])
 	n := copy(buf, z.data[z.idx:])
 	z.idx += n
 	z.idx += n
+	z.dpos += int64(n)
 	if debugFlag {
 	if debugFlag {
 		debug("copied %d bytes to input", n)
 		debug("copied %d bytes to input", n)
 	}
 	}
@@ -295,15 +298,18 @@ func (z *Reader) Read(buf []byte) (int, error) {
 	return n, nil
 	return n, nil
 }
 }
 
 
-// Skip n bytes in the output stream. Equivalent to a forward seek.
+// Seek implements io.Seeker, but supports seeking forward from the current
+// position only. Any other seek will return an error. Allows skipping output
+// bytes which aren't needed, which in some scenarios is faster than reading
+// and discarding them.
 // Note this may cause future calls to Read() to read 0 bytes if all of the
 // Note this may cause future calls to Read() to read 0 bytes if all of the
 // data they would have returned is skipped.
 // data they would have returned is skipped.
-func (z *Reader) SkipBytes(n int64) error {
-	if n < 0 {
-		return errors.New("can only skip forward")
+func (z *Reader) Seek(offset int64, whence int) (int64, error) {
+	if offset < 0 || whence != io.SeekCurrent {
+		return z.dpos + z.skip, ErrUnsupportedSeek
 	}
 	}
-	z.skip += n
-	return nil
+	z.skip += offset
+	return z.dpos + z.skip, nil
 }
 }
 
 
 // Reset discards the Reader's state and makes it equivalent to the
 // Reset discards the Reader's state and makes it equivalent to the

+ 27 - 1
reader_test.go

@@ -76,10 +76,36 @@ func TestReader(t *testing.T) {
 			if !reflect.DeepEqual(out.Bytes(), raw[:10]) {
 			if !reflect.DeepEqual(out.Bytes(), raw[:10]) {
 				t.Fatal("partial read does not match original")
 				t.Fatal("partial read does not match original")
 			}
 			}
-			err = zr.SkipBytes(int64(len(raw) - 20))
+
+			pos, err := zr.Seek(-1, io.SeekCurrent)
+			if err == nil {
+				t.Fatal("expected error from invalid seek")
+			}
+			if pos != 10 {
+				t.Fatalf("unexpected position %d", pos)
+			}
+			pos, err = zr.Seek(1, io.SeekStart)
+			if err == nil {
+				t.Fatal("expected error from invalid seek")
+			}
+			if pos != 10 {
+				t.Fatalf("unexpected position %d", pos)
+			}
+			pos, err = zr.Seek(-1, io.SeekEnd)
+			if err == nil {
+				t.Fatal("expected error from invalid seek")
+			}
+			if pos != 10 {
+				t.Fatalf("unexpected position %d", pos)
+			}
+
+			pos, err = zr.Seek(int64(len(raw)-20), io.SeekCurrent)
 			if err != nil {
 			if err != nil {
 				t.Fatal(err)
 				t.Fatal(err)
 			}
 			}
+			if pos != int64(len(raw)-10) {
+				t.Fatalf("unexpected position %d", pos)
+			}
 
 
 			out.Reset()
 			out.Reset()
 			_, err = io.CopyN(&out, zr, 10)
 			_, err = io.CopyN(&out, zr, 10)