srv.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. // Copyright 2015 The etcd Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Package srv looks up DNS SRV records.
  15. package srv
  16. import (
  17. "fmt"
  18. "net"
  19. "net/url"
  20. "strings"
  21. "github.com/coreos/etcd/pkg/types"
  22. )
  23. var (
  24. // indirection for testing
  25. lookupSRV = net.LookupSRV // net.DefaultResolver.LookupSRV when ctxs don't conflict
  26. resolveTCPAddr = net.ResolveTCPAddr
  27. )
  28. // GetCluster gets the cluster information via DNS discovery.
  29. // Also sees each entry as a separate instance.
  30. func GetCluster(service, name, dns string, apurls types.URLs) ([]string, error) {
  31. tempName := int(0)
  32. tcp2ap := make(map[string]url.URL)
  33. // First, resolve the apurls
  34. for _, url := range apurls {
  35. tcpAddr, err := resolveTCPAddr("tcp", url.Host)
  36. if err != nil {
  37. return nil, err
  38. }
  39. tcp2ap[tcpAddr.String()] = url
  40. }
  41. stringParts := []string{}
  42. updateNodeMap := func(service, scheme string) error {
  43. _, addrs, err := lookupSRV(service, "tcp", dns)
  44. if err != nil {
  45. return err
  46. }
  47. for _, srv := range addrs {
  48. port := fmt.Sprintf("%d", srv.Port)
  49. host := net.JoinHostPort(srv.Target, port)
  50. tcpAddr, terr := resolveTCPAddr("tcp", host)
  51. if terr != nil {
  52. err = terr
  53. continue
  54. }
  55. n := ""
  56. url, ok := tcp2ap[tcpAddr.String()]
  57. if ok {
  58. n = name
  59. }
  60. if n == "" {
  61. n = fmt.Sprintf("%d", tempName)
  62. tempName++
  63. }
  64. // SRV records have a trailing dot but URL shouldn't.
  65. shortHost := strings.TrimSuffix(srv.Target, ".")
  66. urlHost := net.JoinHostPort(shortHost, port)
  67. if ok && url.Scheme != scheme {
  68. err = fmt.Errorf("bootstrap at %s from DNS for %s has scheme mismatch with expected peer %s", scheme+"://"+urlHost, service, url.String())
  69. } else {
  70. stringParts = append(stringParts, fmt.Sprintf("%s=%s://%s", n, scheme, urlHost))
  71. }
  72. }
  73. if len(stringParts) == 0 {
  74. return err
  75. }
  76. return nil
  77. }
  78. failCount := 0
  79. err := updateNodeMap(service+"-ssl", "https")
  80. srvErr := make([]string, 2)
  81. if err != nil {
  82. srvErr[0] = fmt.Sprintf("error querying DNS SRV records for _%s-ssl %s", service, err)
  83. failCount++
  84. }
  85. err = updateNodeMap(service, "http")
  86. if err != nil {
  87. srvErr[1] = fmt.Sprintf("error querying DNS SRV records for _%s %s", service, err)
  88. failCount++
  89. }
  90. if failCount == 2 {
  91. return nil, fmt.Errorf("srv: too many errors querying DNS SRV records (%q, %q)", srvErr[0], srvErr[1])
  92. }
  93. return stringParts, nil
  94. }
  95. type SRVClients struct {
  96. Endpoints []string
  97. SRVs []*net.SRV
  98. }
  99. // GetClient looks up the client endpoints for a service and domain.
  100. func GetClient(service, domain string) (*SRVClients, error) {
  101. var urls []*url.URL
  102. var srvs []*net.SRV
  103. updateURLs := func(service, scheme string) error {
  104. _, addrs, err := lookupSRV(service, "tcp", domain)
  105. if err != nil {
  106. return err
  107. }
  108. for _, srv := range addrs {
  109. urls = append(urls, &url.URL{
  110. Scheme: scheme,
  111. Host: net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)),
  112. })
  113. }
  114. srvs = append(srvs, addrs...)
  115. return nil
  116. }
  117. errHTTPS := updateURLs(service+"-ssl", "https")
  118. errHTTP := updateURLs(service, "http")
  119. if errHTTPS != nil && errHTTP != nil {
  120. return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP)
  121. }
  122. endpoints := make([]string, len(urls))
  123. for i := range urls {
  124. endpoints[i] = urls[i].String()
  125. }
  126. return &SRVClients{Endpoints: endpoints, SRVs: srvs}, nil
  127. }