package utils import ( "bytes" "compress/gzip" "errors" "fmt" "go/format" "html/template" "net/url" "os" "path/filepath" "strconv" "strings" "time" "unicode" "unsafe" ) var DSNError = errors.New("dsn string error") func FirstCharacter(name string) string { return strings.ToLower(name)[:1] } func CamelizeStr(s string, upperCase bool) string { if len(s) == 0 { return s } s = replaceInvalidChars(s) var result string words := strings.Split(s, "_") for i, word := range words { //if upper := strings.ToUpper(word); commonInitialisms[upper] { // result += upper // continue //} if i > 0 || upperCase { result += camelizeWord(word) } else { result += word } } return result } func camelizeWord(word string) string { runes := []rune(word) for i, r := range runes { if i == 0 { runes[i] = unicode.ToUpper(r) } else { runes[i] = r //runes[i] = unicode.ToLower(r) } } return string(runes) } func replaceInvalidChars(str string) string { str = strings.ReplaceAll(str, "-", "_") str = strings.ReplaceAll(str, " ", "_") return strings.ReplaceAll(str, ".", "_") } // https://github.com/golang/lint/blob/206c0f020eba0f7fbcfbc467a5eb808037df2ed6/lint.go#L731 var commonInitialisms = map[string]bool{ "ACL": true, "API": true, "ASCII": true, "CPU": true, "CSS": true, "DNS": true, "EOF": true, "ETA": true, "GPU": true, "GUID": true, "HTML": true, "HTTP": true, "HTTPS": true, "ID": true, "IP": true, "JSON": true, "LHS": true, "OS": true, "QPS": true, "RAM": true, "RHS": true, "RPC": true, "SLA": true, "SMTP": true, "SQL": true, "SSH": true, "TCP": true, "TLS": true, "TTL": true, "UDP": true, "UI": true, "UID": true, "UUID": true, "URI": true, "URL": true, "UTF8": true, "VM": true, "XML": true, "XMPP": true, "XSRF": true, "XSS": true, "OAuth": true, } func GetDbNameFromDSN(dsn string) (string, error) { if len(strings.Split(dsn, " ")) > 1 { return getDbNameFromDsn(dsn) } index := strings.LastIndex(dsn, "/") if index <= 0 { return getDbNameFromDsn(dsn) } str := dsn[index:] urlStr, err := url.Parse(str) if err != nil { return "", err } return strings.Trim(urlStr.Path, "/"), nil } // host=127.0.0.1 dbname=test sslmode=disable Timezone=Asia/Shanghai const dbNamePrefix = "dbname=" func getDbNameFromDsn(dsn string) (string, error) { strArray := strings.Split(dsn, " ") for _, item := range strArray { if strings.HasPrefix(item, dbNamePrefix) { return strings.TrimPrefix(item, dbNamePrefix), nil } } return "", DSNError } func SaveFile(dirPath, fileName string, text []byte) error { file, err := os.Create(filepath.Join(dirPath, fileName)) if err != nil { return err } defer file.Close() p, err := format.Source(text) if err != nil { return err } _, err = file.Write(p) return err } func MkdirPathIfNotExist(dirPath string) error { if _, err := os.Stat(dirPath); os.IsNotExist(err) { return os.MkdirAll(dirPath, os.ModePerm) } return nil } func CleanUpGenFiles(dir string) error { exist, err := FileExists(dir) if err != nil { return err } if exist { return os.RemoveAll(dir) } return nil } // FileExists reports whether the named file or directory exists. func FileExists(name string) (bool, error) { if _, err := os.Stat(name); err != nil { if os.IsNotExist(err) { return false, err } } return true, nil } func ZipBytes(data []byte) []byte { var out bytes.Buffer w := gzip.NewWriter(&out) defer w.Close() w.Write(data) w.Flush() return out.Bytes() } func UnzipBytes(data []byte) []byte { var in bytes.Buffer in.Write(data) r, _ := gzip.NewReader(&in) var out bytes.Buffer defer r.Close() out.ReadFrom(r) return out.Bytes() } func Str2bytes(s string) []byte { x := (*[2]uintptr)(unsafe.Pointer(&s)) h := [3]uintptr{x[0], x[1], x[1]} return *(*[]byte)(unsafe.Pointer(&h)) } func Bytes2str(b []byte) string { return *(*string)(unsafe.Pointer(&b)) } func If(condition bool, trueVal, falseVal interface{}) interface{} { if condition { return trueVal } return falseVal } func NewResult(rType string, rName string, apigen func(templateData interface{}, params map[string]interface{}) ([]byte, error), templateData interface{}, isCompress bool) *GenerateResult { gr := GenerateResult{} gr.Type = rType gr.Name = rName genBytes, err := apigen(templateData, nil) if err != nil { fmt.Println(err.Error()) return nil } //fmt.Println(string(genBytes)) if isCompress { gr.Content = ZipBytes(genBytes) } else { gr.Content = genBytes } return &gr } func ParamName(index int, param XmlApiParam) string { str := param.Name if str == "" { str = "paramObj" + strconv.Itoa(index) } return str } func GenTemplate(templateText string, templateData interface{}, params map[string]interface{}) ([]byte, error) { t, err := template.New("template").Funcs(template.FuncMap{ "CamelizeStr": CamelizeStr, "FirstCharacter": FirstCharacter, "Replace": func(old, new, src string) string { return strings.ReplaceAll(src, old, new) }, "Add": func(a, b int) int { return a + b }, "Now": func() string { return time.Now().Format(time.RFC3339) }, "Xorm": func(col XmlColumn) string { str := fmt.Sprintf(`xorm:"%s %s %s %s %s %s" `, "'"+col.Name+"'", col.DbType, If(col.IsPK, "pk", ""), If(col.AutoIncrement, "autoincr", ""), If(col.IsNull, "null", "notnull"), If(col.IsUnique, "unique", If(col.IsIndex, "index", ""))) return str }, "XormTime": func(time string) interface{} { return If(time == "local_time", "sysmodel.LocalTime", time) }, "Backquote": func() string { return "`" }, // 定义函数unescaped "Unescaped": func(x string) interface{} { return template.HTML(x) }, "TrimPrefix": strings.TrimPrefix, "TrimSuffix": strings.TrimSuffix, "Contains": strings.Contains, "ParamName": ParamName, "ParamsContainDT": func(params []XmlColumn) bool { isContain := false for _, v := range params { if strings.ToLower(v.DbType) == "datetime" { isContain = true break } } return isContain }, "AllParams": func(params []XmlApiParam) string { str := "" for i, v := range params { if i > 0 { str += ", " } str += ParamName(i, v) } return str }, "SqlAllColumns": func(cols []XmlColumn, isValues bool) string { str := "" for i, v := range cols { if i > 0 { str += ", " } if isValues { str += "?" + v.Name } else { str += "`" + v.Name + "`" } } return str }, "SqlNoPKUpdate": func(cols []XmlColumn) string { str := "" for _, v := range cols { if !v.IsPK { str += "`" + v.Name + "` = ?" + v.Name str += ", " } } out := strings.TrimSuffix(str, ", ") return out }, "SqlPKWhere": func(cols []XmlColumn) string { str := "" for _, v := range cols { if v.IsPK { str += "`" + v.Name + "` = ?" + v.Name str += " AND " } } out := strings.TrimSuffix(str, " AND ") return out }, "Param": func(name string) interface{} { if v, ok := params[name]; ok { return v } return "" }, }).Parse(templateText) if err != nil { fmt.Println(err.Error()) return nil, err } var buf bytes.Buffer if err := t.Execute(&buf, templateData); err != nil { fmt.Println(err.Error()) return nil, err } return buf.Bytes(), nil }