浏览代码

support named query values (#971)

Add NamedValue to support naming query values
Chris Bannister 8 年之前
父节点
当前提交
2badfed26d
共有 3 个文件被更改,包括 63 次插入17 次删除
  1. 22 0
      cassandra_test.go
  2. 28 17
      conn.go
  3. 13 0
      frame.go

+ 22 - 0
cassandra_test.go

@@ -2695,3 +2695,25 @@ func TestUnsetColBatch(t *testing.T) {
 		t.Fatalf("expected id, my_int to be 1, got %v and %v", id, mInt)
 	}
 }
+
+func TestQuery_NamedValues(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if session.cfg.ProtoVersion < 3 {
+		t.Skip("named Values are not supported in protocol < 3")
+	}
+
+	if err := createTable(session, "CREATE TABLE gocql_test.named_query(id int, value text, PRIMARY KEY (id))"); err != nil {
+		t.Fatal(err)
+	}
+
+	err := session.Query("INSERT INTO gocql_test.named_query(id, value) VALUES(:id, :value)", NamedValue("id", 1), NamedValue("value", "i am a value")).Exec()
+	if err != nil {
+		t.Fatal(err)
+	}
+	var value string
+	if err := session.Query("SELECT VALUE from gocql_test.named_query WHERE id = :id", NamedValue("id", 1)).Scan(&value); err != nil {
+		t.Fatal(err)
+	}
+}

+ 28 - 17
conn.go

@@ -756,6 +756,26 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer)
 	return flight.preparedStatment, flight.err
 }
 
+func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error {
+	if named, ok := value.(*namedValue); ok {
+		dst.name = named.name
+		value = named.value
+	}
+
+	if _, ok := value.(unsetColumn); !ok {
+		val, err := Marshal(typ, value)
+		if err != nil {
+			return err
+		}
+
+		dst.value = val
+	} else {
+		dst.isUnset = true
+	}
+
+	return nil
+}
+
 func (c *Conn) executeQuery(qry *Query) *Iter {
 	params := queryParams{
 		consistency: qry.cons,
@@ -809,17 +829,12 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 
 		params.values = make([]queryValues, len(values))
 		for i := 0; i < len(values); i++ {
-			val, err := Marshal(info.request.columns[i].TypeInfo, values[i])
-			if err != nil {
-				return &Iter{err: err}
-			}
-
 			v := &params.values[i]
-			v.value = val
-			if _, ok := values[i].(unsetColumn); ok {
-				v.isUnset = true
+			value := values[i]
+			typ := info.request.columns[i].TypeInfo
+			if err := marshalQueryValue(typ, value, v); err != nil {
+				return &Iter{err: err}
 			}
-			// TODO: handle query binding names
 		}
 
 		params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata)
@@ -1009,16 +1024,12 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 			b.values = make([]queryValues, info.request.actualColCount)
 
 			for j := 0; j < info.request.actualColCount; j++ {
-				val, err := Marshal(info.request.columns[j].TypeInfo, values[j])
-				if err != nil {
+				v := &b.values[j]
+				value := values[j]
+				typ := info.request.columns[j].TypeInfo
+				if err := marshalQueryValue(typ, value, v); err != nil {
 					return &Iter{err: err}
 				}
-
-				b.values[j].value = val
-				if _, ok := values[j].(unsetColumn); ok {
-					b.values[j].isUnset = true
-				}
-				// TODO: add names
 			}
 		} else {
 			b.statement = entry.Stmt

+ 13 - 0
frame.go

@@ -20,6 +20,19 @@ type unsetColumn struct{}
 
 var UnsetValue = unsetColumn{}
 
+type namedValue struct {
+	name  string
+	value interface{}
+}
+
+// NamedValue produce a value which will bind to the named parameter in a query
+func NamedValue(name string, value interface{}) interface{} {
+	return &namedValue{
+		name:  name,
+		value: value,
+	}
+}
+
 const (
 	protoDirectionMask = 0x80
 	protoVersionMask   = 0x7F