|
|
@@ -1,65 +1,166 @@
|
|
|
package sarama
|
|
|
|
|
|
+import "fmt"
|
|
|
+
|
|
|
const (
|
|
|
+ unknownRecords = iota
|
|
|
+ legacyRecords
|
|
|
+ defaultRecords
|
|
|
+
|
|
|
magicOffset = 16
|
|
|
magicLength = 1
|
|
|
)
|
|
|
|
|
|
// Records implements a union type containing either a RecordBatch or a legacy MessageSet.
|
|
|
type Records struct {
|
|
|
+ recordsType int
|
|
|
msgSet *MessageSet
|
|
|
recordBatch *RecordBatch
|
|
|
}
|
|
|
|
|
|
-func (c *Records) numRecords() int {
|
|
|
- if c.msgSet != nil {
|
|
|
- return len(c.msgSet.Messages)
|
|
|
- }
|
|
|
+func newLegacyRecords(msgSet *MessageSet) Records {
|
|
|
+ return Records{recordsType: legacyRecords, msgSet: msgSet}
|
|
|
+}
|
|
|
|
|
|
- if c.recordBatch != nil {
|
|
|
- return len(c.recordBatch.Records)
|
|
|
- }
|
|
|
+func newDefaultRecords(batch *RecordBatch) Records {
|
|
|
+ return Records{recordsType: defaultRecords, recordBatch: batch}
|
|
|
+}
|
|
|
|
|
|
- return 0
|
|
|
+// setTypeFromFields sets type of Records depending on which of msgSet or recordBatch is not nil.
|
|
|
+// The first return value indicates whether both fields are nil (and the type is not set).
|
|
|
+// If both fields are not nil, it returns an error.
|
|
|
+func (r *Records) setTypeFromFields() (bool, error) {
|
|
|
+ if r.msgSet == nil && r.recordBatch == nil {
|
|
|
+ return true, nil
|
|
|
+ }
|
|
|
+ if r.msgSet != nil && r.recordBatch != nil {
|
|
|
+ return false, fmt.Errorf("both msgSet and recordBatch are set, but record type is unknown")
|
|
|
+ }
|
|
|
+ r.recordsType = defaultRecords
|
|
|
+ if r.msgSet != nil {
|
|
|
+ r.recordsType = legacyRecords
|
|
|
+ }
|
|
|
+ return false, nil
|
|
|
}
|
|
|
|
|
|
-func (c *Records) isPartial() bool {
|
|
|
- if c.msgSet != nil {
|
|
|
- return c.msgSet.PartialTrailingMessage
|
|
|
+func (r *Records) encode(pe packetEncoder) error {
|
|
|
+ if r.recordsType == unknownRecords {
|
|
|
+ if empty, err := r.setTypeFromFields(); err != nil || empty {
|
|
|
+ return err
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- if c.recordBatch != nil {
|
|
|
- return c.recordBatch.PartialTrailingRecord
|
|
|
+ switch r.recordsType {
|
|
|
+ case legacyRecords:
|
|
|
+ if r.msgSet == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ return r.msgSet.encode(pe)
|
|
|
+ case defaultRecords:
|
|
|
+ if r.recordBatch == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ return r.recordBatch.encode(pe)
|
|
|
}
|
|
|
|
|
|
- return false
|
|
|
+ return fmt.Errorf("unknown records type: %v", r.recordsType)
|
|
|
}
|
|
|
|
|
|
-func (c *Records) decode(pd packetDecoder) (err error) {
|
|
|
+func (r *Records) setTypeFromMagic(pd packetDecoder) error {
|
|
|
magic, err := magicValue(pd)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
+ r.recordsType = defaultRecords
|
|
|
if magic < 2 {
|
|
|
- c.msgSet = &MessageSet{}
|
|
|
- return c.msgSet.decode(pd)
|
|
|
+ r.recordsType = legacyRecords
|
|
|
}
|
|
|
|
|
|
- c.recordBatch = &RecordBatch{}
|
|
|
- return c.recordBatch.decode(pd)
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
-func (c *Records) encode(pe packetEncoder) (err error) {
|
|
|
- if c.msgSet != nil {
|
|
|
- return c.msgSet.encode(pe)
|
|
|
+func (r *Records) decode(pd packetDecoder) error {
|
|
|
+ if r.recordsType == unknownRecords {
|
|
|
+ if err := r.setTypeFromMagic(pd); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- if c.recordBatch != nil {
|
|
|
- return c.recordBatch.encode(pe)
|
|
|
+ switch r.recordsType {
|
|
|
+ case legacyRecords:
|
|
|
+ r.msgSet = &MessageSet{}
|
|
|
+ return r.msgSet.decode(pd)
|
|
|
+ case defaultRecords:
|
|
|
+ r.recordBatch = &RecordBatch{}
|
|
|
+ return r.recordBatch.decode(pd)
|
|
|
}
|
|
|
+ return fmt.Errorf("unknown records type: %v", r.recordsType)
|
|
|
+}
|
|
|
|
|
|
- return nil
|
|
|
+func (r *Records) numRecords() (int, error) {
|
|
|
+ if r.recordsType == unknownRecords {
|
|
|
+ if empty, err := r.setTypeFromFields(); err != nil || empty {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ switch r.recordsType {
|
|
|
+ case legacyRecords:
|
|
|
+ if r.msgSet == nil {
|
|
|
+ return 0, nil
|
|
|
+ }
|
|
|
+ return len(r.msgSet.Messages), nil
|
|
|
+ case defaultRecords:
|
|
|
+ if r.recordBatch == nil {
|
|
|
+ return 0, nil
|
|
|
+ }
|
|
|
+ return len(r.recordBatch.Records), nil
|
|
|
+ }
|
|
|
+ return 0, fmt.Errorf("unknown records type: %v", r.recordsType)
|
|
|
+}
|
|
|
+
|
|
|
+func (r *Records) isPartial() (bool, error) {
|
|
|
+ if r.recordsType == unknownRecords {
|
|
|
+ if empty, err := r.setTypeFromFields(); err != nil || empty {
|
|
|
+ return false, err
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ switch r.recordsType {
|
|
|
+ case unknownRecords:
|
|
|
+ return false, nil
|
|
|
+ case legacyRecords:
|
|
|
+ if r.msgSet == nil {
|
|
|
+ return false, nil
|
|
|
+ }
|
|
|
+ return r.msgSet.PartialTrailingMessage, nil
|
|
|
+ case defaultRecords:
|
|
|
+ if r.recordBatch == nil {
|
|
|
+ return false, nil
|
|
|
+ }
|
|
|
+ return r.recordBatch.PartialTrailingRecord, nil
|
|
|
+ }
|
|
|
+ return false, fmt.Errorf("unknown records type: %v", r.recordsType)
|
|
|
+}
|
|
|
+
|
|
|
+func (r *Records) isControl() (bool, error) {
|
|
|
+ if r.recordsType == unknownRecords {
|
|
|
+ if empty, err := r.setTypeFromFields(); err != nil || empty {
|
|
|
+ return false, err
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ switch r.recordsType {
|
|
|
+ case legacyRecords:
|
|
|
+ return false, nil
|
|
|
+ case defaultRecords:
|
|
|
+ if r.recordBatch == nil {
|
|
|
+ return false, nil
|
|
|
+ }
|
|
|
+ return r.recordBatch.Control, nil
|
|
|
+ }
|
|
|
+ return false, fmt.Errorf("unknown records type: %v", r.recordsType)
|
|
|
}
|
|
|
|
|
|
func magicValue(pd packetDecoder) (int8, error) {
|