瀏覽代碼

Correctly handle missing UDT fields in Marshal

When we receive a struct which does not have all the fields for
that UDT then we should encode them as missing values, null on
the wire.

Fixes #515
Chris Bannister 10 年之前
父節點
當前提交
69051b37de
共有 3 個文件被更改,包括 77 次插入31 次删除
  1. 21 11
      frame.go
  2. 7 20
      marshal.go
  3. 49 0
      udt_test.go

+ 21 - 11
frame.go

@@ -1579,25 +1579,22 @@ func (f *framer) writeByte(b byte) {
 	f.wbuf = append(f.wbuf, b)
 }
 
-// these are protocol level binary types
-func (f *framer) writeInt(n int32) {
-	f.wbuf = append(f.wbuf,
-		byte(n>>24),
-		byte(n>>16),
+func appendShort(p []byte, n uint16) []byte {
+	return append(p,
 		byte(n>>8),
 		byte(n),
 	)
 }
 
-func (f *framer) writeShort(n uint16) {
-	f.wbuf = append(f.wbuf,
+func appendInt(p []byte, n int32) []byte {
+	return append(p, byte(n>>24),
+		byte(n>>16),
 		byte(n>>8),
-		byte(n),
-	)
+		byte(n))
 }
 
-func (f *framer) writeLong(n int64) {
-	f.wbuf = append(f.wbuf,
+func appendLong(p []byte, n int64) []byte {
+	return append(p,
 		byte(n>>56),
 		byte(n>>48),
 		byte(n>>40),
@@ -1609,6 +1606,19 @@ func (f *framer) writeLong(n int64) {
 	)
 }
 
+// these are protocol level binary types
+func (f *framer) writeInt(n int32) {
+	f.wbuf = appendInt(f.wbuf, n)
+}
+
+func (f *framer) writeShort(n uint16) {
+	f.wbuf = appendShort(f.wbuf, n)
+}
+
+func (f *framer) writeLong(n int64) {
+	f.wbuf = appendLong(f.wbuf, n)
+}
+
 func (f *framer) writeString(s string) {
 	f.writeShort(uint16(len(s)))
 	f.wbuf = append(f.wbuf, s...)

+ 7 - 20
marshal.go

@@ -1290,11 +1290,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 			}
 
 			n := len(data)
-			buf = append(buf, byte(n>>24),
-				byte(n>>16),
-				byte(n>>8),
-				byte(n))
-
+			buf = appendInt(buf, int32(n))
 			buf = append(buf, data...)
 		}
 
@@ -1313,11 +1309,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 			}
 
 			n := len(data)
-			buf = append(buf, byte(n>>24),
-				byte(n>>16),
-				byte(n>>8),
-				byte(n))
-
+			buf = appendInt(buf, int32(n))
 			buf = append(buf, data...)
 		}
 
@@ -1354,14 +1346,13 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 		}
 
 		if !f.IsValid() {
-			return nil, marshalErrorf("cannot marshal %T into %s", value, info)
+			n := -1
+			buf = appendInt(buf, int32(n))
+			continue
 		} else if f.Kind() == reflect.Ptr {
 			if f.IsNil() {
 				n := -1
-				buf = append(buf, byte(n>>24),
-					byte(n>>16),
-					byte(n>>8),
-					byte(n))
+				buf = appendInt(buf, int32(n))
 				continue
 			} else {
 				f = f.Elem()
@@ -1374,11 +1365,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 		}
 
 		n := len(data)
-		buf = append(buf, byte(n>>24),
-			byte(n>>16),
-			byte(n>>8),
-			byte(n))
-
+		buf = appendInt(buf, int32(n))
 		buf = append(buf, data...)
 	}
 

+ 49 - 0
udt_test.go

@@ -339,5 +339,54 @@ func TestMapScanUDT(t *testing.T) {
 			t.Errorf("message was not string got: %T", message)
 		}
 	}
+}
+
+func TestUDT_MissingField(t *testing.T) {
+	if *flagProto < protoVersion3 {
+		t.Skip("UDT are only available on protocol >= 3")
+	}
+
+	session := createSession(t)
+	defer session.Close()
+
+	err := createTable(session, `CREATE TYPE gocql_test.missing_field(
+		name text,
+		owner text);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = createTable(session, `CREATE TABLE gocql_test.missing_field(
+		id uuid,
+		udt_col frozen<udt_null_type>,
+
+		primary key(id)
+	);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	type col struct {
+		Name string `cql:"name"`
+	}
 
+	writeCol := &col{
+		Name: "test",
+	}
+
+	id := TimeUUID()
+	err = session.Query("INSERT INTO missing_field(id, udt_col) VALUES(?, ?)", id, writeCol).Exec()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	readCol := &col{}
+	err = session.Query("SELECT udt_col FROM missing_field WHERE id = ?", id).Scan(readCol)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if readCol.Name != writeCol.Name {
+		t.Errorf("expected %q: got %q", writeCol.Name, readCol.Name)
+	}
 }