瀏覽代碼

Make an encoder/decoder for records array

Vlad Hanciuta 8 年之前
父節點
當前提交
b51e2317f4
共有 1 個文件被更改,包括 36 次插入31 次删除
  1. 36 31
      record_batch.go

+ 36 - 31
record_batch.go

@@ -12,6 +12,28 @@ import (
 
 const recordBatchOverhead = 49
 
+type recordsArray []*Record
+
+func (e recordsArray) encode(pe packetEncoder) error {
+	for _, r := range e {
+		if err := r.encode(pe); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (e recordsArray) decode(pd packetDecoder) error {
+	for i := range e {
+		rec := &Record{}
+		if err := rec.decode(pd); err != nil {
+			return err
+		}
+		e[i] = rec
+	}
+	return nil
+}
+
 type RecordBatch struct {
 	FirstOffset           int64
 	PartitionLeaderEpoch  int32
@@ -62,29 +84,18 @@ func (b *RecordBatch) encode(pe packetEncoder) error {
 		return pe.pop()
 	}
 
-	var re packetEncoder
 	var raw []byte
-
-	switch b.Codec {
-	case CompressionNone:
-		re = pe
-	case CompressionGZIP, CompressionLZ4, CompressionSnappy:
-		if err := b.computeRecordsLength(); err != nil {
+	if b.Codec != CompressionNone {
+		var err error
+		if raw, err = encode(recordsArray(b.Records), nil); err != nil {
 			return err
 		}
-		raw = make([]byte, b.recordsLen)
-		re = &realEncoder{raw: raw}
-	default:
-		return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)}
 	}
-
-	for _, r := range b.Records {
-		if err := r.encode(re); err != nil {
+	switch b.Codec {
+	case CompressionNone:
+		if err := recordsArray(b.Records).encode(pe); err != nil {
 			return err
 		}
-	}
-
-	switch b.Codec {
 	case CompressionGZIP:
 		var buf bytes.Buffer
 		writer := gzip.NewWriter(&buf)
@@ -107,6 +118,8 @@ func (b *RecordBatch) encode(pe packetEncoder) error {
 			return err
 		}
 		b.compressedRecords = buf.Bytes()
+	default:
+		return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)}
 	}
 	if err := pe.putRawBytes(b.compressedRecords); err != nil {
 		return err
@@ -211,22 +224,14 @@ func (b *RecordBatch) decode(pd packetDecoder) (err error) {
 	default:
 		return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", b.Codec)}
 	}
-	recPd := &realDecoder{raw: recBuffer}
 
-	for i := 0; i < numRecs; i++ {
-		rec := &Record{}
-		if err = rec.decode(recPd); err != nil {
-			if err == ErrInsufficientData {
-				b.PartialTrailingRecord = true
-				b.Records = nil
-				return nil
-			}
-			return err
-		}
-		b.Records[i] = rec
+	err = decode(recBuffer, recordsArray(b.Records))
+	if err == ErrInsufficientData {
+		b.PartialTrailingRecord = true
+		b.Records = nil
+		return nil
 	}
-
-	return nil
+	return err
 }
 
 func (b *RecordBatch) computeAttributes() int16 {