| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475 |
- package ldap
- import (
- "crypto/tls"
- "io"
- "log"
- "net"
- "strings"
- "sync"
- "github.com/nmcclain/asn1-ber"
- )
- type Binder interface {
- Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error)
- }
- type Searcher interface {
- Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error)
- }
- type Adder interface {
- Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error)
- }
- type Modifier interface {
- Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error)
- }
- type Deleter interface {
- Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error)
- }
- type ModifyDNr interface {
- ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error)
- }
- type Comparer interface {
- Compare(boundDN string, req CompareRequest, conn net.Conn) (LDAPResultCode, error)
- }
- type Abandoner interface {
- Abandon(boundDN string, conn net.Conn) error
- }
- type Extender interface {
- Extended(boundDN string, req ExtendedRequest, conn net.Conn) (LDAPResultCode, error)
- }
- type Unbinder interface {
- Unbind(boundDN string, conn net.Conn) (LDAPResultCode, error)
- }
- type Closer interface {
- Close(boundDN string, conn net.Conn) error
- }
- //
- type Server struct {
- BindFns map[string]Binder
- SearchFns map[string]Searcher
- AddFns map[string]Adder
- ModifyFns map[string]Modifier
- DeleteFns map[string]Deleter
- ModifyDNFns map[string]ModifyDNr
- CompareFns map[string]Comparer
- AbandonFns map[string]Abandoner
- ExtendedFns map[string]Extender
- UnbindFns map[string]Unbinder
- CloseFns map[string]Closer
- Quit chan bool
- EnforceLDAP bool
- Stats *Stats
- }
- type Stats struct {
- Conns int
- Binds int
- Unbinds int
- Searches int
- statsMutex sync.Mutex
- }
- type ServerSearchResult struct {
- Entries []*Entry
- Referrals []string
- Controls []Control
- ResultCode LDAPResultCode
- }
- //
- func NewServer() *Server {
- s := new(Server)
- s.Quit = make(chan bool)
- d := defaultHandler{}
- s.BindFns = make(map[string]Binder)
- s.SearchFns = make(map[string]Searcher)
- s.AddFns = make(map[string]Adder)
- s.ModifyFns = make(map[string]Modifier)
- s.DeleteFns = make(map[string]Deleter)
- s.ModifyDNFns = make(map[string]ModifyDNr)
- s.CompareFns = make(map[string]Comparer)
- s.AbandonFns = make(map[string]Abandoner)
- s.ExtendedFns = make(map[string]Extender)
- s.UnbindFns = make(map[string]Unbinder)
- s.CloseFns = make(map[string]Closer)
- s.BindFunc("", d)
- s.SearchFunc("", d)
- s.AddFunc("", d)
- s.ModifyFunc("", d)
- s.DeleteFunc("", d)
- s.ModifyDNFunc("", d)
- s.CompareFunc("", d)
- s.AbandonFunc("", d)
- s.ExtendedFunc("", d)
- s.UnbindFunc("", d)
- s.CloseFunc("", d)
- s.Stats = nil
- return s
- }
- func (server *Server) BindFunc(baseDN string, f Binder) {
- server.BindFns[baseDN] = f
- }
- func (server *Server) SearchFunc(baseDN string, f Searcher) {
- server.SearchFns[baseDN] = f
- }
- func (server *Server) AddFunc(baseDN string, f Adder) {
- server.AddFns[baseDN] = f
- }
- func (server *Server) ModifyFunc(baseDN string, f Modifier) {
- server.ModifyFns[baseDN] = f
- }
- func (server *Server) DeleteFunc(baseDN string, f Deleter) {
- server.DeleteFns[baseDN] = f
- }
- func (server *Server) ModifyDNFunc(baseDN string, f ModifyDNr) {
- server.ModifyDNFns[baseDN] = f
- }
- func (server *Server) CompareFunc(baseDN string, f Comparer) {
- server.CompareFns[baseDN] = f
- }
- func (server *Server) AbandonFunc(baseDN string, f Abandoner) {
- server.AbandonFns[baseDN] = f
- }
- func (server *Server) ExtendedFunc(baseDN string, f Extender) {
- server.ExtendedFns[baseDN] = f
- }
- func (server *Server) UnbindFunc(baseDN string, f Unbinder) {
- server.UnbindFns[baseDN] = f
- }
- func (server *Server) CloseFunc(baseDN string, f Closer) {
- server.CloseFns[baseDN] = f
- }
- func (server *Server) QuitChannel(quit chan bool) {
- server.Quit = quit
- }
- func (server *Server) ListenAndServeTLS(listenString string, certFile string, keyFile string) error {
- cert, err := tls.LoadX509KeyPair(certFile, keyFile)
- if err != nil {
- return err
- }
- tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}}
- tlsConfig.ServerName = "localhost"
- ln, err := tls.Listen("tcp", listenString, &tlsConfig)
- if err != nil {
- return err
- }
- err = server.Serve(ln)
- if err != nil {
- return err
- }
- return nil
- }
- func (server *Server) SetStats(enable bool) {
- if enable {
- server.Stats = &Stats{}
- } else {
- server.Stats = nil
- }
- }
- func (server *Server) GetStats() Stats {
- defer func() {
- server.Stats.statsMutex.Unlock()
- }()
- server.Stats.statsMutex.Lock()
- return *server.Stats
- }
- func (server *Server) ListenAndServe(listenString string) error {
- ln, err := net.Listen("tcp", listenString)
- if err != nil {
- return err
- }
- err = server.Serve(ln)
- if err != nil {
- return err
- }
- return nil
- }
- func (server *Server) Serve(ln net.Listener) error {
- newConn := make(chan net.Conn)
- go func() {
- for {
- conn, err := ln.Accept()
- if err != nil {
- if !strings.HasSuffix(err.Error(), "use of closed network connection") {
- log.Printf("Error accepting network connection: %s", err.Error())
- }
- break
- }
- newConn <- conn
- }
- }()
- listener:
- for {
- select {
- case c := <-newConn:
- server.Stats.countConns(1)
- go server.handleConnection(c)
- case <-server.Quit:
- ln.Close()
- break listener
- }
- }
- return nil
- }
- //
- func (server *Server) handleConnection(conn net.Conn) {
- boundDN := "" // "" == anonymous
- handler:
- for {
- // read incoming LDAP packet
- packet, err := ber.ReadPacket(conn)
- log.Println(packet)
- if err == io.EOF { // Client closed connection
- break
- } else if err != nil {
- log.Printf("handleConnection ber.ReadPacket ERROR: %s", err.Error())
- break
- }
- // sanity check this packet
- if len(packet.Children) < 2 {
- log.Print("len(packet.Children) < 2")
- break
- }
- // check the message ID and ClassType
- messageID, ok := packet.Children[0].Value.(uint64)
- if !ok {
- log.Print("malformed messageID")
- break
- }
- req := packet.Children[1]
- if req.ClassType != ber.ClassApplication {
- log.Print("req.ClassType != ber.ClassApplication")
- break
- }
- // handle controls if present
- controls := []Control{}
- if len(packet.Children) > 2 {
- for _, child := range packet.Children[2].Children {
- controls = append(controls, DecodeControl(child))
- }
- }
- //log.Printf("DEBUG: handling operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
- //ber.PrintPacket(packet) // DEBUG
- // dispatch the LDAP operation
- switch req.Tag { // ldap op code
- default:
- responsePacket := encodeLDAPResponse(messageID, ApplicationAddResponse, LDAPResultOperationsError, "Unsupported operation: add")
- if err = sendPacket(conn, responsePacket); err != nil {
- log.Printf("sendPacket error %s", err.Error())
- }
- log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
- break handler
- case ApplicationBindRequest:
- server.Stats.countBinds(1)
- ldapResultCode := HandleBindRequest(req, server.BindFns, conn)
- if ldapResultCode == LDAPResultSuccess {
- boundDN, ok = req.Children[1].Value.(string)
- if !ok {
- log.Printf("Malformed Bind DN")
- break handler
- }
- }
- responsePacket := encodeBindResponse(messageID, ldapResultCode)
- if err = sendPacket(conn, responsePacket); err != nil {
- log.Printf("sendPacket error %s", err.Error())
- break handler
- }
- case ApplicationSearchRequest:
- server.Stats.countSearches(1)
- if err := HandleSearchRequest(req, &controls, messageID, boundDN, server, conn); err != nil {
- log.Printf("handleSearchRequest error %s", err.Error()) // TODO: make this more testable/better err handling - stop using log, stop using breaks?
- e := err.(*Error)
- if err = sendPacket(conn, encodeSearchDone(messageID, e.ResultCode)); err != nil {
- log.Printf("sendPacket error %s", err.Error())
- break handler
- }
- break handler
- } else {
- if err = sendPacket(conn, encodeSearchDone(messageID, LDAPResultSuccess)); err != nil {
- log.Printf("sendPacket error %s", err.Error())
- break handler
- }
- }
- case ApplicationUnbindRequest:
- server.Stats.countUnbinds(1)
- break handler // simply disconnect
- case ApplicationExtendedRequest:
- ldapResultCode := HandleExtendedRequest(req, boundDN, server.ExtendedFns, conn)
- responsePacket := encodeLDAPResponse(messageID, ApplicationExtendedResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
- if err = sendPacket(conn, responsePacket); err != nil {
- log.Printf("sendPacket error %s", err.Error())
- break handler
- }
- case ApplicationAbandonRequest:
- HandleAbandonRequest(req, boundDN, server.AbandonFns, conn)
- break handler
- case ApplicationAddRequest:
- ldapResultCode := HandleAddRequest(req, boundDN, server.AddFns, conn)
- responsePacket := encodeLDAPResponse(messageID, ApplicationAddResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
- if err = sendPacket(conn, responsePacket); err != nil {
- log.Printf("sendPacket error %s", err.Error())
- break handler
- }
- case ApplicationModifyRequest:
- ldapResultCode := HandleModifyRequest(req, boundDN, server.ModifyFns, conn)
- responsePacket := encodeLDAPResponse(messageID, ApplicationModifyResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
- if err = sendPacket(conn, responsePacket); err != nil {
- log.Printf("sendPacket error %s", err.Error())
- break handler
- }
- case ApplicationDelRequest:
- ldapResultCode := HandleDeleteRequest(req, boundDN, server.DeleteFns, conn)
- responsePacket := encodeLDAPResponse(messageID, ApplicationDelResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
- if err = sendPacket(conn, responsePacket); err != nil {
- log.Printf("sendPacket error %s", err.Error())
- break handler
- }
- case ApplicationModifyDNRequest:
- ldapResultCode := HandleModifyDNRequest(req, boundDN, server.ModifyDNFns, conn)
- responsePacket := encodeLDAPResponse(messageID, ApplicationModifyDNResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
- if err = sendPacket(conn, responsePacket); err != nil {
- log.Printf("sendPacket error %s", err.Error())
- break handler
- }
- case ApplicationCompareRequest:
- ldapResultCode := HandleCompareRequest(req, boundDN, server.CompareFns, conn)
- responsePacket := encodeLDAPResponse(messageID, ApplicationCompareResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
- if err = sendPacket(conn, responsePacket); err != nil {
- log.Printf("sendPacket error %s", err.Error())
- break handler
- }
- }
- }
- for _, c := range server.CloseFns {
- c.Close(boundDN, conn)
- }
- conn.Close()
- }
- //
- func sendPacket(conn net.Conn, packet *ber.Packet) error {
- _, err := conn.Write(packet.Bytes())
- if err != nil {
- log.Printf("Error Sending Message: %s", err.Error())
- return err
- }
- return nil
- }
- //
- func routeFunc(dn string, funcNames []string) string {
- bestPick := ""
- for _, fn := range funcNames {
- if strings.HasSuffix(dn, fn) {
- l := len(strings.Split(bestPick, ","))
- if bestPick == "" {
- l = 0
- }
- if len(strings.Split(fn, ",")) > l {
- bestPick = fn
- }
- }
- }
- return bestPick
- }
- //
- func encodeLDAPResponse(messageID uint64, responseType uint8, ldapResultCode LDAPResultCode, message string) *ber.Packet {
- responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
- responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
- reponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, responseType, nil, ApplicationMap[responseType])
- reponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: "))
- reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
- reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, message, "errorMessage: "))
- responsePacket.AppendChild(reponse)
- return responsePacket
- }
- //
- type defaultHandler struct {
- }
- func (h defaultHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
- return LDAPResultInvalidCredentials, nil
- }
- func (h defaultHandler) Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error) {
- return ServerSearchResult{make([]*Entry, 0), []string{}, []Control{}, LDAPResultSuccess}, nil
- }
- func (h defaultHandler) Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) {
- return LDAPResultInsufficientAccessRights, nil
- }
- func (h defaultHandler) Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error) {
- return LDAPResultInsufficientAccessRights, nil
- }
- func (h defaultHandler) Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error) {
- return LDAPResultInsufficientAccessRights, nil
- }
- func (h defaultHandler) ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error) {
- return LDAPResultInsufficientAccessRights, nil
- }
- func (h defaultHandler) Compare(boundDN string, req CompareRequest, conn net.Conn) (LDAPResultCode, error) {
- return LDAPResultInsufficientAccessRights, nil
- }
- func (h defaultHandler) Abandon(boundDN string, conn net.Conn) error {
- return nil
- }
- func (h defaultHandler) Extended(boundDN string, req ExtendedRequest, conn net.Conn) (LDAPResultCode, error) {
- return LDAPResultProtocolError, nil
- }
- func (h defaultHandler) Unbind(boundDN string, conn net.Conn) (LDAPResultCode, error) {
- return LDAPResultSuccess, nil
- }
- func (h defaultHandler) Close(boundDN string, conn net.Conn) error {
- conn.Close()
- return nil
- }
- //
- func (stats *Stats) countConns(delta int) {
- if stats != nil {
- stats.statsMutex.Lock()
- stats.Conns += delta
- stats.statsMutex.Unlock()
- }
- }
- func (stats *Stats) countBinds(delta int) {
- if stats != nil {
- stats.statsMutex.Lock()
- stats.Binds += delta
- stats.statsMutex.Unlock()
- }
- }
- func (stats *Stats) countUnbinds(delta int) {
- if stats != nil {
- stats.statsMutex.Lock()
- stats.Unbinds += delta
- stats.statsMutex.Unlock()
- }
- }
- func (stats *Stats) countSearches(delta int) {
- if stats != nil {
- stats.statsMutex.Lock()
- stats.Searches += delta
- stats.statsMutex.Unlock()
- }
- }
- //
|