Selaa lähdekoodia

Implement a sum type that can hold RecordBatch or MessageSet

Many request/response structures can contain either RecordBatches or
MessageSets depending on the version of Kafka the client is talking to.
This changeset implements a sum type that makes it more convenient to
work with these structures by abstracting away the type of the records.
Vlad Hanciuta 8 vuotta sitten
vanhempi
commit
e1067e3e2d
2 muutettua tiedostoa jossa 201 lisäystä ja 0 poistoa
  1. 96 0
      records.go
  2. 105 0
      records_test.go

+ 96 - 0
records.go

@@ -0,0 +1,96 @@
+package sarama
+
+import "fmt"
+
+const (
+	legacyRecords = iota
+	defaultRecords
+)
+
+// Records implements a union type containing either a RecordBatch or a legacy MessageSet.
+type Records struct {
+	recordsType int
+	msgSet      *MessageSet
+	recordBatch *RecordBatch
+}
+
+func newLegacyRecords(msgSet *MessageSet) Records {
+	return Records{recordsType: legacyRecords, msgSet: msgSet}
+}
+
+func newDefaultRecords(batch *RecordBatch) Records {
+	return Records{recordsType: defaultRecords, recordBatch: batch}
+}
+
+func (r *Records) encode(pe packetEncoder) error {
+	switch r.recordsType {
+	case legacyRecords:
+		if r.msgSet == nil {
+			return nil
+		}
+		return r.msgSet.encode(pe)
+	case defaultRecords:
+		if r.recordBatch == nil {
+			return nil
+		}
+		return r.recordBatch.encode(pe)
+	}
+	return fmt.Errorf("unknown records type: %v", r.recordsType)
+}
+
+func (r *Records) decode(pd packetDecoder) error {
+	switch r.recordsType {
+	case legacyRecords:
+		r.msgSet = &MessageSet{}
+		return r.msgSet.decode(pd)
+	case defaultRecords:
+		r.recordBatch = &RecordBatch{}
+		return r.recordBatch.decode(pd)
+	}
+	return fmt.Errorf("unknown records type: %v", r.recordsType)
+}
+
+func (r *Records) numRecords() (int, error) {
+	switch r.recordsType {
+	case legacyRecords:
+		if r.msgSet == nil {
+			return 0, nil
+		}
+		return len(r.msgSet.Messages), nil
+	case defaultRecords:
+		if r.recordBatch == nil {
+			return 0, nil
+		}
+		return len(r.recordBatch.Records), nil
+	}
+	return 0, fmt.Errorf("unknown records type: %v", r.recordsType)
+}
+
+func (r *Records) isPartial() (bool, error) {
+	switch r.recordsType {
+	case legacyRecords:
+		if r.msgSet == nil {
+			return false, nil
+		}
+		return r.msgSet.PartialTrailingMessage, nil
+	case defaultRecords:
+		if r.recordBatch == nil {
+			return false, nil
+		}
+		return r.recordBatch.PartialTrailingRecord, nil
+	}
+	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
+}
+
+func (r *Records) isControl() (bool, error) {
+	switch r.recordsType {
+	case legacyRecords:
+		return false, nil
+	case defaultRecords:
+		if r.recordBatch == nil {
+			return false, nil
+		}
+		return r.recordBatch.Control, nil
+	}
+	return false, fmt.Errorf("unknown records type: %v", r.recordsType)
+}

+ 105 - 0
records_test.go

@@ -0,0 +1,105 @@
+package sarama
+
+import (
+	"bytes"
+	"reflect"
+	"testing"
+)
+
+func TestLegacyRecords(t *testing.T) {
+	set := &MessageSet{
+		Messages: []*MessageBlock{
+			{
+				Msg: &Message{
+					Version: 1,
+				},
+			},
+		},
+	}
+	r := newLegacyRecords(set)
+
+	exp, err := encode(set, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	buf, err := encode(&r, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if !bytes.Equal(buf, exp) {
+		t.Errorf("Wrong encoding for legacy records, wanted %v, got %v", exp, buf)
+	}
+
+	set = &MessageSet{}
+	r = newLegacyRecords(nil)
+
+	err = decode(exp, set)
+	if err != nil {
+		t.Fatal(err)
+	}
+	err = decode(buf, &r)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if !reflect.DeepEqual(set, r.msgSet) {
+		t.Errorf("Wrong decoding for legacy records, wanted %#+v, got %#+v", set, r.msgSet)
+	}
+
+	n, err := r.numRecords()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if n != 1 {
+		t.Errorf("Wrong number of records, wanted 1, got %d", n)
+	}
+}
+
+func TestDefaultRecords(t *testing.T) {
+	batch := &RecordBatch{
+		Version: 2,
+		Records: []*Record{
+			{
+				Value: []byte{1},
+			},
+		},
+	}
+
+	r := newDefaultRecords(batch)
+
+	exp, err := encode(batch, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	buf, err := encode(&r, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if !bytes.Equal(buf, exp) {
+		t.Errorf("Wrong encoding for default records, wanted %v, got %v", exp, buf)
+	}
+
+	batch = &RecordBatch{}
+	r = newDefaultRecords(nil)
+
+	err = decode(exp, batch)
+	if err != nil {
+		t.Fatal(err)
+	}
+	err = decode(buf, &r)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if !reflect.DeepEqual(batch, r.recordBatch) {
+		t.Errorf("Wrong decoding for default records, wanted %#+v, got %#+v", batch, r.recordBatch)
+	}
+
+	n, err := r.numRecords()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if n != 1 {
+		t.Errorf("Wrong number of records, wanted 1, got %d", n)
+	}
+}