Browse Source

支持 通过 http 的方式获取 token 信息

paddy 3 years ago
parent
commit
027eb39247
5 changed files with 51 additions and 51 deletions
  1. 3 0
      .gitignore
  2. 1 1
      go.mod
  3. 3 3
      internal/logic/parse_token_logic.go
  4. 43 44
      internal/utils/token.go
  5. 1 3
      transform.go

+ 3 - 0
.gitignore

@@ -0,0 +1,3 @@
+
+*_gen.yaml
+

+ 1 - 1
go.mod

@@ -19,7 +19,7 @@ require (
 	github.com/tealeg/xlsx v1.0.5 // indirect
 	github.com/thoas/go-funk v0.8.0
 	github.com/xormplus/builder v0.0.0-20200331055651-240ff40009be // indirect
-	github.com/xormplus/xorm v0.0.0-20210512135344-8123d584d5f5 // indirect
+	github.com/xormplus/xorm v0.0.0-20210512135344-8123d584d5f5
 	go.uber.org/automaxprocs v1.4.0 // indirect
 	google.golang.org/grpc v1.29.1
 	google.golang.org/protobuf v1.26.0-rc.1

+ 3 - 3
internal/logic/parse_token_logic.go

@@ -36,11 +36,11 @@ func NewParseTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ParseT
 }
 
 func (l *ParseTokenLogic) ParseToken(in *transform.TokenRequest) (*transform.TokenResponse, error) {
-	erpToken := utils.GetGlobalTokenStore().Get(in.Token)
+	userId := utils.GetGlobalTokenStore().Get(in.Token)
 	privileges := ErpUserPrivileges{}
 	tr := transform.TokenResponse{}
-	if erpToken != nil {
-		tr.UserId = erpToken.UserId
+	if userId != "" {
+		tr.UserId = userId
 		l.svcCtx.SqlConn.QueryRowPartial(&privileges, `SELECT
 		GROUP_CONCAT(DISTINCT sys_user_role.role_id)user_role_ids
 		FROM

+ 43 - 44
internal/utils/token.go

@@ -8,10 +8,13 @@ import (
 	"crypto/md5"
 	"encoding/binary"
 	"encoding/hex"
+	"encoding/json"
 	"errors"
 	"fmt"
+	"io/ioutil"
 	"log"
 	"net"
+	"net/http"
 	"strconv"
 	"strings"
 	"sync"
@@ -36,6 +39,8 @@ type Token struct {
 }
 
 type TokenStore struct {
+	storeServiceUrl string
+
 	name   string
 	lock   *sync.RWMutex
 	tokens map[string]*Token
@@ -47,13 +52,21 @@ type IAuth interface {
 
 var globalTokenStore *TokenStore = nil
 
-func init() {
-	iauthMap = make(map[string]IAuth)
+func Init(c *config.Config) {
 	globalTokenStore = &TokenStore{name: "sso", lock: new(sync.RWMutex), tokens: make(map[string]*Token)}
-	go globalTokenStore.startTokenCheckProcess()
+	if strings.HasPrefix(c.Erp.AuthServer, "http") {
+		globalTokenStore.storeServiceUrl = c.Erp.AuthServer
+	} else {
+		iauthMap = make(map[string]IAuth)
+
+		go globalTokenStore.startTokenCheckProcess()
+
+		lightAuth := &LightAuth{}
+		RegisterAuth("qianqiusoft.com", lightAuth)
 
-	lightAuth := &LightAuth{}
-	RegisterAuth("qianqiusoft.com", lightAuth)
+		erpClient := NewTcpClient(c)
+		erpClient.Start()
+	}
 }
 
 type LightAuth struct {
@@ -89,14 +102,32 @@ func GetGlobalTokenStore() *TokenStore {
 	return globalTokenStore
 }
 
-func (t *TokenStore) Get(key string) *Token {
-	t.lock.RLock()
-	defer t.lock.RUnlock()
-	if val, ok := t.tokens[key]; ok {
-		//log.Println(key, "获取Token:", val.AccessToken, val.RefreshToken, val.LoginID)
-		return val
+func (t *TokenStore) Get(key string) string {
+	if t.storeServiceUrl == "" {
+		t.lock.RLock()
+		defer t.lock.RUnlock()
+		if val, ok := t.tokens[key]; ok {
+			//log.Println(key, "获取Token:", val.AccessToken, val.RefreshToken, val.LoginID)
+			return val.UserId
+		}
+	} else {
+		resp, err := http.Get(t.storeServiceUrl + key)
+		if err != nil {
+			return ""
+		}
+		b, err := ioutil.ReadAll(resp.Body)
+		if err != nil {
+			return ""
+		}
+		token := &Token{}
+		err = json.Unmarshal(b, &token)
+		if err != nil {
+			return ""
+		}
+		return token.UserId
 	}
-	return nil
+
+	return ""
 }
 
 func (t *TokenStore) Set(key string, v *Token) {
@@ -129,38 +160,6 @@ func (t *TokenStore) Refresh(key string) {
 func (t *TokenStore) startTokenCheckProcess() {
 }
 
-func Validate(accessToken, loginId string, domain string) (*Token, error) {
-	token := globalTokenStore.Get(loginId + domain)
-	if token != nil {
-		if strings.EqualFold(token.AccessToken, accessToken) {
-			logx.Info("get the token ", accessToken, " of id ", loginId+domain)
-			globalTokenStore.Refresh(loginId + domain)
-			return token, nil
-		} else {
-			logx.Error(token.AccessToken, "is not equal to", accessToken)
-			return token, errors.New(token.AccessToken + " is not equal to " + accessToken)
-		}
-	} else {
-		logx.Error("can not get the token of", loginId+domain)
-		return token, errors.New("can not get the token of " + loginId + domain)
-	}
-
-}
-
-func TokenValidate(token string) (*Token, error) {
-	user := globalTokenStore.Get(token)
-
-	if strings.EqualFold(user.AccessToken, token) {
-		logx.Info("get the token ", token, " of id ")
-		globalTokenStore.Refresh(token)
-		return user, nil
-	} else {
-		logx.Error(user.AccessToken, "is not equal to", token)
-		return user, errors.New(user.AccessToken + " is not equal to " + token)
-	}
-
-}
-
 const (
 	__KEY = "Light#dauth-@*I2"
 

+ 1 - 3
transform.go

@@ -22,9 +22,7 @@ func main() {
 	var c config.Config
 	conf.MustLoad(*configFile, &c)
 
-	// erp client
-	erpClient := utils.NewTcpClient(&c)
-	erpClient.Start()
+	utils.Init(&c)
 
 	ctx := svc.NewServiceContext(c)
 	srv := server.NewTransformServer(ctx)