server.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. package ldap
  2. import (
  3. "crypto/tls"
  4. "io"
  5. "log"
  6. "net"
  7. "strings"
  8. "sync"
  9. "github.com/nmcclain/asn1-ber"
  10. )
  11. type Binder interface {
  12. Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error)
  13. }
  14. type Searcher interface {
  15. Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error)
  16. }
  17. type Adder interface {
  18. Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error)
  19. }
  20. type Modifier interface {
  21. Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error)
  22. }
  23. type Deleter interface {
  24. Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error)
  25. }
  26. type ModifyDNr interface {
  27. ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error)
  28. }
  29. type Comparer interface {
  30. Compare(boundDN string, req CompareRequest, conn net.Conn) (LDAPResultCode, error)
  31. }
  32. type Abandoner interface {
  33. Abandon(boundDN string, conn net.Conn) error
  34. }
  35. type Extender interface {
  36. Extended(boundDN string, req ExtendedRequest, conn net.Conn) (LDAPResultCode, error)
  37. }
  38. type Unbinder interface {
  39. Unbind(boundDN string, conn net.Conn) (LDAPResultCode, error)
  40. }
  41. type Closer interface {
  42. Close(boundDN string, conn net.Conn) error
  43. }
  44. //
  45. type Server struct {
  46. BindFns map[string]Binder
  47. SearchFns map[string]Searcher
  48. AddFns map[string]Adder
  49. ModifyFns map[string]Modifier
  50. DeleteFns map[string]Deleter
  51. ModifyDNFns map[string]ModifyDNr
  52. CompareFns map[string]Comparer
  53. AbandonFns map[string]Abandoner
  54. ExtendedFns map[string]Extender
  55. UnbindFns map[string]Unbinder
  56. CloseFns map[string]Closer
  57. Quit chan bool
  58. EnforceLDAP bool
  59. Stats *Stats
  60. }
  61. type Stats struct {
  62. Conns int
  63. Binds int
  64. Unbinds int
  65. Searches int
  66. statsMutex sync.Mutex
  67. }
  68. type ServerSearchResult struct {
  69. Entries []*Entry
  70. Referrals []string
  71. Controls []Control
  72. ResultCode LDAPResultCode
  73. }
  74. //
  75. func NewServer() *Server {
  76. s := new(Server)
  77. s.Quit = make(chan bool)
  78. d := defaultHandler{}
  79. s.BindFns = make(map[string]Binder)
  80. s.SearchFns = make(map[string]Searcher)
  81. s.AddFns = make(map[string]Adder)
  82. s.ModifyFns = make(map[string]Modifier)
  83. s.DeleteFns = make(map[string]Deleter)
  84. s.ModifyDNFns = make(map[string]ModifyDNr)
  85. s.CompareFns = make(map[string]Comparer)
  86. s.AbandonFns = make(map[string]Abandoner)
  87. s.ExtendedFns = make(map[string]Extender)
  88. s.UnbindFns = make(map[string]Unbinder)
  89. s.CloseFns = make(map[string]Closer)
  90. s.BindFunc("", d)
  91. s.SearchFunc("", d)
  92. s.AddFunc("", d)
  93. s.ModifyFunc("", d)
  94. s.DeleteFunc("", d)
  95. s.ModifyDNFunc("", d)
  96. s.CompareFunc("", d)
  97. s.AbandonFunc("", d)
  98. s.ExtendedFunc("", d)
  99. s.UnbindFunc("", d)
  100. s.CloseFunc("", d)
  101. s.Stats = nil
  102. return s
  103. }
  104. func (server *Server) BindFunc(baseDN string, f Binder) {
  105. server.BindFns[baseDN] = f
  106. }
  107. func (server *Server) SearchFunc(baseDN string, f Searcher) {
  108. server.SearchFns[baseDN] = f
  109. }
  110. func (server *Server) AddFunc(baseDN string, f Adder) {
  111. server.AddFns[baseDN] = f
  112. }
  113. func (server *Server) ModifyFunc(baseDN string, f Modifier) {
  114. server.ModifyFns[baseDN] = f
  115. }
  116. func (server *Server) DeleteFunc(baseDN string, f Deleter) {
  117. server.DeleteFns[baseDN] = f
  118. }
  119. func (server *Server) ModifyDNFunc(baseDN string, f ModifyDNr) {
  120. server.ModifyDNFns[baseDN] = f
  121. }
  122. func (server *Server) CompareFunc(baseDN string, f Comparer) {
  123. server.CompareFns[baseDN] = f
  124. }
  125. func (server *Server) AbandonFunc(baseDN string, f Abandoner) {
  126. server.AbandonFns[baseDN] = f
  127. }
  128. func (server *Server) ExtendedFunc(baseDN string, f Extender) {
  129. server.ExtendedFns[baseDN] = f
  130. }
  131. func (server *Server) UnbindFunc(baseDN string, f Unbinder) {
  132. server.UnbindFns[baseDN] = f
  133. }
  134. func (server *Server) CloseFunc(baseDN string, f Closer) {
  135. server.CloseFns[baseDN] = f
  136. }
  137. func (server *Server) QuitChannel(quit chan bool) {
  138. server.Quit = quit
  139. }
  140. func (server *Server) ListenAndServeTLS(listenString string, certFile string, keyFile string) error {
  141. cert, err := tls.LoadX509KeyPair(certFile, keyFile)
  142. if err != nil {
  143. return err
  144. }
  145. tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}}
  146. tlsConfig.ServerName = "localhost"
  147. ln, err := tls.Listen("tcp", listenString, &tlsConfig)
  148. if err != nil {
  149. return err
  150. }
  151. err = server.Serve(ln)
  152. if err != nil {
  153. return err
  154. }
  155. return nil
  156. }
  157. func (server *Server) SetStats(enable bool) {
  158. if enable {
  159. server.Stats = &Stats{}
  160. } else {
  161. server.Stats = nil
  162. }
  163. }
  164. func (server *Server) GetStats() Stats {
  165. defer func() {
  166. server.Stats.statsMutex.Unlock()
  167. }()
  168. server.Stats.statsMutex.Lock()
  169. return *server.Stats
  170. }
  171. func (server *Server) ListenAndServe(listenString string) error {
  172. ln, err := net.Listen("tcp", listenString)
  173. if err != nil {
  174. return err
  175. }
  176. err = server.Serve(ln)
  177. if err != nil {
  178. return err
  179. }
  180. return nil
  181. }
  182. func (server *Server) Serve(ln net.Listener) error {
  183. newConn := make(chan net.Conn)
  184. go func() {
  185. for {
  186. conn, err := ln.Accept()
  187. if err != nil {
  188. if !strings.HasSuffix(err.Error(), "use of closed network connection") {
  189. log.Printf("Error accepting network connection: %s", err.Error())
  190. }
  191. break
  192. }
  193. newConn <- conn
  194. }
  195. }()
  196. listener:
  197. for {
  198. select {
  199. case c := <-newConn:
  200. server.Stats.countConns(1)
  201. go server.handleConnection(c)
  202. case <-server.Quit:
  203. ln.Close()
  204. break listener
  205. }
  206. }
  207. return nil
  208. }
  209. //
  210. func (server *Server) handleConnection(conn net.Conn) {
  211. boundDN := "" // "" == anonymous
  212. handler:
  213. for {
  214. // read incoming LDAP packet
  215. packet, err := ber.ReadPacket(conn)
  216. log.Println(packet)
  217. if err == io.EOF { // Client closed connection
  218. break
  219. } else if err != nil {
  220. log.Printf("handleConnection ber.ReadPacket ERROR: %s", err.Error())
  221. break
  222. }
  223. // sanity check this packet
  224. if len(packet.Children) < 2 {
  225. log.Print("len(packet.Children) < 2")
  226. break
  227. }
  228. // check the message ID and ClassType
  229. messageID, ok := packet.Children[0].Value.(uint64)
  230. if !ok {
  231. log.Print("malformed messageID")
  232. break
  233. }
  234. req := packet.Children[1]
  235. if req.ClassType != ber.ClassApplication {
  236. log.Print("req.ClassType != ber.ClassApplication")
  237. break
  238. }
  239. // handle controls if present
  240. controls := []Control{}
  241. if len(packet.Children) > 2 {
  242. for _, child := range packet.Children[2].Children {
  243. controls = append(controls, DecodeControl(child))
  244. }
  245. }
  246. //log.Printf("DEBUG: handling operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
  247. //ber.PrintPacket(packet) // DEBUG
  248. // dispatch the LDAP operation
  249. switch req.Tag { // ldap op code
  250. default:
  251. responsePacket := encodeLDAPResponse(messageID, ApplicationAddResponse, LDAPResultOperationsError, "Unsupported operation: add")
  252. if err = sendPacket(conn, responsePacket); err != nil {
  253. log.Printf("sendPacket error %s", err.Error())
  254. }
  255. log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
  256. break handler
  257. case ApplicationBindRequest:
  258. server.Stats.countBinds(1)
  259. ldapResultCode := HandleBindRequest(req, server.BindFns, conn)
  260. if ldapResultCode == LDAPResultSuccess {
  261. boundDN, ok = req.Children[1].Value.(string)
  262. if !ok {
  263. log.Printf("Malformed Bind DN")
  264. break handler
  265. }
  266. }
  267. responsePacket := encodeBindResponse(messageID, ldapResultCode)
  268. if err = sendPacket(conn, responsePacket); err != nil {
  269. log.Printf("sendPacket error %s", err.Error())
  270. break handler
  271. }
  272. case ApplicationSearchRequest:
  273. server.Stats.countSearches(1)
  274. if err := HandleSearchRequest(req, &controls, messageID, boundDN, server, conn); err != nil {
  275. log.Printf("handleSearchRequest error %s", err.Error()) // TODO: make this more testable/better err handling - stop using log, stop using breaks?
  276. e := err.(*Error)
  277. if err = sendPacket(conn, encodeSearchDone(messageID, e.ResultCode)); err != nil {
  278. log.Printf("sendPacket error %s", err.Error())
  279. break handler
  280. }
  281. break handler
  282. } else {
  283. if err = sendPacket(conn, encodeSearchDone(messageID, LDAPResultSuccess)); err != nil {
  284. log.Printf("sendPacket error %s", err.Error())
  285. break handler
  286. }
  287. }
  288. case ApplicationUnbindRequest:
  289. server.Stats.countUnbinds(1)
  290. break handler // simply disconnect
  291. case ApplicationExtendedRequest:
  292. ldapResultCode := HandleExtendedRequest(req, boundDN, server.ExtendedFns, conn)
  293. responsePacket := encodeLDAPResponse(messageID, ApplicationExtendedResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
  294. if err = sendPacket(conn, responsePacket); err != nil {
  295. log.Printf("sendPacket error %s", err.Error())
  296. break handler
  297. }
  298. case ApplicationAbandonRequest:
  299. HandleAbandonRequest(req, boundDN, server.AbandonFns, conn)
  300. break handler
  301. case ApplicationAddRequest:
  302. ldapResultCode := HandleAddRequest(req, boundDN, server.AddFns, conn)
  303. responsePacket := encodeLDAPResponse(messageID, ApplicationAddResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
  304. if err = sendPacket(conn, responsePacket); err != nil {
  305. log.Printf("sendPacket error %s", err.Error())
  306. break handler
  307. }
  308. case ApplicationModifyRequest:
  309. ldapResultCode := HandleModifyRequest(req, boundDN, server.ModifyFns, conn)
  310. responsePacket := encodeLDAPResponse(messageID, ApplicationModifyResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
  311. if err = sendPacket(conn, responsePacket); err != nil {
  312. log.Printf("sendPacket error %s", err.Error())
  313. break handler
  314. }
  315. case ApplicationDelRequest:
  316. ldapResultCode := HandleDeleteRequest(req, boundDN, server.DeleteFns, conn)
  317. responsePacket := encodeLDAPResponse(messageID, ApplicationDelResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
  318. if err = sendPacket(conn, responsePacket); err != nil {
  319. log.Printf("sendPacket error %s", err.Error())
  320. break handler
  321. }
  322. case ApplicationModifyDNRequest:
  323. ldapResultCode := HandleModifyDNRequest(req, boundDN, server.ModifyDNFns, conn)
  324. responsePacket := encodeLDAPResponse(messageID, ApplicationModifyDNResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
  325. if err = sendPacket(conn, responsePacket); err != nil {
  326. log.Printf("sendPacket error %s", err.Error())
  327. break handler
  328. }
  329. case ApplicationCompareRequest:
  330. ldapResultCode := HandleCompareRequest(req, boundDN, server.CompareFns, conn)
  331. responsePacket := encodeLDAPResponse(messageID, ApplicationCompareResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
  332. if err = sendPacket(conn, responsePacket); err != nil {
  333. log.Printf("sendPacket error %s", err.Error())
  334. break handler
  335. }
  336. }
  337. }
  338. for _, c := range server.CloseFns {
  339. c.Close(boundDN, conn)
  340. }
  341. conn.Close()
  342. }
  343. //
  344. func sendPacket(conn net.Conn, packet *ber.Packet) error {
  345. _, err := conn.Write(packet.Bytes())
  346. if err != nil {
  347. log.Printf("Error Sending Message: %s", err.Error())
  348. return err
  349. }
  350. return nil
  351. }
  352. //
  353. func routeFunc(dn string, funcNames []string) string {
  354. bestPick := ""
  355. for _, fn := range funcNames {
  356. if strings.HasSuffix(dn, fn) {
  357. l := len(strings.Split(bestPick, ","))
  358. if bestPick == "" {
  359. l = 0
  360. }
  361. if len(strings.Split(fn, ",")) > l {
  362. bestPick = fn
  363. }
  364. }
  365. }
  366. return bestPick
  367. }
  368. //
  369. func encodeLDAPResponse(messageID uint64, responseType uint8, ldapResultCode LDAPResultCode, message string) *ber.Packet {
  370. responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
  371. responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
  372. reponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, responseType, nil, ApplicationMap[responseType])
  373. reponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: "))
  374. reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
  375. reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, message, "errorMessage: "))
  376. responsePacket.AppendChild(reponse)
  377. return responsePacket
  378. }
  379. //
  380. type defaultHandler struct {
  381. }
  382. func (h defaultHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
  383. return LDAPResultInvalidCredentials, nil
  384. }
  385. func (h defaultHandler) Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error) {
  386. return ServerSearchResult{make([]*Entry, 0), []string{}, []Control{}, LDAPResultSuccess}, nil
  387. }
  388. func (h defaultHandler) Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) {
  389. return LDAPResultInsufficientAccessRights, nil
  390. }
  391. func (h defaultHandler) Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error) {
  392. return LDAPResultInsufficientAccessRights, nil
  393. }
  394. func (h defaultHandler) Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error) {
  395. return LDAPResultInsufficientAccessRights, nil
  396. }
  397. func (h defaultHandler) ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error) {
  398. return LDAPResultInsufficientAccessRights, nil
  399. }
  400. func (h defaultHandler) Compare(boundDN string, req CompareRequest, conn net.Conn) (LDAPResultCode, error) {
  401. return LDAPResultInsufficientAccessRights, nil
  402. }
  403. func (h defaultHandler) Abandon(boundDN string, conn net.Conn) error {
  404. return nil
  405. }
  406. func (h defaultHandler) Extended(boundDN string, req ExtendedRequest, conn net.Conn) (LDAPResultCode, error) {
  407. return LDAPResultProtocolError, nil
  408. }
  409. func (h defaultHandler) Unbind(boundDN string, conn net.Conn) (LDAPResultCode, error) {
  410. return LDAPResultSuccess, nil
  411. }
  412. func (h defaultHandler) Close(boundDN string, conn net.Conn) error {
  413. conn.Close()
  414. return nil
  415. }
  416. //
  417. func (stats *Stats) countConns(delta int) {
  418. if stats != nil {
  419. stats.statsMutex.Lock()
  420. stats.Conns += delta
  421. stats.statsMutex.Unlock()
  422. }
  423. }
  424. func (stats *Stats) countBinds(delta int) {
  425. if stats != nil {
  426. stats.statsMutex.Lock()
  427. stats.Binds += delta
  428. stats.statsMutex.Unlock()
  429. }
  430. }
  431. func (stats *Stats) countUnbinds(delta int) {
  432. if stats != nil {
  433. stats.statsMutex.Lock()
  434. stats.Unbinds += delta
  435. stats.statsMutex.Unlock()
  436. }
  437. }
  438. func (stats *Stats) countSearches(delta int) {
  439. if stats != nil {
  440. stats.statsMutex.Lock()
  441. stats.Searches += delta
  442. stats.statsMutex.Unlock()
  443. }
  444. }
  445. //