Przeglądaj źródła

add: 自定义获取 access token 方法

zdpdpdp 6 lat temu
rodzic
commit
81f26cd6dc
3 zmienionych plików z 43 dodań i 0 usunięć
  1. 10 0
      context/access_token.go
  2. 30 0
      context/access_token_test.go
  3. 3 0
      context/context.go

+ 10 - 0
context/access_token.go

@@ -22,16 +22,26 @@ type ResAccessToken struct {
 	ExpiresIn   int64  `json:"expires_in"`
 }
 
+type CustomGetAccessToken func(ctx *Context) (accessToken string, err error)
+
 //SetAccessTokenLock 设置读写锁(一个appID一个读写锁)
 func (ctx *Context) SetAccessTokenLock(l *sync.RWMutex) {
 	ctx.accessTokenLock = l
 }
 
+//SetGetAccessTokenFunc 设置自定义获取accessToken的方式, 需要自己实现缓存
+func (ctx *Context) SetGetAccessTokenFunc(f CustomGetAccessToken) {
+	ctx.accessTokenFunc = f
+}
+
 //GetAccessToken 获取access_token
 func (ctx *Context) GetAccessToken() (accessToken string, err error) {
 	ctx.accessTokenLock.Lock()
 	defer ctx.accessTokenLock.Unlock()
 
+	if ctx.accessTokenFunc != nil {
+		return ctx.accessTokenFunc(ctx)
+	}
 	accessTokenCacheKey := fmt.Sprintf("access_token_%s", ctx.AppID)
 	val := ctx.Cache.Get(accessTokenCacheKey)
 	if val != nil {

+ 30 - 0
context/access_token_test.go

@@ -0,0 +1,30 @@
+package context
+
+import (
+	"sync"
+	"testing"
+)
+
+func TestContext_SetCustomAccessTokenFunc(t *testing.T) {
+	ctx := Context{
+		accessTokenLock: new(sync.RWMutex),
+	}
+	f := func(ctx *Context) (accessToken string, err error) {
+		return "fake token", nil
+	}
+	ctx.SetGetAccessTokenFunc(f)
+	res, err := ctx.GetAccessToken()
+	if res != "fake token" || err != nil {
+		t.Error("expect fake token but error")
+	}
+}
+
+func TestContext_NoSetCustomAccessTokenFunc(t *testing.T) {
+	ctx := Context{
+		accessTokenLock: new(sync.RWMutex),
+	}
+
+	if ctx.accessTokenFunc != nil {
+		t.Error("error accessTokenFunc")
+	}
+}

+ 3 - 0
context/context.go

@@ -27,6 +27,9 @@ type Context struct {
 
 	//jsAPITicket 读写锁 同一个AppID一个
 	jsAPITicketLock *sync.RWMutex
+
+	//自定义获取 access token 的方法
+	accessTokenFunc CustomGetAccessToken
 }
 
 // Query returns the keyed url query value if it exists