123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324 |
- // Copyright 2017 The Xorm Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package main
- import (
- "errors"
- "fmt"
- "go/format"
- "reflect"
- "sort"
- "strings"
- "text/template"
- "github.com/go-xorm/core"
- )
- var (
- supportComment bool
- GoLangTmpl LangTmpl = LangTmpl{
- template.FuncMap{"Mapper": mapper.Table2Obj,
- "Type": typestring,
- "Tag": tag,
- "UnTitle": unTitle,
- "gt": gt,
- "getCol": getCol,
- "UpperTitle": upTitle,
- },
- formatGo,
- genGoImports,
- }
- )
- var (
- errBadComparisonType = errors.New("invalid type for comparison")
- errBadComparison = errors.New("incompatible types for comparison")
- errNoComparison = errors.New("missing argument for comparison")
- )
- type kind int
- const (
- invalidKind kind = iota
- boolKind
- complexKind
- intKind
- floatKind
- integerKind
- stringKind
- uintKind
- )
- func basicKind(v reflect.Value) (kind, error) {
- switch v.Kind() {
- case reflect.Bool:
- return boolKind, nil
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- return intKind, nil
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
- return uintKind, nil
- case reflect.Float32, reflect.Float64:
- return floatKind, nil
- case reflect.Complex64, reflect.Complex128:
- return complexKind, nil
- case reflect.String:
- return stringKind, nil
- }
- return invalidKind, errBadComparisonType
- }
- // eq evaluates the comparison a == b || a == c || ...
- func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
- v1 := reflect.ValueOf(arg1)
- k1, err := basicKind(v1)
- if err != nil {
- return false, err
- }
- if len(arg2) == 0 {
- return false, errNoComparison
- }
- for _, arg := range arg2 {
- v2 := reflect.ValueOf(arg)
- k2, err := basicKind(v2)
- if err != nil {
- return false, err
- }
- if k1 != k2 {
- return false, errBadComparison
- }
- truth := false
- switch k1 {
- case boolKind:
- truth = v1.Bool() == v2.Bool()
- case complexKind:
- truth = v1.Complex() == v2.Complex()
- case floatKind:
- truth = v1.Float() == v2.Float()
- case intKind:
- truth = v1.Int() == v2.Int()
- case stringKind:
- truth = v1.String() == v2.String()
- case uintKind:
- truth = v1.Uint() == v2.Uint()
- default:
- panic("invalid kind")
- }
- if truth {
- return true, nil
- }
- }
- return false, nil
- }
- // lt evaluates the comparison a < b.
- func lt(arg1, arg2 interface{}) (bool, error) {
- v1 := reflect.ValueOf(arg1)
- k1, err := basicKind(v1)
- if err != nil {
- return false, err
- }
- v2 := reflect.ValueOf(arg2)
- k2, err := basicKind(v2)
- if err != nil {
- return false, err
- }
- if k1 != k2 {
- return false, errBadComparison
- }
- truth := false
- switch k1 {
- case boolKind, complexKind:
- return false, errBadComparisonType
- case floatKind:
- truth = v1.Float() < v2.Float()
- case intKind:
- truth = v1.Int() < v2.Int()
- case stringKind:
- truth = v1.String() < v2.String()
- case uintKind:
- truth = v1.Uint() < v2.Uint()
- default:
- panic("invalid kind")
- }
- return truth, nil
- }
- // le evaluates the comparison <= b.
- func le(arg1, arg2 interface{}) (bool, error) {
- // <= is < or ==.
- lessThan, err := lt(arg1, arg2)
- if lessThan || err != nil {
- return lessThan, err
- }
- return eq(arg1, arg2)
- }
- // gt evaluates the comparison a > b.
- func gt(arg1, arg2 interface{}) (bool, error) {
- // > is the inverse of <=.
- lessOrEqual, err := le(arg1, arg2)
- if err != nil {
- return false, err
- }
- return !lessOrEqual, nil
- }
- func getCol(cols map[string]*core.Column, name string) *core.Column {
- return cols[strings.ToLower(name)]
- }
- func formatGo(src string) (string, error) {
- source, err := format.Source([]byte(src))
- if err != nil {
- return "", err
- }
- return string(source), nil
- }
- func genGoImports(tables []*core.Table) map[string]string {
- imports := make(map[string]string)
- for _, table := range tables {
- for _, col := range table.Columns() {
- if typestring(col) == "time.Time" {
- imports["time"] = "time"
- }
- }
- }
- return imports
- }
- func typestring(col *core.Column) string {
- st := col.SQLType
- t := core.SQLType2Type(st)
- s := t.String()
- if s == "[]uint8" {
- return "[]byte"
- }
- return s
- }
- func tag(table *core.Table, col *core.Column) string {
- isNameId := (mapper.Table2Obj(col.Name) == "Id")
- isIdPk := isNameId && typestring(col) == "int64"
- var res []string
- if !col.Nullable {
- if !isIdPk {
- res = append(res, "not null")
- }
- }
- if col.IsPrimaryKey {
- res = append(res, "pk")
- }
- if col.Default != "" {
- res = append(res, "default "+col.Default)
- }
- if col.IsAutoIncrement {
- res = append(res, "autoincr")
- }
- if col.SQLType.IsTime() && include(created, col.Name) {
- res = append(res, "created")
- }
- if col.SQLType.IsTime() && include(updated, col.Name) {
- res = append(res, "updated")
- }
- if col.SQLType.IsTime() && include(deleted, col.Name) {
- res = append(res, "deleted")
- }
- if supportComment && col.Comment != "" {
- res = append(res, fmt.Sprintf("comment('%s')", col.Comment))
- }
- names := make([]string, 0, len(col.Indexes))
- for name := range col.Indexes {
- names = append(names, name)
- }
- sort.Strings(names)
- for _, name := range names {
- index := table.Indexes[name]
- var uistr string
- if index.Type == core.UniqueType {
- uistr = "unique"
- } else if index.Type == core.IndexType {
- uistr = "index"
- }
- if len(index.Cols) > 1 {
- uistr += "(" + index.Name + ")"
- }
- res = append(res, uistr)
- }
- nstr := col.SQLType.Name
- if col.Length != 0 {
- if col.Length2 != 0 {
- nstr += fmt.Sprintf("(%v,%v)", col.Length, col.Length2)
- } else {
- nstr += fmt.Sprintf("(%v)", col.Length)
- }
- } else if len(col.EnumOptions) > 0 { //enum
- nstr += "("
- opts := ""
- enumOptions := make([]string, 0, len(col.EnumOptions))
- for enumOption := range col.EnumOptions {
- enumOptions = append(enumOptions, enumOption)
- }
- sort.Strings(enumOptions)
- for _, v := range enumOptions {
- opts += fmt.Sprintf(",'%v'", v)
- }
- nstr += strings.TrimLeft(opts, ",")
- nstr += ")"
- } else if len(col.SetOptions) > 0 { //enum
- nstr += "("
- opts := ""
- setOptions := make([]string, 0, len(col.SetOptions))
- for setOption := range col.SetOptions {
- setOptions = append(setOptions, setOption)
- }
- sort.Strings(setOptions)
- for _, v := range setOptions {
- opts += fmt.Sprintf(",'%v'", v)
- }
- nstr += strings.TrimLeft(opts, ",")
- nstr += ")"
- }
- res = append(res, nstr)
- var tags []string
- if genJson {
- if include(ignoreColumnsJSON, col.Name) {
- tags = append(tags, "json:\"-\"")
- } else {
- tags = append(tags, "json:\""+col.Name+"\"")
- }
- }
- if len(res) > 0 {
- tags = append(tags, "xorm:\""+strings.Join(res, " ")+"\"")
- }
- if len(tags) > 0 {
- return "`" + strings.Join(tags, " ") + "`"
- } else {
- return ""
- }
- }
- func include(source []string, target string) bool {
- for _, s := range source {
- if s == target {
- return true
- }
- }
- return false
- }
|