Browse Source

Introduce dynamicPushEncoders

Added dynamicPushEncoder interface that extends the pushEncoder with an
adjustLength method that will be called by prepEncoder.pop()
time so that it computes the actual length of the field.
Also made varintLengthField implement this method so we can avoid a
needless run of prepEncoder for uncompressed records.
Vlad Hanciuta 8 years ago
parent
commit
b37b1580be
6 changed files with 75 additions and 69 deletions
  1. 28 7
      length_field.go
  2. 10 0
      packet_encoder.go
  3. 9 0
      prep_encoder.go
  4. 10 40
      record.go
  5. 16 16
      record_batch.go
  6. 2 6
      record_test.go

+ 28 - 7
length_field.go

@@ -31,22 +31,43 @@ func (l *lengthField) check(curOffset int, buf []byte) error {
 type varintLengthField struct {
 	startOffset int
 	length      int64
+	adjusted    bool
+	size        int
 }
 
-func newVarintLengthField(pd packetDecoder) (*varintLengthField, error) {
-	n, err := pd.getVarint()
-	if err != nil {
-		return nil, err
-	}
-	return &varintLengthField{length: n}, nil
+func (l *varintLengthField) decode(pd packetDecoder) error {
+	var err error
+	l.length, err = pd.getVarint()
+	return err
 }
 
 func (l *varintLengthField) saveOffset(in int) {
 	l.startOffset = in
 }
 
+func (l *varintLengthField) adjustLength(currOffset int) int {
+	l.adjusted = true
+
+	var tmp [binary.MaxVarintLen64]byte
+	l.length = int64(currOffset - l.startOffset - l.size)
+
+	newSize := binary.PutVarint(tmp[:], l.length)
+	diff := newSize - l.size
+	l.size = newSize
+
+	return diff
+}
+
 func (l *varintLengthField) reserveLength() int {
-	return 0
+	return l.size
+}
+
+func (l *varintLengthField) run(curOffset int, buf []byte) error {
+	if !l.adjusted {
+		return PacketEncodingError{"varintLengthField.run called before adjustLength"}
+	}
+	binary.PutVarint(buf[l.startOffset:], l.length)
+	return nil
 }
 
 func (l *varintLengthField) check(curOffset int, buf []byte) error {

+ 10 - 0
packet_encoder.go

@@ -50,3 +50,13 @@ type pushEncoder interface {
 	// of data to the saved offset, based on the data between the saved offset and curOffset.
 	run(curOffset int, buf []byte) error
 }
+
+// dynamicPushEncoder extends the interface of pushEncoder for uses cases where the length of the
+// fields itself is unknown until its value was computed (for instance varint encoded lenght
+// fields).
+type dynamicPushEncoder interface {
+	pushEncoder
+
+	// Called during pop() to adjust the length of the field.
+	adjustLength(currOffset int) int
+}

+ 9 - 0
prep_encoder.go

@@ -9,6 +9,7 @@ import (
 )
 
 type prepEncoder struct {
+	stack  []pushEncoder
 	length int
 }
 
@@ -119,10 +120,18 @@ func (pe *prepEncoder) offset() int {
 // stackable
 
 func (pe *prepEncoder) push(in pushEncoder) {
+	in.saveOffset(pe.length)
 	pe.length += in.reserveLength()
+	pe.stack = append(pe.stack, in)
 }
 
 func (pe *prepEncoder) pop() error {
+	in := pe.stack[len(pe.stack)-1]
+	pe.stack = pe.stack[:len(pe.stack)-1]
+	if dpe, ok := in.(dynamicPushEncoder); ok {
+		pe.length += dpe.adjustLength(pe.length)
+	}
+
 	return nil
 }
 

+ 10 - 40
record.go

@@ -1,7 +1,5 @@
 package sarama
 
-import "encoding/binary"
-
 const (
 	controlMask = 0x20
 )
@@ -37,17 +35,12 @@ type Record struct {
 	Value          []byte
 	Headers        []*RecordHeader
 
-	lengthComputed bool
-	length         int64
-	totalLength    int
+	length      varintLengthField
+	totalLength int
 }
 
 func (r *Record) encode(pe packetEncoder) error {
-	if err := r.computeLength(); err != nil {
-		return err
-	}
-
-	pe.putVarint(r.length)
+	pe.push(&r.length)
 	pe.putInt8(r.Attributes)
 	pe.putVarint(r.TimestampDelta)
 	pe.putVarint(r.OffsetDelta)
@@ -65,19 +58,16 @@ func (r *Record) encode(pe packetEncoder) error {
 		}
 	}
 
-	return nil
+	return pe.pop()
 }
 
 func (r *Record) decode(pd packetDecoder) (err error) {
-	length, err := newVarintLengthField(pd)
-	if err != nil {
+	if err := r.length.decode(pd); err != nil {
 		return err
 	}
-	if err = pd.push(length); err != nil {
+	if err = pd.push(&r.length); err != nil {
 		return err
 	}
-	r.length = length.length
-	r.lengthComputed = true
 
 	if r.Attributes, err = pd.getInt8(); err != nil {
 		return err
@@ -118,32 +108,12 @@ func (r *Record) decode(pd packetDecoder) (err error) {
 	return pd.pop()
 }
 
-// Because the length is varint we can't reserve a fixed amount of bytes for it.
-// We use the prepEncoder to figure out the length of the record and then we cache it.
-func (r *Record) computeLength() error {
-	if !r.lengthComputed {
-		r.lengthComputed = true
-
-		var prep prepEncoder
-		if err := r.encode(&prep); err != nil {
-			return err
-		}
-		// subtract 1 because we don't want to include the length field itself (which 1 byte, the
-		// length of varint encoding of 0)
-		r.length = int64(prep.length) - 1
-	}
-
-	return nil
-}
-
 func (r *Record) getTotalLength() (int, error) {
-	if r.totalLength == 0 {
-		if err := r.computeLength(); err != nil {
+	var prep prepEncoder
+	if !r.length.adjusted {
+		if err := r.encode(&prep); err != nil {
 			return 0, err
 		}
-		var buf [binary.MaxVarintLen64]byte
-		r.totalLength = int(r.length) + binary.PutVarint(buf[:], r.length)
 	}
-
-	return r.totalLength, nil
+	return int(r.length.length) + r.length.size, nil
 }

+ 16 - 16
record_batch.go

@@ -59,10 +59,7 @@ func (b *RecordBatch) encode(pe packetEncoder) error {
 		if err := pe.pop(); err != nil {
 			return err
 		}
-		if err := pe.pop(); err != nil {
-			return err
-		}
-		return nil
+		return pe.pop()
 	}
 
 	var re packetEncoder
@@ -72,14 +69,9 @@ func (b *RecordBatch) encode(pe packetEncoder) error {
 	case CompressionNone:
 		re = pe
 	case CompressionGZIP, CompressionLZ4, CompressionSnappy:
-		for _, r := range b.Records {
-			l, err := r.getTotalLength()
-			if err != nil {
-				return err
-			}
-			b.recordsLen += l
+		if err := b.computeRecordsLength(); err != nil {
+			return err
 		}
-
 		raw = make([]byte, b.recordsLen)
 		re = &realEncoder{raw: raw}
 	default:
@@ -123,11 +115,7 @@ func (b *RecordBatch) encode(pe packetEncoder) error {
 	if err := pe.pop(); err != nil {
 		return err
 	}
-	if err := pe.pop(); err != nil {
-		return err
-	}
-
-	return nil
+	return pe.pop()
 }
 
 func (b *RecordBatch) decode(pd packetDecoder) (err error) {
@@ -249,6 +237,18 @@ func (b *RecordBatch) computeAttributes() int16 {
 	return attr
 }
 
+func (b *RecordBatch) computeRecordsLength() error {
+	b.recordsLen = 0
+	for _, r := range b.Records {
+		l, err := r.getTotalLength()
+		if err != nil {
+			return err
+		}
+		b.recordsLen += l
+	}
+	return nil
+}
+
 func (b *RecordBatch) addRecord(r *Record) {
 	b.Records = append(b.Records, r)
 }

+ 2 - 6
record_test.go

@@ -252,14 +252,10 @@ func TestRecordBatchDecoding(t *testing.T) {
 		batch := RecordBatch{}
 		testDecodable(t, tc.name, &batch, tc.encoded)
 		for _, r := range batch.Records {
-			if _, err := r.getTotalLength(); err != nil {
-				t.Fatalf("Unexpected error: %v", err)
-			}
+			r.length = varintLengthField{}
 		}
 		for _, r := range tc.batch.Records {
-			if _, err := r.getTotalLength(); err != nil {
-				t.Fatalf("Unexpected error: %v", err)
-			}
+			r.length = varintLengthField{}
 		}
 		if !reflect.DeepEqual(batch, tc.batch) {
 			t.Errorf(spew.Sprintf("invalid decode of %s\ngot %+v\nwanted %+v", tc.name, batch, tc.batch))