Pārlūkot izejas kodu

support user-defined interface to get AK

taowei.wtw 6 gadi atpakaļ
vecāks
revīzija
f8b8c753a2
4 mainītis faili ar 53 papildinājumiem un 12 dzēšanām
  1. 3 3
      oss/auth.go
  2. 9 0
      oss/client.go
  3. 32 0
      oss/conf.go
  4. 9 9
      oss/conn.go

+ 3 - 3
oss/auth.go

@@ -23,7 +23,7 @@ type headerSorter struct {
 // signHeader signs the header and sets it as the authorization header.
 func (conn Conn) signHeader(req *http.Request, canonicalizedResource string) {
 	// Get the final authorization string
-	authorizationStr := "OSS " + conn.config.AccessKeyID + ":" + conn.getSignedStr(req, canonicalizedResource)
+	authorizationStr := "OSS " + conn.config.GetAccessKeyID() + ":" + conn.getSignedStr(req, canonicalizedResource)
 
 	// Give the parameter "Authorization" value
 	req.Header.Set(HTTPHeaderAuthorization, authorizationStr)
@@ -59,7 +59,7 @@ func (conn Conn) getSignedStr(req *http.Request, canonicalizedResource string) s
 
 	conn.config.WriteLog(Debug, "[Req:%p]signStr:%s.\n", req, signStr)
 
-	h := hmac.New(func() hash.Hash { return sha1.New() }, []byte(conn.config.AccessKeySecret))
+	h := hmac.New(func() hash.Hash { return sha1.New() }, []byte(conn.config.GetAccessKeySecret()))
 	io.WriteString(h, signStr)
 	signedStr := base64.StdEncoding.EncodeToString(h.Sum(nil))
 
@@ -88,7 +88,7 @@ func (conn Conn) getRtmpSignedStr(bucketName, channelName, playlistName string,
 	expireStr := strconv.FormatInt(expiration, 10)
 	signStr := expireStr + "\n" + canonParamsStr + canonResource
 
-	h := hmac.New(func() hash.Hash { return sha1.New() }, []byte(conn.config.AccessKeySecret))
+	h := hmac.New(func() hash.Hash { return sha1.New() }, []byte(conn.config.GetAccessKeySecret()))
 	io.WriteString(h, signStr)
 	signedStr := base64.StdEncoding.EncodeToString(h.Sum(nil))
 	return signedStr

+ 9 - 0
oss/client.go

@@ -1216,6 +1216,15 @@ func SetLogger(Logger *log.Logger) ClientOption {
 	}
 }
 
+//
+// SetAKInterface sets funciton for get the user's ak
+//
+func SetAKInterface(akIf AKInterface) ClientOption {
+	return func(client *Client) {
+		client.Config.UserAKInf = akIf
+	}
+}
+
 // Private
 func (client Client) do(method, bucketName string, params map[string]interface{},
 	headers map[string]string, data io.Reader) (*Response, error) {

+ 32 - 0
oss/conf.go

@@ -35,6 +35,13 @@ type HTTPMaxConns struct {
 	MaxIdleConnsPerHost int
 }
 
+// AKInterface is interface for getting AccessKeyID, AccessKeySecret, SecurityToken
+type AKInterface interface {
+	GetAccessKeyID() string
+	GetAccessKeySecret() string
+	GetSecurityToken() string
+}
+
 // Config defines oss configuration
 type Config struct {
 	Endpoint         string       // OSS endpoint
@@ -60,6 +67,7 @@ type Config struct {
 	Logger           *log.Logger  // For write log
 	UploadLimitSpeed int          // Upload limit speed:KB/s, 0 is unlimited
 	UploadLimiter    *OssLimiter  // Bandwidth limit reader for upload
+	UserAKInf        AKInterface  // User provides interface to get AccessKeyID, AccessKeySecret, SecurityToken
 }
 
 // LimitUploadSpeed uploadSpeed:KB/s, 0 is unlimited,default is 0
@@ -92,6 +100,30 @@ func (config *Config) WriteLog(LogLevel int, format string, a ...interface{}) {
 	config.Logger.Printf("%s", logBuffer.String())
 }
 
+// for get AccessKeyID
+func (config *Config) GetAccessKeyID() string {
+	if config.UserAKInf != nil {
+		return config.UserAKInf.GetAccessKeyID()
+	}
+	return config.AccessKeyID
+}
+
+// for get AccessKeySecret
+func (config *Config) GetAccessKeySecret() string {
+	if config.UserAKInf != nil {
+		return config.UserAKInf.GetAccessKeySecret()
+	}
+	return config.AccessKeySecret
+}
+
+// for get SecurityToken
+func (config *Config) GetSecurityToken() string {
+	if config.UserAKInf != nil {
+		return config.UserAKInf.GetSecurityToken()
+	}
+	return config.SecurityToken
+}
+
 // getDefaultOssConfig gets the default configuration.
 func getDefaultOssConfig() *Config {
 	config := Config{}

+ 9 - 9
oss/conn.go

@@ -239,8 +239,8 @@ func (conn Conn) doRequest(method string, uri *url.URL, canonicalizedResource st
 	req.Header.Set(HTTPHeaderDate, date)
 	req.Header.Set(HTTPHeaderHost, conn.config.Endpoint)
 	req.Header.Set(HTTPHeaderUserAgent, conn.config.UserAgent)
-	if conn.config.SecurityToken != "" {
-		req.Header.Set(HTTPHeaderOssSecurityToken, conn.config.SecurityToken)
+	if conn.config.GetSecurityToken() != "" {
+		req.Header.Set(HTTPHeaderOssSecurityToken, conn.config.GetSecurityToken())
 	}
 
 	if headers != nil {
@@ -281,8 +281,8 @@ func (conn Conn) doRequest(method string, uri *url.URL, canonicalizedResource st
 }
 
 func (conn Conn) signURL(method HTTPMethod, bucketName, objectName string, expiration int64, params map[string]interface{}, headers map[string]string) string {
-	if conn.config.SecurityToken != "" {
-		params[HTTPParamSecurityToken] = conn.config.SecurityToken
+	if conn.config.GetSecurityToken() != "" {
+		params[HTTPParamSecurityToken] = conn.config.GetSecurityToken()
 	}
 	subResource := conn.getSubResource(params)
 	canonicalizedResource := conn.url.getResource(bucketName, objectName, subResource)
@@ -312,7 +312,7 @@ func (conn Conn) signURL(method HTTPMethod, bucketName, objectName string, expir
 	signedStr := conn.getSignedStr(req, canonicalizedResource)
 
 	params[HTTPParamExpires] = strconv.FormatInt(expiration, 10)
-	params[HTTPParamAccessKeyID] = conn.config.AccessKeyID
+	params[HTTPParamAccessKeyID] = conn.config.GetAccessKeyID()
 	params[HTTPParamSignature] = signedStr
 
 	urlParams := conn.getURLParams(params)
@@ -327,10 +327,10 @@ func (conn Conn) signRtmpURL(bucketName, channelName, playlistName string, expir
 	expireStr := strconv.FormatInt(expiration, 10)
 	params[HTTPParamExpires] = expireStr
 
-	if conn.config.AccessKeyID != "" {
-		params[HTTPParamAccessKeyID] = conn.config.AccessKeyID
-		if conn.config.SecurityToken != "" {
-			params[HTTPParamSecurityToken] = conn.config.SecurityToken
+	if conn.config.GetAccessKeyID() != "" {
+		params[HTTPParamAccessKeyID] = conn.config.GetAccessKeyID()
+		if conn.config.GetSecurityToken() != "" {
+			params[HTTPParamSecurityToken] = conn.config.GetSecurityToken()
 		}
 		signedStr := conn.getRtmpSignedStr(bucketName, channelName, playlistName, expiration, params)
 		params[HTTPParamSignature] = signedStr