client.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. // Copyright 2018 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package socks
  5. import (
  6. "context"
  7. "errors"
  8. "io"
  9. "net"
  10. "strconv"
  11. "time"
  12. )
  13. var (
  14. noDeadline = time.Time{}
  15. aLongTimeAgo = time.Unix(1, 0)
  16. )
  17. func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
  18. host, port, err := splitHostPort(address)
  19. if err != nil {
  20. return nil, err
  21. }
  22. if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
  23. c.SetDeadline(deadline)
  24. defer c.SetDeadline(noDeadline)
  25. }
  26. if ctx != context.Background() {
  27. errCh := make(chan error, 1)
  28. done := make(chan struct{})
  29. defer func() {
  30. close(done)
  31. if ctxErr == nil {
  32. ctxErr = <-errCh
  33. }
  34. }()
  35. go func() {
  36. select {
  37. case <-ctx.Done():
  38. c.SetDeadline(aLongTimeAgo)
  39. errCh <- ctx.Err()
  40. case <-done:
  41. errCh <- nil
  42. }
  43. }()
  44. }
  45. b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
  46. b = append(b, Version5)
  47. if len(d.AuthMethods) == 0 || d.Authenticate == nil {
  48. b = append(b, 1, byte(AuthMethodNotRequired))
  49. } else {
  50. ams := d.AuthMethods
  51. if len(ams) > 255 {
  52. return nil, errors.New("too many authentication methods")
  53. }
  54. b = append(b, byte(len(ams)))
  55. for _, am := range ams {
  56. b = append(b, byte(am))
  57. }
  58. }
  59. if _, ctxErr = c.Write(b); ctxErr != nil {
  60. return
  61. }
  62. if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
  63. return
  64. }
  65. if b[0] != Version5 {
  66. return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
  67. }
  68. am := AuthMethod(b[1])
  69. if am == AuthMethodNoAcceptableMethods {
  70. return nil, errors.New("no acceptable authentication methods")
  71. }
  72. if d.Authenticate != nil {
  73. if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
  74. return
  75. }
  76. }
  77. b = b[:0]
  78. b = append(b, Version5, byte(d.cmd), 0)
  79. if ip := net.ParseIP(host); ip != nil {
  80. if ip4 := ip.To4(); ip4 != nil {
  81. b = append(b, AddrTypeIPv4)
  82. b = append(b, ip4...)
  83. } else if ip6 := ip.To16(); ip6 != nil {
  84. b = append(b, AddrTypeIPv6)
  85. b = append(b, ip6...)
  86. } else {
  87. return nil, errors.New("unknown address type")
  88. }
  89. } else {
  90. if len(host) > 255 {
  91. return nil, errors.New("FQDN too long")
  92. }
  93. b = append(b, AddrTypeFQDN)
  94. b = append(b, byte(len(host)))
  95. b = append(b, host...)
  96. }
  97. b = append(b, byte(port>>8), byte(port))
  98. if _, ctxErr = c.Write(b); ctxErr != nil {
  99. return
  100. }
  101. if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
  102. return
  103. }
  104. if b[0] != Version5 {
  105. return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
  106. }
  107. if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
  108. return nil, errors.New("unknown error " + cmdErr.String())
  109. }
  110. if b[2] != 0 {
  111. return nil, errors.New("non-zero reserved field")
  112. }
  113. l := 2
  114. var a Addr
  115. switch b[3] {
  116. case AddrTypeIPv4:
  117. l += net.IPv4len
  118. a.IP = make(net.IP, net.IPv4len)
  119. case AddrTypeIPv6:
  120. l += net.IPv6len
  121. a.IP = make(net.IP, net.IPv6len)
  122. case AddrTypeFQDN:
  123. if _, err := io.ReadFull(c, b[:1]); err != nil {
  124. return nil, err
  125. }
  126. l += int(b[0])
  127. default:
  128. return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
  129. }
  130. if cap(b) < l {
  131. b = make([]byte, l)
  132. } else {
  133. b = b[:l]
  134. }
  135. if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
  136. return
  137. }
  138. if a.IP != nil {
  139. copy(a.IP, b)
  140. } else {
  141. a.Name = string(b[:len(b)-2])
  142. }
  143. a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
  144. return &a, nil
  145. }
  146. func splitHostPort(address string) (string, int, error) {
  147. host, port, err := net.SplitHostPort(address)
  148. if err != nil {
  149. return "", 0, err
  150. }
  151. portnum, err := strconv.Atoi(port)
  152. if err != nil {
  153. return "", 0, err
  154. }
  155. if 1 > portnum || portnum > 0xffff {
  156. return "", 0, errors.New("port number out of range " + port)
  157. }
  158. return host, portnum, nil
  159. }