Explorar el Código

Handle xorm tags via tagHandler prepared for customerize tag support

xormplus hace 8 años
padre
commit
0084808e69
Se han modificado 4 ficheros con 355 adiciones y 159 borrados
  1. 69 158
      engine.go
  2. 4 1
      session_schema.go
  3. 281 0
      tag.go
  4. 1 0
      xorm.go

+ 69 - 158
engine.go

@@ -783,13 +783,18 @@ func (engine *Engine) Having(conditions string) *Session {
 	return session.Having(conditions)
 }
 
-func (engine *Engine) autoMapType(v reflect.Value) *core.Table {
+func (engine *Engine) autoMapType(v reflect.Value) (*core.Table, error) {
 	t := v.Type()
 	engine.mutex.Lock()
 	defer engine.mutex.Unlock()
 	table, ok := engine.Tables[t]
 	if !ok {
-		table = engine.mapType(v)
+		var err error
+		table, err = engine.mapType(v)
+		if err != nil {
+			return nil, err
+		}
+
 		engine.Tables[t] = table
 		if engine.Cacher != nil {
 			if v.CanAddr() {
@@ -799,12 +804,11 @@ func (engine *Engine) autoMapType(v reflect.Value) *core.Table {
 			}
 		}
 	}
-	return table
+	return table, nil
 }
 
 // GobRegister register one struct to gob for cache use
 func (engine *Engine) GobRegister(v interface{}) *Engine {
-	//fmt.Printf("Type: %[1]T => Data: %[1]#v\n", v)
 	gob.Register(v)
 	return engine
 }
@@ -851,7 +855,7 @@ var (
 	tpTableName = reflect.TypeOf((*TableName)(nil)).Elem()
 )
 
-func (engine *Engine) mapType(v reflect.Value) *core.Table {
+func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
 	t := v.Type()
 	table := engine.newTable()
 	if tb, ok := v.Interface().(TableName); ok {
@@ -870,7 +874,6 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
 	table.Type = t
 
 	var idFieldColName string
-	var err error
 	var hasCacheTag, hasNoCacheTag bool
 
 	for i := 0; i < t.NumField(); i++ {
@@ -890,186 +893,94 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
 				if tags[0] == "-" {
 					continue
 				}
+
+				var ctx = tagContext{
+					table:      table,
+					col:        col,
+					fieldValue: fieldValue,
+					indexNames: make(map[string]int),
+					engine:     engine,
+				}
+
 				if strings.ToUpper(tags[0]) == "EXTENDS" {
-					switch fieldValue.Kind() {
-					case reflect.Ptr:
-						f := fieldValue.Type().Elem()
-						if f.Kind() == reflect.Struct {
-							fieldPtr := fieldValue
-							fieldValue = fieldValue.Elem()
-							if !fieldValue.IsValid() || fieldPtr.IsNil() {
-								fieldValue = reflect.New(f).Elem()
-							}
-						}
-						fallthrough
-					case reflect.Struct:
-						parentTable := engine.mapType(fieldValue)
-						for _, col := range parentTable.Columns() {
-							col.FieldName = fmt.Sprintf("%v.%v", t.Field(i).Name, col.FieldName)
-							table.AddColumn(col)
-							for indexName, indexType := range col.Indexes {
-								addIndex(indexName, table, col, indexType)
-							}
-						}
-						continue
-					default:
-						//TODO: warning
+					if err := ExtendsTagHandler(&ctx); err != nil {
+						return nil, err
 					}
+					continue
 				}
 
-				indexNames := make(map[string]int)
-				var isIndex, isUnique bool
-				var preKey string
 				for j, key := range tags {
+					if ctx.ignoreNext {
+						ctx.ignoreNext = false
+						continue
+					}
+
 					k := strings.ToUpper(key)
-					switch {
-					case k == "<-":
-						col.MapType = core.ONLYFROMDB
-					case k == "->":
-						col.MapType = core.ONLYTODB
-					case k == "PK":
-						col.IsPrimaryKey = true
-						col.Nullable = false
-					case k == "NULL":
-						if j == 0 {
-							col.Nullable = true
-						} else {
-							col.Nullable = (strings.ToUpper(tags[j-1]) != "NOT")
-						}
-					// TODO: for postgres how add autoincr?
-					/*case strings.HasPrefix(k, "AUTOINCR(") && strings.HasSuffix(k, ")"):
-					col.IsAutoIncrement = true
+					ctx.tagName = k
 
-					autoStart := k[len("AUTOINCR")+1 : len(k)-1]
-					autoStartInt, err := strconv.Atoi(autoStart)
-					if err != nil {
-						engine.LogError(err)
+					pStart := strings.Index(k, "(")
+					if pStart == 0 {
+						return nil, errors.New("( could not be the first charactor")
 					}
-					col.AutoIncrStart = autoStartInt*/
-					case k == "AUTOINCR":
-						col.IsAutoIncrement = true
-						//col.AutoIncrStart = 1
-					case k == "DEFAULT":
-						col.Default = tags[j+1]
-					case k == "CREATED":
-						col.IsCreated = true
-					case k == "VERSION":
-						col.IsVersion = true
-						col.Default = "1"
-					case k == "UTC":
-						col.TimeZone = time.UTC
-					case k == "LOCAL":
-						col.TimeZone = time.Local
-					case strings.HasPrefix(k, "LOCALE(") && strings.HasSuffix(k, ")"):
-						location := k[len("LOCALE")+1 : len(k)-1]
-						col.TimeZone, err = time.LoadLocation(location)
-						if err != nil {
-							engine.logger.Error(err)
-						}
-					case k == "UPDATED":
-						col.IsUpdated = true
-					case k == "DELETED":
-						col.IsDeleted = true
-					case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"):
-						indexName := k[len("INDEX")+1 : len(k)-1]
-						indexNames[indexName] = core.IndexType
-					case k == "INDEX":
-						isIndex = true
-					case strings.HasPrefix(k, "UNIQUE(") && strings.HasSuffix(k, ")"):
-						indexName := k[len("UNIQUE")+1 : len(k)-1]
-						indexNames[indexName] = core.UniqueType
-					case k == "UNIQUE":
-						isUnique = true
-					case k == "NOTNULL":
-						col.Nullable = false
-					case k == "CACHE":
-						if !hasCacheTag {
-							hasCacheTag = true
+					if pStart > -1 {
+						if !strings.HasSuffix(k, ")") {
+							return nil, errors.New("cannot match ) charactor")
 						}
-					case k == "NOCACHE":
-						if !hasNoCacheTag {
-							hasNoCacheTag = true
+
+						ctx.tagName = k[:pStart]
+						ctx.params = strings.Split(k[pStart+1:len(k)-1], ",")
+					}
+
+					if j > 0 {
+						ctx.preTag = strings.ToUpper(tags[j-1])
+					}
+					if j < len(tags)-1 {
+						ctx.nextTag = strings.ToUpper(tags[j+1])
+					} else {
+						ctx.nextTag = ""
+					}
+
+					if h, ok := engine.tagHandlers[ctx.tagName]; ok {
+						if err := h(&ctx); err != nil {
+							return nil, err
 						}
-					case k == "NOT":
-					default:
-						if strings.HasPrefix(k, "'") && strings.HasSuffix(k, "'") {
-							if preKey != "DEFAULT" {
-								col.Name = key[1 : len(key)-1]
-							}
-						} else if strings.Contains(k, "(") && strings.HasSuffix(k, ")") {
-							fs := strings.Split(k, "(")
-
-							if _, ok := core.SqlTypes[fs[0]]; !ok {
-								preKey = k
-								continue
-							}
-							col.SQLType = core.SQLType{Name: fs[0]}
-							if fs[0] == core.Enum && fs[1][0] == '\'' { //enum
-								options := strings.Split(fs[1][0:len(fs[1])-1], ",")
-								col.EnumOptions = make(map[string]int)
-								for k, v := range options {
-									v = strings.TrimSpace(v)
-									v = strings.Trim(v, "'")
-									col.EnumOptions[v] = k
-								}
-							} else if fs[0] == core.Set && fs[1][0] == '\'' { //set
-								options := strings.Split(fs[1][0:len(fs[1])-1], ",")
-								col.SetOptions = make(map[string]int)
-								for k, v := range options {
-									v = strings.TrimSpace(v)
-									v = strings.Trim(v, "'")
-									col.SetOptions[v] = k
-								}
-							} else {
-								fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",")
-								if len(fs2) == 2 {
-									col.Length, err = strconv.Atoi(fs2[0])
-									if err != nil {
-										engine.logger.Error(err)
-									}
-									col.Length2, err = strconv.Atoi(fs2[1])
-									if err != nil {
-										engine.logger.Error(err)
-									}
-								} else if len(fs2) == 1 {
-									col.Length, err = strconv.Atoi(fs2[0])
-									if err != nil {
-										engine.logger.Error(err)
-									}
-								}
-							}
+					} else {
+						if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") {
+							col.Name = key[1 : len(key)-1]
 						} else {
-							if _, ok := core.SqlTypes[k]; ok {
-								col.SQLType = core.SQLType{Name: k}
-							} else if key != col.Default {
-								col.Name = key
-							}
+							col.Name = key
 						}
-						engine.dialect.SqlType(col)
 					}
-					preKey = k
+
+					if ctx.hasCacheTag {
+						hasCacheTag = true
+					}
+					if ctx.hasNoCacheTag {
+						hasNoCacheTag = true
+					}
 				}
+
 				if col.SQLType.Name == "" {
 					col.SQLType = core.Type2SQLType(fieldType)
 				}
+				engine.dialect.SqlType(col)
 				if col.Length == 0 {
 					col.Length = col.SQLType.DefaultLength
 				}
 				if col.Length2 == 0 {
 					col.Length2 = col.SQLType.DefaultLength2
 				}
-
 				if col.Name == "" {
 					col.Name = engine.ColumnMapper.Obj2Table(t.Field(i).Name)
 				}
 
-				if isUnique {
-					indexNames[col.Name] = core.UniqueType
-				} else if isIndex {
-					indexNames[col.Name] = core.IndexType
+				if ctx.isUnique {
+					ctx.indexNames[col.Name] = core.UniqueType
+				} else if ctx.isIndex {
+					ctx.indexNames[col.Name] = core.IndexType
 				}
 
-				for indexName, indexType := range indexNames {
+				for indexName, indexType := range ctx.indexNames {
 					addIndex(indexName, table, col, indexType)
 				}
 			}
@@ -1123,7 +1034,7 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
 		table.Cacher = nil
 	}
 
-	return table
+	return table, nil
 }
 
 // IsTableEmpty if a table has any reocrd

+ 4 - 1
session_schema.go

@@ -306,7 +306,10 @@ func (session *Session) Sync2(beans ...interface{}) error {
 
 	for _, bean := range beans {
 		v := rValue(bean)
-		table := engine.mapType(v)
+		table, err := engine.mapType(v)
+		if err != nil {
+			return err
+		}
 		structTables = append(structTables, table)
 		var tbName = session.tbNameNoSchema(table)
 

+ 281 - 0
tag.go

@@ -0,0 +1,281 @@
+// Copyright 2017 The Xorm Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xorm
+
+import (
+	"fmt"
+	"reflect"
+	"strconv"
+	"strings"
+	"time"
+
+	"github.com/go-xorm/core"
+)
+
+type tagContext struct {
+	tagName         string
+	params          []string
+	preTag, nextTag string
+	table           *core.Table
+	col             *core.Column
+	fieldValue      reflect.Value
+	isIndex         bool
+	isUnique        bool
+	indexNames      map[string]int
+	engine          *Engine
+	hasCacheTag     bool
+	hasNoCacheTag   bool
+	ignoreNext      bool
+}
+
+// tagHandler describes tag handler for XORM
+type tagHandler func(ctx *tagContext) error
+
+var (
+	// defaultTagHandlers enumerates all the default tag handler
+	defaultTagHandlers = map[string]tagHandler{
+		"<-":       OnlyFromDBTagHandler,
+		"->":       OnlyToDBTagHandler,
+		"PK":       PKTagHandler,
+		"NULL":     NULLTagHandler,
+		"NOT":      IgnoreTagHandler,
+		"AUTOINCR": AutoIncrTagHandler,
+		"DEFAULT":  DefaultTagHandler,
+		"CREATED":  CreatedTagHandler,
+		"UPDATED":  UpdatedTagHandler,
+		"DELETED":  DeletedTagHandler,
+		"VERSION":  VersionTagHandler,
+		"UTC":      UTCTagHandler,
+		"LOCAL":    LocalTagHandler,
+		"NOTNULL":  NotNullTagHandler,
+		"INDEX":    IndexTagHandler,
+		"UNIQUE":   UniqueTagHandler,
+		"CACHE":    CacheTagHandler,
+		"NOCACHE":  NoCacheTagHandler,
+	}
+)
+
+func init() {
+	for k := range core.SqlTypes {
+		defaultTagHandlers[k] = SQLTypeTagHandler
+	}
+}
+
+// IgnoreTagHandler describes ignored tag handler
+func IgnoreTagHandler(ctx *tagContext) error {
+	return nil
+}
+
+// OnlyFromDBTagHandler describes mapping direction tag handler
+func OnlyFromDBTagHandler(ctx *tagContext) error {
+	ctx.col.MapType = core.ONLYFROMDB
+	return nil
+}
+
+// OnlyToDBTagHandler describes mapping direction tag handler
+func OnlyToDBTagHandler(ctx *tagContext) error {
+	ctx.col.MapType = core.ONLYTODB
+	return nil
+}
+
+// PKTagHandler decribes primary key tag handler
+func PKTagHandler(ctx *tagContext) error {
+	ctx.col.IsPrimaryKey = true
+	ctx.col.Nullable = false
+	return nil
+}
+
+// NULLTagHandler describes null tag handler
+func NULLTagHandler(ctx *tagContext) error {
+	ctx.col.Nullable = (strings.ToUpper(ctx.preTag) != "NOT")
+	return nil
+}
+
+// NotNullTagHandler describes notnull tag handler
+func NotNullTagHandler(ctx *tagContext) error {
+	ctx.col.Nullable = false
+	return nil
+}
+
+// AutoIncrTagHandler describes autoincr tag handler
+func AutoIncrTagHandler(ctx *tagContext) error {
+	ctx.col.IsAutoIncrement = true
+	/*
+		if len(ctx.params) > 0 {
+			autoStartInt, err := strconv.Atoi(ctx.params[0])
+			if err != nil {
+				return err
+			}
+			ctx.col.AutoIncrStart = autoStartInt
+		} else {
+			ctx.col.AutoIncrStart = 1
+		}
+	*/
+	return nil
+}
+
+// DefaultTagHandler describes default tag handler
+func DefaultTagHandler(ctx *tagContext) error {
+	if len(ctx.params) > 0 {
+		ctx.col.Default = ctx.params[0]
+	} else {
+		ctx.col.Default = ctx.nextTag
+		ctx.ignoreNext = true
+	}
+	return nil
+}
+
+// CreatedTagHandler describes created tag handler
+func CreatedTagHandler(ctx *tagContext) error {
+	ctx.col.IsCreated = true
+	return nil
+}
+
+// VersionTagHandler describes version tag handler
+func VersionTagHandler(ctx *tagContext) error {
+	ctx.col.IsVersion = true
+	ctx.col.Default = "1"
+	return nil
+}
+
+// UTCTagHandler describes utc tag handler
+func UTCTagHandler(ctx *tagContext) error {
+	ctx.col.TimeZone = time.UTC
+	return nil
+}
+
+// LocalTagHandler describes local tag handler
+func LocalTagHandler(ctx *tagContext) error {
+	if len(ctx.params) == 0 {
+		ctx.col.TimeZone = time.Local
+	} else {
+		var err error
+		ctx.col.TimeZone, err = time.LoadLocation(ctx.params[0])
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+// UpdatedTagHandler describes updated tag handler
+func UpdatedTagHandler(ctx *tagContext) error {
+	ctx.col.IsUpdated = true
+	return nil
+}
+
+// DeletedTagHandler describes deleted tag handler
+func DeletedTagHandler(ctx *tagContext) error {
+	ctx.col.IsDeleted = true
+	return nil
+}
+
+// IndexTagHandler describes index tag handler
+func IndexTagHandler(ctx *tagContext) error {
+	if len(ctx.params) > 0 {
+		ctx.indexNames[ctx.params[0]] = core.IndexType
+	} else {
+		ctx.isIndex = true
+	}
+	return nil
+}
+
+// UniqueTagHandler describes unique tag handler
+func UniqueTagHandler(ctx *tagContext) error {
+	if len(ctx.params) > 0 {
+		ctx.indexNames[ctx.params[0]] = core.UniqueType
+	} else {
+		ctx.isUnique = true
+	}
+	return nil
+}
+
+// SQLTypeTagHandler describes SQL Type tag handler
+func SQLTypeTagHandler(ctx *tagContext) error {
+	ctx.col.SQLType = core.SQLType{Name: ctx.tagName}
+	if len(ctx.params) > 0 {
+		if ctx.tagName == core.Enum {
+			ctx.col.EnumOptions = make(map[string]int)
+			for k, v := range ctx.params {
+				v = strings.TrimSpace(v)
+				v = strings.Trim(v, "'")
+				ctx.col.EnumOptions[v] = k
+			}
+		} else if ctx.tagName == core.Set {
+			ctx.col.SetOptions = make(map[string]int)
+			for k, v := range ctx.params {
+				v = strings.TrimSpace(v)
+				v = strings.Trim(v, "'")
+				ctx.col.SetOptions[v] = k
+			}
+		} else {
+			var err error
+			if len(ctx.params) == 2 {
+				ctx.col.Length, err = strconv.Atoi(ctx.params[0])
+				if err != nil {
+					return err
+				}
+				ctx.col.Length2, err = strconv.Atoi(ctx.params[1])
+				if err != nil {
+					return err
+				}
+			} else if len(ctx.params) == 1 {
+				ctx.col.Length, err = strconv.Atoi(ctx.params[0])
+				if err != nil {
+					return err
+				}
+			}
+		}
+	}
+	return nil
+}
+
+// ExtendsTagHandler describes extends tag handler
+func ExtendsTagHandler(ctx *tagContext) error {
+	var fieldValue = ctx.fieldValue
+	switch fieldValue.Kind() {
+	case reflect.Ptr:
+		f := fieldValue.Type().Elem()
+		if f.Kind() == reflect.Struct {
+			fieldPtr := fieldValue
+			fieldValue = fieldValue.Elem()
+			if !fieldValue.IsValid() || fieldPtr.IsNil() {
+				fieldValue = reflect.New(f).Elem()
+			}
+		}
+		fallthrough
+	case reflect.Struct:
+		parentTable, err := ctx.engine.mapType(fieldValue)
+		if err != nil {
+			return err
+		}
+		for _, col := range parentTable.Columns() {
+			col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName)
+			ctx.table.AddColumn(col)
+			for indexName, indexType := range col.Indexes {
+				addIndex(indexName, ctx.table, col, indexType)
+			}
+		}
+	default:
+		//TODO: warning
+	}
+	return nil
+}
+
+// CacheTagHandler describes cache tag handler
+func CacheTagHandler(ctx *tagContext) error {
+	if !ctx.hasCacheTag {
+		ctx.hasCacheTag = true
+	}
+	return nil
+}
+
+// NoCacheTagHandler describes nocache tag handler
+func NoCacheTagHandler(ctx *tagContext) error {
+	if !ctx.hasNoCacheTag {
+		ctx.hasNoCacheTag = true
+	}
+	return nil
+}

+ 1 - 0
xorm.go

@@ -86,6 +86,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
 		mutex:         &sync.RWMutex{},
 		TagIdentifier: "xorm",
 		TZLocation:    time.Local,
+		tagHandlers:   defaultTagHandlers,
 	}
 
 	logger := NewSimpleLogger(os.Stdout)