Browse Source

Allow to set the custom payload (#1182)

* Allow to set the custom payload

* Review comments
Jaume Marhuenda 7 years ago
parent
commit
4563d9e75d
5 changed files with 141 additions and 21 deletions
  1. 58 0
      cassandra_test.go
  2. 7 4
      conn.go
  3. 65 16
      frame.go
  4. 2 0
      integration.sh
  5. 9 1
      session.go

+ 58 - 0
cassandra_test.go

@@ -2422,6 +2422,64 @@ func TestTokenAwareConnPool(t *testing.T) {
 	// TODO add verification that the query went to the correct host
 	// TODO add verification that the query went to the correct host
 }
 }
 
 
+func TestCustomPayloadMessages(t *testing.T) {
+	cluster := createCluster()
+	session := createSessionFromCluster(cluster, t)
+	defer session.Close()
+
+	if err := createTable(session, "CREATE TABLE gocql_test.testCustomPayloadMessages (id int, value int, PRIMARY KEY (id))"); err != nil {
+		t.Fatal(err)
+	}
+
+	// QueryMessage
+	var customPayload = map[string][]byte{"a": []byte{10, 20}, "b": []byte{20, 30}}
+	query := session.Query("SELECT id FROM testCustomPayloadMessages where id = ?", 42).Consistency(One).CustomPayload(customPayload)
+	iter := query.Iter()
+	rCustomPayload := iter.GetCustomPayload()
+	if !reflect.DeepEqual(customPayload, rCustomPayload) {
+		t.Fatal("The received custom payload should match the sent")
+	}
+	iter.Close()
+
+	// Insert query
+	query = session.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)").Consistency(One).CustomPayload(customPayload)
+	iter = query.Iter()
+	rCustomPayload = iter.GetCustomPayload()
+	if !reflect.DeepEqual(customPayload, rCustomPayload) {
+		t.Fatal("The received custom payload should match the sent")
+	}
+	iter.Close()
+
+	// Batch Message
+	b := session.NewBatch(LoggedBatch)
+	b.CustomPayload = customPayload
+	b.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)")
+	if err := session.ExecuteBatch(b); err != nil {
+		t.Fatalf("query failed. %v", err)
+	}
+}
+
+func TestCustomPayloadValues(t *testing.T) {
+	cluster := createCluster()
+	session := createSessionFromCluster(cluster, t)
+	defer session.Close()
+
+	if err := createTable(session, "CREATE TABLE gocql_test.testCustomPayloadValues (id int, value int, PRIMARY KEY (id))"); err != nil {
+		t.Fatal(err)
+	}
+
+	values := []map[string][]byte{map[string][]byte{"a": []byte{10, 20}, "b": []byte{20, 30}}, nil, map[string][]byte{"a": []byte{10, 20}, "b": nil}}
+
+	for _, customPayload := range values {
+		query := session.Query("SELECT id FROM testCustomPayloadValues where id = ?", 42).Consistency(One).CustomPayload(customPayload)
+		iter := query.Iter()
+		rCustomPayload := iter.GetCustomPayload()
+		if !reflect.DeepEqual(customPayload, rCustomPayload) {
+			t.Fatal("The received custom payload should match the sent")
+		}
+	}
+}
+
 func TestNegativeStream(t *testing.T) {
 func TestNegativeStream(t *testing.T) {
 	session := createSession(t)
 	session := createSession(t)
 	defer session.Close()
 	defer session.Close()

+ 7 - 4
conn.go

@@ -946,13 +946,15 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata)
 		params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata)
 
 
 		frame = &writeExecuteFrame{
 		frame = &writeExecuteFrame{
-			preparedID: info.id,
-			params:     params,
+			preparedID:    info.id,
+			params:        params,
+			customPayload: qry.customPayload,
 		}
 		}
 	} else {
 	} else {
 		frame = &writeQueryFrame{
 		frame = &writeQueryFrame{
-			statement: qry.stmt,
-			params:    params,
+			statement:     qry.stmt,
+			params:        params,
+			customPayload: qry.customPayload,
 		}
 		}
 	}
 	}
 
 
@@ -1093,6 +1095,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 		serialConsistency:     batch.serialCons,
 		serialConsistency:     batch.serialCons,
 		defaultTimestamp:      batch.defaultTimestamp,
 		defaultTimestamp:      batch.defaultTimestamp,
 		defaultTimestampValue: batch.defaultTimestampValue,
 		defaultTimestampValue: batch.defaultTimestampValue,
+		customPayload:         batch.CustomPayload,
 	}
 	}
 
 
 	stmts := make(map[string]string, len(batch.Entries))
 	stmts := make(map[string]string, len(batch.Entries))

+ 65 - 16
frame.go

@@ -332,13 +332,12 @@ func readShort(p []byte) uint16 {
 }
 }
 
 
 type frameHeader struct {
 type frameHeader struct {
-	version       protoVersion
-	flags         byte
-	stream        int
-	op            frameOp
-	length        int
-	customPayload map[string][]byte
-	warnings      []string
+	version  protoVersion
+	flags    byte
+	stream   int
+	op       frameOp
+	length   int
+	warnings []string
 }
 }
 
 
 func (f frameHeader) String() string {
 func (f frameHeader) String() string {
@@ -398,6 +397,8 @@ type framer struct {
 
 
 	rbuf []byte
 	rbuf []byte
 	wbuf []byte
 	wbuf []byte
+
+	customPayload map[string][]byte
 }
 }
 
 
 func newFramer(r io.Reader, w io.Writer, compressor Compressor, version byte) *framer {
 func newFramer(r io.Reader, w io.Writer, compressor Compressor, version byte) *framer {
@@ -494,6 +495,11 @@ func (f *framer) trace() {
 	f.flags |= flagTracing
 	f.flags |= flagTracing
 }
 }
 
 
+// explicitly enables the custom payload flag
+func (f *framer) payload() {
+	f.flags |= flagCustomPayload
+}
+
 // reads a frame form the wire into the framers buffer
 // reads a frame form the wire into the framers buffer
 func (f *framer) readFrame(head *frameHeader) error {
 func (f *framer) readFrame(head *frameHeader) error {
 	if head.length < 0 {
 	if head.length < 0 {
@@ -558,7 +564,7 @@ func (f *framer) parseFrame() (frame frame, err error) {
 	}
 	}
 
 
 	if f.header.flags&flagCustomPayload == flagCustomPayload {
 	if f.header.flags&flagCustomPayload == flagCustomPayload {
-		f.header.customPayload = f.readBytesMap()
+		f.customPayload = f.readBytesMap()
 	}
 	}
 
 
 	// assumes that the frame body has been read into rbuf
 	// assumes that the frame body has been read into rbuf
@@ -827,12 +833,17 @@ func (w *writeStartupFrame) writeFrame(f *framer, streamID int) error {
 }
 }
 
 
 type writePrepareFrame struct {
 type writePrepareFrame struct {
-	statement string
-	keyspace  string
+	statement     string
+	keyspace      string
+	customPayload map[string][]byte
 }
 }
 
 
 func (w *writePrepareFrame) writeFrame(f *framer, streamID int) error {
 func (w *writePrepareFrame) writeFrame(f *framer, streamID int) error {
+	if len(w.customPayload) > 0 {
+		f.payload()
+	}
 	f.writeHeader(f.flags, opPrepare, streamID)
 	f.writeHeader(f.flags, opPrepare, streamID)
+	f.writeCustomPayload(&w.customPayload)
 	f.writeLongString(w.statement)
 	f.writeLongString(w.statement)
 
 
 	var flags uint32 = 0
 	var flags uint32 = 0
@@ -1540,6 +1551,9 @@ func (f *framer) writeQueryParams(opts *queryParams) {
 type writeQueryFrame struct {
 type writeQueryFrame struct {
 	statement string
 	statement string
 	params    queryParams
 	params    queryParams
+
+	// v4+
+	customPayload map[string][]byte
 }
 }
 
 
 func (w *writeQueryFrame) String() string {
 func (w *writeQueryFrame) String() string {
@@ -1547,11 +1561,15 @@ func (w *writeQueryFrame) String() string {
 }
 }
 
 
 func (w *writeQueryFrame) writeFrame(framer *framer, streamID int) error {
 func (w *writeQueryFrame) writeFrame(framer *framer, streamID int) error {
-	return framer.writeQueryFrame(streamID, w.statement, &w.params)
+	return framer.writeQueryFrame(streamID, w.statement, &w.params, w.customPayload)
 }
 }
 
 
-func (f *framer) writeQueryFrame(streamID int, statement string, params *queryParams) error {
+func (f *framer) writeQueryFrame(streamID int, statement string, params *queryParams, customPayload map[string][]byte) error {
+	if len(customPayload) > 0 {
+		f.payload()
+	}
 	f.writeHeader(f.flags, opQuery, streamID)
 	f.writeHeader(f.flags, opQuery, streamID)
+	f.writeCustomPayload(&customPayload)
 	f.writeLongString(statement)
 	f.writeLongString(statement)
 	f.writeQueryParams(params)
 	f.writeQueryParams(params)
 
 
@@ -1571,6 +1589,9 @@ func (f frameWriterFunc) writeFrame(framer *framer, streamID int) error {
 type writeExecuteFrame struct {
 type writeExecuteFrame struct {
 	preparedID []byte
 	preparedID []byte
 	params     queryParams
 	params     queryParams
+
+	// v4+
+	customPayload map[string][]byte
 }
 }
 
 
 func (e *writeExecuteFrame) String() string {
 func (e *writeExecuteFrame) String() string {
@@ -1578,11 +1599,15 @@ func (e *writeExecuteFrame) String() string {
 }
 }
 
 
 func (e *writeExecuteFrame) writeFrame(fr *framer, streamID int) error {
 func (e *writeExecuteFrame) writeFrame(fr *framer, streamID int) error {
-	return fr.writeExecuteFrame(streamID, e.preparedID, &e.params)
+	return fr.writeExecuteFrame(streamID, e.preparedID, &e.params, &e.customPayload)
 }
 }
 
 
-func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *queryParams) error {
+func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *queryParams, customPayload *map[string][]byte) error {
+	if len(*customPayload) > 0 {
+		f.payload()
+	}
 	f.writeHeader(f.flags, opExecute, streamID)
 	f.writeHeader(f.flags, opExecute, streamID)
+	f.writeCustomPayload(customPayload)
 	f.writeShortBytes(preparedID)
 	f.writeShortBytes(preparedID)
 	if f.proto > protoVersion1 {
 	if f.proto > protoVersion1 {
 		f.writeQueryParams(params)
 		f.writeQueryParams(params)
@@ -1619,14 +1644,21 @@ type writeBatchFrame struct {
 	serialConsistency     SerialConsistency
 	serialConsistency     SerialConsistency
 	defaultTimestamp      bool
 	defaultTimestamp      bool
 	defaultTimestampValue int64
 	defaultTimestampValue int64
+
+	//v4+
+	customPayload map[string][]byte
 }
 }
 
 
 func (w *writeBatchFrame) writeFrame(framer *framer, streamID int) error {
 func (w *writeBatchFrame) writeFrame(framer *framer, streamID int) error {
-	return framer.writeBatchFrame(streamID, w)
+	return framer.writeBatchFrame(streamID, w, w.customPayload)
 }
 }
 
 
-func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame) error {
+func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload map[string][]byte) error {
+	if len(customPayload) > 0 {
+		f.payload()
+	}
 	f.writeHeader(f.flags, opBatch, streamID)
 	f.writeHeader(f.flags, opBatch, streamID)
+	f.writeCustomPayload(&customPayload)
 	f.writeByte(byte(w.typ))
 	f.writeByte(byte(w.typ))
 
 
 	n := len(w.statements)
 	n := len(w.statements)
@@ -1962,6 +1994,15 @@ func appendLong(p []byte, n int64) []byte {
 	)
 	)
 }
 }
 
 
+func (f *framer) writeCustomPayload(customPayload *map[string][]byte) {
+	if len(*customPayload) > 0 {
+		if f.proto < protoVersion4 {
+			panic("Custom payload is not supported with version V3 or less")
+		}
+		f.writeBytesMap(*customPayload)
+	}
+}
+
 // these are protocol level binary types
 // these are protocol level binary types
 func (f *framer) writeInt(n int32) {
 func (f *framer) writeInt(n int32) {
 	f.wbuf = appendInt(f.wbuf, n)
 	f.wbuf = appendInt(f.wbuf, n)
@@ -2048,3 +2089,11 @@ func (f *framer) writeStringMap(m map[string]string) {
 		f.writeString(v)
 		f.writeString(v)
 	}
 	}
 }
 }
+
+func (f *framer) writeBytesMap(m map[string][]byte) {
+	f.writeShort(uint16(len(m)))
+	for k, v := range m {
+		f.writeString(k)
+		f.writeBytes(v)
+	}
+}

+ 2 - 0
integration.sh

@@ -50,9 +50,11 @@ function run_tests() {
 	elif [[ $version == 2.2.* || $version == 3.0.* ]]; then
 	elif [[ $version == 2.2.* || $version == 3.0.* ]]; then
 		proto=4
 		proto=4
 		ccm updateconf 'enable_user_defined_functions: true'
 		ccm updateconf 'enable_user_defined_functions: true'
+		export JVM_EXTRA_OPTS=" -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler"
 	elif [[ $version == 3.*.* ]]; then
 	elif [[ $version == 3.*.* ]]; then
 		proto=5
 		proto=5
 		ccm updateconf 'enable_user_defined_functions: true'
 		ccm updateconf 'enable_user_defined_functions: true'
+		export JVM_EXTRA_OPTS=" -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler"
 	fi
 	fi
 
 
 	sleep 1s
 	sleep 1s

+ 9 - 1
session.go

@@ -686,6 +686,7 @@ type Query struct {
 	cancelQuery           func()
 	cancelQuery           func()
 	idempotent            bool
 	idempotent            bool
 	metrics               map[string]*queryMetrics
 	metrics               map[string]*queryMetrics
+	customPayload         map[string][]byte
 
 
 	disableAutoPage bool
 	disableAutoPage bool
 }
 }
@@ -774,6 +775,12 @@ func (q *Query) SetConsistency(c Consistency) {
 	q.cons = c
 	q.cons = c
 }
 }
 
 
+// CustomPayload sets the custom payload level for this query.
+func (q *Query) CustomPayload(customPayload map[string][]byte) *Query {
+	q.customPayload = customPayload
+	return q
+}
+
 // Trace enables tracing of this query. Look at the documentation of the
 // Trace enables tracing of this query. Look at the documentation of the
 // Tracer interface to learn more about tracing.
 // Tracer interface to learn more about tracing.
 func (q *Query) Trace(trace Tracer) *Query {
 func (q *Query) Trace(trace Tracer) *Query {
@@ -1344,7 +1351,7 @@ func (iter *Iter) Scan(dest ...interface{}) bool {
 // custom QueryHandlers running in your C* cluster.
 // custom QueryHandlers running in your C* cluster.
 // See https://datastax.github.io/java-driver/manual/custom_payloads/
 // See https://datastax.github.io/java-driver/manual/custom_payloads/
 func (iter *Iter) GetCustomPayload() map[string][]byte {
 func (iter *Iter) GetCustomPayload() map[string][]byte {
-	return iter.framer.header.customPayload
+	return iter.framer.customPayload
 }
 }
 
 
 // Warnings returns any warnings generated if given in the response from Cassandra.
 // Warnings returns any warnings generated if given in the response from Cassandra.
@@ -1422,6 +1429,7 @@ type Batch struct {
 	Type                  BatchType
 	Type                  BatchType
 	Entries               []BatchEntry
 	Entries               []BatchEntry
 	Cons                  Consistency
 	Cons                  Consistency
+	CustomPayload         map[string][]byte
 	rt                    RetryPolicy
 	rt                    RetryPolicy
 	observer              BatchObserver
 	observer              BatchObserver
 	serialCons            SerialConsistency
 	serialCons            SerialConsistency