per_host.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proxy
  5. import (
  6. "context"
  7. "net"
  8. "strings"
  9. )
  10. // A PerHost directs connections to a default Dialer unless the host name
  11. // requested matches one of a number of exceptions.
  12. type PerHost struct {
  13. def, bypass Dialer
  14. bypassNetworks []*net.IPNet
  15. bypassIPs []net.IP
  16. bypassZones []string
  17. bypassHosts []string
  18. }
  19. // NewPerHost returns a PerHost Dialer that directs connections to either
  20. // defaultDialer or bypass, depending on whether the connection matches one of
  21. // the configured rules.
  22. func NewPerHost(defaultDialer, bypass Dialer) *PerHost {
  23. return &PerHost{
  24. def: defaultDialer,
  25. bypass: bypass,
  26. }
  27. }
  28. // Dial connects to the address addr on the given network through either
  29. // defaultDialer or bypass.
  30. func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) {
  31. host, _, err := net.SplitHostPort(addr)
  32. if err != nil {
  33. return nil, err
  34. }
  35. return p.dialerForRequest(host).Dial(network, addr)
  36. }
  37. // DialContext connects to the address addr on the given network through either
  38. // defaultDialer or bypass.
  39. func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) {
  40. host, _, err := net.SplitHostPort(addr)
  41. if err != nil {
  42. return nil, err
  43. }
  44. d := p.dialerForRequest(host)
  45. if x, ok := d.(ContextDialer); ok {
  46. return x.DialContext(ctx, network, addr)
  47. }
  48. return dialContext(ctx, d, network, addr)
  49. }
  50. func (p *PerHost) dialerForRequest(host string) Dialer {
  51. if ip := net.ParseIP(host); ip != nil {
  52. for _, net := range p.bypassNetworks {
  53. if net.Contains(ip) {
  54. return p.bypass
  55. }
  56. }
  57. for _, bypassIP := range p.bypassIPs {
  58. if bypassIP.Equal(ip) {
  59. return p.bypass
  60. }
  61. }
  62. return p.def
  63. }
  64. for _, zone := range p.bypassZones {
  65. if strings.HasSuffix(host, zone) {
  66. return p.bypass
  67. }
  68. if host == zone[1:] {
  69. // For a zone ".example.com", we match "example.com"
  70. // too.
  71. return p.bypass
  72. }
  73. }
  74. for _, bypassHost := range p.bypassHosts {
  75. if bypassHost == host {
  76. return p.bypass
  77. }
  78. }
  79. return p.def
  80. }
  81. // AddFromString parses a string that contains comma-separated values
  82. // specifying hosts that should use the bypass proxy. Each value is either an
  83. // IP address, a CIDR range, a zone (*.example.com) or a host name
  84. // (localhost). A best effort is made to parse the string and errors are
  85. // ignored.
  86. func (p *PerHost) AddFromString(s string) {
  87. hosts := strings.Split(s, ",")
  88. for _, host := range hosts {
  89. host = strings.TrimSpace(host)
  90. if len(host) == 0 {
  91. continue
  92. }
  93. if strings.Contains(host, "/") {
  94. // We assume that it's a CIDR address like 127.0.0.0/8
  95. if _, net, err := net.ParseCIDR(host); err == nil {
  96. p.AddNetwork(net)
  97. }
  98. continue
  99. }
  100. if ip := net.ParseIP(host); ip != nil {
  101. p.AddIP(ip)
  102. continue
  103. }
  104. if strings.HasPrefix(host, "*.") {
  105. p.AddZone(host[1:])
  106. continue
  107. }
  108. p.AddHost(host)
  109. }
  110. }
  111. // AddIP specifies an IP address that will use the bypass proxy. Note that
  112. // this will only take effect if a literal IP address is dialed. A connection
  113. // to a named host will never match an IP.
  114. func (p *PerHost) AddIP(ip net.IP) {
  115. p.bypassIPs = append(p.bypassIPs, ip)
  116. }
  117. // AddNetwork specifies an IP range that will use the bypass proxy. Note that
  118. // this will only take effect if a literal IP address is dialed. A connection
  119. // to a named host will never match.
  120. func (p *PerHost) AddNetwork(net *net.IPNet) {
  121. p.bypassNetworks = append(p.bypassNetworks, net)
  122. }
  123. // AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
  124. // "example.com" matches "example.com" and all of its subdomains.
  125. func (p *PerHost) AddZone(zone string) {
  126. if strings.HasSuffix(zone, ".") {
  127. zone = zone[:len(zone)-1]
  128. }
  129. if !strings.HasPrefix(zone, ".") {
  130. zone = "." + zone
  131. }
  132. p.bypassZones = append(p.bypassZones, zone)
  133. }
  134. // AddHost specifies a host name that will use the bypass proxy.
  135. func (p *PerHost) AddHost(host string) {
  136. if strings.HasSuffix(host, ".") {
  137. host = host[:len(host)-1]
  138. }
  139. p.bypassHosts = append(p.bypassHosts, host)
  140. }