records.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. package sarama
  2. import "fmt"
  3. const (
  4. unknownRecords = iota
  5. legacyRecords
  6. defaultRecords
  7. magicOffset = 16
  8. magicLength = 1
  9. )
  10. // Records implements a union type containing either a RecordBatch or a legacy MessageSet.
  11. type Records struct {
  12. recordsType int
  13. MsgSet *MessageSet
  14. RecordBatch *RecordBatch
  15. }
  16. func newLegacyRecords(msgSet *MessageSet) Records {
  17. return Records{recordsType: legacyRecords, MsgSet: msgSet}
  18. }
  19. func newDefaultRecords(batch *RecordBatch) Records {
  20. return Records{recordsType: defaultRecords, RecordBatch: batch}
  21. }
  22. // setTypeFromFields sets type of Records depending on which of MsgSet or RecordBatch is not nil.
  23. // The first return value indicates whether both fields are nil (and the type is not set).
  24. // If both fields are not nil, it returns an error.
  25. func (r *Records) setTypeFromFields() (bool, error) {
  26. if r.MsgSet == nil && r.RecordBatch == nil {
  27. return true, nil
  28. }
  29. if r.MsgSet != nil && r.RecordBatch != nil {
  30. return false, fmt.Errorf("both MsgSet and RecordBatch are set, but record type is unknown")
  31. }
  32. r.recordsType = defaultRecords
  33. if r.MsgSet != nil {
  34. r.recordsType = legacyRecords
  35. }
  36. return false, nil
  37. }
  38. func (r *Records) encode(pe packetEncoder) error {
  39. if r.recordsType == unknownRecords {
  40. if empty, err := r.setTypeFromFields(); err != nil || empty {
  41. return err
  42. }
  43. }
  44. switch r.recordsType {
  45. case legacyRecords:
  46. if r.MsgSet == nil {
  47. return nil
  48. }
  49. return r.MsgSet.encode(pe)
  50. case defaultRecords:
  51. if r.RecordBatch == nil {
  52. return nil
  53. }
  54. return r.RecordBatch.encode(pe)
  55. }
  56. return fmt.Errorf("unknown records type: %v", r.recordsType)
  57. }
  58. func (r *Records) setTypeFromMagic(pd packetDecoder) error {
  59. magic, err := magicValue(pd)
  60. if err != nil {
  61. return err
  62. }
  63. r.recordsType = defaultRecords
  64. if magic < 2 {
  65. r.recordsType = legacyRecords
  66. }
  67. return nil
  68. }
  69. func (r *Records) decode(pd packetDecoder) error {
  70. if r.recordsType == unknownRecords {
  71. if err := r.setTypeFromMagic(pd); err != nil {
  72. return err
  73. }
  74. }
  75. switch r.recordsType {
  76. case legacyRecords:
  77. r.MsgSet = &MessageSet{}
  78. return r.MsgSet.decode(pd)
  79. case defaultRecords:
  80. r.RecordBatch = &RecordBatch{}
  81. return r.RecordBatch.decode(pd)
  82. }
  83. return fmt.Errorf("unknown records type: %v", r.recordsType)
  84. }
  85. func (r *Records) numRecords() (int, error) {
  86. if r.recordsType == unknownRecords {
  87. if empty, err := r.setTypeFromFields(); err != nil || empty {
  88. return 0, err
  89. }
  90. }
  91. switch r.recordsType {
  92. case legacyRecords:
  93. if r.MsgSet == nil {
  94. return 0, nil
  95. }
  96. return len(r.MsgSet.Messages), nil
  97. case defaultRecords:
  98. if r.RecordBatch == nil {
  99. return 0, nil
  100. }
  101. return len(r.RecordBatch.Records), nil
  102. }
  103. return 0, fmt.Errorf("unknown records type: %v", r.recordsType)
  104. }
  105. func (r *Records) isPartial() (bool, error) {
  106. if r.recordsType == unknownRecords {
  107. if empty, err := r.setTypeFromFields(); err != nil || empty {
  108. return false, err
  109. }
  110. }
  111. switch r.recordsType {
  112. case unknownRecords:
  113. return false, nil
  114. case legacyRecords:
  115. if r.MsgSet == nil {
  116. return false, nil
  117. }
  118. return r.MsgSet.PartialTrailingMessage, nil
  119. case defaultRecords:
  120. if r.RecordBatch == nil {
  121. return false, nil
  122. }
  123. return r.RecordBatch.PartialTrailingRecord, nil
  124. }
  125. return false, fmt.Errorf("unknown records type: %v", r.recordsType)
  126. }
  127. func (r *Records) isControl() (bool, error) {
  128. if r.recordsType == unknownRecords {
  129. if empty, err := r.setTypeFromFields(); err != nil || empty {
  130. return false, err
  131. }
  132. }
  133. switch r.recordsType {
  134. case legacyRecords:
  135. return false, nil
  136. case defaultRecords:
  137. if r.RecordBatch == nil {
  138. return false, nil
  139. }
  140. return r.RecordBatch.Control, nil
  141. }
  142. return false, fmt.Errorf("unknown records type: %v", r.recordsType)
  143. }
  144. func magicValue(pd packetDecoder) (int8, error) {
  145. dec, err := pd.peek(magicOffset, magicLength)
  146. if err != nil {
  147. return 0, err
  148. }
  149. return dec.getInt8()
  150. }