소스 검색

Add missing whitespace in StringNoPk in column.go

1.Add missing whitespace in StringNoPk in column.go
2.Remove func QuoteStr() in interface Dialect
Unknown 6 년 전
부모
커밋
02dee4a64a
14개의 변경된 파일110개의 추가작업 그리고 48개의 파일을 삭제
  1. 2 2
      column.go
  2. 4 0
      db.go
  3. 2 1
      db_test.go
  4. 15 6
      dialect.go
  5. 3 1
      error.go
  6. 18 2
      filter.go
  7. 25 0
      filter_test.go
  8. 3 1
      ilogger.go
  9. 4 3
      index.go
  10. 4 4
      mapper.go
  11. 1 1
      rows.go
  12. 1 0
      stmt.go
  13. 5 4
      table.go
  14. 23 23
      type.go

+ 2 - 2
column.go

@@ -73,7 +73,7 @@ func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable
 
 // String generate column description string according dialect
 func (col *Column) String(d Dialect) string {
-	sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
+	sql := d.Quote(col.Name) + " "
 
 	sql += d.SqlType(col) + " "
 
@@ -101,7 +101,7 @@ func (col *Column) String(d Dialect) string {
 
 // StringNoPk generate column description string according dialect without primary keys
 func (col *Column) StringNoPk(d Dialect) string {
-	sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
+	sql := d.Quote(col.Name) + " "
 
 	sql += d.SqlType(col) + " "
 

+ 4 - 0
db.go

@@ -15,6 +15,7 @@ import (
 )
 
 var (
+	// DefaultCacheSize sets the default cache size
 	DefaultCacheSize = 200
 )
 
@@ -132,6 +133,7 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
 	return db.QueryContext(context.Background(), query, args...)
 }
 
+// QueryMapContext executes query with parameters via map and context
 func (db *DB) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) {
 	query, args, err := MapToSlice(query, mp)
 	if err != nil {
@@ -140,6 +142,7 @@ func (db *DB) QueryMapContext(ctx context.Context, query string, mp interface{})
 	return db.QueryContext(ctx, query, args...)
 }
 
+// QueryMap executes query with parameters via map
 func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
 	return db.QueryMapContext(context.Background(), query, mp)
 }
@@ -196,6 +199,7 @@ var (
 	re = regexp.MustCompile(`[?](\w+)`)
 )
 
+// ExecMapContext exec map with context.Context
 // insert into (name) values (?)
 // insert into (name) values (?name)
 func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) {

+ 2 - 1
db_test.go

@@ -17,6 +17,7 @@ import (
 
 var (
 	dbtype         = flag.String("dbtype", "mysql", "database type")
+	dbConn         = flag.String("dbConn", "root:@/core_test?charset=utf8", "database connect string")
 	createTableSql string
 )
 
@@ -50,7 +51,7 @@ func testOpen() (*DB, error) {
 		os.Remove("./test.db")
 		return Open("sqlite3", "./test.db")
 	case "mysql":
-		return Open("mysql", "root:@/core_test?charset=utf8")
+		return Open("mysql", *dbConn)
 	default:
 		panic("no db type")
 	}

+ 15 - 6
dialect.go

@@ -40,9 +40,10 @@ type Dialect interface {
 	DriverName() string
 	DataSourceName() string
 
-	QuoteStr() string
 	IsReserved(string) bool
 	Quote(string) string
+	// Deprecated: use Quote(string) string instead
+	QuoteStr() string
 	AndStr() string
 	OrStr() string
 	EqStr() string
@@ -70,8 +71,8 @@ type Dialect interface {
 
 	ForUpdateSql(query string) string
 
-	//CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error
-	//MustDropTable(tableName string) error
+	// CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error
+	// MustDropTable(tableName string) error
 
 	GetColumns(tableName string) ([]string, map[string]*Column, error)
 	GetTables() ([]*Table, error)
@@ -85,6 +86,7 @@ func OpenDialect(dialect Dialect) (*DB, error) {
 	return Open(dialect.DriverName(), dialect.DataSourceName())
 }
 
+// Base represents a basic dialect and all real dialects could embed this struct
 type Base struct {
 	db             *DB
 	dialect        Dialect
@@ -172,8 +174,15 @@ func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) {
 }
 
 func (db *Base) IsColumnExist(tableName, colName string) (bool, error) {
-	query := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
-	query = strings.Replace(query, "`", db.dialect.QuoteStr(), -1)
+	query := fmt.Sprintf(
+		"SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?",
+		db.dialect.Quote("COLUMN_NAME"),
+		db.dialect.Quote("INFORMATION_SCHEMA"),
+		db.dialect.Quote("COLUMNS"),
+		db.dialect.Quote("TABLE_SCHEMA"),
+		db.dialect.Quote("TABLE_NAME"),
+		db.dialect.Quote("COLUMN_NAME"),
+	)
 	return db.HasRecords(query, db.DbName, tableName, colName)
 }
 
@@ -310,7 +319,7 @@ func RegisterDialect(dbName DbType, dialectFunc func() Dialect) {
 	dialects[strings.ToLower(string(dbName))] = dialectFunc // !nashtsai! allow override dialect
 }
 
-// QueryDialect query if registed database dialect
+// QueryDialect query if registered database dialect
 func QueryDialect(dbName DbType) Dialect {
 	if d, ok := dialects[strings.ToLower(string(dbName))]; ok {
 		return d()

+ 3 - 1
error.go

@@ -7,6 +7,8 @@ package core
 import "errors"
 
 var (
-	ErrNoMapPointer    = errors.New("mp should be a map's pointer")
+	// ErrNoMapPointer represents error when no map pointer
+	ErrNoMapPointer = errors.New("mp should be a map's pointer")
+	// ErrNoStructPointer represents error when no struct pointer
 	ErrNoStructPointer = errors.New("mp should be a struct's pointer")
 )

+ 18 - 2
filter.go

@@ -19,7 +19,23 @@ type QuoteFilter struct {
 }
 
 func (s *QuoteFilter) Do(sql string, dialect Dialect, table *Table) string {
-	return strings.Replace(sql, "`", dialect.QuoteStr(), -1)
+	dummy := dialect.Quote("")
+	if len(dummy) != 2 {
+		return sql
+	}
+	prefix, suffix := dummy[0], dummy[1]
+	raw := []byte(sql)
+	for i, cnt := 0, 0; i < len(raw); i = i + 1 {
+		if raw[i] == '`' {
+			if cnt%2 == 0 {
+				raw[i] = prefix
+			} else {
+				raw[i] = suffix
+			}
+			cnt++
+		}
+	}
+	return string(raw)
 }
 
 // IdFilter filter SQL replace (id) to primary key column name
@@ -35,7 +51,7 @@ func NewQuoter(dialect Dialect) *Quoter {
 }
 
 func (q *Quoter) Quote(content string) string {
-	return q.dialect.QuoteStr() + content + q.dialect.QuoteStr()
+	return q.dialect.Quote(content)
 }
 
 func (i *IdFilter) Do(sql string, dialect Dialect, table *Table) string {

+ 25 - 0
filter_test.go

@@ -0,0 +1,25 @@
+package core
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+type quoterOnly struct {
+	Dialect
+}
+
+func (q *quoterOnly) Quote(item string) string {
+	return "[" + item + "]"
+}
+
+func TestQuoteFilter_Do(t *testing.T) {
+	f := QuoteFilter{}
+	sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
+	res := f.Do(sql, new(quoterOnly), nil)
+	assert.EqualValues(t,
+		"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
+		res,
+	)
+}

+ 3 - 1
ilogger.go

@@ -4,8 +4,10 @@
 
 package core
 
+// LogLevel defines a log level
 type LogLevel int
 
+// enumerate all LogLevels
 const (
 	// !nashtsai! following level also match syslog.Priority value
 	LOG_DEBUG LogLevel = iota
@@ -16,7 +18,7 @@ const (
 	LOG_UNKNOWN
 )
 
-// logger interface
+// ILogger is a logger interface
 type ILogger interface {
 	Debug(v ...interface{})
 	Debugf(format string, v ...interface{})

+ 4 - 3
index.go

@@ -9,12 +9,13 @@ import (
 	"strings"
 )
 
+// enumerate all index types
 const (
 	IndexType = iota + 1
 	UniqueType
 )
 
-// database index
+// Index represents a database index
 type Index struct {
 	IsRegular bool
 	Name      string
@@ -35,7 +36,7 @@ func (index *Index) XName(tableName string) string {
 	return index.Name
 }
 
-// add columns which will be composite index
+// AddColumn add columns which will be composite index
 func (index *Index) AddColumn(cols ...string) {
 	for _, col := range cols {
 		index.Cols = append(index.Cols, col)
@@ -65,7 +66,7 @@ func (index *Index) Equal(dst *Index) bool {
 	return true
 }
 
-// new an index
+// NewIndex new an index object
 func NewIndex(name string, indexType int) *Index {
 	return &Index{true, name, indexType, make([]string, 0)}
 }

+ 4 - 4
mapper.go

@@ -9,7 +9,7 @@ import (
 	"sync"
 )
 
-// name translation between struct, fields names and table, column names
+// IMapper represents a name convertation between struct's fields name and table's column name
 type IMapper interface {
 	Obj2Table(string) string
 	Table2Obj(string) string
@@ -184,7 +184,7 @@ func (mapper GonicMapper) Table2Obj(name string) string {
 	return string(newstr)
 }
 
-// A GonicMapper that contains a list of common initialisms taken from golang/lint
+// LintGonicMapper is A GonicMapper that contains a list of common initialisms taken from golang/lint
 var LintGonicMapper = GonicMapper{
 	"API":   true,
 	"ASCII": true,
@@ -221,7 +221,7 @@ var LintGonicMapper = GonicMapper{
 	"XSS":   true,
 }
 
-// provide prefix table name support
+// PrefixMapper provides prefix table name support
 type PrefixMapper struct {
 	Mapper IMapper
 	Prefix string
@@ -239,7 +239,7 @@ func NewPrefixMapper(mapper IMapper, prefix string) PrefixMapper {
 	return PrefixMapper{mapper, prefix}
 }
 
-// provide suffix table name support
+// SuffixMapper provides suffix table name support
 type SuffixMapper struct {
 	Mapper IMapper
 	Suffix string

+ 1 - 1
rows.go

@@ -170,7 +170,7 @@ func (rs *Rows) ScanMap(dest interface{}) error {
 	newDest := make([]interface{}, len(cols))
 	vvv := vv.Elem()
 
-	for i, _ := range cols {
+	for i := range cols {
 		newDest[i] = rs.db.reflectNew(vvv.Type().Elem()).Interface()
 	}
 

+ 1 - 0
stmt.go

@@ -11,6 +11,7 @@ import (
 	"reflect"
 )
 
+// Stmt reprents a stmt objects
 type Stmt struct {
 	*sql.Stmt
 	db    *DB

+ 5 - 4
table.go

@@ -9,7 +9,7 @@ import (
 	"strings"
 )
 
-// database table
+// Table represents a database table
 type Table struct {
 	Name          string
 	Type          reflect.Type
@@ -41,6 +41,7 @@ func NewEmptyTable() *Table {
 	return NewTable("", nil)
 }
 
+// NewTable creates a new Table object
 func NewTable(name string, t reflect.Type) *Table {
 	return &Table{Name: name, Type: t,
 		columnsSeq:  make([]string, 0),
@@ -87,7 +88,7 @@ func (table *Table) GetColumnIdx(name string, idx int) *Column {
 	return nil
 }
 
-// if has primary key, return column
+// PKColumns reprents all primary key columns
 func (table *Table) PKColumns() []*Column {
 	columns := make([]*Column, len(table.PrimaryKeys))
 	for i, name := range table.PrimaryKeys {
@@ -117,7 +118,7 @@ func (table *Table) DeletedColumn() *Column {
 	return table.GetColumn(table.Deleted)
 }
 
-// add a column to table
+// AddColumn adds a column to table
 func (table *Table) AddColumn(col *Column) {
 	table.columnsSeq = append(table.columnsSeq, col.Name)
 	table.columns = append(table.columns, col)
@@ -148,7 +149,7 @@ func (table *Table) AddColumn(col *Column) {
 	}
 }
 
-// add an index or an unique to table
+// AddIndex adds an index or an unique to table
 func (table *Table) AddIndex(index *Index) {
 	table.Indexes[index.Name] = index
 }

+ 23 - 23
type.go

@@ -87,16 +87,16 @@ var (
 	UniqueIdentifier = "UNIQUEIDENTIFIER"
 	SysName          = "SYSNAME"
 
-	Date       = "DATE"
-	DateTime   = "DATETIME"
-	SmallDateTime   = "SMALLDATETIME"
-	Time       = "TIME"
-	TimeStamp  = "TIMESTAMP"
-	TimeStampz = "TIMESTAMPZ"
-
-	Decimal = "DECIMAL"
-	Numeric = "NUMERIC"
-	Money   = "MONEY"
+	Date          = "DATE"
+	DateTime      = "DATETIME"
+	SmallDateTime = "SMALLDATETIME"
+	Time          = "TIME"
+	TimeStamp     = "TIMESTAMP"
+	TimeStampz    = "TIMESTAMPZ"
+
+	Decimal    = "DECIMAL"
+	Numeric    = "NUMERIC"
+	Money      = "MONEY"
 	SmallMoney = "SMALLMONEY"
 
 	Real   = "REAL"
@@ -147,19 +147,19 @@ var (
 		Clob:       TEXT_TYPE,
 		SysName:    TEXT_TYPE,
 
-		Date:       TIME_TYPE,
-		DateTime:   TIME_TYPE,
-		Time:       TIME_TYPE,
-		TimeStamp:  TIME_TYPE,
-		TimeStampz: TIME_TYPE,
-		SmallDateTime:   TIME_TYPE,
-
-		Decimal: NUMERIC_TYPE,
-		Numeric: NUMERIC_TYPE,
-		Real:    NUMERIC_TYPE,
-		Float:   NUMERIC_TYPE,
-		Double:  NUMERIC_TYPE,
-		Money:   NUMERIC_TYPE,
+		Date:          TIME_TYPE,
+		DateTime:      TIME_TYPE,
+		Time:          TIME_TYPE,
+		TimeStamp:     TIME_TYPE,
+		TimeStampz:    TIME_TYPE,
+		SmallDateTime: TIME_TYPE,
+
+		Decimal:    NUMERIC_TYPE,
+		Numeric:    NUMERIC_TYPE,
+		Real:       NUMERIC_TYPE,
+		Float:      NUMERIC_TYPE,
+		Double:     NUMERIC_TYPE,
+		Money:      NUMERIC_TYPE,
 		SmallMoney: NUMERIC_TYPE,
 
 		Binary:    BLOB_TYPE,