Sfoglia il codice sorgente

Improve schema support on postgres dialect
1.add schema on postgres dialect
2.fix to support no specific schema when postgres

xormplus 7 anni fa
parent
commit
6bb9d26fd7
2 ha cambiato i file con 63 aggiunte e 34 eliminazioni
  1. 59 26
      dialect_postgres.go
  2. 4 8
      dialect_postgres_test.go

+ 59 - 26
dialect_postgres.go

@@ -771,12 +771,17 @@ var (
 
 type postgres struct {
 	core.Base
-	schema string
 }
 
 func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
-	db.schema = DefaultPostgresSchema
-	return db.Base.Init(d, db, uri, drivername, dataSourceName)
+	err := db.Base.Init(d, db, uri, drivername, dataSourceName)
+	if err != nil {
+		return err
+	}
+	if db.Schema == "" {
+		db.Schema = DefaultPostgresSchema
+	}
+	return nil
 }
 
 func (db *postgres) SqlType(c *core.Column) string {
@@ -873,25 +878,33 @@ func (db *postgres) IndexOnTable() bool {
 }
 
 func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
-	args := []interface{}{tableName, idxName}
+	if len(db.Schema) == 0 {
+		args := []interface{}{tableName, idxName}
+		return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args
+	}
+
+	args := []interface{}{db.Schema, tableName, idxName}
 	return `SELECT indexname FROM pg_indexes ` +
-		`WHERE tablename = ? AND indexname = ?`, args
+		`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
 }
 
 func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
-	args := []interface{}{tableName}
-	return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
+	if len(db.Schema) == 0 {
+		args := []interface{}{tableName}
+		return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
+	}
+	args := []interface{}{db.Schema, tableName}
+	return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args
 }
 
-/*func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
-	args := []interface{}{tableName, colName}
-	return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" +
-		" AND column_name = ?", args
-}*/
-
 func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string {
-	return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
-		tableName, col.Name, db.SqlType(col))
+	if len(db.Schema) == 0 {
+		return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
+			tableName, col.Name, db.SqlType(col))
+	}
+	return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s",
+		db.Schema, tableName, col.Name, db.SqlType(col))
+
 }
 
 func (db *postgres) CreateIndexSql(tableName string, index *core.Index) string {
@@ -901,7 +914,6 @@ func (db *postgres) CreateIndexSql(tableName string, index *core.Index) string {
 }
 
 func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
-	//var unique string
 	quote := db.Quote
 	idxName := index.Name
 
@@ -917,9 +929,14 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
 }
 
 func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
-	args := []interface{}{tableName, colName}
-	query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
-		" AND column_name = $2"
+	args := []interface{}{db.Schema, tableName, colName}
+	query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" +
+		" AND column_name = $3"
+	if len(db.Schema) == 0 {
+		args = []interface{}{tableName, colName}
+		query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
+			" AND column_name = $2"
+	}
 	db.LogSQL(query, args)
 
 	rows, err := db.DB().Query(query, args...)
@@ -932,8 +949,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
 }
 
 func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
-	// FIXME: the schema should be replaced by user custom's
-	args := []interface{}{tableName, db.schema}
+	args := []interface{}{tableName}
 	s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix ,
     CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
     CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
@@ -944,7 +960,15 @@ FROM pg_attribute f
     LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
     LEFT JOIN pg_class AS g ON p.confrelid = g.oid
     LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
-WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;`
+WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;`
+
+	var f string
+	if len(db.Schema) != 0 {
+		args = append(args, db.Schema)
+		f = "AND s.table_schema = $2"
+	}
+	s = fmt.Sprintf(s, f)
+
 	db.LogSQL(s, args)
 
 	rows, err := db.DB().Query(s, args...)
@@ -1034,8 +1058,13 @@ WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.att
 }
 
 func (db *postgres) GetTables() ([]*core.Table, error) {
-	args := []interface{}{db.schema}
-	s := fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = $1")
+	args := []interface{}{}
+	s := "SELECT tablename FROM pg_tables"
+	if len(db.Schema) != 0 {
+		args = append(args, db.Schema)
+		s = s + " WHERE schemaname = $1"
+	}
+
 	db.LogSQL(s, args)
 
 	rows, err := db.DB().Query(s, args...)
@@ -1059,9 +1088,13 @@ func (db *postgres) GetTables() ([]*core.Table, error) {
 }
 
 func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
-	args := []interface{}{db.schema, tableName}
-	s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE schemaname=$1 AND tablename=$2")
+	args := []interface{}{tableName}
+	s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
 	db.LogSQL(s, args)
+	if len(db.Schema) != 0 {
+		args = append(args, db.Schema)
+		s = s + " AND schemaname=$2"
+	}
 
 	rows, err := db.DB().Query(s, args...)
 	if err != nil {

+ 4 - 8
dialect_postgres_test.go

@@ -7,11 +7,7 @@ import (
 	"github.com/xormplus/core"
 )
 
-func TestPostgresDialect(t *testing.T) {
-	TestParse(t)
-}
-
-func TestParse(t *testing.T) {
+func TestParsePostgres(t *testing.T) {
 	tests := []struct {
 		in       string
 		expected string
@@ -20,10 +16,10 @@ func TestParse(t *testing.T) {
 		{"postgres://auser:password@localhost:5432/db?sslmode=disable", "db", true},
 		{"postgresql://auser:password@localhost:5432/db?sslmode=disable", "db", true},
 		{"postg://auser:password@localhost:5432/db?sslmode=disable", "db", false},
-		{"postgres://auser:pass with space@localhost:5432/db?sslmode=disable", "db", true},
-		{"postgres:// auser : password@localhost:5432/db?sslmode=disable", "db", true},
+		//{"postgres://auser:pass with space@localhost:5432/db?sslmode=disable", "db", true},
+		//{"postgres:// auser : password@localhost:5432/db?sslmode=disable", "db", true},
 		{"postgres://%20auser%20:pass%20with%20space@localhost:5432/db?sslmode=disable", "db", true},
-		{"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true},
+		//{"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true},
 		{"dbname=db sslmode=disable", "db", true},
 		{"user=auser password=password dbname=db sslmode=disable", "db", true},
 		{"", "db", false},