record_batch.go 5.4 KB


  1. package sarama
  2. import (
  3. "bytes"
  4. "compress/gzip"
  5. "fmt"
  6. "io/ioutil"
  7. "github.com/eapache/go-xerial-snappy"
  8. "github.com/pierrec/lz4"
  9. )
  10. const recordBatchOverhead = 49
  11. type recordsArray []*Record
  12. func (e recordsArray) encode(pe packetEncoder) error {
  13. for _, r := range e {
  14. if err := r.encode(pe); err != nil {
  15. return err
  16. }
  17. }
  18. return nil
  19. }
  20. func (e recordsArray) decode(pd packetDecoder) error {
  21. for i := range e {
  22. rec := &Record{}
  23. if err := rec.decode(pd); err != nil {
  24. return err
  25. }
  26. e[i] = rec
  27. }
  28. return nil
  29. }
  30. type RecordBatch struct {
  31. FirstOffset int64
  32. PartitionLeaderEpoch int32
  33. Version int8
  34. Codec CompressionCodec
  35. Control bool
  36. LastOffsetDelta int32
  37. FirstTimestamp int64
  38. MaxTimestamp int64
  39. ProducerID int64
  40. ProducerEpoch int16
  41. FirstSequence int32
  42. Records []*Record
  43. PartialTrailingRecord bool
  44. compressedRecords []byte
  45. recordsLen int
  46. }
  47. func (b *RecordBatch) encode(pe packetEncoder) error {
  48. if b.Version != 2 {
  49. return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)}
  50. }
  51. pe.putInt64(b.FirstOffset)
  52. pe.push(&lengthField{})
  53. pe.putInt32(b.PartitionLeaderEpoch)
  54. pe.putInt8(b.Version)
  55. pe.push(newCRC32Field(crcCastagnoli))
  56. pe.putInt16(b.computeAttributes())
  57. pe.putInt32(b.LastOffsetDelta)
  58. pe.putInt64(b.FirstTimestamp)
  59. pe.putInt64(b.MaxTimestamp)
  60. pe.putInt64(b.ProducerID)
  61. pe.putInt16(b.ProducerEpoch)
  62. pe.putInt32(b.FirstSequence)
  63. if err := pe.putArrayLength(len(b.Records)); err != nil {
  64. return err
  65. }
  66. if b.compressedRecords != nil {
  67. if err := pe.putRawBytes(b.compressedRecords); err != nil {
  68. return err
  69. }
  70. if err := pe.pop(); err != nil {
  71. return err
  72. }
  73. return pe.pop()
  74. }
  75. var raw []byte
  76. if b.Codec != CompressionNone {
  77. var err error
  78. if raw, err = encode(recordsArray(b.Records), nil); err != nil {
  79. return err
  80. }
  81. }
  82. switch b.Codec {
  83. case CompressionNone:
  84. if err := recordsArray(b.Records).encode(pe); err != nil {
  85. return err
  86. }
  87. case CompressionGZIP:
  88. var buf bytes.Buffer
  89. writer := gzip.NewWriter(&buf)
  90. if _, err := writer.Write(raw); err != nil {
  91. return err
  92. }
  93. if err := writer.Close(); err != nil {
  94. return err
  95. }
  96. b.compressedRecords = buf.Bytes()
  97. case CompressionSnappy:
  98. b.compressedRecords = snappy.Encode(raw)
  99. case CompressionLZ4:
  100. var buf bytes.Buffer
  101. writer := lz4.NewWriter(&buf)
  102. if _, err := writer.Write(raw); err != nil {
  103. return err
  104. }
  105. if err := writer.Close(); err != nil {
  106. return err
  107. }
  108. b.compressedRecords = buf.Bytes()
  109. default:
  110. return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)}
  111. }
  112. if err := pe.putRawBytes(b.compressedRecords); err != nil {
  113. return err
  114. }
  115. if err := pe.pop(); err != nil {
  116. return err
  117. }
  118. return pe.pop()
  119. }
  120. func (b *RecordBatch) decode(pd packetDecoder) (err error) {
  121. if b.FirstOffset, err = pd.getInt64(); err != nil {
  122. return err
  123. }
  124. var batchLen int32
  125. if batchLen, err = pd.getInt32(); err != nil {
  126. return err
  127. }
  128. if b.PartitionLeaderEpoch, err = pd.getInt32(); err != nil {
  129. return err
  130. }
  131. if b.Version, err = pd.getInt8(); err != nil {
  132. return err
  133. }
  134. if err = pd.push(&crc32Field{polynomial: crcCastagnoli}); err != nil {
  135. return err
  136. }
  137. var attributes int16
  138. if attributes, err = pd.getInt16(); err != nil {
  139. return err
  140. }
  141. b.Codec = CompressionCodec(int8(attributes) & compressionCodecMask)
  142. b.Control = attributes&controlMask == controlMask
  143. if b.LastOffsetDelta, err = pd.getInt32(); err != nil {
  144. return err
  145. }
  146. if b.FirstTimestamp, err = pd.getInt64(); err != nil {
  147. return err
  148. }
  149. if b.MaxTimestamp, err = pd.getInt64(); err != nil {
  150. return err
  151. }
  152. if b.ProducerID, err = pd.getInt64(); err != nil {
  153. return err
  154. }
  155. if b.ProducerEpoch, err = pd.getInt16(); err != nil {
  156. return err
  157. }
  158. if b.FirstSequence, err = pd.getInt32(); err != nil {
  159. return err
  160. }
  161. numRecs, err := pd.getArrayLength()
  162. if err != nil {
  163. return err
  164. }
  165. if numRecs >= 0 {
  166. b.Records = make([]*Record, numRecs)
  167. }
  168. bufSize := int(batchLen) - recordBatchOverhead
  169. recBuffer, err := pd.getRawBytes(bufSize)
  170. if err != nil {
  171. return err
  172. }
  173. if err = pd.pop(); err != nil {
  174. return err
  175. }
  176. switch b.Codec {
  177. case CompressionNone:
  178. case CompressionGZIP:
  179. reader, err := gzip.NewReader(bytes.NewReader(recBuffer))
  180. if err != nil {
  181. return err
  182. }
  183. if recBuffer, err = ioutil.ReadAll(reader); err != nil {
  184. return err
  185. }
  186. case CompressionSnappy:
  187. if recBuffer, err = snappy.Decode(recBuffer); err != nil {
  188. return err
  189. }
  190. case CompressionLZ4:
  191. reader := lz4.NewReader(bytes.NewReader(recBuffer))
  192. if recBuffer, err = ioutil.ReadAll(reader); err != nil {
  193. return err
  194. }
  195. default:
  196. return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", b.Codec)}
  197. }
  198. err = decode(recBuffer, recordsArray(b.Records))
  199. if err == ErrInsufficientData {
  200. b.PartialTrailingRecord = true
  201. b.Records = nil
  202. return nil
  203. }
  204. return err
  205. }
  206. func (b *RecordBatch) computeAttributes() int16 {
  207. attr := int16(b.Codec) & int16(compressionCodecMask)
  208. if b.Control {
  209. attr |= controlMask
  210. }
  211. return attr
  212. }
  213. func (b *RecordBatch) computeRecordsLength() error {
  214. b.recordsLen = 0
  215. for _, r := range b.Records {
  216. l, err := r.getTotalLength()
  217. if err != nil {
  218. return err
  219. }
  220. b.recordsLen += l
  221. }
  222. return nil
  223. }
  224. func (b *RecordBatch) addRecord(r *Record) {
  225. b.Records = append(b.Records, r)
  226. }