Browse Source

Introduce dynamicPushEncoders

Added dynamicPushEncoder interface that extends the pushEncoder with an
adjustLength method that will be called by prepEncoder.pop()
time so that it computes the actual length of the field.
Also made varintLengthField implement this method so we can avoid a
needless run of prepEncoder for uncompressed records.
Vlad Hanciuta 8 years ago
parent
commit
b37b1580be
6 changed files with 75 additions and 69 deletions
  1. 28 7
      length_field.go
  2. 10 0
      packet_encoder.go
  3. 9 0
      prep_encoder.go
  4. 10 40
      record.go
  5. 16 16
      record_batch.go
  6. 2 6
      record_test.go

+ 28 - 7
length_field.go

@@ -31,22 +31,43 @@ func (l *lengthField) check(curOffset int, buf []byte) error {
 type varintLengthField struct {
 type varintLengthField struct {
 	startOffset int
 	startOffset int
 	length      int64
 	length      int64
+	adjusted    bool
+	size        int
 }
 }
 
 
-func newVarintLengthField(pd packetDecoder) (*varintLengthField, error) {
-	n, err := pd.getVarint()
-	if err != nil {
-		return nil, err
-	}
-	return &varintLengthField{length: n}, nil
+func (l *varintLengthField) decode(pd packetDecoder) error {
+	var err error
+	l.length, err = pd.getVarint()
+	return err
 }
 }
 
 
 func (l *varintLengthField) saveOffset(in int) {
 func (l *varintLengthField) saveOffset(in int) {
 	l.startOffset = in
 	l.startOffset = in
 }
 }
 
 
+func (l *varintLengthField) adjustLength(currOffset int) int {
+	l.adjusted = true
+
+	var tmp [binary.MaxVarintLen64]byte
+	l.length = int64(currOffset - l.startOffset - l.size)
+
+	newSize := binary.PutVarint(tmp[:], l.length)
+	diff := newSize - l.size
+	l.size = newSize
+
+	return diff
+}
+
 func (l *varintLengthField) reserveLength() int {
 func (l *varintLengthField) reserveLength() int {
-	return 0
+	return l.size
+}
+
+func (l *varintLengthField) run(curOffset int, buf []byte) error {
+	if !l.adjusted {
+		return PacketEncodingError{"varintLengthField.run called before adjustLength"}
+	}
+	binary.PutVarint(buf[l.startOffset:], l.length)
+	return nil
 }
 }
 
 
 func (l *varintLengthField) check(curOffset int, buf []byte) error {
 func (l *varintLengthField) check(curOffset int, buf []byte) error {

+ 10 - 0
packet_encoder.go

@@ -50,3 +50,13 @@ type pushEncoder interface {
 	// of data to the saved offset, based on the data between the saved offset and curOffset.
 	// of data to the saved offset, based on the data between the saved offset and curOffset.
 	run(curOffset int, buf []byte) error
 	run(curOffset int, buf []byte) error
 }
 }
+
+// dynamicPushEncoder extends the interface of pushEncoder for uses cases where the length of the
+// fields itself is unknown until its value was computed (for instance varint encoded lenght
+// fields).
+type dynamicPushEncoder interface {
+	pushEncoder
+
+	// Called during pop() to adjust the length of the field.
+	adjustLength(currOffset int) int
+}

+ 9 - 0
prep_encoder.go

@@ -9,6 +9,7 @@ import (
 )
 )
 
 
 type prepEncoder struct {
 type prepEncoder struct {
+	stack  []pushEncoder
 	length int
 	length int
 }
 }
 
 
@@ -119,10 +120,18 @@ func (pe *prepEncoder) offset() int {
 // stackable
 // stackable
 
 
 func (pe *prepEncoder) push(in pushEncoder) {
 func (pe *prepEncoder) push(in pushEncoder) {
+	in.saveOffset(pe.length)
 	pe.length += in.reserveLength()
 	pe.length += in.reserveLength()
+	pe.stack = append(pe.stack, in)
 }
 }
 
 
 func (pe *prepEncoder) pop() error {
 func (pe *prepEncoder) pop() error {
+	in := pe.stack[len(pe.stack)-1]
+	pe.stack = pe.stack[:len(pe.stack)-1]
+	if dpe, ok := in.(dynamicPushEncoder); ok {
+		pe.length += dpe.adjustLength(pe.length)
+	}
+
 	return nil
 	return nil
 }
 }
 
 

+ 10 - 40
record.go

@@ -1,7 +1,5 @@
 package sarama
 package sarama
 
 
-import "encoding/binary"
-
 const (
 const (
 	controlMask = 0x20
 	controlMask = 0x20
 )
 )
@@ -37,17 +35,12 @@ type Record struct {
 	Value          []byte
 	Value          []byte
 	Headers        []*RecordHeader
 	Headers        []*RecordHeader
 
 
-	lengthComputed bool
-	length         int64
-	totalLength    int
+	length      varintLengthField
+	totalLength int
 }
 }
 
 
 func (r *Record) encode(pe packetEncoder) error {
 func (r *Record) encode(pe packetEncoder) error {
-	if err := r.computeLength(); err != nil {
-		return err
-	}
-
-	pe.putVarint(r.length)
+	pe.push(&r.length)
 	pe.putInt8(r.Attributes)
 	pe.putInt8(r.Attributes)
 	pe.putVarint(r.TimestampDelta)
 	pe.putVarint(r.TimestampDelta)
 	pe.putVarint(r.OffsetDelta)
 	pe.putVarint(r.OffsetDelta)
@@ -65,19 +58,16 @@ func (r *Record) encode(pe packetEncoder) error {
 		}
 		}
 	}
 	}
 
 
-	return nil
+	return pe.pop()
 }
 }
 
 
 func (r *Record) decode(pd packetDecoder) (err error) {
 func (r *Record) decode(pd packetDecoder) (err error) {
-	length, err := newVarintLengthField(pd)
-	if err != nil {
+	if err := r.length.decode(pd); err != nil {
 		return err
 		return err
 	}
 	}
-	if err = pd.push(length); err != nil {
+	if err = pd.push(&r.length); err != nil {
 		return err
 		return err
 	}
 	}
-	r.length = length.length
-	r.lengthComputed = true
 
 
 	if r.Attributes, err = pd.getInt8(); err != nil {
 	if r.Attributes, err = pd.getInt8(); err != nil {
 		return err
 		return err
@@ -118,32 +108,12 @@ func (r *Record) decode(pd packetDecoder) (err error) {
 	return pd.pop()
 	return pd.pop()
 }
 }
 
 
-// Because the length is varint we can't reserve a fixed amount of bytes for it.
-// We use the prepEncoder to figure out the length of the record and then we cache it.
-func (r *Record) computeLength() error {
-	if !r.lengthComputed {
-		r.lengthComputed = true
-
-		var prep prepEncoder
-		if err := r.encode(&prep); err != nil {
-			return err
-		}
-		// subtract 1 because we don't want to include the length field itself (which 1 byte, the
-		// length of varint encoding of 0)
-		r.length = int64(prep.length) - 1
-	}
-
-	return nil
-}
-
 func (r *Record) getTotalLength() (int, error) {
 func (r *Record) getTotalLength() (int, error) {
-	if r.totalLength == 0 {
-		if err := r.computeLength(); err != nil {
+	var prep prepEncoder
+	if !r.length.adjusted {
+		if err := r.encode(&prep); err != nil {
 			return 0, err
 			return 0, err
 		}
 		}
-		var buf [binary.MaxVarintLen64]byte
-		r.totalLength = int(r.length) + binary.PutVarint(buf[:], r.length)
 	}
 	}
-
-	return r.totalLength, nil
+	return int(r.length.length) + r.length.size, nil
 }
 }

+ 16 - 16
record_batch.go

@@ -59,10 +59,7 @@ func (b *RecordBatch) encode(pe packetEncoder) error {
 		if err := pe.pop(); err != nil {
 		if err := pe.pop(); err != nil {
 			return err
 			return err
 		}
 		}
-		if err := pe.pop(); err != nil {
-			return err
-		}
-		return nil
+		return pe.pop()
 	}
 	}
 
 
 	var re packetEncoder
 	var re packetEncoder
@@ -72,14 +69,9 @@ func (b *RecordBatch) encode(pe packetEncoder) error {
 	case CompressionNone:
 	case CompressionNone:
 		re = pe
 		re = pe
 	case CompressionGZIP, CompressionLZ4, CompressionSnappy:
 	case CompressionGZIP, CompressionLZ4, CompressionSnappy:
-		for _, r := range b.Records {
-			l, err := r.getTotalLength()
-			if err != nil {
-				return err
-			}
-			b.recordsLen += l
+		if err := b.computeRecordsLength(); err != nil {
+			return err
 		}
 		}
-
 		raw = make([]byte, b.recordsLen)
 		raw = make([]byte, b.recordsLen)
 		re = &realEncoder{raw: raw}
 		re = &realEncoder{raw: raw}
 	default:
 	default:
@@ -123,11 +115,7 @@ func (b *RecordBatch) encode(pe packetEncoder) error {
 	if err := pe.pop(); err != nil {
 	if err := pe.pop(); err != nil {
 		return err
 		return err
 	}
 	}
-	if err := pe.pop(); err != nil {
-		return err
-	}
-
-	return nil
+	return pe.pop()
 }
 }
 
 
 func (b *RecordBatch) decode(pd packetDecoder) (err error) {
 func (b *RecordBatch) decode(pd packetDecoder) (err error) {
@@ -249,6 +237,18 @@ func (b *RecordBatch) computeAttributes() int16 {
 	return attr
 	return attr
 }
 }
 
 
+func (b *RecordBatch) computeRecordsLength() error {
+	b.recordsLen = 0
+	for _, r := range b.Records {
+		l, err := r.getTotalLength()
+		if err != nil {
+			return err
+		}
+		b.recordsLen += l
+	}
+	return nil
+}
+
 func (b *RecordBatch) addRecord(r *Record) {
 func (b *RecordBatch) addRecord(r *Record) {
 	b.Records = append(b.Records, r)
 	b.Records = append(b.Records, r)
 }
 }

+ 2 - 6
record_test.go

@@ -252,14 +252,10 @@ func TestRecordBatchDecoding(t *testing.T) {
 		batch := RecordBatch{}
 		batch := RecordBatch{}
 		testDecodable(t, tc.name, &batch, tc.encoded)
 		testDecodable(t, tc.name, &batch, tc.encoded)
 		for _, r := range batch.Records {
 		for _, r := range batch.Records {
-			if _, err := r.getTotalLength(); err != nil {
-				t.Fatalf("Unexpected error: %v", err)
-			}
+			r.length = varintLengthField{}
 		}
 		}
 		for _, r := range tc.batch.Records {
 		for _, r := range tc.batch.Records {
-			if _, err := r.getTotalLength(); err != nil {
-				t.Fatalf("Unexpected error: %v", err)
-			}
+			r.length = varintLengthField{}
 		}
 		}
 		if !reflect.DeepEqual(batch, tc.batch) {
 		if !reflect.DeepEqual(batch, tc.batch) {
 			t.Errorf(spew.Sprintf("invalid decode of %s\ngot %+v\nwanted %+v", tc.name, batch, tc.batch))
 			t.Errorf(spew.Sprintf("invalid decode of %s\ngot %+v\nwanted %+v", tc.name, batch, tc.batch))