server.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. // Copyright 2018 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package sockstest provides utilities for SOCKS testing.
  5. package sockstest
  6. import (
  7. "errors"
  8. "io"
  9. "net"
  10. "golang.org/x/net/internal/socks"
  11. "golang.org/x/net/nettest"
  12. )
  13. // An AuthRequest represents an authentication request.
  14. type AuthRequest struct {
  15. Version int
  16. Methods []socks.AuthMethod
  17. }
  18. // ParseAuthRequest parses an authentication request.
  19. func ParseAuthRequest(b []byte) (*AuthRequest, error) {
  20. if len(b) < 2 {
  21. return nil, errors.New("short auth request")
  22. }
  23. if b[0] != socks.Version5 {
  24. return nil, errors.New("unexpected protocol version")
  25. }
  26. if len(b)-2 < int(b[1]) {
  27. return nil, errors.New("short auth request")
  28. }
  29. req := &AuthRequest{Version: int(b[0])}
  30. if b[1] > 0 {
  31. req.Methods = make([]socks.AuthMethod, b[1])
  32. for i, m := range b[2 : 2+b[1]] {
  33. req.Methods[i] = socks.AuthMethod(m)
  34. }
  35. }
  36. return req, nil
  37. }
  38. // MarshalAuthReply returns an authentication reply in wire format.
  39. func MarshalAuthReply(ver int, m socks.AuthMethod) ([]byte, error) {
  40. return []byte{byte(ver), byte(m)}, nil
  41. }
  42. // A CmdRequest repesents a command request.
  43. type CmdRequest struct {
  44. Version int
  45. Cmd socks.Command
  46. Addr socks.Addr
  47. }
  48. // ParseCmdRequest parses a command request.
  49. func ParseCmdRequest(b []byte) (*CmdRequest, error) {
  50. if len(b) < 7 {
  51. return nil, errors.New("short cmd request")
  52. }
  53. if b[0] != socks.Version5 {
  54. return nil, errors.New("unexpected protocol version")
  55. }
  56. if socks.Command(b[1]) != socks.CmdConnect {
  57. return nil, errors.New("unexpected command")
  58. }
  59. if b[2] != 0 {
  60. return nil, errors.New("non-zero reserved field")
  61. }
  62. req := &CmdRequest{Version: int(b[0]), Cmd: socks.Command(b[1])}
  63. l := 2
  64. off := 4
  65. switch b[3] {
  66. case socks.AddrTypeIPv4:
  67. l += net.IPv4len
  68. req.Addr.IP = make(net.IP, net.IPv4len)
  69. case socks.AddrTypeIPv6:
  70. l += net.IPv6len
  71. req.Addr.IP = make(net.IP, net.IPv6len)
  72. case socks.AddrTypeFQDN:
  73. l += int(b[4])
  74. off = 5
  75. default:
  76. return nil, errors.New("unknown address type")
  77. }
  78. if len(b[off:]) < l {
  79. return nil, errors.New("short cmd request")
  80. }
  81. if req.Addr.IP != nil {
  82. copy(req.Addr.IP, b[off:])
  83. } else {
  84. req.Addr.Name = string(b[off : off+l-2])
  85. }
  86. req.Addr.Port = int(b[off+l-2])<<8 | int(b[off+l-1])
  87. return req, nil
  88. }
  89. // MarshalCmdReply returns a command reply in wire format.
  90. func MarshalCmdReply(ver int, reply socks.Reply, a *socks.Addr) ([]byte, error) {
  91. b := make([]byte, 4)
  92. b[0] = byte(ver)
  93. b[1] = byte(reply)
  94. if a.Name != "" {
  95. if len(a.Name) > 255 {
  96. return nil, errors.New("fqdn too long")
  97. }
  98. b[3] = socks.AddrTypeFQDN
  99. b = append(b, byte(len(a.Name)))
  100. b = append(b, a.Name...)
  101. } else if ip4 := a.IP.To4(); ip4 != nil {
  102. b[3] = socks.AddrTypeIPv4
  103. b = append(b, ip4...)
  104. } else if ip6 := a.IP.To16(); ip6 != nil {
  105. b[3] = socks.AddrTypeIPv6
  106. b = append(b, ip6...)
  107. } else {
  108. return nil, errors.New("unknown address type")
  109. }
  110. b = append(b, byte(a.Port>>8), byte(a.Port))
  111. return b, nil
  112. }
  113. // A Server repesents a server for handshake testing.
  114. type Server struct {
  115. ln net.Listener
  116. }
  117. // Addr rerurns a server address.
  118. func (s *Server) Addr() net.Addr {
  119. return s.ln.Addr()
  120. }
  121. // TargetAddr returns a fake final destination address.
  122. //
  123. // The returned address is only valid for testing with Server.
  124. func (s *Server) TargetAddr() net.Addr {
  125. a := s.ln.Addr()
  126. switch a := a.(type) {
  127. case *net.TCPAddr:
  128. if a.IP.To4() != nil {
  129. return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 5963}
  130. }
  131. if a.IP.To16() != nil && a.IP.To4() == nil {
  132. return &net.TCPAddr{IP: net.IPv6loopback, Port: 5963}
  133. }
  134. }
  135. return nil
  136. }
  137. // Close closes the server.
  138. func (s *Server) Close() error {
  139. return s.ln.Close()
  140. }
  141. func (s *Server) serve(authFunc, cmdFunc func(io.ReadWriter, []byte) error) {
  142. c, err := s.ln.Accept()
  143. if err != nil {
  144. return
  145. }
  146. defer c.Close()
  147. go s.serve(authFunc, cmdFunc)
  148. b := make([]byte, 512)
  149. n, err := c.Read(b)
  150. if err != nil {
  151. return
  152. }
  153. if err := authFunc(c, b[:n]); err != nil {
  154. return
  155. }
  156. n, err = c.Read(b)
  157. if err != nil {
  158. return
  159. }
  160. if err := cmdFunc(c, b[:n]); err != nil {
  161. return
  162. }
  163. }
  164. // NewServer returns a new server.
  165. //
  166. // The provided authFunc and cmdFunc must parse requests and return
  167. // appropriate replies to clients.
  168. func NewServer(authFunc, cmdFunc func(io.ReadWriter, []byte) error) (*Server, error) {
  169. var err error
  170. s := new(Server)
  171. s.ln, err = nettest.NewLocalListener("tcp")
  172. if err != nil {
  173. return nil, err
  174. }
  175. go s.serve(authFunc, cmdFunc)
  176. return s, nil
  177. }
  178. // NoAuthRequired handles a no-authentication-required signaling.
  179. func NoAuthRequired(rw io.ReadWriter, b []byte) error {
  180. req, err := ParseAuthRequest(b)
  181. if err != nil {
  182. return err
  183. }
  184. b, err = MarshalAuthReply(req.Version, socks.AuthMethodNotRequired)
  185. if err != nil {
  186. return err
  187. }
  188. n, err := rw.Write(b)
  189. if err != nil {
  190. return err
  191. }
  192. if n != len(b) {
  193. return errors.New("short write")
  194. }
  195. return nil
  196. }
  197. // NoProxyRequired handles a command signaling without constructing a
  198. // proxy connection to the final destination.
  199. func NoProxyRequired(rw io.ReadWriter, b []byte) error {
  200. req, err := ParseCmdRequest(b)
  201. if err != nil {
  202. return err
  203. }
  204. req.Addr.Port += 1
  205. if req.Addr.Name != "" {
  206. req.Addr.Name = "boundaddr.doesnotexist"
  207. } else if req.Addr.IP.To4() != nil {
  208. req.Addr.IP = net.IPv4(127, 0, 0, 1)
  209. } else {
  210. req.Addr.IP = net.IPv6loopback
  211. }
  212. b, err = MarshalCmdReply(socks.Version5, socks.StatusSucceeded, &req.Addr)
  213. if err != nil {
  214. return err
  215. }
  216. n, err := rw.Write(b)
  217. if err != nil {
  218. return err
  219. }
  220. if n != len(b) {
  221. return errors.New("short write")
  222. }
  223. return nil
  224. }