Sfoglia il codice sorgente

改用sysutils.GetHostname分割域名列表

huangrf 6 anni fa
parent
commit
1dfa4cf17a
1 ha cambiato i file con 11 aggiunte e 47 eliminazioni
  1. 11 47
      controllers/partial/SsoController.go

+ 11 - 47
controllers/partial/SsoController.go

@@ -1,14 +1,12 @@
 package partial
 
 import (
-	"fmt"
 	"git.qianqiusoft.com/qianqiusoft/light-apiengine/config"
 	"git.qianqiusoft.com/qianqiusoft/light-apiengine/entitys"
 	"git.qianqiusoft.com/qianqiusoft/light-apiengine/logs"
 	sysmodel "git.qianqiusoft.com/qianqiusoft/light-apiengine/models"
 	sysutils "git.qianqiusoft.com/qianqiusoft/light-apiengine/utils"
 	"git.qianqiusoft.com/qianqiusoft/light-apiengine/utils/auth"
-	"regexp"
 	"strings"
 
 	//"git.qianqiusoft.com/qianqiusoft/light-apiengine/models"
@@ -22,7 +20,7 @@ import (
 // @Failure 403 :id is empty
 func Sso_Login(c *entitys.CtrlContext) {
 	iauth := getAuth(c)
-	if iauth == nil{
+	if iauth == nil {
 		hostname := sysutils.GetHostname(c.Ctx)
 		c.Ctx.JSON(500, sysmodel.SysReturn{500, "iauth of " + hostname + " is nil", nil})
 		return
@@ -37,7 +35,7 @@ func Sso_Login(c *entitys.CtrlContext) {
 // @Failure 403 :id is empty
 func Sso_Logout(c *entitys.CtrlContext) {
 	iauth := getAuth(c)
-	if iauth == nil{
+	if iauth == nil {
 		hostname := sysutils.GetHostname(c.Ctx)
 		c.Ctx.JSON(500, sysmodel.SysReturn{500, "iauth of " + hostname + " is nil", nil})
 		return
@@ -91,55 +89,21 @@ func Sso_TokenValidate(c *entitys.CtrlContext) {
 	c.Ctx.JSON(200, user)
 }
 
-
-
-func getAuth(c *entitys.CtrlContext)auth.IAuth{
+func getAuth(c *entitys.CtrlContext) auth.IAuth {
 	var iauth auth.IAuth = nil
 	authMode := config.AppConfig.GetKey("auth_mode")
-	if authMode == "local"{
+	if authMode == "local" {
 		iauth = auth.GetAuth("qianqiusoft.com")
-	}else{
-		hostname := sysutils.GetHostname(c.Ctx)
-		tld := getTLD(hostname)
-		fmt.Println("------>hostname", hostname, "tld", tld)
-		iauth = auth.GetAuth(hostname)
-		if iauth == nil{
-			iauth = auth.GetAuth(tld)
-		}
-	}
-	return iauth
-}
-
-/**
-* @brief: 获取一级域名
-× @param1 hostname: 请求名称
-*/
-func getTLD(hostname string)string{
-	patternstr := `(2(5[0-5]{1}|[0-4]\d{1})|[0-1]?\d{1,2})(\.(2(5[0-5]{1}|[0-4]\d{1})|[0-1]?\d{1,2})){3}`
-	reg := regexp.MustCompile(patternstr)
-	if res := reg.FindAllString(hostname, -1); res == nil {
-		size := 0
-		if strings.HasSuffix(hostname, "gov.cn") || strings.HasSuffix(hostname, "edu.cn"){
-			size = 3
-		}else{
-			size = 2
-		}
-		hnarr := strings.Split(hostname, ".")
-
-		if len(hnarr) >= size{
-			tld := hnarr[len(hnarr) - size]
-			for i := size - 1; i >= 1; i--{
-				tld += "." + hnarr[len(hnarr) - i]
+	} else {
+		hostnames := sysutils.GetHostnames(c.Ctx)
+		for i := range hostnames {
+			iauth = auth.GetAuth(hostnames[i])
+			if iauth != nil {
+				break
 			}
-			return tld
-		}else{
-			// 少于两个的直接返回
-			return hostname
 		}
-	} else {
-		// 直接返回ip
-		return hostname
 	}
+	return iauth
 }
 
 func __none_func_sso__(params ...interface{}) bool {