瀏覽代碼

add test cases for credential_updater.go/signer_key_pair.go (#125)

* add test cases for credential_updater.go/signer_key_pair.go
Jackson Tian 7 年之前
父節點
當前提交
02da354e52

+ 2 - 1
sdk/auth/signers/credential_updater.go

@@ -15,9 +15,10 @@
 package signers
 
 import (
+	"time"
+
 	"github.com/aliyun/alibaba-cloud-sdk-go/sdk/requests"
 	"github.com/aliyun/alibaba-cloud-sdk-go/sdk/responses"
-	"time"
 )
 
 const defaultInAdvanceScale = 0.8

+ 80 - 0
sdk/auth/signers/credential_updater_test.go

@@ -0,0 +1,80 @@
+package signers
+
+import (
+	"fmt"
+	"testing"
+	"time"
+
+	"github.com/aliyun/alibaba-cloud-sdk-go/sdk/requests"
+	"github.com/aliyun/alibaba-cloud-sdk-go/sdk/responses"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestCredentialUpdater_NeedUpdateCredential(t *testing.T) {
+	// default
+	updater := &credentialUpdater{}
+	assert.NotNil(t, updater)
+	assert.True(t, updater.needUpdateCredential())
+
+	// no need update
+	updater = &credentialUpdater{
+		inAdvanceScale:       1.0,
+		lastUpdateTimestamp:  time.Now().Unix() - 4000,
+		credentialExpiration: 5000,
+	}
+	assert.NotNil(t, updater)
+	assert.False(t, updater.needUpdateCredential())
+
+	// need update
+	updater = &credentialUpdater{
+		inAdvanceScale:       1.0,
+		lastUpdateTimestamp:  time.Now().Unix() - 10000,
+		credentialExpiration: 5000,
+	}
+	assert.NotNil(t, updater)
+	assert.True(t, updater.needUpdateCredential())
+}
+
+func TestCredentialUpdater_UpdateCredential(t *testing.T) {
+	updater := &credentialUpdater{}
+	assert.NotNil(t, updater)
+	updater.buildRequestMethod = func() (*requests.CommonRequest, error) {
+		return nil, fmt.Errorf("build request method failed")
+	}
+
+	err := updater.updateCredential()
+	assert.NotNil(t, err)
+	assert.Equal(t, "build request method failed", err.Error())
+
+	updater.buildRequestMethod = func() (*requests.CommonRequest, error) {
+		return requests.NewCommonRequest(), nil
+	}
+	updater.refreshApi = func(request *requests.CommonRequest) (response *responses.CommonResponse, err error) {
+		return nil, fmt.Errorf("refresh api executed fail")
+	}
+
+	err = updater.updateCredential()
+	assert.NotNil(t, err)
+	assert.Equal(t, "refresh api executed fail", err.Error())
+
+	updater.refreshApi = func(request *requests.CommonRequest) (response *responses.CommonResponse, err error) {
+		return responses.NewCommonResponse(), nil
+	}
+
+	updater.responseCallBack = func(response *responses.CommonResponse) error {
+		return fmt.Errorf("response callback fail")
+	}
+
+	err = updater.updateCredential()
+	assert.NotNil(t, err)
+	// update timestamp
+	assert.True(t, time.Now().Unix()-updater.lastUpdateTimestamp < 10)
+	assert.Equal(t, "response callback fail", err.Error())
+
+	updater.responseCallBack = func(response *responses.CommonResponse) error {
+		return nil
+	}
+
+	err = updater.updateCredential()
+	assert.Nil(t, err)
+}

+ 21 - 19
sdk/auth/signers/signer_key_pair.go

@@ -71,23 +71,28 @@ func (*SignerKeyPair) GetVersion() string {
 	return "1.0"
 }
 
-func (signer *SignerKeyPair) GetAccessKeyId() (accessKeyId string, err error) {
+func (signer *SignerKeyPair) ensureCredential() error {
 	if signer.sessionCredential == nil || signer.needUpdateCredential() {
-		err = signer.updateCredential()
-	}
-	if err != nil && (signer.sessionCredential == nil || len(signer.sessionCredential.AccessKeyId) <= 0) {
-		return "", err
+		return signer.updateCredential()
 	}
-	return signer.sessionCredential.AccessKeyId, err
+	return nil
 }
 
-func (signer *SignerKeyPair) GetExtraParam() map[string]string {
-	if signer.sessionCredential == nil || signer.needUpdateCredential() {
-		signer.updateCredential()
+func (signer *SignerKeyPair) GetAccessKeyId() (accessKeyId string, err error) {
+	err = signer.ensureCredential()
+	if err != nil {
+		return
 	}
 	if signer.sessionCredential == nil || len(signer.sessionCredential.AccessKeyId) <= 0 {
-		return make(map[string]string)
+		accessKeyId = ""
+		return
 	}
+
+	accessKeyId = signer.sessionCredential.AccessKeyId
+	return
+}
+
+func (signer *SignerKeyPair) GetExtraParam() map[string]string {
 	return make(map[string]string)
 }
 
@@ -107,9 +112,9 @@ func (signer *SignerKeyPair) buildCommonRequest() (request *requests.CommonReque
 	return
 }
 
-func (signerKeyPair *SignerKeyPair) refreshApi(request *requests.CommonRequest) (response *responses.CommonResponse, err error) {
-	signerV2 := NewSignerV2(signerKeyPair.credential)
-	return signerKeyPair.commonApi(request, signerV2)
+func (signer *SignerKeyPair) refreshApi(request *requests.CommonRequest) (response *responses.CommonResponse, err error) {
+	signerV2 := NewSignerV2(signer.credential)
+	return signer.commonApi(request, signerV2)
 }
 
 func (signer *SignerKeyPair) refreshCredential(response *responses.CommonResponse) (err error) {
@@ -121,18 +126,15 @@ func (signer *SignerKeyPair) refreshCredential(response *responses.CommonRespons
 	var data interface{}
 	err = json.Unmarshal(response.GetHttpContentBytes(), &data)
 	if err != nil {
-		fmt.Println("refresh KeyPair err, json.Unmarshal fail", err)
-		return
+		return fmt.Errorf("refresh KeyPair err, json.Unmarshal fail: %s", err.Error())
 	}
 	accessKeyId, err := jmespath.Search("SessionAccessKey.SessionAccessKeyId", data)
 	if err != nil {
-		fmt.Println("refresh KeyPair err, fail to get SessionAccessKeyId", err)
-		return
+		return fmt.Errorf("refresh KeyPair err, fail to get SessionAccessKeyId: %s", err.Error())
 	}
 	accessKeySecret, err := jmespath.Search("SessionAccessKey.SessionAccessKeySecret", data)
 	if err != nil {
-		fmt.Println("refresh KeyPair err, fail to get SessionAccessKeySecret", err)
-		return
+		return fmt.Errorf("refresh KeyPair err, fail to get SessionAccessKeySecret: %s", err.Error())
 	}
 	if accessKeyId == nil || accessKeySecret == nil {
 		return

+ 176 - 0
sdk/auth/signers/signer_key_pair_test.go

@@ -0,0 +1,176 @@
+package signers
+
+import (
+	"bytes"
+	"fmt"
+	"io/ioutil"
+	"net/http"
+	"strconv"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+
+	"github.com/aliyun/alibaba-cloud-sdk-go/sdk/auth/credentials"
+	"github.com/aliyun/alibaba-cloud-sdk-go/sdk/requests"
+	"github.com/aliyun/alibaba-cloud-sdk-go/sdk/responses"
+)
+
+func TestKeyPairError(t *testing.T) {
+	c := credentials.NewRsaKeyPairCredential("privateKey", "publicKey", 1)
+	_, err := NewSignerKeyPair(c, nil)
+	assert.NotNil(t, err)
+	assert.Equal(t, "[SDK.InvalidParam] Key Pair session duration should be in the range of 15min - 1Hr", err.Error())
+}
+
+func TestKeyPairOk(t *testing.T) {
+	c := credentials.NewRsaKeyPairCredential("privateKey", "publicKey", 0)
+	s, err := NewSignerKeyPair(c, nil)
+	assert.Nil(t, err)
+	assert.NotNil(t, s)
+	assert.Equal(t, 3600, s.credentialExpiration)
+	c = credentials.NewRsaKeyPairCredential("privateKey", "publicKey", 3500)
+	s, err = NewSignerKeyPair(c, nil)
+	assert.Nil(t, err)
+	assert.NotNil(t, s)
+	assert.Equal(t, 3500, s.credentialExpiration)
+	assert.Equal(t, "HMAC-SHA1", s.GetName())
+	assert.Equal(t, "1.0", s.GetVersion())
+	assert.Equal(t, "", s.GetType())
+	assert.Len(t, s.GetExtraParam(), 0)
+	// nothing
+	s.Shutdown()
+}
+
+func Test_buildCommonRequest(t *testing.T) {
+	c := credentials.NewRsaKeyPairCredential("privateKey", "publicKey", 0)
+	s, err := NewSignerKeyPair(c, func(*requests.CommonRequest, interface{}) (response *responses.CommonResponse, err error) {
+		return nil, fmt.Errorf("common api fails")
+	})
+	assert.Nil(t, err)
+	request, err := s.buildCommonRequest()
+	assert.Nil(t, err)
+	assert.NotNil(t, request)
+	assert.Equal(t, "Sts", request.Product)
+	assert.Equal(t, "2015-04-01", request.Version)
+	assert.Equal(t, "GenerateSessionAccessKey", request.ApiName)
+	assert.Equal(t, "HTTPS", request.Scheme)
+	assert.Equal(t, "publicKey", request.QueryParams["PublicKeyId"])
+	assert.Equal(t, "3600", request.QueryParams["DurationSeconds"])
+}
+
+func TestGetAccessKeyId(t *testing.T) {
+	c := credentials.NewRsaKeyPairCredential("privateKey", "publicKey", 0)
+	s, err := NewSignerKeyPair(c, func(*requests.CommonRequest, interface{}) (response *responses.CommonResponse, err error) {
+		return nil, fmt.Errorf("common api fails")
+	})
+	assert.Nil(t, err)
+	assert.NotNil(t, s)
+	accessKeyId, err := s.GetAccessKeyId()
+	assert.Equal(t, "common api fails", err.Error())
+	assert.Equal(t, "", accessKeyId)
+}
+
+func TestGetAccessKeyId2(t *testing.T) {
+	// default response is not OK
+	c := credentials.NewRsaKeyPairCredential("privateKey", "publicKey", 0)
+	s, err := NewSignerKeyPair(c, func(*requests.CommonRequest, interface{}) (response *responses.CommonResponse, err error) {
+		return responses.NewCommonResponse(), nil
+	})
+	assert.Nil(t, err)
+	assert.NotNil(t, s)
+	// s.lastUpdateTimestamp = time.Now().Unix() - 1000
+	accessKeyId, err := s.GetAccessKeyId()
+	assert.Equal(t, "SDK.ServerError\nErrorCode: \nRecommend: refresh session AccessKey failed\nRequestId: \nMessage: ", err.Error())
+	assert.Equal(t, "", accessKeyId)
+}
+
+func TestGetAccessKeyId3(t *testing.T) {
+	c := credentials.NewRsaKeyPairCredential("privateKey", "publicKey", 0)
+	// Mock the 200 response and invalid json
+	s, err := NewSignerKeyPair(c, func(*requests.CommonRequest, interface{}) (response *responses.CommonResponse, err error) {
+		res := responses.NewCommonResponse()
+		statusCode := 200
+		header := make(http.Header)
+		status := strconv.Itoa(statusCode)
+		httpresp := &http.Response{
+			Proto:      "HTTP/1.1",
+			ProtoMajor: 1,
+			Header:     header,
+			StatusCode: statusCode,
+			Status:     status + " " + http.StatusText(statusCode),
+		}
+		httpresp.Header = make(http.Header)
+		httpresp.Body = ioutil.NopCloser(bytes.NewReader([]byte("invalid json")))
+		responses.Unmarshal(res, httpresp, "JSON")
+		return res, nil
+	})
+	assert.Nil(t, err)
+	assert.NotNil(t, s)
+	// s.lastUpdateTimestamp = time.Now().Unix() - 1000
+	accessKeyId, err := s.GetAccessKeyId()
+	assert.NotNil(t, err)
+	assert.Equal(t, "refresh KeyPair err, json.Unmarshal fail: invalid character 'i' looking for beginning of value", err.Error())
+	assert.Equal(t, "", accessKeyId)
+}
+
+func TestGetAccessKeyId4(t *testing.T) {
+	c := credentials.NewRsaKeyPairCredential("privateKey", "publicKey", 0)
+	// mock 200 response and valid json, but no data
+	s, err := NewSignerKeyPair(c, func(*requests.CommonRequest, interface{}) (response *responses.CommonResponse, err error) {
+		res := responses.NewCommonResponse()
+		statusCode := 200
+		header := make(http.Header)
+		status := strconv.Itoa(statusCode)
+		httpresp := &http.Response{
+			Proto:      "HTTP/1.1",
+			ProtoMajor: 1,
+			Header:     header,
+			StatusCode: statusCode,
+			Status:     status + " " + http.StatusText(statusCode),
+		}
+		httpresp.Header = make(http.Header)
+		httpresp.Body = ioutil.NopCloser(bytes.NewReader([]byte("{}")))
+		responses.Unmarshal(res, httpresp, "JSON")
+		return res, nil
+	})
+	assert.Nil(t, err)
+	assert.NotNil(t, s)
+	// s.lastUpdateTimestamp = time.Now().Unix() - 1000
+	accessKeyId, err := s.GetAccessKeyId()
+	assert.Nil(t, err)
+	assert.Equal(t, "", accessKeyId)
+}
+
+func TestGetAccessKeyIdAndSign(t *testing.T) {
+	c := credentials.NewRsaKeyPairCredential("privateKey", "publicKey", 0)
+	// mock 200 response and valid json and valid result
+	s, err := NewSignerKeyPair(c, func(*requests.CommonRequest, interface{}) (response *responses.CommonResponse, err error) {
+		res := responses.NewCommonResponse()
+		statusCode := 200
+		header := make(http.Header)
+		status := strconv.Itoa(statusCode)
+		httpresp := &http.Response{
+			Proto:      "HTTP/1.1",
+			ProtoMajor: 1,
+			Header:     header,
+			StatusCode: statusCode,
+			Status:     status + " " + http.StatusText(statusCode),
+		}
+		httpresp.Header = make(http.Header)
+		json := `{"SessionAccessKey":{"SessionAccessKeyId":"session access key id","SessionAccessKeySecret": "session access key secret"}}`
+		httpresp.Body = ioutil.NopCloser(bytes.NewReader([]byte(json)))
+		responses.Unmarshal(res, httpresp, "JSON")
+		return res, nil
+	})
+	assert.Nil(t, err)
+	assert.NotNil(t, s)
+	// s.lastUpdateTimestamp = time.Now().Unix() - 1000
+	accessKeyId, err := s.GetAccessKeyId()
+	assert.Nil(t, err)
+	assert.Equal(t, "session access key id", accessKeyId)
+	// no need update
+	err = s.ensureCredential()
+	assert.Nil(t, err)
+	signature := s.Sign("string to sign", "/")
+	assert.Equal(t, "a3pLxd685VW4u078cdBKVh/Qf/A=", signature)
+}

+ 4 - 3
sdk/endpoints/location_resolver.go

@@ -15,9 +15,10 @@ package endpoints
 
 import (
 	"encoding/json"
-	"github.com/aliyun/alibaba-cloud-sdk-go/sdk/requests"
 	"sync"
 	"time"
+
+	"github.com/aliyun/alibaba-cloud-sdk-go/sdk/requests"
 )
 
 const (
@@ -49,8 +50,8 @@ type LocationResolver struct {
 }
 
 func (resolver *LocationResolver) GetName() (name string) {
-  name = "location resolver"
-  return
+	name = "location resolver"
+	return
 }
 
 func (resolver *LocationResolver) TryResolve(param *ResolveParam) (endpoint string, support bool, err error) {