Explorar el Código

Support for UnSet columns (#903)

* Support for UnSet columns

- Adding support for the UNSET_VALUE in gocql

References:
- Cassandra: https://issues.apache.org/jira/browse/CASSANDRA-7304
- gocql/gocql#861

>     Protocol version 4 specifies that bind variables do not require having a
>     value when executing a statement. Bind variables without a value are
>     called 'unset'. The 'unset' bind variable is serialized as the int
>     value '-2' without following bytes.

* Adding tests

* Skipping tests if not in protocol v4

* Code review changes

- Moved Unmarshal to individual unmarshal functions
- Added error handling for tuples and UDTs
Michael Highstead hace 8 años
padre
commit
25729d9d26
Se han modificado 4 ficheros con 148 adiciones y 8 borrados
  1. 71 0
      cassandra_test.go
  2. 6 0
      conn.go
  3. 31 4
      frame.go
  4. 40 4
      marshal.go

+ 71 - 0
cassandra_test.go

@@ -2586,3 +2586,74 @@ func TestControl_DiscoverProtocol(t *testing.T) {
 		t.Fatal("did not discovery protocol")
 		t.Fatal("did not discovery protocol")
 	}
 	}
 }
 }
+
+// TestUnsetCol verify unset column will not replace an existing column
+func TestUnsetCol(t *testing.T) {
+	if *flagProto < 4 {
+		t.Skip("Unset Values are not supported in protocol < 4")
+	}
+	session := createSession(t)
+	defer session.Close()
+
+	if err := createTable(session, "CREATE TABLE gocql_test.testUnsetInsert (id int, my_int int, my_text text, PRIMARY KEY (id))"); err != nil {
+		t.Fatalf("failed to create table with error '%v'", err)
+	}
+	if err := session.Query("INSERT INTO testUnSetInsert (id,my_int,my_text) VALUES (?,?,?)", 1, 2, "3").Exec(); err != nil {
+		t.Fatalf("failed to insert with err: %v", err)
+	}
+	if err := session.Query("INSERT INTO testUnSetInsert (id,my_int,my_text) VALUES (?,?,?)", 1, UnsetValue, UnsetValue).Exec(); err != nil {
+		t.Fatalf("failed to insert with err: %v", err)
+	}
+
+	var id, mInt int
+	var mText string
+
+	if err := session.Query("SELECT id, my_int ,my_text FROM testUnsetInsert").Scan(&id, &mInt, &mText); err != nil {
+		t.Fatalf("failed to select with err: %v", err)
+	} else if id != 1 || mInt != 2 || mText != "3" {
+		t.Fatalf("Expected results: 1, 2, \"3\", got %v, %v, %v", id, mInt, mText)
+	}
+}
+
+// TestUnsetColBatch verify unset column will not replace a column in batch
+func TestUnsetColBatch(t *testing.T) {
+	if *flagProto < 4 {
+		t.Skip("Unset Values are not supported in protocol < 4")
+	}
+	session := createSession(t)
+	defer session.Close()
+
+	if err := createTable(session, "CREATE TABLE gocql_test.batchUnsetInsert (id int, my_int int, my_text text, PRIMARY KEY (id))"); err != nil {
+		t.Fatalf("failed to create table with error '%v'", err)
+	}
+
+	b := session.NewBatch(LoggedBatch)
+	b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, 1, UnsetValue)
+	b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, UnsetValue, "")
+	b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 2, 2, UnsetValue)
+
+	if err := session.ExecuteBatch(b); err != nil {
+		t.Fatalf("query failed. %v", err)
+	} else {
+		if b.Attempts() < 1 {
+			t.Fatal("expected at least 1 attempt, but got 0")
+		}
+		if b.Latency() <= 0 {
+			t.Fatalf("expected latency to be greater than 0, but got %v instead.", b.Latency())
+		}
+	}
+	var id, mInt, count int
+	var mText string
+
+	if err := session.Query("SELECT count(*) FROM gocql_test.batchUnsetInsert;").Scan(&count); err != nil {
+		t.Fatalf("Failed to select with err: %v", err)
+	} else if count != 2 {
+		t.Fatalf("Expected Batch Insert count 2, got %v", count)
+	}
+
+	if err := session.Query("SELECT id, my_int ,my_text FROM gocql_test.batchUnsetInsert where id=1;").Scan(&id, &mInt, &mText); err != nil {
+		t.Fatalf("failed to select with err: %v", err)
+	} else if id != mInt {
+		t.Fatalf("expected id, my_int to be 1, got %v and %v", id, mInt)
+	}
+}

+ 6 - 0
conn.go

@@ -816,6 +816,9 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 
 
 			v := &params.values[i]
 			v := &params.values[i]
 			v.value = val
 			v.value = val
+			if _, ok := values[i].(unsetColumn); ok {
+				v.isUnset = true
+			}
 			// TODO: handle query binding names
 			// TODO: handle query binding names
 		}
 		}
 
 
@@ -1012,6 +1015,9 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 				}
 				}
 
 
 				b.values[j].value = val
 				b.values[j].value = val
+				if _, ok := values[j].(unsetColumn); ok {
+					b.values[j].isUnset = true
+				}
 				// TODO: add names
 				// TODO: add names
 			}
 			}
 		} else {
 		} else {

+ 31 - 4
frame.go

@@ -16,6 +16,10 @@ import (
 	"time"
 	"time"
 )
 )
 
 
+type unsetColumn struct{}
+
+var UnsetValue = unsetColumn{}
+
 const (
 const (
 	protoDirectionMask = 0x80
 	protoDirectionMask = 0x80
 	protoVersionMask   = 0x7F
 	protoVersionMask   = 0x7F
@@ -1255,8 +1259,10 @@ func (f *framer) writeAuthResponseFrame(streamID int, data []byte) error {
 
 
 type queryValues struct {
 type queryValues struct {
 	value []byte
 	value []byte
+
 	// optional name, will set With names for values flag
 	// optional name, will set With names for values flag
-	name string
+	name    string
+	isUnset bool
 }
 }
 
 
 type queryParams struct {
 type queryParams struct {
@@ -1319,11 +1325,16 @@ func (f *framer) writeQueryParams(opts *queryParams) {
 
 
 	if n := len(opts.values); n > 0 {
 	if n := len(opts.values); n > 0 {
 		f.writeShort(uint16(n))
 		f.writeShort(uint16(n))
+
 		for i := 0; i < n; i++ {
 		for i := 0; i < n; i++ {
 			if names {
 			if names {
 				f.writeString(opts.values[i].name)
 				f.writeString(opts.values[i].name)
 			}
 			}
-			f.writeBytes(opts.values[i].value)
+			if opts.values[i].isUnset {
+				f.writeUnset()
+			} else {
+				f.writeBytes(opts.values[i].value)
+			}
 		}
 		}
 	}
 	}
 
 
@@ -1404,7 +1415,11 @@ func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *quer
 		n := len(params.values)
 		n := len(params.values)
 		f.writeShort(uint16(n))
 		f.writeShort(uint16(n))
 		for i := 0; i < n; i++ {
 		for i := 0; i < n; i++ {
-			f.writeBytes(params.values[i].value)
+			if params.values[i].isUnset {
+				f.writeUnset()
+			} else {
+				f.writeBytes(params.values[i].value)
+			}
 		}
 		}
 		f.writeConsistency(params.consistency)
 		f.writeConsistency(params.consistency)
 	}
 	}
@@ -1463,7 +1478,11 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame) error {
 				flags |= flagWithNameValues
 				flags |= flagWithNameValues
 				f.writeString(col.name)
 				f.writeString(col.name)
 			}
 			}
-			f.writeBytes(col.value)
+			if col.isUnset {
+				f.writeUnset()
+			} else {
+				f.writeBytes(col.value)
+			}
 		}
 		}
 	}
 	}
 
 
@@ -1786,6 +1805,14 @@ func (f *framer) writeStringList(l []string) {
 	}
 	}
 }
 }
 
 
+func (f *framer) writeUnset() {
+	// Protocol version 4 specifies that bind variables do not require having a
+	// value when executing a statement.   Bind variables without a value are
+	// called 'unset'. The 'unset' bind variable is serialized as the int
+	// value '-2' without following bytes.
+	f.writeInt(-2)
+}
+
 func (f *framer) writeBytes(p []byte) {
 func (f *framer) writeBytes(p []byte) {
 	// TODO: handle null case correctly,
 	// TODO: handle null case correctly,
 	//     [bytes]        A [int] n, followed by n bytes if n >= 0. If n < 0,
 	//     [bytes]        A [int] n, followed by n bytes if n >= 0. If n < 0,

+ 40 - 4
marshal.go

@@ -197,6 +197,8 @@ func marshalVarchar(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	switch v := value.(type) {
 	case Marshaler:
 	case Marshaler:
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
+	case unsetColumn:
+		return nil, nil
 	case string:
 	case string:
 		return []byte(v), nil
 		return []byte(v), nil
 	case []byte:
 	case []byte:
@@ -262,6 +264,8 @@ func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	switch v := value.(type) {
 	case Marshaler:
 	case Marshaler:
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
+	case unsetColumn:
+		return nil, nil
 	case int16:
 	case int16:
 		return encShort(v), nil
 		return encShort(v), nil
 	case uint16:
 	case uint16:
@@ -338,6 +342,8 @@ func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	switch v := value.(type) {
 	case Marshaler:
 	case Marshaler:
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
+	case unsetColumn:
+		return nil, nil
 	case int8:
 	case int8:
 		return []byte{byte(v)}, nil
 		return []byte{byte(v)}, nil
 	case uint8:
 	case uint8:
@@ -420,6 +426,8 @@ func marshalInt(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	switch v := value.(type) {
 	case Marshaler:
 	case Marshaler:
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
+	case unsetColumn:
+		return nil, nil
 	case int:
 	case int:
 		if v > math.MaxInt32 || v < math.MinInt32 {
 		if v > math.MaxInt32 || v < math.MinInt32 {
 			return nil, marshalErrorf("marshal int: value %d out of range", v)
 			return nil, marshalErrorf("marshal int: value %d out of range", v)
@@ -522,6 +530,8 @@ func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	switch v := value.(type) {
 	case Marshaler:
 	case Marshaler:
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
+	case unsetColumn:
+		return nil, nil
 	case int:
 	case int:
 		return encBigInt(int64(v)), nil
 		return encBigInt(int64(v)), nil
 	case uint:
 	case uint:
@@ -638,6 +648,8 @@ func marshalVarint(info TypeInfo, value interface{}) ([]byte, error) {
 	)
 	)
 
 
 	switch v := value.(type) {
 	switch v := value.(type) {
+	case unsetColumn:
+		return nil, nil
 	case uint64:
 	case uint64:
 		if v > uint64(math.MaxInt64) {
 		if v > uint64(math.MaxInt64) {
 			retBytes = make([]byte, 9)
 			retBytes = make([]byte, 9)
@@ -857,6 +869,8 @@ func marshalBool(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	switch v := value.(type) {
 	case Marshaler:
 	case Marshaler:
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
+	case unsetColumn:
+		return nil, nil
 	case bool:
 	case bool:
 		return encBool(v), nil
 		return encBool(v), nil
 	}
 	}
@@ -912,6 +926,8 @@ func marshalFloat(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	switch v := value.(type) {
 	case Marshaler:
 	case Marshaler:
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
+	case unsetColumn:
+		return nil, nil
 	case float32:
 	case float32:
 		return encInt(int32(math.Float32bits(v))), nil
 		return encInt(int32(math.Float32bits(v))), nil
 	}
 	}
@@ -953,6 +969,8 @@ func marshalDouble(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	switch v := value.(type) {
 	case Marshaler:
 	case Marshaler:
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
+	case unsetColumn:
+		return nil, nil
 	case float64:
 	case float64:
 		return encBigInt(int64(math.Float64bits(v))), nil
 		return encBigInt(int64(math.Float64bits(v))), nil
 	}
 	}
@@ -996,6 +1014,8 @@ func marshalDecimal(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	switch v := value.(type) {
 	case Marshaler:
 	case Marshaler:
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
+	case unsetColumn:
+		return nil, nil
 	case inf.Dec:
 	case inf.Dec:
 		unscaled := encBigInt2C(v.UnscaledBig())
 		unscaled := encBigInt2C(v.UnscaledBig())
 		if unscaled == nil {
 		if unscaled == nil {
@@ -1067,6 +1087,8 @@ func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	switch v := value.(type) {
 	case Marshaler:
 	case Marshaler:
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
+	case unsetColumn:
+		return nil, nil
 	case int64:
 	case int64:
 		return encBigInt(v), nil
 		return encBigInt(v), nil
 	case time.Time:
 	case time.Time:
@@ -1125,23 +1147,25 @@ func marshalDate(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	switch v := value.(type) {
 	case Marshaler:
 	case Marshaler:
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
+	case unsetColumn:
+		return nil, nil
 	case int64:
 	case int64:
 		timestamp = v
 		timestamp = v
-		x := timestamp/86400000 + int64(1 << 31)
+		x := timestamp/86400000 + int64(1<<31)
 		return encInt(int32(x)), nil
 		return encInt(int32(x)), nil
 	case time.Time:
 	case time.Time:
 		if v.IsZero() {
 		if v.IsZero() {
 			return []byte{}, nil
 			return []byte{}, nil
 		}
 		}
 		timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6)
 		timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6)
-		x := timestamp/86400000 + int64(1 << 31)
+		x := timestamp/86400000 + int64(1<<31)
 		return encInt(int32(x)), nil
 		return encInt(int32(x)), nil
 	case *time.Time:
 	case *time.Time:
 		if v.IsZero() {
 		if v.IsZero() {
 			return []byte{}, nil
 			return []byte{}, nil
 		}
 		}
 		timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6)
 		timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6)
-		x := timestamp/86400000 + int64(1 << 31)
+		x := timestamp/86400000 + int64(1<<31)
 		return encInt(int32(x)), nil
 		return encInt(int32(x)), nil
 	case string:
 	case string:
 		if v == "" {
 		if v == "" {
@@ -1152,7 +1176,7 @@ func marshalDate(info TypeInfo, value interface{}) ([]byte, error) {
 			return nil, marshalErrorf("can not marshal %T into %s, date layout must be '2006-01-02'", value, info)
 			return nil, marshalErrorf("can not marshal %T into %s, date layout must be '2006-01-02'", value, info)
 		}
 		}
 		timestamp = int64(t.UTC().Unix()*1e3) + int64(t.UTC().Nanosecond()/1e6)
 		timestamp = int64(t.UTC().Unix()*1e3) + int64(t.UTC().Nanosecond()/1e6)
-		x := timestamp/86400000 + int64(1 << 31)
+		x := timestamp/86400000 + int64(1<<31)
 		return encInt(int32(x)), nil
 		return encInt(int32(x)), nil
 	}
 	}
 
 
@@ -1210,6 +1234,8 @@ func marshalList(info TypeInfo, value interface{}) ([]byte, error) {
 
 
 	if value == nil {
 	if value == nil {
 		return nil, nil
 		return nil, nil
+	} else if _, ok := value.(unsetColumn); ok {
+		return nil, nil
 	}
 	}
 
 
 	rv := reflect.ValueOf(value)
 	rv := reflect.ValueOf(value)
@@ -1326,6 +1352,8 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) {
 
 
 	if value == nil {
 	if value == nil {
 		return nil, nil
 		return nil, nil
+	} else if _, ok := value.(unsetColumn); ok {
+		return nil, nil
 	}
 	}
 
 
 	rv := reflect.ValueOf(value)
 	rv := reflect.ValueOf(value)
@@ -1420,6 +1448,8 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error {
 
 
 func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) {
 func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) {
 	switch val := value.(type) {
 	switch val := value.(type) {
+	case unsetColumn:
+		return nil, nil
 	case UUID:
 	case UUID:
 		return val.Bytes(), nil
 		return val.Bytes(), nil
 	case []byte:
 	case []byte:
@@ -1500,6 +1530,8 @@ func marshalInet(info TypeInfo, value interface{}) ([]byte, error) {
 	// ip address here otherwise the db value will be prefixed
 	// ip address here otherwise the db value will be prefixed
 	// with the remaining byte values e.g. ::ffff:127.0.0.1 and not 127.0.0.1
 	// with the remaining byte values e.g. ::ffff:127.0.0.1 and not 127.0.0.1
 	switch val := value.(type) {
 	switch val := value.(type) {
+	case unsetColumn:
+		return nil, nil
 	case net.IP:
 	case net.IP:
 		t := val.To4()
 		t := val.To4()
 		if t == nil {
 		if t == nil {
@@ -1560,6 +1592,8 @@ func unmarshalInet(info TypeInfo, data []byte, value interface{}) error {
 func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) {
 func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) {
 	tuple := info.(TupleTypeInfo)
 	tuple := info.(TupleTypeInfo)
 	switch v := value.(type) {
 	switch v := value.(type) {
+	case unsetColumn:
+		return nil, unmarshalErrorf("Invalid request: UnsetValue is unsupported for tuples")
 	case []interface{}:
 	case []interface{}:
 		var buf []byte
 		var buf []byte
 
 
@@ -1638,6 +1672,8 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	switch v := value.(type) {
 	case Marshaler:
 	case Marshaler:
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
+	case unsetColumn:
+		return nil, unmarshalErrorf("Invalid request: UnsetValue is unsupported for user defined types")
 	case UDTMarshaler:
 	case UDTMarshaler:
 		var buf []byte
 		var buf []byte
 		for _, e := range udt.Elements {
 		for _, e := range udt.Elements {