Przeglądaj źródła

Change pgx driver

* Change pgx driver

* Add test case

* Fix bug
xormplus 7 lat temu
rodzic
commit
337639a4d6
3 zmienionych plików z 61 dodań i 2 usunięć
  1. 13 0
      dialect_postgres.go
  2. 46 0
      dialect_postgres_test.go
  3. 2 2
      xorm.go

+ 13 - 0
dialect_postgres.go

@@ -895,6 +895,7 @@ func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
 		args := []interface{}{tableName}
 		return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
 	}
+
 	args := []interface{}{db.Schema, tableName}
 	return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args
 }
@@ -1237,3 +1238,15 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
 
 	return db, nil
 }
+
+type pqDriverPgx struct {
+	pqDriver
+}
+
+func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*core.Uri, error) {
+	// Remove the leading characters for driver to work
+	if len(dataSourceName) >= 9 && dataSourceName[0] == 0 {
+		dataSourceName = dataSourceName[9:]
+	}
+	return pgx.pqDriver.Parse(driverName, dataSourceName)
+}

+ 46 - 0
dialect_postgres_test.go

@@ -4,6 +4,7 @@ import (
 	"reflect"
 	"testing"
 
+	"github.com/jackc/pgx/stdlib"
 	"github.com/xormplus/core"
 )
 
@@ -37,3 +38,48 @@ func TestParsePostgres(t *testing.T) {
 		}
 	}
 }
+
+func TestParsePgx(t *testing.T) {
+	tests := []struct {
+		in       string
+		expected string
+		valid    bool
+	}{
+		{"postgres://auser:password@localhost:5432/db?sslmode=disable", "db", true},
+		{"postgresql://auser:password@localhost:5432/db?sslmode=disable", "db", true},
+		{"postg://auser:password@localhost:5432/db?sslmode=disable", "db", false},
+		//{"postgres://auser:pass with space@localhost:5432/db?sslmode=disable", "db", true},
+		//{"postgres:// auser : password@localhost:5432/db?sslmode=disable", "db", true},
+		{"postgres://%20auser%20:pass%20with%20space@localhost:5432/db?sslmode=disable", "db", true},
+		//{"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true},
+		{"dbname=db sslmode=disable", "db", true},
+		{"user=auser password=password dbname=db sslmode=disable", "db", true},
+		{"", "db", false},
+		{"dbname=db =disable", "db", false},
+	}
+
+	driver := core.QueryDriver("pgx")
+
+	for _, test := range tests {
+		uri, err := driver.Parse("pgx", test.in)
+
+		if err != nil && test.valid {
+			t.Errorf("%q got unexpected error: %s", test.in, err)
+		} else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) {
+			t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected)
+		}
+
+		// Register DriverConfig
+		drvierConfig := stdlib.DriverConfig{}
+		stdlib.RegisterDriverConfig(&drvierConfig)
+		uri, err = driver.Parse("pgx",
+			drvierConfig.ConnectionString(test.in))
+		if err != nil && test.valid {
+			t.Errorf("%q got unexpected error: %s", test.in, err)
+		} else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) {
+			t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected)
+		}
+
+	}
+
+}

+ 2 - 2
xorm.go

@@ -17,7 +17,7 @@ import (
 
 const (
 	// Version show the xorm's version
-	Version string = "0.7.0.0504"
+	Version string = "0.7.0.0608"
 )
 
 func regDrvsNDialects() bool {
@@ -31,7 +31,7 @@ func regDrvsNDialects() bool {
 		"mysql":    {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }},
 		"mymysql":  {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }},
 		"postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }},
-		"pgx":      {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }},
+		"pgx":      {"postgres", func() core.Driver { return &pqDriverPgx{} }, func() core.Dialect { return &postgres{} }},
 		"sqlite3":  {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }},
 		"oci8":     {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }},
 		"goracle":  {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }},