autocert_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. // Copyright 2016 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package autocert
  5. import (
  6. "crypto"
  7. "crypto/ecdsa"
  8. "crypto/elliptic"
  9. "crypto/rand"
  10. "crypto/rsa"
  11. "crypto/tls"
  12. "crypto/x509"
  13. "crypto/x509/pkix"
  14. "encoding/base64"
  15. "encoding/json"
  16. "fmt"
  17. "html/template"
  18. "io"
  19. "math/big"
  20. "net/http"
  21. "net/http/httptest"
  22. "reflect"
  23. "testing"
  24. "time"
  25. "golang.org/x/crypto/acme"
  26. "golang.org/x/net/context"
  27. )
  28. var discoTmpl = template.Must(template.New("disco").Parse(`{
  29. "new-reg": "{{.}}/new-reg",
  30. "new-authz": "{{.}}/new-authz",
  31. "new-cert": "{{.}}/new-cert"
  32. }`))
  33. var authzTmpl = template.Must(template.New("authz").Parse(`{
  34. "status": "pending",
  35. "challenges": [
  36. {
  37. "uri": "{{.}}/challenge/1",
  38. "type": "tls-sni-01",
  39. "token": "token-01"
  40. },
  41. {
  42. "uri": "{{.}}/challenge/2",
  43. "type": "tls-sni-02",
  44. "token": "token-02"
  45. }
  46. ]
  47. }`))
  48. type memCache map[string][]byte
  49. func (m memCache) Get(ctx context.Context, key string) ([]byte, error) {
  50. v, ok := m[key]
  51. if !ok {
  52. return nil, ErrCacheMiss
  53. }
  54. return v, nil
  55. }
  56. func (m memCache) Put(ctx context.Context, key string, data []byte) error {
  57. m[key] = data
  58. return nil
  59. }
  60. func (m memCache) Delete(ctx context.Context, key string) error {
  61. delete(m, key)
  62. return nil
  63. }
  64. func dummyCert(pub interface{}, san ...string) ([]byte, error) {
  65. return dateDummyCert(pub, time.Now(), time.Now().Add(90*24*time.Hour), san...)
  66. }
  67. func dateDummyCert(pub interface{}, start, end time.Time, san ...string) ([]byte, error) {
  68. // use EC key to run faster on 386
  69. key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  70. if err != nil {
  71. return nil, err
  72. }
  73. t := &x509.Certificate{
  74. SerialNumber: big.NewInt(1),
  75. NotBefore: start,
  76. NotAfter: end,
  77. BasicConstraintsValid: true,
  78. KeyUsage: x509.KeyUsageKeyEncipherment,
  79. DNSNames: san,
  80. }
  81. if pub == nil {
  82. pub = &key.PublicKey
  83. }
  84. return x509.CreateCertificate(rand.Reader, t, t, pub, key)
  85. }
  86. func decodePayload(v interface{}, r io.Reader) error {
  87. var req struct{ Payload string }
  88. if err := json.NewDecoder(r).Decode(&req); err != nil {
  89. return err
  90. }
  91. payload, err := base64.RawURLEncoding.DecodeString(req.Payload)
  92. if err != nil {
  93. return err
  94. }
  95. return json.Unmarshal(payload, v)
  96. }
  97. func TestGetCertificate(t *testing.T) {
  98. man := &Manager{Prompt: AcceptTOS}
  99. defer man.stopRenew()
  100. hello := &tls.ClientHelloInfo{ServerName: "example.org"}
  101. testGetCertificate(t, man, "example.org", hello)
  102. }
  103. func TestGetCertificate_trailingDot(t *testing.T) {
  104. man := &Manager{Prompt: AcceptTOS}
  105. defer man.stopRenew()
  106. hello := &tls.ClientHelloInfo{ServerName: "example.org."}
  107. testGetCertificate(t, man, "example.org", hello)
  108. }
  109. func TestGetCertificate_ForceRSA(t *testing.T) {
  110. man := &Manager{
  111. Prompt: AcceptTOS,
  112. Cache: make(memCache),
  113. ForceRSA: true,
  114. }
  115. defer man.stopRenew()
  116. hello := &tls.ClientHelloInfo{ServerName: "example.org"}
  117. testGetCertificate(t, man, "example.org", hello)
  118. cert, err := man.cacheGet("example.org")
  119. if err != nil {
  120. t.Fatalf("man.cacheGet: %v", err)
  121. }
  122. if _, ok := cert.PrivateKey.(*rsa.PrivateKey); !ok {
  123. t.Errorf("cert.PrivateKey is %T; want *rsa.PrivateKey", cert.PrivateKey)
  124. }
  125. }
  126. // tests man.GetCertificate flow using the provided hello argument.
  127. // The domain argument is the expected domain name of a certificate request.
  128. func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) {
  129. // echo token-02 | shasum -a 256
  130. // then divide result in 2 parts separated by dot
  131. tokenCertName := "4e8eb87631187e9ff2153b56b13a4dec.13a35d002e485d60ff37354b32f665d9.token.acme.invalid"
  132. verifyTokenCert := func() {
  133. hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
  134. _, err := man.GetCertificate(hello)
  135. if err != nil {
  136. t.Errorf("verifyTokenCert: GetCertificate(%q): %v", tokenCertName, err)
  137. return
  138. }
  139. }
  140. // ACME CA server stub
  141. var ca *httptest.Server
  142. ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  143. w.Header().Set("replay-nonce", "nonce")
  144. if r.Method == "HEAD" {
  145. // a nonce request
  146. return
  147. }
  148. switch r.URL.Path {
  149. // discovery
  150. case "/":
  151. if err := discoTmpl.Execute(w, ca.URL); err != nil {
  152. t.Fatalf("discoTmpl: %v", err)
  153. }
  154. // client key registration
  155. case "/new-reg":
  156. w.Write([]byte("{}"))
  157. // domain authorization
  158. case "/new-authz":
  159. w.Header().Set("location", ca.URL+"/authz/1")
  160. w.WriteHeader(http.StatusCreated)
  161. if err := authzTmpl.Execute(w, ca.URL); err != nil {
  162. t.Fatalf("authzTmpl: %v", err)
  163. }
  164. // accept tls-sni-02 challenge
  165. case "/challenge/2":
  166. verifyTokenCert()
  167. w.Write([]byte("{}"))
  168. // authorization status
  169. case "/authz/1":
  170. w.Write([]byte(`{"status": "valid"}`))
  171. // cert request
  172. case "/new-cert":
  173. var req struct {
  174. CSR string `json:"csr"`
  175. }
  176. decodePayload(&req, r.Body)
  177. b, _ := base64.RawURLEncoding.DecodeString(req.CSR)
  178. csr, err := x509.ParseCertificateRequest(b)
  179. if err != nil {
  180. t.Fatalf("new-cert: CSR: %v", err)
  181. }
  182. if csr.Subject.CommonName != domain {
  183. t.Errorf("CommonName in CSR = %q; want %q", csr.Subject.CommonName, domain)
  184. }
  185. der, err := dummyCert(csr.PublicKey, domain)
  186. if err != nil {
  187. t.Fatalf("new-cert: dummyCert: %v", err)
  188. }
  189. chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL)
  190. w.Header().Set("link", chainUp)
  191. w.WriteHeader(http.StatusCreated)
  192. w.Write(der)
  193. // CA chain cert
  194. case "/ca-cert":
  195. der, err := dummyCert(nil, "ca")
  196. if err != nil {
  197. t.Fatalf("ca-cert: dummyCert: %v", err)
  198. }
  199. w.Write(der)
  200. default:
  201. t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path)
  202. }
  203. }))
  204. defer ca.Close()
  205. // use EC key to run faster on 386
  206. key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  207. if err != nil {
  208. t.Fatal(err)
  209. }
  210. man.Client = &acme.Client{
  211. Key: key,
  212. DirectoryURL: ca.URL,
  213. }
  214. // simulate tls.Config.GetCertificate
  215. var tlscert *tls.Certificate
  216. done := make(chan struct{})
  217. go func() {
  218. tlscert, err = man.GetCertificate(hello)
  219. close(done)
  220. }()
  221. select {
  222. case <-time.After(time.Minute):
  223. t.Fatal("man.GetCertificate took too long to return")
  224. case <-done:
  225. }
  226. if err != nil {
  227. t.Fatalf("man.GetCertificate: %v", err)
  228. }
  229. // verify the tlscert is the same we responded with from the CA stub
  230. if len(tlscert.Certificate) == 0 {
  231. t.Fatal("len(tlscert.Certificate) is 0")
  232. }
  233. cert, err := x509.ParseCertificate(tlscert.Certificate[0])
  234. if err != nil {
  235. t.Fatalf("x509.ParseCertificate: %v", err)
  236. }
  237. if len(cert.DNSNames) == 0 || cert.DNSNames[0] != domain {
  238. t.Errorf("cert.DNSNames = %v; want %q", cert.DNSNames, domain)
  239. }
  240. // make sure token cert was removed
  241. done = make(chan struct{})
  242. go func() {
  243. for {
  244. hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
  245. if _, err := man.GetCertificate(hello); err != nil {
  246. break
  247. }
  248. time.Sleep(100 * time.Millisecond)
  249. }
  250. close(done)
  251. }()
  252. select {
  253. case <-time.After(5 * time.Second):
  254. t.Error("token cert was not removed")
  255. case <-done:
  256. }
  257. }
  258. func TestAccountKeyCache(t *testing.T) {
  259. cache := make(memCache)
  260. m := Manager{Cache: cache}
  261. ctx := context.Background()
  262. k1, err := m.accountKey(ctx)
  263. if err != nil {
  264. t.Fatal(err)
  265. }
  266. k2, err := m.accountKey(ctx)
  267. if err != nil {
  268. t.Fatal(err)
  269. }
  270. if !reflect.DeepEqual(k1, k2) {
  271. t.Errorf("account keys don't match: k1 = %#v; k2 = %#v", k1, k2)
  272. }
  273. }
  274. func TestCache(t *testing.T) {
  275. privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  276. if err != nil {
  277. t.Fatal(err)
  278. }
  279. tmpl := &x509.Certificate{
  280. SerialNumber: big.NewInt(1),
  281. Subject: pkix.Name{CommonName: "example.org"},
  282. NotAfter: time.Now().Add(time.Hour),
  283. }
  284. pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privKey.PublicKey, privKey)
  285. if err != nil {
  286. t.Fatal(err)
  287. }
  288. tlscert := &tls.Certificate{
  289. Certificate: [][]byte{pub},
  290. PrivateKey: privKey,
  291. }
  292. cache := make(memCache)
  293. man := &Manager{Cache: cache}
  294. defer man.stopRenew()
  295. if err := man.cachePut("example.org", tlscert); err != nil {
  296. t.Fatalf("man.cachePut: %v", err)
  297. }
  298. res, err := man.cacheGet("example.org")
  299. if err != nil {
  300. t.Fatalf("man.cacheGet: %v", err)
  301. }
  302. if res == nil {
  303. t.Fatal("res is nil")
  304. }
  305. }
  306. func TestHostWhitelist(t *testing.T) {
  307. policy := HostWhitelist("example.com", "example.org", "*.example.net")
  308. tt := []struct {
  309. host string
  310. allow bool
  311. }{
  312. {"example.com", true},
  313. {"example.org", true},
  314. {"one.example.com", false},
  315. {"two.example.org", false},
  316. {"three.example.net", false},
  317. {"dummy", false},
  318. }
  319. for i, test := range tt {
  320. err := policy(nil, test.host)
  321. if err != nil && test.allow {
  322. t.Errorf("%d: policy(%q): %v; want nil", i, test.host, err)
  323. }
  324. if err == nil && !test.allow {
  325. t.Errorf("%d: policy(%q): nil; want an error", i, test.host)
  326. }
  327. }
  328. }
  329. func TestValidCert(t *testing.T) {
  330. key1, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  331. if err != nil {
  332. t.Fatal(err)
  333. }
  334. key2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  335. if err != nil {
  336. t.Fatal(err)
  337. }
  338. key3, err := rsa.GenerateKey(rand.Reader, 512)
  339. if err != nil {
  340. t.Fatal(err)
  341. }
  342. cert1, err := dummyCert(key1.Public(), "example.org")
  343. if err != nil {
  344. t.Fatal(err)
  345. }
  346. cert2, err := dummyCert(key2.Public(), "example.org")
  347. if err != nil {
  348. t.Fatal(err)
  349. }
  350. cert3, err := dummyCert(key3.Public(), "example.org")
  351. if err != nil {
  352. t.Fatal(err)
  353. }
  354. now := time.Now()
  355. early, err := dateDummyCert(key1.Public(), now.Add(time.Hour), now.Add(2*time.Hour), "example.org")
  356. if err != nil {
  357. t.Fatal(err)
  358. }
  359. expired, err := dateDummyCert(key1.Public(), now.Add(-2*time.Hour), now.Add(-time.Hour), "example.org")
  360. if err != nil {
  361. t.Fatal(err)
  362. }
  363. tt := []struct {
  364. domain string
  365. key crypto.Signer
  366. cert [][]byte
  367. ok bool
  368. }{
  369. {"example.org", key1, [][]byte{cert1}, true},
  370. {"example.org", key3, [][]byte{cert3}, true},
  371. {"example.org", key1, [][]byte{cert1, cert2, cert3}, true},
  372. {"example.org", key1, [][]byte{cert1, {1}}, false},
  373. {"example.org", key1, [][]byte{{1}}, false},
  374. {"example.org", key1, [][]byte{cert2}, false},
  375. {"example.org", key2, [][]byte{cert1}, false},
  376. {"example.org", key1, [][]byte{cert3}, false},
  377. {"example.org", key3, [][]byte{cert1}, false},
  378. {"example.net", key1, [][]byte{cert1}, false},
  379. {"example.org", key1, [][]byte{early}, false},
  380. {"example.org", key1, [][]byte{expired}, false},
  381. }
  382. for i, test := range tt {
  383. leaf, err := validCert(test.domain, test.cert, test.key)
  384. if err != nil && test.ok {
  385. t.Errorf("%d: err = %v", i, err)
  386. }
  387. if err == nil && !test.ok {
  388. t.Errorf("%d: err is nil", i)
  389. }
  390. if err == nil && test.ok && leaf == nil {
  391. t.Errorf("%d: leaf is nil", i)
  392. }
  393. }
  394. }