瀏覽代碼

fix: QueryRowPartial

2637309949 4 年之前
父節點
當前提交
d9bbde61d5
共有 3 個文件被更改,包括 23 次插入5 次删除
  1. 1 1
      internal/logic/getuserlogic.go
  2. 2 2
      internal/logic/loginbyweixinlogic.go
  3. 20 2
      internal/svc/servicecontext.go

+ 1 - 1
internal/logic/getuserlogic.go

@@ -29,7 +29,7 @@ func NewGetUserLogic(ctx context.Context, svcCtx *svc.ServiceContext) GetUserLog
 
 func (l *GetUserLogic) GetUser() (*types.InfoResponse, error) {
 	var user model.User
-	err := l.svcCtx.SqlConn.QueryRow(&user, "select id, erp_id, avatar, birthday, username, nickname, gender, mobile from user where `id` = ? limit 1", l.UserId)
+	err := l.svcCtx.SqlConn.QueryRowPartial(&user, "select id, erp_id, avatar, birthday, username, nickname, gender, mobile from user where `id` = ? limit 1", l.UserId)
 	if err != nil {
 		logx.Error(err)
 		return nil, err

+ 2 - 2
internal/logic/loginbyweixinlogic.go

@@ -38,7 +38,7 @@ func (l *LoginByWeixinLogic) LoginByWeixin(req types.LoginByWeixinRequest) (*typ
 	}
 	err = l.svcCtx.SqlConn.Transact(func(session sqlx.Session) error {
 		var user model.User
-		err := session.QueryRow(&user, fmt.Sprintf("select %s from user where `weixin_openid` = ? limit 1", model.UserRows), userInfo.OpenID)
+		err := session.QueryRowPartial(&user, fmt.Sprintf("select %s from user where `weixin_openid` = ? limit 1", model.UserRows), userInfo.OpenID)
 		if err == sqlc.ErrNotFound {
 			user.Username = utils.GetUUID()
 			user.Password = ""
@@ -57,7 +57,7 @@ func (l *LoginByWeixinLogic) LoginByWeixin(req types.LoginByWeixinRequest) (*typ
 				logx.Error(err)
 				return err
 			}
-			err = session.QueryRow(&user, fmt.Sprintf("select %s from user where `weixin_openid` = ? limit 1", model.UserRows), userInfo.OpenID)
+			err = session.QueryRowPartial(&user, fmt.Sprintf("select %s from user where `weixin_openid` = ? limit 1", model.UserRows), userInfo.OpenID)
 			if err != nil {
 				logx.Error(err)
 				return err

+ 20 - 2
internal/svc/servicecontext.go

@@ -8,6 +8,7 @@ import (
 	"fmt"
 	"net/http"
 	"path"
+	"strconv"
 	"time"
 
 	"git.i2edu.net/i2/go-zero/core/stores/sqlx"
@@ -162,7 +163,16 @@ func (sc *ServiceContext) GetUserId(r *http.Request) (int64, error) {
 	if err != nil {
 		return 0, err
 	}
-	return tok.Claims.(jwt.MapClaims)["userId"].(int64), err
+	m := tok.Claims.(jwt.MapClaims)
+	switch nbf := m["userId"].(type) {
+	case string:
+		i, _ := strconv.ParseInt(nbf, 10, 64)
+		return i, err
+	case json.Number:
+		v, _ := nbf.Int64()
+		return v, err
+	}
+	return 0, err
 }
 
 // GetSessionKey defined TODO
@@ -172,7 +182,15 @@ func (sc *ServiceContext) GetSessionKey(r *http.Request) (string, error) {
 	if err != nil {
 		return "", err
 	}
-	return tok.Claims.(jwt.MapClaims)["sessionKey"].(string), err
+	m := tok.Claims.(jwt.MapClaims)
+	switch nbf := m["sessionKey"].(type) {
+	case string:
+		return nbf, err
+	case json.Number:
+		v, _ := nbf.Int64()
+		return fmt.Sprintf("%v", v), err
+	}
+	return "", err
 }
 
 func NewServiceContext(c config.Config) *ServiceContext {