server.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. package server
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "net/http"
  6. "net/url"
  7. "strings"
  8. "time"
  9. "github.com/2637309949/dolphin/packages/oauth2"
  10. "github.com/2637309949/dolphin/packages/oauth2/errors"
  11. )
  12. // NewDefaultServer create a default authorization server
  13. func NewDefaultServer(manager oauth2.Manager) *Server {
  14. return NewServer(NewConfig(), manager)
  15. }
  16. // NewServer create authorization server
  17. func NewServer(cfg *Config, manager oauth2.Manager) *Server {
  18. srv := &Server{
  19. Config: cfg,
  20. Manager: manager,
  21. }
  22. // default handler
  23. srv.ClientInfoHandler = ClientBasicHandler
  24. srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, string, error) {
  25. return "", "", errors.ErrAccessDenied
  26. }
  27. srv.PasswordAuthorizationHandler = func(username, password string) (string, error) {
  28. return "", errors.ErrAccessDenied
  29. }
  30. return srv
  31. }
  32. // Server Provide authorization server
  33. type Server struct {
  34. Config *Config
  35. Manager oauth2.Manager
  36. ClientInfoHandler ClientInfoHandler
  37. ClientAuthorizedHandler ClientAuthorizedHandler
  38. ClientScopeHandler ClientScopeHandler
  39. UserAuthorizationHandler UserAuthorizationHandler
  40. PasswordAuthorizationHandler PasswordAuthorizationHandler
  41. RefreshingScopeHandler RefreshingScopeHandler
  42. ResponseErrorHandler ResponseErrorHandler
  43. InternalErrorHandler InternalErrorHandler
  44. ExtensionFieldsHandler ExtensionFieldsHandler
  45. AccessTokenExpHandler AccessTokenExpHandler
  46. AuthorizeScopeHandler AuthorizeScopeHandler
  47. }
  48. func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error {
  49. if req == nil {
  50. return err
  51. }
  52. data, _, _ := s.GetErrorData(err)
  53. return s.redirect(w, req, data)
  54. }
  55. func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error {
  56. uri, err := s.GetRedirectURI(req, data)
  57. if err != nil {
  58. return err
  59. }
  60. w.Header().Set("Location", uri)
  61. w.WriteHeader(302)
  62. return nil
  63. }
  64. func (s *Server) tokenError(w http.ResponseWriter, err error) error {
  65. data, statusCode, header := s.GetErrorData(err)
  66. return s.token(w, data, header, statusCode)
  67. }
  68. func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error {
  69. w.Header().Set("Content-Type", "application/json;charset=UTF-8")
  70. w.Header().Set("Cache-Control", "no-store")
  71. w.Header().Set("Pragma", "no-cache")
  72. for key := range header {
  73. w.Header().Set(key, header.Get(key))
  74. }
  75. status := http.StatusOK
  76. if len(statusCode) > 0 && statusCode[0] > 0 {
  77. status = statusCode[0]
  78. }
  79. w.WriteHeader(status)
  80. return json.NewEncoder(w).Encode(data)
  81. }
  82. // GetRedirectURI get redirect uri
  83. func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) {
  84. u, err := url.Parse(req.RedirectURI)
  85. if err != nil {
  86. return "", err
  87. }
  88. q := u.Query()
  89. if req.State != "" {
  90. q.Set("state", req.State)
  91. }
  92. for k, v := range data {
  93. q.Set(k, fmt.Sprint(v))
  94. }
  95. switch req.ResponseType {
  96. case oauth2.Code:
  97. u.RawQuery = q.Encode()
  98. case oauth2.Token:
  99. u.RawQuery = ""
  100. fragment, err := url.QueryUnescape(q.Encode())
  101. if err != nil {
  102. return "", err
  103. }
  104. u.Fragment = fragment
  105. }
  106. return u.String(), nil
  107. }
  108. // CheckResponseType check allows response type
  109. func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool {
  110. for _, art := range s.Config.AllowedResponseTypes {
  111. if art == rt {
  112. return true
  113. }
  114. }
  115. return false
  116. }
  117. // ValidationAuthorizeRequest the authorization request validation
  118. func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) {
  119. redirectURI := r.FormValue("redirect_uri")
  120. clientID := r.FormValue("client_id")
  121. if !(r.Method == "GET" || r.Method == "POST") ||
  122. clientID == "" {
  123. return nil, errors.ErrInvalidRequest
  124. }
  125. resType := oauth2.ResponseType(r.FormValue("response_type"))
  126. if resType.String() == "" {
  127. return nil, errors.ErrUnsupportedResponseType
  128. } else if allowed := s.CheckResponseType(resType); !allowed {
  129. return nil, errors.ErrUnauthorizedClient
  130. }
  131. req := &AuthorizeRequest{
  132. RedirectURI: redirectURI,
  133. ResponseType: resType,
  134. ClientID: clientID,
  135. State: r.FormValue("state"),
  136. Scope: r.FormValue("scope"),
  137. Request: r,
  138. }
  139. return req, nil
  140. }
  141. // GetAuthorizeToken get authorization token(code)
  142. func (s *Server) GetAuthorizeToken(req *AuthorizeRequest) (oauth2.TokenInfo, error) {
  143. // check the client allows the grant type
  144. if fn := s.ClientAuthorizedHandler; fn != nil {
  145. gt := oauth2.AuthorizationCode
  146. if req.ResponseType == oauth2.Token {
  147. gt = oauth2.Implicit
  148. }
  149. allowed, err := fn(req.ClientID, gt)
  150. if err != nil {
  151. return nil, err
  152. } else if !allowed {
  153. return nil, errors.ErrUnauthorizedClient
  154. }
  155. }
  156. // check the client allows the authorized scope
  157. if fn := s.ClientScopeHandler; fn != nil {
  158. allowed, err := fn(req.ClientID, req.Scope)
  159. if err != nil {
  160. return nil, err
  161. } else if !allowed {
  162. return nil, errors.ErrInvalidScope
  163. }
  164. }
  165. tgr := &oauth2.TokenGenerateRequest{
  166. ClientID: req.ClientID,
  167. UserID: req.UserID,
  168. Domain: req.Domain,
  169. RedirectURI: req.RedirectURI,
  170. Scope: req.Scope,
  171. AccessTokenExp: req.AccessTokenExp,
  172. Request: req.Request,
  173. }
  174. return s.Manager.GenerateAuthToken(req.ResponseType, tgr)
  175. }
  176. // GetAuthorizeData get authorization response data
  177. func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} {
  178. if rt == oauth2.Code {
  179. return map[string]interface{}{
  180. "code": ti.GetCode(),
  181. }
  182. }
  183. return s.GetTokenData(ti)
  184. }
  185. // HandleAuthorizeRequest the authorization request handling
  186. func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error {
  187. req, err := s.ValidationAuthorizeRequest(r)
  188. if err != nil {
  189. return s.redirectError(w, req, err)
  190. }
  191. // user authorization
  192. userID, domain, err := s.UserAuthorizationHandler(w, r)
  193. if err != nil {
  194. return s.redirectError(w, req, err)
  195. } else if userID == "" {
  196. return nil
  197. }
  198. req.UserID = userID
  199. req.Domain = domain
  200. // specify the scope of authorization
  201. if fn := s.AuthorizeScopeHandler; fn != nil {
  202. scope, err := fn(w, r)
  203. if err != nil {
  204. return err
  205. } else if scope != "" {
  206. req.Scope = scope
  207. }
  208. }
  209. // specify the expiration time of access token
  210. if fn := s.AccessTokenExpHandler; fn != nil {
  211. exp, err := fn(w, r)
  212. if err != nil {
  213. return err
  214. }
  215. req.AccessTokenExp = exp
  216. }
  217. ti, err := s.GetAuthorizeToken(req)
  218. if err != nil {
  219. return s.redirectError(w, req, err)
  220. }
  221. // If the redirect URI is empty, the default domain provided by the client is used.
  222. if req.RedirectURI == "" {
  223. client, err := s.Manager.GetClient(req.ClientID)
  224. if err != nil {
  225. return err
  226. }
  227. req.RedirectURI = client.GetDomain()
  228. }
  229. return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti))
  230. }
  231. // ValidationTokenRequest the token request validation
  232. func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) {
  233. if v := r.Method; !(v == "POST" ||
  234. (s.Config.AllowGetAccessRequest && v == "GET")) {
  235. return "", nil, errors.ErrInvalidRequest
  236. }
  237. gt := oauth2.GrantType(r.FormValue("grant_type"))
  238. if gt.String() == "" {
  239. return "", nil, errors.ErrUnsupportedGrantType
  240. }
  241. clientID, clientSecret, err := s.ClientInfoHandler(r)
  242. if err != nil {
  243. return "", nil, err
  244. }
  245. tgr := &oauth2.TokenGenerateRequest{
  246. ClientID: clientID,
  247. ClientSecret: clientSecret,
  248. Request: r,
  249. }
  250. switch gt {
  251. case oauth2.AuthorizationCode:
  252. tgr.RedirectURI = r.FormValue("redirect_uri")
  253. tgr.Code = r.FormValue("code")
  254. if tgr.RedirectURI == "" ||
  255. tgr.Code == "" {
  256. return "", nil, errors.ErrInvalidRequest
  257. }
  258. case oauth2.PasswordCredentials:
  259. tgr.Scope = r.FormValue("scope")
  260. username, password := r.FormValue("username"), r.FormValue("password")
  261. if username == "" || password == "" {
  262. return "", nil, errors.ErrInvalidRequest
  263. }
  264. userID, err := s.PasswordAuthorizationHandler(username, password)
  265. if err != nil {
  266. return "", nil, err
  267. } else if userID == "" {
  268. return "", nil, errors.ErrInvalidGrant
  269. }
  270. tgr.UserID = userID
  271. case oauth2.ClientCredentials:
  272. tgr.Scope = r.FormValue("scope")
  273. case oauth2.Refreshing:
  274. tgr.Refresh = r.FormValue("refresh_token")
  275. tgr.Scope = r.FormValue("scope")
  276. if tgr.Refresh == "" {
  277. return "", nil, errors.ErrInvalidRequest
  278. }
  279. }
  280. return gt, tgr, nil
  281. }
  282. // CheckGrantType check allows grant type
  283. func (s *Server) CheckGrantType(gt oauth2.GrantType) bool {
  284. for _, agt := range s.Config.AllowedGrantTypes {
  285. if agt == gt {
  286. return true
  287. }
  288. }
  289. return false
  290. }
  291. // GetAccessToken access token
  292. func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
  293. if allowed := s.CheckGrantType(gt); !allowed {
  294. return nil, errors.ErrUnauthorizedClient
  295. }
  296. if fn := s.ClientAuthorizedHandler; fn != nil {
  297. allowed, err := fn(tgr.ClientID, gt)
  298. if err != nil {
  299. return nil, err
  300. } else if !allowed {
  301. return nil, errors.ErrUnauthorizedClient
  302. }
  303. }
  304. switch gt {
  305. case oauth2.AuthorizationCode:
  306. ti, err := s.Manager.GenerateAccessToken(gt, tgr)
  307. if err != nil {
  308. switch err {
  309. case errors.ErrInvalidAuthorizeCode:
  310. return nil, errors.ErrInvalidGrant
  311. case errors.ErrInvalidClient:
  312. return nil, errors.ErrInvalidClient
  313. default:
  314. return nil, err
  315. }
  316. }
  317. return ti, nil
  318. case oauth2.PasswordCredentials, oauth2.ClientCredentials:
  319. if fn := s.ClientScopeHandler; fn != nil {
  320. allowed, err := fn(tgr.ClientID, tgr.Scope)
  321. if err != nil {
  322. return nil, err
  323. } else if !allowed {
  324. return nil, errors.ErrInvalidScope
  325. }
  326. }
  327. return s.Manager.GenerateAccessToken(gt, tgr)
  328. case oauth2.Refreshing:
  329. // check scope
  330. if scope, scopeFn := tgr.Scope, s.RefreshingScopeHandler; scope != "" && scopeFn != nil {
  331. rti, err := s.Manager.LoadRefreshToken(tgr.Refresh)
  332. if err != nil {
  333. if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
  334. return nil, errors.ErrInvalidGrant
  335. }
  336. return nil, err
  337. }
  338. allowed, err := scopeFn(scope, rti.GetScope())
  339. if err != nil {
  340. return nil, err
  341. } else if !allowed {
  342. return nil, errors.ErrInvalidScope
  343. }
  344. }
  345. ti, err := s.Manager.RefreshAccessToken(tgr)
  346. if err != nil {
  347. if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
  348. return nil, errors.ErrInvalidGrant
  349. }
  350. return nil, err
  351. }
  352. return ti, nil
  353. }
  354. return nil, errors.ErrUnsupportedGrantType
  355. }
  356. // GetTokenData token data
  357. func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} {
  358. data := map[string]interface{}{
  359. "access_token": ti.GetAccess(),
  360. "token_type": s.Config.TokenType,
  361. "expires_in": int64(ti.GetAccessExpiresIn() / time.Second),
  362. }
  363. if scope := ti.GetScope(); scope != "" {
  364. data["scope"] = scope
  365. }
  366. if refresh := ti.GetRefresh(); refresh != "" {
  367. data["refresh_token"] = refresh
  368. }
  369. if fn := s.ExtensionFieldsHandler; fn != nil {
  370. ext := fn(ti)
  371. for k, v := range ext {
  372. if _, ok := data[k]; ok {
  373. continue
  374. }
  375. data[k] = v
  376. }
  377. }
  378. return data
  379. }
  380. // HandleTokenRequest token request handling
  381. func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
  382. gt, tgr, err := s.ValidationTokenRequest(r)
  383. if err != nil {
  384. return s.tokenError(w, err)
  385. }
  386. ti, err := s.GetAccessToken(gt, tgr)
  387. if err != nil {
  388. return s.tokenError(w, err)
  389. }
  390. return s.token(w, s.GetTokenData(ti), nil)
  391. }
  392. // GetErrorData get error response data
  393. func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) {
  394. var re errors.Response
  395. if v, ok := errors.Descriptions[err]; ok {
  396. re.Error = err
  397. re.Description = v
  398. re.StatusCode = errors.StatusCodes[err]
  399. } else {
  400. if fn := s.InternalErrorHandler; fn != nil {
  401. if v := fn(err); v != nil {
  402. re = *v
  403. }
  404. }
  405. if re.Error == nil {
  406. re.Error = errors.ErrServerError
  407. re.Description = errors.Descriptions[errors.ErrServerError]
  408. re.StatusCode = errors.StatusCodes[errors.ErrServerError]
  409. }
  410. }
  411. if fn := s.ResponseErrorHandler; fn != nil {
  412. fn(&re)
  413. }
  414. data := make(map[string]interface{})
  415. if err := re.Error; err != nil {
  416. data["error"] = err.Error()
  417. }
  418. if v := re.ErrorCode; v != 0 {
  419. data["error_code"] = v
  420. }
  421. if v := re.Description; v != "" {
  422. data["error_description"] = v
  423. }
  424. if v := re.URI; v != "" {
  425. data["error_uri"] = v
  426. }
  427. statusCode := http.StatusInternalServerError
  428. if v := re.StatusCode; v > 0 {
  429. statusCode = v
  430. }
  431. return data, statusCode, re.Header
  432. }
  433. // BearerAuth parse bearer token
  434. func (s *Server) BearerAuth(r *http.Request) (string, bool) {
  435. auth := r.Header.Get("Authorization")
  436. prefix := "Bearer "
  437. token := ""
  438. if auth != "" && strings.HasPrefix(auth, prefix) {
  439. token = auth[len(prefix):]
  440. } else {
  441. token = r.Header.Get("token")
  442. }
  443. return token, token != ""
  444. }
  445. // ValidationBearerToken validation the bearer tokens
  446. // https://tools.ietf.org/html/rfc6750
  447. func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
  448. accessToken, ok := s.BearerAuth(r)
  449. if !ok {
  450. return nil, errors.ErrInvalidAccessToken
  451. }
  452. return s.Manager.LoadAccessToken(accessToken)
  453. }