srv.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. // Copyright 2015 The etcd Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Package srv looks up DNS SRV records.
  15. package srv
  16. import (
  17. "fmt"
  18. "net"
  19. "net/url"
  20. "strings"
  21. "go.etcd.io/etcd/pkg/types"
  22. )
  23. var (
  24. // indirection for testing
  25. lookupSRV = net.LookupSRV // net.DefaultResolver.LookupSRV when ctxs don't conflict
  26. resolveTCPAddr = net.ResolveTCPAddr
  27. )
  28. // GetCluster gets the cluster information via DNS discovery.
  29. // Also sees each entry as a separate instance.
  30. func GetCluster(serviceScheme, service, name, dns string, apurls types.URLs) ([]string, error) {
  31. tempName := int(0)
  32. tcp2ap := make(map[string]url.URL)
  33. // First, resolve the apurls
  34. for _, url := range apurls {
  35. tcpAddr, err := resolveTCPAddr("tcp", url.Host)
  36. if err != nil {
  37. return nil, err
  38. }
  39. tcp2ap[tcpAddr.String()] = url
  40. }
  41. stringParts := []string{}
  42. updateNodeMap := func(service, scheme string) error {
  43. _, addrs, err := lookupSRV(service, "tcp", dns)
  44. if err != nil {
  45. return err
  46. }
  47. for _, srv := range addrs {
  48. port := fmt.Sprintf("%d", srv.Port)
  49. host := net.JoinHostPort(srv.Target, port)
  50. tcpAddr, terr := resolveTCPAddr("tcp", host)
  51. if terr != nil {
  52. err = terr
  53. continue
  54. }
  55. n := ""
  56. url, ok := tcp2ap[tcpAddr.String()]
  57. if ok {
  58. n = name
  59. }
  60. if n == "" {
  61. n = fmt.Sprintf("%d", tempName)
  62. tempName++
  63. }
  64. // SRV records have a trailing dot but URL shouldn't.
  65. shortHost := strings.TrimSuffix(srv.Target, ".")
  66. urlHost := net.JoinHostPort(shortHost, port)
  67. if ok && url.Scheme != scheme {
  68. err = fmt.Errorf("bootstrap at %s from DNS for %s has scheme mismatch with expected peer %s", scheme+"://"+urlHost, service, url.String())
  69. } else {
  70. stringParts = append(stringParts, fmt.Sprintf("%s=%s://%s", n, scheme, urlHost))
  71. }
  72. }
  73. if len(stringParts) == 0 {
  74. return err
  75. }
  76. return nil
  77. }
  78. err := updateNodeMap(service, serviceScheme)
  79. if err != nil {
  80. return nil, fmt.Errorf("error querying DNS SRV records for _%s %s", service, err)
  81. }
  82. return stringParts, nil
  83. }
  84. type SRVClients struct {
  85. Endpoints []string
  86. SRVs []*net.SRV
  87. }
  88. // GetClient looks up the client endpoints for a service and domain.
  89. func GetClient(service, domain string, serviceName string) (*SRVClients, error) {
  90. var urls []*url.URL
  91. var srvs []*net.SRV
  92. updateURLs := func(service, scheme string) error {
  93. _, addrs, err := lookupSRV(service, "tcp", domain)
  94. if err != nil {
  95. return err
  96. }
  97. for _, srv := range addrs {
  98. urls = append(urls, &url.URL{
  99. Scheme: scheme,
  100. Host: net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)),
  101. })
  102. }
  103. srvs = append(srvs, addrs...)
  104. return nil
  105. }
  106. errHTTPS := updateURLs(GetSRVService(service, serviceName, "https"), "https")
  107. errHTTP := updateURLs(GetSRVService(service, serviceName, "http"), "http")
  108. if errHTTPS != nil && errHTTP != nil {
  109. return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP)
  110. }
  111. endpoints := make([]string, len(urls))
  112. for i := range urls {
  113. endpoints[i] = urls[i].String()
  114. }
  115. return &SRVClients{Endpoints: endpoints, SRVs: srvs}, nil
  116. }
  117. // GetSRVService generates a SRV service including an optional suffix.
  118. func GetSRVService(service, serviceName string, scheme string) (SRVService string) {
  119. if scheme == "https" {
  120. service = fmt.Sprintf("%s-ssl", service)
  121. }
  122. if serviceName != "" {
  123. return fmt.Sprintf("%s-%s", service, serviceName)
  124. }
  125. return service
  126. }