srv.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. // Copyright 2015 The etcd Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package srv
  15. import (
  16. "fmt"
  17. "net"
  18. "net/url"
  19. "strings"
  20. "github.com/coreos/etcd/pkg/types"
  21. )
  22. var (
  23. // indirection for testing
  24. lookupSRV = net.LookupSRV // net.DefaultResolver.LookupSRV when ctxs don't conflict
  25. resolveTCPAddr = net.ResolveTCPAddr
  26. )
  27. // GetCluster gets the cluster information via DNS discovery.
  28. // Also sees each entry as a separate instance.
  29. func GetCluster(service, name, dns string, apurls types.URLs) ([]string, error) {
  30. tempName := int(0)
  31. tcp2ap := make(map[string]url.URL)
  32. // First, resolve the apurls
  33. for _, url := range apurls {
  34. tcpAddr, err := resolveTCPAddr("tcp", url.Host)
  35. if err != nil {
  36. return nil, err
  37. }
  38. tcp2ap[tcpAddr.String()] = url
  39. }
  40. stringParts := []string{}
  41. updateNodeMap := func(service, scheme string) error {
  42. _, addrs, err := lookupSRV(service, "tcp", dns)
  43. if err != nil {
  44. return err
  45. }
  46. for _, srv := range addrs {
  47. port := fmt.Sprintf("%d", srv.Port)
  48. host := net.JoinHostPort(srv.Target, port)
  49. tcpAddr, terr := resolveTCPAddr("tcp", host)
  50. if terr != nil {
  51. terr = err
  52. continue
  53. }
  54. n := ""
  55. url, ok := tcp2ap[tcpAddr.String()]
  56. if ok {
  57. n = name
  58. }
  59. if n == "" {
  60. n = fmt.Sprintf("%d", tempName)
  61. tempName++
  62. }
  63. // SRV records have a trailing dot but URL shouldn't.
  64. shortHost := strings.TrimSuffix(srv.Target, ".")
  65. urlHost := net.JoinHostPort(shortHost, port)
  66. stringParts = append(stringParts, fmt.Sprintf("%s=%s://%s", n, scheme, urlHost))
  67. if ok && url.Scheme != scheme {
  68. err = fmt.Errorf("bootstrap at %s from DNS for %s has scheme mismatch with expected peer %s", scheme+"://"+urlHost, service, url.String())
  69. }
  70. }
  71. if len(stringParts) == 0 {
  72. return err
  73. }
  74. return nil
  75. }
  76. failCount := 0
  77. err := updateNodeMap(service+"-ssl", "https")
  78. srvErr := make([]string, 2)
  79. if err != nil {
  80. srvErr[0] = fmt.Sprintf("error querying DNS SRV records for _%s-ssl %s", service, err)
  81. failCount++
  82. }
  83. err = updateNodeMap(service, "http")
  84. if err != nil {
  85. srvErr[1] = fmt.Sprintf("error querying DNS SRV records for _%s %s", service, err)
  86. failCount++
  87. }
  88. if failCount == 2 {
  89. return nil, fmt.Errorf("srv: too many errors querying DNS SRV records (%q, %q)", srvErr[0], srvErr[1])
  90. }
  91. return stringParts, nil
  92. }
  93. type SRVClients struct {
  94. Endpoints []string
  95. SRVs []*net.SRV
  96. }
  97. // GetClient looks up the client endpoints for a service and domain.
  98. func GetClient(service, domain string) (*SRVClients, error) {
  99. var urls []*url.URL
  100. var srvs []*net.SRV
  101. updateURLs := func(service, scheme string) error {
  102. _, addrs, err := lookupSRV(service, "tcp", domain)
  103. if err != nil {
  104. return err
  105. }
  106. for _, srv := range addrs {
  107. urls = append(urls, &url.URL{
  108. Scheme: scheme,
  109. Host: net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)),
  110. })
  111. }
  112. srvs = append(srvs, addrs...)
  113. return nil
  114. }
  115. errHTTPS := updateURLs(service+"-ssl", "https")
  116. errHTTP := updateURLs(service, "http")
  117. if errHTTPS != nil && errHTTP != nil {
  118. return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP)
  119. }
  120. endpoints := make([]string, len(urls))
  121. for i := range urls {
  122. endpoints[i] = urls[i].String()
  123. }
  124. return &SRVClients{Endpoints: endpoints, SRVs: srvs}, nil
  125. }