| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522 |
- package server
- import (
- "encoding/json"
- "fmt"
- "net/http"
- "net/url"
- "strings"
- "time"
- "github.com/2637309949/dolphin/packages/oauth2"
- "github.com/2637309949/dolphin/packages/oauth2/errors"
- )
- // NewDefaultServer create a default authorization server
- func NewDefaultServer(manager oauth2.Manager) *Server {
- return NewServer(NewConfig(), manager)
- }
- // NewServer create authorization server
- func NewServer(cfg *Config, manager oauth2.Manager) *Server {
- srv := &Server{
- Config: cfg,
- Manager: manager,
- }
- // default handler
- srv.ClientInfoHandler = ClientBasicHandler
- srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, string, error) {
- return "", "", errors.ErrAccessDenied
- }
- srv.PasswordAuthorizationHandler = func(username, password string) (string, error) {
- return "", errors.ErrAccessDenied
- }
- return srv
- }
- // Server Provide authorization server
- type Server struct {
- Config *Config
- Manager oauth2.Manager
- ClientInfoHandler ClientInfoHandler
- ClientAuthorizedHandler ClientAuthorizedHandler
- ClientScopeHandler ClientScopeHandler
- UserAuthorizationHandler UserAuthorizationHandler
- PasswordAuthorizationHandler PasswordAuthorizationHandler
- RefreshingScopeHandler RefreshingScopeHandler
- ResponseErrorHandler ResponseErrorHandler
- InternalErrorHandler InternalErrorHandler
- ExtensionFieldsHandler ExtensionFieldsHandler
- AccessTokenExpHandler AccessTokenExpHandler
- AuthorizeScopeHandler AuthorizeScopeHandler
- }
- func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error {
- if req == nil {
- return err
- }
- data, _, _ := s.GetErrorData(err)
- return s.redirect(w, req, data)
- }
- func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error {
- uri, err := s.GetRedirectURI(req, data)
- if err != nil {
- return err
- }
- w.Header().Set("Location", uri)
- w.WriteHeader(302)
- return nil
- }
- func (s *Server) tokenError(w http.ResponseWriter, err error) error {
- data, statusCode, header := s.GetErrorData(err)
- return s.token(w, data, header, statusCode)
- }
- func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error {
- w.Header().Set("Content-Type", "application/json;charset=UTF-8")
- w.Header().Set("Cache-Control", "no-store")
- w.Header().Set("Pragma", "no-cache")
- for key := range header {
- w.Header().Set(key, header.Get(key))
- }
- status := http.StatusOK
- if len(statusCode) > 0 && statusCode[0] > 0 {
- status = statusCode[0]
- }
- w.WriteHeader(status)
- return json.NewEncoder(w).Encode(data)
- }
- // GetRedirectURI get redirect uri
- func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) {
- u, err := url.Parse(req.RedirectURI)
- if err != nil {
- return "", err
- }
- q := u.Query()
- if req.State != "" {
- q.Set("state", req.State)
- }
- for k, v := range data {
- q.Set(k, fmt.Sprint(v))
- }
- switch req.ResponseType {
- case oauth2.Code:
- u.RawQuery = q.Encode()
- case oauth2.Token:
- u.RawQuery = ""
- fragment, err := url.QueryUnescape(q.Encode())
- if err != nil {
- return "", err
- }
- u.Fragment = fragment
- }
- return u.String(), nil
- }
- // CheckResponseType check allows response type
- func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool {
- for _, art := range s.Config.AllowedResponseTypes {
- if art == rt {
- return true
- }
- }
- return false
- }
- // ValidationAuthorizeRequest the authorization request validation
- func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) {
- redirectURI := r.FormValue("redirect_uri")
- clientID := r.FormValue("client_id")
- if !(r.Method == "GET" || r.Method == "POST") ||
- clientID == "" {
- return nil, errors.ErrInvalidRequest
- }
- resType := oauth2.ResponseType(r.FormValue("response_type"))
- if resType.String() == "" {
- return nil, errors.ErrUnsupportedResponseType
- } else if allowed := s.CheckResponseType(resType); !allowed {
- return nil, errors.ErrUnauthorizedClient
- }
- req := &AuthorizeRequest{
- RedirectURI: redirectURI,
- ResponseType: resType,
- ClientID: clientID,
- State: r.FormValue("state"),
- Scope: r.FormValue("scope"),
- Request: r,
- }
- return req, nil
- }
- // GetAuthorizeToken get authorization token(code)
- func (s *Server) GetAuthorizeToken(req *AuthorizeRequest) (oauth2.TokenInfo, error) {
- // check the client allows the grant type
- if fn := s.ClientAuthorizedHandler; fn != nil {
- gt := oauth2.AuthorizationCode
- if req.ResponseType == oauth2.Token {
- gt = oauth2.Implicit
- }
- allowed, err := fn(req.ClientID, gt)
- if err != nil {
- return nil, err
- } else if !allowed {
- return nil, errors.ErrUnauthorizedClient
- }
- }
- // check the client allows the authorized scope
- if fn := s.ClientScopeHandler; fn != nil {
- allowed, err := fn(req.ClientID, req.Scope)
- if err != nil {
- return nil, err
- } else if !allowed {
- return nil, errors.ErrInvalidScope
- }
- }
- tgr := &oauth2.TokenGenerateRequest{
- ClientID: req.ClientID,
- UserID: req.UserID,
- Domain: req.Domain,
- RedirectURI: req.RedirectURI,
- Scope: req.Scope,
- AccessTokenExp: req.AccessTokenExp,
- Request: req.Request,
- }
- return s.Manager.GenerateAuthToken(req.ResponseType, tgr)
- }
- // GetAuthorizeData get authorization response data
- func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} {
- if rt == oauth2.Code {
- return map[string]interface{}{
- "code": ti.GetCode(),
- }
- }
- return s.GetTokenData(ti)
- }
- // HandleAuthorizeRequest the authorization request handling
- func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error {
- req, err := s.ValidationAuthorizeRequest(r)
- if err != nil {
- return s.redirectError(w, req, err)
- }
- // user authorization
- userID, domain, err := s.UserAuthorizationHandler(w, r)
- if err != nil {
- return s.redirectError(w, req, err)
- } else if userID == "" {
- return nil
- }
- req.UserID = userID
- req.Domain = domain
- // specify the scope of authorization
- if fn := s.AuthorizeScopeHandler; fn != nil {
- scope, err := fn(w, r)
- if err != nil {
- return err
- } else if scope != "" {
- req.Scope = scope
- }
- }
- // specify the expiration time of access token
- if fn := s.AccessTokenExpHandler; fn != nil {
- exp, err := fn(w, r)
- if err != nil {
- return err
- }
- req.AccessTokenExp = exp
- }
- ti, err := s.GetAuthorizeToken(req)
- if err != nil {
- return s.redirectError(w, req, err)
- }
- // If the redirect URI is empty, the default domain provided by the client is used.
- if req.RedirectURI == "" {
- client, err := s.Manager.GetClient(req.ClientID)
- if err != nil {
- return err
- }
- req.RedirectURI = client.GetDomain()
- }
- return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti))
- }
- // ValidationTokenRequest the token request validation
- func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) {
- if v := r.Method; !(v == "POST" ||
- (s.Config.AllowGetAccessRequest && v == "GET")) {
- return "", nil, errors.ErrInvalidRequest
- }
- gt := oauth2.GrantType(r.FormValue("grant_type"))
- if gt.String() == "" {
- return "", nil, errors.ErrUnsupportedGrantType
- }
- clientID, clientSecret, err := s.ClientInfoHandler(r)
- if err != nil {
- return "", nil, err
- }
- tgr := &oauth2.TokenGenerateRequest{
- ClientID: clientID,
- ClientSecret: clientSecret,
- Request: r,
- }
- switch gt {
- case oauth2.AuthorizationCode:
- tgr.RedirectURI = r.FormValue("redirect_uri")
- tgr.Code = r.FormValue("code")
- if tgr.RedirectURI == "" ||
- tgr.Code == "" {
- return "", nil, errors.ErrInvalidRequest
- }
- case oauth2.PasswordCredentials:
- tgr.Scope = r.FormValue("scope")
- username, password := r.FormValue("username"), r.FormValue("password")
- if username == "" || password == "" {
- return "", nil, errors.ErrInvalidRequest
- }
- userID, err := s.PasswordAuthorizationHandler(username, password)
- if err != nil {
- return "", nil, err
- } else if userID == "" {
- return "", nil, errors.ErrInvalidGrant
- }
- tgr.UserID = userID
- case oauth2.ClientCredentials:
- tgr.Scope = r.FormValue("scope")
- case oauth2.Refreshing:
- tgr.Refresh = r.FormValue("refresh_token")
- tgr.Scope = r.FormValue("scope")
- if tgr.Refresh == "" {
- return "", nil, errors.ErrInvalidRequest
- }
- }
- return gt, tgr, nil
- }
- // CheckGrantType check allows grant type
- func (s *Server) CheckGrantType(gt oauth2.GrantType) bool {
- for _, agt := range s.Config.AllowedGrantTypes {
- if agt == gt {
- return true
- }
- }
- return false
- }
- // GetAccessToken access token
- func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
- if allowed := s.CheckGrantType(gt); !allowed {
- return nil, errors.ErrUnauthorizedClient
- }
- if fn := s.ClientAuthorizedHandler; fn != nil {
- allowed, err := fn(tgr.ClientID, gt)
- if err != nil {
- return nil, err
- } else if !allowed {
- return nil, errors.ErrUnauthorizedClient
- }
- }
- switch gt {
- case oauth2.AuthorizationCode:
- ti, err := s.Manager.GenerateAccessToken(gt, tgr)
- if err != nil {
- switch err {
- case errors.ErrInvalidAuthorizeCode:
- return nil, errors.ErrInvalidGrant
- case errors.ErrInvalidClient:
- return nil, errors.ErrInvalidClient
- default:
- return nil, err
- }
- }
- return ti, nil
- case oauth2.PasswordCredentials, oauth2.ClientCredentials:
- if fn := s.ClientScopeHandler; fn != nil {
- allowed, err := fn(tgr.ClientID, tgr.Scope)
- if err != nil {
- return nil, err
- } else if !allowed {
- return nil, errors.ErrInvalidScope
- }
- }
- return s.Manager.GenerateAccessToken(gt, tgr)
- case oauth2.Refreshing:
- // check scope
- if scope, scopeFn := tgr.Scope, s.RefreshingScopeHandler; scope != "" && scopeFn != nil {
- rti, err := s.Manager.LoadRefreshToken(tgr.Refresh)
- if err != nil {
- if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
- return nil, errors.ErrInvalidGrant
- }
- return nil, err
- }
- allowed, err := scopeFn(scope, rti.GetScope())
- if err != nil {
- return nil, err
- } else if !allowed {
- return nil, errors.ErrInvalidScope
- }
- }
- ti, err := s.Manager.RefreshAccessToken(tgr)
- if err != nil {
- if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
- return nil, errors.ErrInvalidGrant
- }
- return nil, err
- }
- return ti, nil
- }
- return nil, errors.ErrUnsupportedGrantType
- }
- // GetTokenData token data
- func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} {
- data := map[string]interface{}{
- "access_token": ti.GetAccess(),
- "token_type": s.Config.TokenType,
- "expires_in": int64(ti.GetAccessExpiresIn() / time.Second),
- }
- if scope := ti.GetScope(); scope != "" {
- data["scope"] = scope
- }
- if refresh := ti.GetRefresh(); refresh != "" {
- data["refresh_token"] = refresh
- }
- if fn := s.ExtensionFieldsHandler; fn != nil {
- ext := fn(ti)
- for k, v := range ext {
- if _, ok := data[k]; ok {
- continue
- }
- data[k] = v
- }
- }
- return data
- }
- // HandleTokenRequest token request handling
- func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
- gt, tgr, err := s.ValidationTokenRequest(r)
- if err != nil {
- return s.tokenError(w, err)
- }
- ti, err := s.GetAccessToken(gt, tgr)
- if err != nil {
- return s.tokenError(w, err)
- }
- return s.token(w, s.GetTokenData(ti), nil)
- }
- // GetErrorData get error response data
- func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) {
- var re errors.Response
- if v, ok := errors.Descriptions[err]; ok {
- re.Error = err
- re.Description = v
- re.StatusCode = errors.StatusCodes[err]
- } else {
- if fn := s.InternalErrorHandler; fn != nil {
- if v := fn(err); v != nil {
- re = *v
- }
- }
- if re.Error == nil {
- re.Error = errors.ErrServerError
- re.Description = errors.Descriptions[errors.ErrServerError]
- re.StatusCode = errors.StatusCodes[errors.ErrServerError]
- }
- }
- if fn := s.ResponseErrorHandler; fn != nil {
- fn(&re)
- }
- data := make(map[string]interface{})
- if err := re.Error; err != nil {
- data["error"] = err.Error()
- }
- if v := re.ErrorCode; v != 0 {
- data["error_code"] = v
- }
- if v := re.Description; v != "" {
- data["error_description"] = v
- }
- if v := re.URI; v != "" {
- data["error_uri"] = v
- }
- statusCode := http.StatusInternalServerError
- if v := re.StatusCode; v > 0 {
- statusCode = v
- }
- return data, statusCode, re.Header
- }
- // BearerAuth parse bearer token
- func (s *Server) BearerAuth(r *http.Request) (string, bool) {
- auth := r.Header.Get("Authorization")
- prefix := "Bearer "
- token := ""
- if auth != "" && strings.HasPrefix(auth, prefix) {
- token = auth[len(prefix):]
- } else {
- token = r.Header.Get("token")
- }
- return token, token != ""
- }
- // ValidationBearerToken validation the bearer tokens
- // https://tools.ietf.org/html/rfc6750
- func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
- accessToken, ok := s.BearerAuth(r)
- if !ok {
- return nil, errors.ErrInvalidAccessToken
- }
- return s.Manager.LoadAccessToken(accessToken)
- }
|