contentsecurityhandler.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. package handler
  2. import (
  3. "net/http"
  4. "time"
  5. "github.com/tal-tech/go-zero/core/codec"
  6. "github.com/tal-tech/go-zero/core/logx"
  7. "github.com/tal-tech/go-zero/rest/httpx"
  8. "github.com/tal-tech/go-zero/rest/internal/security"
  9. )
  10. const contentSecurity = "X-Content-Security"
  11. type UnsignedCallback func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int)
  12. func ContentSecurityHandler(decrypters map[string]codec.RsaDecrypter, tolerance time.Duration,
  13. strict bool, callbacks ...UnsignedCallback) func(http.Handler) http.Handler {
  14. if len(callbacks) == 0 {
  15. callbacks = append(callbacks, handleVerificationFailure)
  16. }
  17. return func(next http.Handler) http.Handler {
  18. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  19. switch r.Method {
  20. case http.MethodDelete, http.MethodGet, http.MethodPost, http.MethodPut:
  21. header, err := security.ParseContentSecurity(decrypters, r)
  22. if err != nil {
  23. logx.Errorf("Signature parse failed, X-Content-Security: %s, error: %s",
  24. r.Header.Get(contentSecurity), err.Error())
  25. executeCallbacks(w, r, next, strict, httpx.CodeSignatureInvalidHeader, callbacks)
  26. } else if code := security.VerifySignature(r, header, tolerance); code != httpx.CodeSignaturePass {
  27. logx.Errorf("Signature verification failed, X-Content-Security: %s",
  28. r.Header.Get(contentSecurity))
  29. executeCallbacks(w, r, next, strict, code, callbacks)
  30. } else if r.ContentLength > 0 && header.Encrypted() {
  31. CryptionHandler(header.Key)(next).ServeHTTP(w, r)
  32. } else {
  33. next.ServeHTTP(w, r)
  34. }
  35. default:
  36. next.ServeHTTP(w, r)
  37. }
  38. })
  39. }
  40. }
  41. func executeCallbacks(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool,
  42. code int, callbacks []UnsignedCallback) {
  43. for _, callback := range callbacks {
  44. callback(w, r, next, strict, code)
  45. }
  46. }
  47. func handleVerificationFailure(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
  48. if strict {
  49. w.WriteHeader(http.StatusForbidden)
  50. } else {
  51. next.ServeHTTP(w, r)
  52. }
  53. }