浏览代码

Add metadata for functions and aggregates (#1204)

* Add metadata for functions and aggregates

* Review comments

* Style nits

* Review comments
Jaume Marhuenda 7 年之前
父节点
当前提交
6832a79641
共有 5 个文件被更改,包括 351 次插入3 次删除
  1. 132 0
      cassandra_test.go
  2. 33 0
      common_test.go
  3. 14 0
      helpers.go
  4. 170 1
      metadata.go
  5. 2 2
      metadata_test.go

+ 132 - 0
cassandra_test.go

@@ -2184,6 +2184,126 @@ func TestGetColumnMetadata(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestAggregateMetadata(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+	createAggregate(t, session)
+
+	aggregates, err := getAggregatesMetadata(session, "gocql_test")
+	if err != nil {
+		t.Fatalf("failed to query aggregate metadata with err: %v", err)
+	}
+	if aggregates == nil {
+		t.Fatal("failed to query aggregate metadata, nil returned")
+	}
+	if len(aggregates) != 1 {
+		t.Fatal("expected only a single aggregate")
+	}
+	aggregate := aggregates[0]
+
+	expectedAggregrate := AggregateMetadata{
+		Keyspace:      "gocql_test",
+		Name:          "average",
+		ArgumentTypes: []TypeInfo{NativeType{typ: TypeInt}},
+		InitCond:      "(0, 0)",
+		ReturnType:    NativeType{typ: TypeDouble},
+		StateType: TupleTypeInfo{
+			NativeType: NativeType{typ: TypeTuple},
+
+			Elems: []TypeInfo{
+				NativeType{typ: TypeInt},
+				NativeType{typ: TypeBigInt},
+			},
+		},
+		stateFunc: "avgstate",
+		finalFunc: "avgfinal",
+	}
+
+	// In this case cassandra is returning a blob
+	if flagCassVersion.Before(3, 0, 0) {
+		expectedAggregrate.InitCond = string([]byte{0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0})
+	}
+
+	if !reflect.DeepEqual(aggregate, expectedAggregrate) {
+		t.Fatalf("aggregate is %+v, but expected %+v", aggregate, expectedAggregrate)
+	}
+}
+
+func TestFunctionMetadata(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+	createFunctions(t, session)
+
+	functions, err := getFunctionsMetadata(session, "gocql_test")
+	if err != nil {
+		t.Fatalf("failed to query function metadata with err: %v", err)
+	}
+	if functions == nil {
+		t.Fatal("failed to query function metadata, nil returned")
+	}
+	if len(functions) != 2 {
+		t.Fatal("expected two functions")
+	}
+	avgState := functions[1]
+	avgFinal := functions[0]
+
+	avgStateBody := "if (val !=null) {state.setInt(0, state.getInt(0)+1); state.setLong(1, state.getLong(1)+val.intValue());}return state;"
+	expectedAvgState := FunctionMetadata{
+		Keyspace: "gocql_test",
+		Name:     "avgstate",
+		ArgumentTypes: []TypeInfo{
+			TupleTypeInfo{
+				NativeType: NativeType{typ: TypeTuple},
+
+				Elems: []TypeInfo{
+					NativeType{typ: TypeInt},
+					NativeType{typ: TypeBigInt},
+				},
+			},
+			NativeType{typ: TypeInt},
+		},
+		ArgumentNames: []string{"state", "val"},
+		ReturnType: TupleTypeInfo{
+			NativeType: NativeType{typ: TypeTuple},
+
+			Elems: []TypeInfo{
+				NativeType{typ: TypeInt},
+				NativeType{typ: TypeBigInt},
+			},
+		},
+		CalledOnNullInput: true,
+		Language:          "java",
+		Body:              avgStateBody,
+	}
+	if !reflect.DeepEqual(avgState, expectedAvgState) {
+		t.Fatalf("function is %+v, but expected %+v", avgState, expectedAvgState)
+	}
+
+	finalStateBody := "double r = 0; if (state.getInt(0) == 0) return null; r = state.getLong(1); r/= state.getInt(0); return Double.valueOf(r);"
+	expectedAvgFinal := FunctionMetadata{
+		Keyspace: "gocql_test",
+		Name:     "avgfinal",
+		ArgumentTypes: []TypeInfo{
+			TupleTypeInfo{
+				NativeType: NativeType{typ: TypeTuple},
+
+				Elems: []TypeInfo{
+					NativeType{typ: TypeInt},
+					NativeType{typ: TypeBigInt},
+				},
+			},
+		},
+		ArgumentNames:     []string{"state"},
+		ReturnType:        NativeType{typ: TypeDouble},
+		CalledOnNullInput: true,
+		Language:          "java",
+		Body:              finalStateBody,
+	}
+	if !reflect.DeepEqual(avgFinal, expectedAvgFinal) {
+		t.Fatalf("function is %+v, but expected %+v", avgFinal, expectedAvgFinal)
+	}
+}
+
 // Integration test of querying and composition the keyspace metadata
 // Integration test of querying and composition the keyspace metadata
 func TestKeyspaceMetadata(t *testing.T) {
 func TestKeyspaceMetadata(t *testing.T) {
 	session := createSession(t)
 	session := createSession(t)
@@ -2192,6 +2312,7 @@ func TestKeyspaceMetadata(t *testing.T) {
 	if err := createTable(session, "CREATE TABLE gocql_test.test_metadata (first_id int, second_id int, third_id int, PRIMARY KEY (first_id, second_id))"); err != nil {
 	if err := createTable(session, "CREATE TABLE gocql_test.test_metadata (first_id int, second_id int, third_id int, PRIMARY KEY (first_id, second_id))"); err != nil {
 		t.Fatalf("failed to create table with error '%v'", err)
 		t.Fatalf("failed to create table with error '%v'", err)
 	}
 	}
+	createAggregate(t, session)
 
 
 	if err := session.Query("CREATE INDEX index_metadata ON test_metadata ( third_id )").Exec(); err != nil {
 	if err := session.Query("CREATE INDEX index_metadata ON test_metadata ( third_id )").Exec(); err != nil {
 		t.Fatalf("failed to create index with err: %v", err)
 		t.Fatalf("failed to create index with err: %v", err)
@@ -2246,6 +2367,17 @@ func TestKeyspaceMetadata(t *testing.T) {
 		// TODO(zariel): scan index info from system_schema
 		// TODO(zariel): scan index info from system_schema
 		t.Errorf("Expected column index named 'index_metadata' but was '%s'", thirdColumn.Index.Name)
 		t.Errorf("Expected column index named 'index_metadata' but was '%s'", thirdColumn.Index.Name)
 	}
 	}
+
+	aggregate, found := keyspaceMetadata.Aggregates["average"]
+	if !found {
+		t.Fatal("failed to find the aggreate in metadata")
+	}
+	if aggregate.FinalFunc.Name != "avgfinal" {
+		t.Fatalf("expected final function %s, but got %s", "avgFinal", aggregate.FinalFunc.Name)
+	}
+	if aggregate.StateFunc.Name != "avgstate" {
+		t.Fatalf("expected state function %s, but got %s", "avgstate", aggregate.StateFunc.Name)
+	}
 }
 }
 
 
 // Integration test of the routing key calculation
 // Integration test of the routing key calculation

+ 33 - 0
common_test.go

@@ -170,6 +170,39 @@ func createTestSession() *Session {
 	return session
 	return session
 }
 }
 
 
+func createFunctions(t *testing.T, session *Session) {
+	if err := session.Query(`
+		CREATE OR REPLACE FUNCTION gocql_test.avgState ( state tuple<int,bigint>, val int )
+		CALLED ON NULL INPUT
+		RETURNS tuple<int,bigint>
+		LANGUAGE java AS
+		$$if (val !=null) {state.setInt(0, state.getInt(0)+1); state.setLong(1, state.getLong(1)+val.intValue());}return state;$$;	`).Exec(); err != nil {
+		t.Fatalf("failed to create function with err: %v", err)
+	}
+	if err := session.Query(`
+		CREATE OR REPLACE FUNCTION gocql_test.avgFinal ( state tuple<int,bigint> )
+		CALLED ON NULL INPUT
+		RETURNS double
+		LANGUAGE java AS
+		$$double r = 0; if (state.getInt(0) == 0) return null; r = state.getLong(1); r/= state.getInt(0); return Double.valueOf(r);$$ 
+	`).Exec(); err != nil {
+		t.Fatalf("failed to create function with err: %v", err)
+	}
+}
+
+func createAggregate(t *testing.T, session *Session) {
+	createFunctions(t, session)
+	if err := session.Query(`
+		CREATE OR REPLACE AGGREGATE gocql_test.average(int)
+		SFUNC avgState
+		STYPE tuple<int,bigint>
+		FINALFUNC avgFinal
+		INITCOND (0,0);
+	`).Exec(); err != nil {
+		t.Fatalf("failed to create aggregate with err: %v", err)
+	}
+}
+
 func staticAddressTranslator(newAddr net.IP, newPort int) AddressTranslator {
 func staticAddressTranslator(newAddr net.IP, newPort int) AddressTranslator {
 	return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
 	return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
 		return newAddr, newPort
 		return newAddr, newPort

+ 14 - 0
helpers.go

@@ -191,6 +191,20 @@ func splitCompositeTypes(name string) []string {
 	return parts
 	return parts
 }
 }
 
 
+func apacheToCassandraType(t string) string {
+	t = strings.Replace(t, apacheCassandraTypePrefix, "", -1)
+	t = strings.Replace(t, "(", "<", -1)
+	t = strings.Replace(t, ")", ">", -1)
+	types := strings.FieldsFunc(t, func(r rune) bool {
+		return r == '<' || r == '>' || r == ','
+	})
+	for _, typ := range types {
+		t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1)
+	}
+	// This is done so it exactly matches what Cassandra returns
+	return strings.Replace(t, ",", ", ", -1)
+}
+
 func getApacheCassandraType(class string) Type {
 func getApacheCassandraType(class string) Type {
 	switch strings.TrimPrefix(class, apacheCassandraTypePrefix) {
 	switch strings.TrimPrefix(class, apacheCassandraTypePrefix) {
 	case "AsciiType":
 	case "AsciiType":

+ 170 - 1
metadata.go

@@ -20,6 +20,8 @@ type KeyspaceMetadata struct {
 	StrategyClass   string
 	StrategyClass   string
 	StrategyOptions map[string]interface{}
 	StrategyOptions map[string]interface{}
 	Tables          map[string]*TableMetadata
 	Tables          map[string]*TableMetadata
+	Functions       map[string]*FunctionMetadata
+	Aggregates      map[string]*AggregateMetadata
 }
 }
 
 
 // schema metadata for a table (a.k.a. column family)
 // schema metadata for a table (a.k.a. column family)
@@ -52,6 +54,33 @@ type ColumnMetadata struct {
 	Index           ColumnIndexMetadata
 	Index           ColumnIndexMetadata
 }
 }
 
 
+// FunctionMetadata holds metadata for function constructs
+type FunctionMetadata struct {
+	Keyspace          string
+	Name              string
+	ArgumentTypes     []TypeInfo
+	ArgumentNames     []string
+	Body              string
+	CalledOnNullInput bool
+	Language          string
+	ReturnType        TypeInfo
+}
+
+// AggregateMetadata holds metadata for aggregate constructs
+type AggregateMetadata struct {
+	Keyspace      string
+	Name          string
+	ArgumentTypes []TypeInfo
+	FinalFunc     FunctionMetadata
+	InitCond      string
+	ReturnType    TypeInfo
+	StateFunc     FunctionMetadata
+	StateType     TypeInfo
+
+	stateFunc string
+	finalFunc string
+}
+
 // the ordering of the column with regard to its comparator
 // the ordering of the column with regard to its comparator
 type ColumnOrder bool
 type ColumnOrder bool
 
 
@@ -196,9 +225,17 @@ func (s *schemaDescriber) refreshSchema(keyspaceName string) error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
+	functions, err := getFunctionsMetadata(s.session, keyspaceName)
+	if err != nil {
+		return err
+	}
+	aggregates, err := getAggregatesMetadata(s.session, keyspaceName)
+	if err != nil {
+		return err
+	}
 
 
 	// organize the schema data
 	// organize the schema data
-	compileMetadata(s.session.cfg.ProtoVersion, keyspace, tables, columns)
+	compileMetadata(s.session.cfg.ProtoVersion, keyspace, tables, columns, functions, aggregates)
 
 
 	// update the cache
 	// update the cache
 	s.cache[keyspaceName] = keyspace
 	s.cache[keyspaceName] = keyspace
@@ -216,6 +253,8 @@ func compileMetadata(
 	keyspace *KeyspaceMetadata,
 	keyspace *KeyspaceMetadata,
 	tables []TableMetadata,
 	tables []TableMetadata,
 	columns []ColumnMetadata,
 	columns []ColumnMetadata,
+	functions []FunctionMetadata,
+	aggregates []AggregateMetadata,
 ) {
 ) {
 	keyspace.Tables = make(map[string]*TableMetadata)
 	keyspace.Tables = make(map[string]*TableMetadata)
 	for i := range tables {
 	for i := range tables {
@@ -223,6 +262,16 @@ func compileMetadata(
 
 
 		keyspace.Tables[tables[i].Name] = &tables[i]
 		keyspace.Tables[tables[i].Name] = &tables[i]
 	}
 	}
+	keyspace.Functions = make(map[string]*FunctionMetadata, len(functions))
+	for i := range functions {
+		keyspace.Functions[functions[i].Name] = &functions[i]
+	}
+	keyspace.Aggregates = make(map[string]*AggregateMetadata, len(aggregates))
+	for _, aggregate := range aggregates {
+		aggregate.FinalFunc = *keyspace.Functions[aggregate.finalFunc]
+		aggregate.StateFunc = *keyspace.Functions[aggregate.stateFunc]
+		keyspace.Aggregates[aggregate.Name] = &aggregate
+	}
 
 
 	// add columns from the schema data
 	// add columns from the schema data
 	for i := range columns {
 	for i := range columns {
@@ -793,6 +842,126 @@ func getColumnMetadata(session *Session, keyspaceName string) ([]ColumnMetadata,
 	return columns, nil
 	return columns, nil
 }
 }
 
 
+func getTypeInfo(t string) TypeInfo {
+	if strings.HasPrefix(t, apacheCassandraTypePrefix) {
+		t = apacheToCassandraType(t)
+	}
+	return getCassandraType(t)
+}
+
+func getFunctionsMetadata(session *Session, keyspaceName string) ([]FunctionMetadata, error) {
+	if session.cfg.ProtoVersion == protoVersion1 {
+		return nil, nil
+	}
+	var tableName string
+	if session.useSystemSchema {
+		tableName = "system_schema.functions"
+	} else {
+		tableName = "system.schema_functions"
+	}
+	stmt := fmt.Sprintf(`
+		SELECT
+			function_name,
+			argument_types,
+			argument_names,
+			body,
+			called_on_null_input,
+			language,
+			return_type
+		FROM %s
+		WHERE keyspace_name = ?`, tableName)
+
+	var functions []FunctionMetadata
+
+	rows := session.control.query(stmt, keyspaceName).Scanner()
+	for rows.Next() {
+		function := FunctionMetadata{Keyspace: keyspaceName}
+		var argumentTypes []string
+		var returnType string
+		err := rows.Scan(&function.Name,
+			&argumentTypes,
+			&function.ArgumentNames,
+			&function.Body,
+			&function.CalledOnNullInput,
+			&function.Language,
+			&returnType,
+		)
+		if err != nil {
+			return nil, err
+		}
+		function.ReturnType = getTypeInfo(returnType)
+		function.ArgumentTypes = make([]TypeInfo, len(argumentTypes))
+		for i, argumentType := range argumentTypes {
+			function.ArgumentTypes[i] = getTypeInfo(argumentType)
+		}
+		functions = append(functions, function)
+	}
+
+	if err := rows.Err(); err != nil {
+		return nil, err
+	}
+
+	return functions, nil
+}
+
+func getAggregatesMetadata(session *Session, keyspaceName string) ([]AggregateMetadata, error) {
+	if session.cfg.ProtoVersion == protoVersion1 {
+		return nil, nil
+	}
+	var tableName string
+	if session.useSystemSchema {
+		tableName = "system_schema.aggregates"
+	} else {
+		tableName = "system.schema_aggregates"
+	}
+
+	stmt := fmt.Sprintf(`
+		SELECT
+			aggregate_name,
+			argument_types,
+			final_func,
+			initcond,
+			return_type,
+			state_func,
+			state_type
+		FROM %s
+		WHERE keyspace_name = ?`, tableName)
+
+	var aggregates []AggregateMetadata
+
+	rows := session.control.query(stmt, keyspaceName).Scanner()
+	for rows.Next() {
+		aggregate := AggregateMetadata{Keyspace: keyspaceName}
+		var argumentTypes []string
+		var returnType string
+		var stateType string
+		err := rows.Scan(&aggregate.Name,
+			&argumentTypes,
+			&aggregate.finalFunc,
+			&aggregate.InitCond,
+			&returnType,
+			&aggregate.stateFunc,
+			&stateType,
+		)
+		if err != nil {
+			return nil, err
+		}
+		aggregate.ReturnType = getTypeInfo(returnType)
+		aggregate.StateType = getTypeInfo(stateType)
+		aggregate.ArgumentTypes = make([]TypeInfo, len(argumentTypes))
+		for i, argumentType := range argumentTypes {
+			aggregate.ArgumentTypes[i] = getTypeInfo(argumentType)
+		}
+		aggregates = append(aggregates, aggregate)
+	}
+
+	if err := rows.Err(); err != nil {
+		return nil, err
+	}
+
+	return aggregates, nil
+}
+
 // type definition parser state
 // type definition parser state
 type typeParser struct {
 type typeParser struct {
 	input string
 	input string

+ 2 - 2
metadata_test.go

@@ -94,7 +94,7 @@ func TestCompileMetadata(t *testing.T) {
 		{Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "schema_version", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UUIDType"},
 		{Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "schema_version", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UUIDType"},
 		{Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "tokens", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.UTF8Type)"},
 		{Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "tokens", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.UTF8Type)"},
 	}
 	}
-	compileMetadata(1, keyspace, tables, columns)
+	compileMetadata(1, keyspace, tables, columns, nil, nil)
 	assertKeyspaceMetadata(
 	assertKeyspaceMetadata(
 		t,
 		t,
 		keyspace,
 		keyspace,
@@ -375,7 +375,7 @@ func TestCompileMetadata(t *testing.T) {
 			Validator: "org.apache.cassandra.db.marshal.UTF8Type",
 			Validator: "org.apache.cassandra.db.marshal.UTF8Type",
 		},
 		},
 	}
 	}
-	compileMetadata(2, keyspace, tables, columns)
+	compileMetadata(2, keyspace, tables, columns, nil, nil)
 	assertKeyspaceMetadata(
 	assertKeyspaceMetadata(
 		t,
 		t,
 		keyspace,
 		keyspace,