zhangjq пре 6 година
родитељ
комит
5f19d6317f
19 измењених фајлова са 4307 додато и 0 уклоњено
  1. 99 0
      ldap/bind.go
  2. 340 0
      ldap/conn.go
  3. 160 0
      ldap/control.go
  4. 24 0
      ldap/debug.go
  5. 402 0
      ldap/filter.go
  6. 137 0
      ldap/filter_test.go
  7. 340 0
      ldap/ldap.go
  8. 123 0
      ldap/ldap_test.go
  9. 1 0
      ldap/light-ldap.go
  10. 162 0
      ldap/modify.go
  11. 350 0
      ldap/search.go
  12. 475 0
      ldap/server.go
  13. 73 0
      ldap/server_bind.go
  14. 232 0
      ldap/server_modify.go
  15. 191 0
      ldap/server_modify_test.go
  16. 218 0
      ldap/server_search.go
  17. 505 0
      ldap/server_search_test.go
  18. 410 0
      ldap/server_test.go
  19. 65 0
      utils/auth/ldap_auth.go

+ 99 - 0
ldap/bind.go

@@ -0,0 +1,99 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ldap
+
+import (
+	"errors"
+
+	"github.com/nmcclain/asn1-ber"
+)
+
+func (l *Conn) Bind(username, password string) error {
+	messageID := l.nextMessageID()
+
+	packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
+	packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
+	bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
+	bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
+	bindRequest.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, username, "User Name"))
+	bindRequest.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, password, "Password"))
+	packet.AppendChild(bindRequest)
+
+	if l.Debug {
+		ber.PrintPacket(packet)
+	}
+
+	channel, err := l.sendMessage(packet)
+	if err != nil {
+		return err
+	}
+	if channel == nil {
+		return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
+	}
+	defer l.finishMessage(messageID)
+
+	packet = <-channel
+	if packet == nil {
+		return NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
+	}
+
+	if l.Debug {
+		if err := addLDAPDescriptions(packet); err != nil {
+			return err
+		}
+		ber.PrintPacket(packet)
+	}
+
+	resultCode, resultDescription := getLDAPResultCode(packet)
+	if resultCode != 0 {
+		return NewError(resultCode, errors.New(resultDescription))
+	}
+
+	return nil
+}
+
+func (l *Conn) Unbind() error {
+  defer l.Close()
+
+  messageID := l.nextMessageID()
+
+  packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
+  packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
+  unbindRequest := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationUnbindRequest, nil, "Unbind Request")
+  packet.AppendChild(unbindRequest)
+
+  if l.Debug {
+    ber.PrintPacket(packet)
+  }
+
+  channel, err := l.sendMessage(packet)
+  if err != nil {
+    return err
+  }
+  if channel == nil {
+    return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
+  }
+  defer l.finishMessage(messageID)
+
+  packet = <-channel
+  if packet == nil {
+    return NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
+  }
+
+  if l.Debug {
+    if err := addLDAPDescriptions(packet); err != nil {
+      return err
+    }
+    ber.PrintPacket(packet)
+  }
+
+  resultCode, resultDescription := getLDAPResultCode(packet)
+  if resultCode != 0 {
+    return NewError(resultCode, errors.New(resultDescription))
+  }
+
+  return nil
+}
+

+ 340 - 0
ldap/conn.go

@@ -0,0 +1,340 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ldap
+
+import (
+	"crypto/tls"
+	"errors"
+	"log"
+	"net"
+	"sync"
+	"time"
+
+	"github.com/nmcclain/asn1-ber"
+)
+
+const (
+	MessageQuit     = 0
+	MessageRequest  = 1
+	MessageResponse = 2
+	MessageFinish   = 3
+)
+
+type messagePacket struct {
+	Op        int
+	MessageID uint64
+	Packet    *ber.Packet
+	Channel   chan *ber.Packet
+}
+
+// Conn represents an LDAP Connection
+type Conn struct {
+	conn          net.Conn
+	isTLS         bool
+	Debug         debugging
+	chanConfirm   chan bool
+	chanResults   map[uint64]chan *ber.Packet
+	chanMessage   chan *messagePacket
+	chanMessageID chan uint64
+	wgSender      sync.WaitGroup
+	chanDone      chan struct{}
+	once          sync.Once
+}
+
+// Dial connects to the given address on the given network using net.Dial
+// and then returns a new Conn for the connection.
+func Dial(network, addr string) (*Conn, error) {
+	c, err := net.Dial(network, addr)
+	if err != nil {
+		return nil, NewError(ErrorNetwork, err)
+	}
+	conn := NewConn(c)
+	conn.start()
+	return conn, nil
+}
+
+// DialTimeout connects to the given address on the given network using net.DialTimeout
+// and then returns a new Conn for the connection. Acts like Dial but takes a timeout.
+func DialTimeout(network, addr string, timeout time.Duration) (*Conn, error) {
+	c, err := net.DialTimeout(network, addr, timeout)
+	if err != nil {
+		return nil, NewError(ErrorNetwork, err)
+	}
+	conn := NewConn(c)
+	conn.start()
+	return conn, nil
+}
+
+// DialTLS connects to the given address on the given network using tls.Dial
+// and then returns a new Conn for the connection.
+func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
+	c, err := tls.Dial(network, addr, config)
+	if err != nil {
+		return nil, NewError(ErrorNetwork, err)
+	}
+	conn := NewConn(c)
+	conn.isTLS = true
+	conn.start()
+	return conn, nil
+}
+
+// DialTLSDialer connects to the given address on the given network using tls.DialWithDialer
+// and then returns a new Conn for the connection.
+func DialTLSDialer(network, addr string, config *tls.Config, dialer *net.Dialer) (*Conn, error) {
+	c, err := tls.DialWithDialer(dialer, network, addr, config)
+	if err != nil {
+		return nil, NewError(ErrorNetwork, err)
+	}
+	conn := NewConn(c)
+	conn.isTLS = true
+	conn.start()
+	return conn, nil
+}
+
+// NewConn returns a new Conn using conn for network I/O.
+func NewConn(conn net.Conn) *Conn {
+	return &Conn{
+		conn:          conn,
+		chanConfirm:   make(chan bool),
+		chanMessageID: make(chan uint64),
+		chanMessage:   make(chan *messagePacket, 10),
+		chanResults:   map[uint64]chan *ber.Packet{},
+		chanDone:      make(chan struct{}),
+	}
+}
+
+func (l *Conn) start() {
+	go l.reader()
+	go l.processMessages()
+}
+
+// Close closes the connection.
+func (l *Conn) Close() {
+	l.once.Do(func() {
+		close(l.chanDone)
+		l.wgSender.Wait()
+
+		l.Debug.Printf("Sending quit message and waiting for confirmation")
+		l.chanMessage <- &messagePacket{Op: MessageQuit}
+		<-l.chanConfirm
+		close(l.chanMessage)
+
+		l.Debug.Printf("Closing network connection")
+		if err := l.conn.Close(); err != nil {
+			log.Print(err)
+		}
+	})
+	<-l.chanDone
+}
+
+// Returns the next available messageID
+func (l *Conn) nextMessageID() uint64 {
+	if l.chanMessageID != nil {
+		if messageID, ok := <-l.chanMessageID; ok {
+			return messageID
+		}
+	}
+	return 0
+}
+
+// StartTLS sends the command to start a TLS session and then creates a new TLS Client
+func (l *Conn) StartTLS(config *tls.Config) error {
+	messageID := l.nextMessageID()
+
+	if l.isTLS {
+		return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
+	}
+
+	packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
+	packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
+	request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
+	request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
+	packet.AppendChild(request)
+	l.Debug.PrintPacket(packet)
+
+	_, err := l.conn.Write(packet.Bytes())
+	if err != nil {
+		return NewError(ErrorNetwork, err)
+	}
+
+	packet, err = ber.ReadPacket(l.conn)
+	if err != nil {
+		return NewError(ErrorNetwork, err)
+	}
+
+	if l.Debug {
+		if err := addLDAPDescriptions(packet); err != nil {
+			return err
+		}
+		ber.PrintPacket(packet)
+	}
+
+	if packet.Children[1].Children[0].Value.(uint64) == 0 {
+		conn := tls.Client(l.conn, config)
+		l.isTLS = true
+		l.conn = conn
+	}
+
+	return nil
+}
+
+func (l *Conn) closing() bool {
+	select {
+	case <-l.chanDone:
+		return true
+	default:
+		return false
+	}
+}
+
+func (l *Conn) sendMessage(packet *ber.Packet) (chan *ber.Packet, error) {
+	if l.closing() {
+		return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
+	}
+	out := make(chan *ber.Packet)
+	message := &messagePacket{
+		Op:        MessageRequest,
+		MessageID: packet.Children[0].Value.(uint64),
+		Packet:    packet,
+		Channel:   out,
+	}
+	l.sendProcessMessage(message)
+	return out, nil
+}
+
+func (l *Conn) finishMessage(messageID uint64) {
+	if l.closing() {
+		return
+	}
+	message := &messagePacket{
+		Op:        MessageFinish,
+		MessageID: messageID,
+	}
+	l.sendProcessMessage(message)
+}
+
+func (l *Conn) sendProcessMessage(message *messagePacket) bool {
+	l.wgSender.Add(1)
+	defer l.wgSender.Done()
+
+	if l.closing() {
+		return false
+	}
+	l.chanMessage <- message
+	return true
+}
+
+func (l *Conn) processMessages() {
+	defer func() {
+		for messageID, channel := range l.chanResults {
+			l.Debug.Printf("Closing channel for MessageID %d", messageID)
+			close(channel)
+			delete(l.chanResults, messageID)
+		}
+		close(l.chanMessageID)
+		l.chanConfirm <- true
+		close(l.chanConfirm)
+	}()
+
+	var messageID uint64 = 1
+	for {
+		select {
+		case l.chanMessageID <- messageID:
+			messageID++
+		case messagePacket, ok := <-l.chanMessage:
+			if !ok {
+				l.Debug.Printf("Shutting down - message channel is closed")
+				return
+			}
+			switch messagePacket.Op {
+			case MessageQuit:
+				l.Debug.Printf("Shutting down - quit message received")
+				return
+			case MessageRequest:
+				// Add to message list and write to network
+				l.Debug.Printf("Sending message %d", messagePacket.MessageID)
+				l.chanResults[messagePacket.MessageID] = messagePacket.Channel
+				// go routine
+				buf := messagePacket.Packet.Bytes()
+
+				_, err := l.conn.Write(buf)
+				if err != nil {
+					l.Debug.Printf("Error Sending Message: %s", err.Error())
+					break
+				}
+			case MessageResponse:
+				l.Debug.Printf("Receiving message %d", messagePacket.MessageID)
+				if chanResult, ok := l.chanResults[messagePacket.MessageID]; ok {
+					chanResult <- messagePacket.Packet
+				} else {
+					log.Printf("Received unexpected message %d", messagePacket.MessageID)
+					ber.PrintPacket(messagePacket.Packet)
+				}
+			case MessageFinish:
+				// Remove from message list
+				l.Debug.Printf("Finished message %d", messagePacket.MessageID)
+				close(l.chanResults[messagePacket.MessageID])
+				delete(l.chanResults, messagePacket.MessageID)
+			}
+		}
+	}
+}
+
+func (l *Conn) reader() {
+	defer func() {
+		l.Close()
+	}()
+
+	for {
+		packet, err := ber.ReadPacket(l.conn)
+		if err != nil {
+			l.Debug.Printf("reader: %s", err.Error())
+			return
+		}
+		addLDAPDescriptions(packet)
+		message := &messagePacket{
+			Op:        MessageResponse,
+			MessageID: packet.Children[0].Value.(uint64),
+			Packet:    packet,
+		}
+		if !l.sendProcessMessage(message) {
+			return
+		}
+
+	}
+}
+
+// Use Abandon operation to perform connection keepalives
+func (l *Conn) Ping() error {
+
+	messageID := l.nextMessageID()
+
+	packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
+	packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
+	abandonRequest := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationAbandonRequest, nil, "Abandon Request")
+	packet.AppendChild(abandonRequest)
+
+	if l.Debug {
+		ber.PrintPacket(packet)
+	}
+
+	channel, err := l.sendMessage(packet)
+	if err != nil {
+		return err
+	}
+	if channel == nil {
+		return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
+	}
+	defer l.finishMessage(messageID)
+
+	if l.Debug {
+		if err := addLDAPDescriptions(packet); err != nil {
+			return err
+		}
+		ber.PrintPacket(packet)
+	}
+
+	return nil
+}

+ 160 - 0
ldap/control.go

@@ -0,0 +1,160 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ldap
+
+import (
+	"strings"
+	"fmt"
+	"github.com/nmcclain/asn1-ber"
+)
+
+const (
+	ControlTypePaging = "1.2.840.113556.1.4.319"
+)
+
+var ControlTypeMap = map[string]string{
+	ControlTypePaging: "Paging",
+}
+
+type Control interface {
+	GetControlType() string
+	Encode() *ber.Packet
+	String() string
+}
+
+type ControlString struct {
+	ControlType  string
+	Criticality  bool
+	ControlValue string
+}
+
+func (c *ControlString) GetControlType() string {
+	return c.ControlType
+}
+
+func (c *ControlString) Encode() *ber.Packet {
+	packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
+	packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, c.ControlType, "Control Type ("+ControlTypeMap[c.ControlType]+")"))
+	if c.Criticality {
+		packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.Criticality, "Criticality"))
+	}
+	if strings.TrimSpace(c.ControlValue) != "" {
+		packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, c.ControlValue, "Control Value"))
+	}
+	return packet
+}
+
+func (c *ControlString) String() string {
+	return fmt.Sprintf("Control Type: %s (%q)  Criticality: %t  Control Value: %s", ControlTypeMap[c.ControlType], c.ControlType, c.Criticality, c.ControlValue)
+}
+
+type ControlPaging struct {
+	PagingSize uint32
+	Cookie     []byte
+}
+
+func (c *ControlPaging) GetControlType() string {
+	return ControlTypePaging
+}
+
+func (c *ControlPaging) Encode() *ber.Packet {
+	packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
+	packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypePaging, "Control Type ("+ControlTypeMap[ControlTypePaging]+")"))
+
+	p2 := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Control Value (Paging)")
+	seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Search Control Value")
+	seq.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(c.PagingSize), "Paging Size"))
+	cookie := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Cookie")
+	cookie.Value = c.Cookie
+	cookie.Data.Write(c.Cookie)
+	seq.AppendChild(cookie)
+	p2.AppendChild(seq)
+
+	packet.AppendChild(p2)
+	return packet
+}
+
+func (c *ControlPaging) String() string {
+	return fmt.Sprintf(
+		"Control Type: %s (%q)  Criticality: %t  PagingSize: %d  Cookie: %q",
+		ControlTypeMap[ControlTypePaging],
+		ControlTypePaging,
+		false,
+		c.PagingSize,
+		c.Cookie)
+}
+
+func (c *ControlPaging) SetCookie(cookie []byte) {
+	c.Cookie = cookie
+}
+
+func FindControl(controls []Control, controlType string) Control {
+	for _, c := range controls {
+		if c.GetControlType() == controlType {
+			return c
+		}
+	}
+	return nil
+}
+
+func DecodeControl(packet *ber.Packet) Control {
+	ControlType := packet.Children[0].Value.(string)
+	packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")"
+	c := new(ControlString)
+	c.ControlType = ControlType
+	c.Criticality = false
+
+	if len(packet.Children) > 1 {
+		value := packet.Children[1]
+		if len(packet.Children) == 3 {
+			value = packet.Children[2]
+			packet.Children[1].Description = "Criticality"
+			c.Criticality = packet.Children[1].Value.(bool)
+		}
+
+		value.Description = "Control Value"
+		switch ControlType {
+		case ControlTypePaging:
+			value.Description += " (Paging)"
+			c := new(ControlPaging)
+			if value.Value != nil {
+				valueChildren := ber.DecodePacket(value.Data.Bytes())
+				value.Data.Truncate(0)
+				value.Value = nil
+				value.AppendChild(valueChildren)
+			}
+			value = value.Children[0]
+			value.Description = "Search Control Value"
+			value.Children[0].Description = "Paging Size"
+			value.Children[1].Description = "Cookie"
+			c.PagingSize = uint32(value.Children[0].Value.(uint64))
+			c.Cookie = value.Children[1].Data.Bytes()
+			value.Children[1].Value = c.Cookie
+			return c
+		}
+		c.ControlValue = value.Value.(string)
+	}
+	return c
+}
+
+func NewControlString(controlType string, criticality bool, controlValue string) *ControlString {
+	return &ControlString{
+		ControlType:  controlType,
+		Criticality:  criticality,
+		ControlValue: controlValue,
+	}
+}
+
+func NewControlPaging(pagingSize uint32) *ControlPaging {
+	return &ControlPaging{PagingSize: pagingSize}
+}
+
+func encodeControls(controls []Control) *ber.Packet {
+	packet := ber.Encode(ber.ClassContext, ber.TypeConstructed, 0, nil, "Controls")
+	for _, control := range controls {
+		packet.AppendChild(control.Encode())
+	}
+	return packet
+}

+ 24 - 0
ldap/debug.go

@@ -0,0 +1,24 @@
+package ldap
+
+import (
+	"log"
+
+	"github.com/nmcclain/asn1-ber"
+)
+
+// debbuging type
+//     - has a Printf method to write the debug output
+type debugging bool
+
+// write debug output
+func (debug debugging) Printf(format string, args ...interface{}) {
+	if debug {
+		log.Printf(format, args...)
+	}
+}
+
+func (debug debugging) PrintPacket(packet *ber.Packet) {
+	if debug {
+		ber.PrintPacket(packet)
+	}
+}

+ 402 - 0
ldap/filter.go

@@ -0,0 +1,402 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ldap
+
+import (
+	"errors"
+	"fmt"
+	"github.com/nmcclain/asn1-ber"
+	"strings"
+)
+
+const (
+	FilterAnd             = 0
+	FilterOr              = 1
+	FilterNot             = 2
+	FilterEqualityMatch   = 3
+	FilterSubstrings      = 4
+	FilterGreaterOrEqual  = 5
+	FilterLessOrEqual     = 6
+	FilterPresent         = 7
+	FilterApproxMatch     = 8
+	FilterExtensibleMatch = 9
+)
+
+var FilterMap = map[uint8]string{
+	FilterAnd:             "And",
+	FilterOr:              "Or",
+	FilterNot:             "Not",
+	FilterEqualityMatch:   "Equality Match",
+	FilterSubstrings:      "Substrings",
+	FilterGreaterOrEqual:  "Greater Or Equal",
+	FilterLessOrEqual:     "Less Or Equal",
+	FilterPresent:         "Present",
+	FilterApproxMatch:     "Approx Match",
+	FilterExtensibleMatch: "Extensible Match",
+}
+
+const (
+	FilterSubstringsInitial = 0
+	FilterSubstringsAny     = 1
+	FilterSubstringsFinal   = 2
+)
+
+func CompileFilter(filter string) (*ber.Packet, error) {
+	if len(filter) == 0 || filter[0] != '(' {
+		return nil, NewError(ErrorFilterCompile, errors.New("ldap: filter does not start with an '('"))
+	}
+	packet, pos, err := compileFilter(filter, 1)
+	if err != nil {
+		return nil, err
+	}
+	if pos != len(filter) {
+		return nil, NewError(ErrorFilterCompile, errors.New("ldap: finished compiling filter with extra at end: "+fmt.Sprint(filter[pos:])))
+	}
+	return packet, nil
+}
+
+func DecompileFilter(packet *ber.Packet) (ret string, err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			err = NewError(ErrorFilterDecompile, errors.New("ldap: error decompiling filter"))
+		}
+	}()
+	ret = "("
+	err = nil
+	childStr := ""
+
+	switch packet.Tag {
+	case FilterAnd:
+		ret += "&"
+		for _, child := range packet.Children {
+			childStr, err = DecompileFilter(child)
+			if err != nil {
+				return
+			}
+			ret += childStr
+		}
+	case FilterOr:
+		ret += "|"
+		for _, child := range packet.Children {
+			childStr, err = DecompileFilter(child)
+			if err != nil {
+				return
+			}
+			ret += childStr
+		}
+	case FilterNot:
+		ret += "!"
+		childStr, err = DecompileFilter(packet.Children[0])
+		if err != nil {
+			return
+		}
+		ret += childStr
+
+	case FilterSubstrings:
+		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
+		ret += "="
+		switch packet.Children[1].Children[0].Tag {
+		case FilterSubstringsInitial:
+			ret += ber.DecodeString(packet.Children[1].Children[0].Data.Bytes()) + "*"
+		case FilterSubstringsAny:
+			ret += "*" + ber.DecodeString(packet.Children[1].Children[0].Data.Bytes()) + "*"
+		case FilterSubstringsFinal:
+			ret += "*" + ber.DecodeString(packet.Children[1].Children[0].Data.Bytes())
+		}
+	case FilterEqualityMatch:
+		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
+		ret += "="
+		ret += ber.DecodeString(packet.Children[1].Data.Bytes())
+	case FilterGreaterOrEqual:
+		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
+		ret += ">="
+		ret += ber.DecodeString(packet.Children[1].Data.Bytes())
+	case FilterLessOrEqual:
+		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
+		ret += "<="
+		ret += ber.DecodeString(packet.Children[1].Data.Bytes())
+	case FilterPresent:
+		ret += ber.DecodeString(packet.Data.Bytes())
+		ret += "=*"
+	case FilterApproxMatch:
+		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
+		ret += "~="
+		ret += ber.DecodeString(packet.Children[1].Data.Bytes())
+	}
+
+	ret += ")"
+	return
+}
+
+func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, error) {
+	for pos < len(filter) && filter[pos] == '(' {
+		child, newPos, err := compileFilter(filter, pos+1)
+		if err != nil {
+			return pos, err
+		}
+		pos = newPos
+		parent.AppendChild(child)
+	}
+	if pos == len(filter) {
+		return pos, NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
+	}
+
+	return pos + 1, nil
+}
+
+func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
+	var packet *ber.Packet
+	var err error
+
+	defer func() {
+		if r := recover(); r != nil {
+			err = NewError(ErrorFilterCompile, errors.New("ldap: error compiling filter"))
+		}
+	}()
+
+	newPos := pos
+	switch filter[pos] {
+	case '(':
+		packet, newPos, err = compileFilter(filter, pos+1)
+		newPos++
+		return packet, newPos, err
+	case '&':
+		packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd])
+		newPos, err = compileFilterSet(filter, pos+1, packet)
+		return packet, newPos, err
+	case '|':
+		packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr])
+		newPos, err = compileFilterSet(filter, pos+1, packet)
+		return packet, newPos, err
+	case '!':
+		packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot])
+		var child *ber.Packet
+		child, newPos, err = compileFilter(filter, pos+1)
+		packet.AppendChild(child)
+		return packet, newPos, err
+	default:
+		attribute := ""
+		condition := ""
+		for newPos < len(filter) && filter[newPos] != ')' {
+			switch {
+			case packet != nil:
+				condition += fmt.Sprintf("%c", filter[newPos])
+			case filter[newPos] == '=':
+				packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch])
+			case filter[newPos] == '>' && filter[newPos+1] == '=':
+				packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual])
+				newPos++
+			case filter[newPos] == '<' && filter[newPos+1] == '=':
+				packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual])
+				newPos++
+			case filter[newPos] == '~' && filter[newPos+1] == '=':
+				packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterLessOrEqual])
+				newPos++
+			case packet == nil:
+				attribute += fmt.Sprintf("%c", filter[newPos])
+			}
+			newPos++
+		}
+		if newPos == len(filter) {
+			err = NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
+			return packet, newPos, err
+		}
+		if packet == nil {
+			err = NewError(ErrorFilterCompile, errors.New("ldap: error parsing filter"))
+			return packet, newPos, err
+		}
+		// Handle FilterEqualityMatch as a separate case (is primitive, not constructed like the other filters)
+		if packet.Tag == FilterEqualityMatch && condition == "*" {
+			packet.TagType = ber.TypePrimitive
+			packet.Tag = FilterPresent
+			packet.Description = FilterMap[packet.Tag]
+			packet.Data.WriteString(attribute)
+			return packet, newPos + 1, nil
+		}
+		packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
+		switch {
+		case packet.Tag == FilterEqualityMatch && condition[0] == '*' && condition[len(condition)-1] == '*':
+			// Any
+			packet.Tag = FilterSubstrings
+			packet.Description = FilterMap[packet.Tag]
+			seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
+			seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsAny, condition[1:len(condition)-1], "Any Substring"))
+			packet.AppendChild(seq)
+		case packet.Tag == FilterEqualityMatch && condition[0] == '*':
+			// Final
+			packet.Tag = FilterSubstrings
+			packet.Description = FilterMap[packet.Tag]
+			seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
+			seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsFinal, condition[1:], "Final Substring"))
+			packet.AppendChild(seq)
+		case packet.Tag == FilterEqualityMatch && condition[len(condition)-1] == '*':
+			// Initial
+			packet.Tag = FilterSubstrings
+			packet.Description = FilterMap[packet.Tag]
+			seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
+			seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsInitial, condition[:len(condition)-1], "Initial Substring"))
+			packet.AppendChild(seq)
+		default:
+			packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, condition, "Condition"))
+		}
+		newPos++
+		return packet, newPos, err
+	}
+}
+
+func ServerApplyFilter(f *ber.Packet, entry *Entry) (bool, LDAPResultCode) {
+	switch FilterMap[f.Tag] {
+	default:
+		//log.Fatalf("Unknown LDAP filter code: %d", f.Tag)
+		return false, LDAPResultOperationsError
+	case "Equality Match":
+		if len(f.Children) != 2 {
+			return false, LDAPResultOperationsError
+		}
+		attribute := f.Children[0].Value.(string)
+		value := f.Children[1].Value.(string)
+		for _, a := range entry.Attributes {
+			if strings.ToLower(a.Name) == strings.ToLower(attribute) {
+				for _, v := range a.Values {
+					if strings.ToLower(v) == strings.ToLower(value) {
+						return true, LDAPResultSuccess
+					}
+				}
+			}
+		}
+	case "Present":
+		for _, a := range entry.Attributes {
+			if strings.ToLower(a.Name) == strings.ToLower(f.Data.String()) {
+				return true, LDAPResultSuccess
+			}
+		}
+	case "And":
+		for _, child := range f.Children {
+			ok, exitCode := ServerApplyFilter(child, entry)
+			if exitCode != LDAPResultSuccess {
+				return false, exitCode
+			}
+			if !ok {
+				return false, LDAPResultSuccess
+			}
+		}
+		return true, LDAPResultSuccess
+	case "Or":
+		anyOk := false
+		for _, child := range f.Children {
+			ok, exitCode := ServerApplyFilter(child, entry)
+			if exitCode != LDAPResultSuccess {
+				return false, exitCode
+			} else if ok {
+				anyOk = true
+			}
+		}
+		if anyOk {
+			return true, LDAPResultSuccess
+		}
+	case "Not":
+		if len(f.Children) != 1 {
+			return false, LDAPResultOperationsError
+		}
+		ok, exitCode := ServerApplyFilter(f.Children[0], entry)
+		if exitCode != LDAPResultSuccess {
+			return false, exitCode
+		} else if !ok {
+			return true, LDAPResultSuccess
+		}
+	case "Substrings":
+		if len(f.Children) != 2 {
+			return false, LDAPResultOperationsError
+		}
+		attribute := f.Children[0].Value.(string)
+		bytes := f.Children[1].Children[0].Data.Bytes()
+		value := string(bytes[:])
+		for _, a := range entry.Attributes {
+			if strings.ToLower(a.Name) == strings.ToLower(attribute) {
+				for _, v := range a.Values {
+					switch f.Children[1].Children[0].Tag {
+					case FilterSubstringsInitial:
+						if strings.HasPrefix(v, value) {
+							return true, LDAPResultSuccess
+						}
+					case FilterSubstringsAny:
+						if strings.Contains(v, value) {
+							return true, LDAPResultSuccess
+						}
+					case FilterSubstringsFinal:
+						if strings.HasSuffix(v, value) {
+							return true, LDAPResultSuccess
+						}
+					}
+				}
+			}
+		}
+	case "FilterGreaterOrEqual": // TODO
+		return false, LDAPResultOperationsError
+	case "FilterLessOrEqual": // TODO
+		return false, LDAPResultOperationsError
+	case "FilterApproxMatch": // TODO
+		return false, LDAPResultOperationsError
+	case "FilterExtensibleMatch": // TODO
+		return false, LDAPResultOperationsError
+	}
+
+	return false, LDAPResultSuccess
+}
+
+func GetFilterObjectClass(filter string) (string, error) {
+	f, err := CompileFilter(filter)
+	if err != nil {
+		return "", err
+	}
+	return parseFilterObjectClass(f)
+}
+func parseFilterObjectClass(f *ber.Packet) (string, error) {
+	objectClass := ""
+	switch FilterMap[f.Tag] {
+	case "Equality Match":
+		if len(f.Children) != 2 {
+			return "", errors.New("Equality match must have only two children")
+		}
+		attribute := strings.ToLower(f.Children[0].Value.(string))
+		value := f.Children[1].Value.(string)
+		if attribute == "objectclass" {
+			objectClass = strings.ToLower(value)
+		}
+	case "And":
+		for _, child := range f.Children {
+			subType, err := parseFilterObjectClass(child)
+			if err != nil {
+				return "", err
+			}
+			if len(subType) > 0 {
+				objectClass = subType
+			}
+		}
+	case "Or":
+		for _, child := range f.Children {
+			subType, err := parseFilterObjectClass(child)
+			if err != nil {
+				return "", err
+			}
+			if len(subType) > 0 {
+				objectClass = subType
+			}
+		}
+	case "Not":
+		if len(f.Children) != 1 {
+			return "", errors.New("Not filter must have only one child")
+		}
+		subType, err := parseFilterObjectClass(f.Children[0])
+		if err != nil {
+			return "", err
+		}
+		if len(subType) > 0 {
+			objectClass = subType
+		}
+
+	}
+	return strings.ToLower(objectClass), nil
+}

+ 137 - 0
ldap/filter_test.go

@@ -0,0 +1,137 @@
+package ldap
+
+import (
+	"reflect"
+	"testing"
+
+	"github.com/nmcclain/asn1-ber"
+)
+
+type compileTest struct {
+	filterStr  string
+	filterType uint8
+}
+
+var testFilters = []compileTest{
+	compileTest{filterStr: "(&(sn=Miller)(givenName=Bob))", filterType: FilterAnd},
+	compileTest{filterStr: "(|(sn=Miller)(givenName=Bob))", filterType: FilterOr},
+	compileTest{filterStr: "(!(sn=Miller))", filterType: FilterNot},
+	compileTest{filterStr: "(sn=Miller)", filterType: FilterEqualityMatch},
+	compileTest{filterStr: "(sn=Mill*)", filterType: FilterSubstrings},
+	compileTest{filterStr: "(sn=*Mill)", filterType: FilterSubstrings},
+	compileTest{filterStr: "(sn=*Mill*)", filterType: FilterSubstrings},
+	compileTest{filterStr: "(sn>=Miller)", filterType: FilterGreaterOrEqual},
+	compileTest{filterStr: "(sn<=Miller)", filterType: FilterLessOrEqual},
+	compileTest{filterStr: "(sn=*)", filterType: FilterPresent},
+	compileTest{filterStr: "(sn~=Miller)", filterType: FilterApproxMatch},
+	// compileTest{ filterStr: "()", filterType: FilterExtensibleMatch },
+}
+
+func TestFilter(t *testing.T) {
+	// Test Compiler and Decompiler
+	for _, i := range testFilters {
+		filter, err := CompileFilter(i.filterStr)
+		if err != nil {
+			t.Errorf("Problem compiling %s - %s", i.filterStr, err.Error())
+		} else if filter.Tag != uint8(i.filterType) {
+			t.Errorf("%q Expected %q got %q", i.filterStr, FilterMap[i.filterType], FilterMap[filter.Tag])
+		} else {
+			o, err := DecompileFilter(filter)
+			if err != nil {
+				t.Errorf("Problem compiling %s - %s", i.filterStr, err.Error())
+			} else if i.filterStr != o {
+				t.Errorf("%q expected, got %q", i.filterStr, o)
+			}
+		}
+	}
+}
+
+type binTestFilter struct {
+	bin []byte
+	str string
+}
+
+var binTestFilters = []binTestFilter{
+	{bin: []byte{0x87, 0x06, 0x6d, 0x65, 0x6d, 0x62, 0x65, 0x72}, str: "(member=*)"},
+}
+
+func TestFiltersDecode(t *testing.T) {
+	for i, test := range binTestFilters {
+		p := ber.DecodePacket(test.bin)
+		if filter, err := DecompileFilter(p); err != nil {
+			t.Errorf("binTestFilters[%d], DecompileFilter returned : %s", i, err)
+		} else if filter != test.str {
+			t.Errorf("binTestFilters[%d], %q expected, got %q", i, test.str, filter)
+		}
+	}
+}
+
+func TestFiltersEncode(t *testing.T) {
+	for i, test := range binTestFilters {
+		p, err := CompileFilter(test.str)
+		if err != nil {
+			t.Errorf("binTestFilters[%d], CompileFilter returned : %s", i, err)
+			continue
+		}
+		b := p.Bytes()
+		if !reflect.DeepEqual(b, test.bin) {
+			t.Errorf("binTestFilters[%d], %q expected for CompileFilter(%q), got %q", i, test.bin, test.str, b)
+		}
+	}
+}
+
+func BenchmarkFilterCompile(b *testing.B) {
+	b.StopTimer()
+	filters := make([]string, len(testFilters))
+
+	// Test Compiler and Decompiler
+	for idx, i := range testFilters {
+		filters[idx] = i.filterStr
+	}
+
+	maxIdx := len(filters)
+	b.StartTimer()
+	for i := 0; i < b.N; i++ {
+		CompileFilter(filters[i%maxIdx])
+	}
+}
+
+func BenchmarkFilterDecompile(b *testing.B) {
+	b.StopTimer()
+	filters := make([]*ber.Packet, len(testFilters))
+
+	// Test Compiler and Decompiler
+	for idx, i := range testFilters {
+		filters[idx], _ = CompileFilter(i.filterStr)
+	}
+
+	maxIdx := len(filters)
+	b.StartTimer()
+	for i := 0; i < b.N; i++ {
+		DecompileFilter(filters[i%maxIdx])
+	}
+}
+
+func TestGetFilterObjectClass(t *testing.T) {
+	c, err := GetFilterObjectClass("(objectClass=*)")
+	if err != nil {
+		t.Errorf("GetFilterObjectClass failed")
+	}
+	if c != "" {
+		t.Errorf("GetFilterObjectClass failed")
+	}
+	c, err = GetFilterObjectClass("(objectClass=posixAccount)")
+	if err != nil {
+		t.Errorf("GetFilterObjectClass failed")
+	}
+	if c != "posixaccount" {
+		t.Errorf("GetFilterObjectClass failed")
+	}
+	c, err = GetFilterObjectClass("(&(cn=awesome)(objectClass=posixGroup))")
+	if err != nil {
+		t.Errorf("GetFilterObjectClass failed")
+	}
+	if c != "posixgroup" {
+		t.Errorf("GetFilterObjectClass failed")
+	}
+}

+ 340 - 0
ldap/ldap.go

@@ -0,0 +1,340 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ldap
+
+import (
+	"errors"
+	"fmt"
+	"io/ioutil"
+
+	"github.com/nmcclain/asn1-ber"
+)
+
+// LDAP Application Codes
+const (
+	ApplicationBindRequest           = 0
+	ApplicationBindResponse          = 1
+	ApplicationUnbindRequest         = 2
+	ApplicationSearchRequest         = 3
+	ApplicationSearchResultEntry     = 4
+	ApplicationSearchResultDone      = 5
+	ApplicationModifyRequest         = 6
+	ApplicationModifyResponse        = 7
+	ApplicationAddRequest            = 8
+	ApplicationAddResponse           = 9
+	ApplicationDelRequest            = 10
+	ApplicationDelResponse           = 11
+	ApplicationModifyDNRequest       = 12
+	ApplicationModifyDNResponse      = 13
+	ApplicationCompareRequest        = 14
+	ApplicationCompareResponse       = 15
+	ApplicationAbandonRequest        = 16
+	ApplicationSearchResultReference = 19
+	ApplicationExtendedRequest       = 23
+	ApplicationExtendedResponse      = 24
+)
+
+var ApplicationMap = map[uint8]string{
+	ApplicationBindRequest:           "Bind Request",
+	ApplicationBindResponse:          "Bind Response",
+	ApplicationUnbindRequest:         "Unbind Request",
+	ApplicationSearchRequest:         "Search Request",
+	ApplicationSearchResultEntry:     "Search Result Entry",
+	ApplicationSearchResultDone:      "Search Result Done",
+	ApplicationModifyRequest:         "Modify Request",
+	ApplicationModifyResponse:        "Modify Response",
+	ApplicationAddRequest:            "Add Request",
+	ApplicationAddResponse:           "Add Response",
+	ApplicationDelRequest:            "Del Request",
+	ApplicationDelResponse:           "Del Response",
+	ApplicationModifyDNRequest:       "Modify DN Request",
+	ApplicationModifyDNResponse:      "Modify DN Response",
+	ApplicationCompareRequest:        "Compare Request",
+	ApplicationCompareResponse:       "Compare Response",
+	ApplicationAbandonRequest:        "Abandon Request",
+	ApplicationSearchResultReference: "Search Result Reference",
+	ApplicationExtendedRequest:       "Extended Request",
+	ApplicationExtendedResponse:      "Extended Response",
+}
+
+// LDAP Result Codes
+const (
+	LDAPResultSuccess                      = 0
+	LDAPResultOperationsError              = 1
+	LDAPResultProtocolError                = 2
+	LDAPResultTimeLimitExceeded            = 3
+	LDAPResultSizeLimitExceeded            = 4
+	LDAPResultCompareFalse                 = 5
+	LDAPResultCompareTrue                  = 6
+	LDAPResultAuthMethodNotSupported       = 7
+	LDAPResultStrongAuthRequired           = 8
+	LDAPResultReferral                     = 10
+	LDAPResultAdminLimitExceeded           = 11
+	LDAPResultUnavailableCriticalExtension = 12
+	LDAPResultConfidentialityRequired      = 13
+	LDAPResultSaslBindInProgress           = 14
+	LDAPResultNoSuchAttribute              = 16
+	LDAPResultUndefinedAttributeType       = 17
+	LDAPResultInappropriateMatching        = 18
+	LDAPResultConstraintViolation          = 19
+	LDAPResultAttributeOrValueExists       = 20
+	LDAPResultInvalidAttributeSyntax       = 21
+	LDAPResultNoSuchObject                 = 32
+	LDAPResultAliasProblem                 = 33
+	LDAPResultInvalidDNSyntax              = 34
+	LDAPResultAliasDereferencingProblem    = 36
+	LDAPResultInappropriateAuthentication  = 48
+	LDAPResultInvalidCredentials           = 49
+	LDAPResultInsufficientAccessRights     = 50
+	LDAPResultBusy                         = 51
+	LDAPResultUnavailable                  = 52
+	LDAPResultUnwillingToPerform           = 53
+	LDAPResultLoopDetect                   = 54
+	LDAPResultNamingViolation              = 64
+	LDAPResultObjectClassViolation         = 65
+	LDAPResultNotAllowedOnNonLeaf          = 66
+	LDAPResultNotAllowedOnRDN              = 67
+	LDAPResultEntryAlreadyExists           = 68
+	LDAPResultObjectClassModsProhibited    = 69
+	LDAPResultAffectsMultipleDSAs          = 71
+	LDAPResultOther                        = 80
+
+	ErrorNetwork         = 200
+	ErrorFilterCompile   = 201
+	ErrorFilterDecompile = 202
+	ErrorDebugging       = 203
+)
+
+var LDAPResultCodeMap = map[LDAPResultCode]string{
+	LDAPResultSuccess:                      "Success",
+	LDAPResultOperationsError:              "Operations Error",
+	LDAPResultProtocolError:                "Protocol Error",
+	LDAPResultTimeLimitExceeded:            "Time Limit Exceeded",
+	LDAPResultSizeLimitExceeded:            "Size Limit Exceeded",
+	LDAPResultCompareFalse:                 "Compare False",
+	LDAPResultCompareTrue:                  "Compare True",
+	LDAPResultAuthMethodNotSupported:       "Auth Method Not Supported",
+	LDAPResultStrongAuthRequired:           "Strong Auth Required",
+	LDAPResultReferral:                     "Referral",
+	LDAPResultAdminLimitExceeded:           "Admin Limit Exceeded",
+	LDAPResultUnavailableCriticalExtension: "Unavailable Critical Extension",
+	LDAPResultConfidentialityRequired:      "Confidentiality Required",
+	LDAPResultSaslBindInProgress:           "Sasl Bind In Progress",
+	LDAPResultNoSuchAttribute:              "No Such Attribute",
+	LDAPResultUndefinedAttributeType:       "Undefined Attribute Type",
+	LDAPResultInappropriateMatching:        "Inappropriate Matching",
+	LDAPResultConstraintViolation:          "Constraint Violation",
+	LDAPResultAttributeOrValueExists:       "Attribute Or Value Exists",
+	LDAPResultInvalidAttributeSyntax:       "Invalid Attribute Syntax",
+	LDAPResultNoSuchObject:                 "No Such Object",
+	LDAPResultAliasProblem:                 "Alias Problem",
+	LDAPResultInvalidDNSyntax:              "Invalid DN Syntax",
+	LDAPResultAliasDereferencingProblem:    "Alias Dereferencing Problem",
+	LDAPResultInappropriateAuthentication:  "Inappropriate Authentication",
+	LDAPResultInvalidCredentials:           "Invalid Credentials",
+	LDAPResultInsufficientAccessRights:     "Insufficient Access Rights",
+	LDAPResultBusy:                         "Busy",
+	LDAPResultUnavailable:                  "Unavailable",
+	LDAPResultUnwillingToPerform:           "Unwilling To Perform",
+	LDAPResultLoopDetect:                   "Loop Detect",
+	LDAPResultNamingViolation:              "Naming Violation",
+	LDAPResultObjectClassViolation:         "Object Class Violation",
+	LDAPResultNotAllowedOnNonLeaf:          "Not Allowed On Non Leaf",
+	LDAPResultNotAllowedOnRDN:              "Not Allowed On RDN",
+	LDAPResultEntryAlreadyExists:           "Entry Already Exists",
+	LDAPResultObjectClassModsProhibited:    "Object Class Mods Prohibited",
+	LDAPResultAffectsMultipleDSAs:          "Affects Multiple DSAs",
+	LDAPResultOther:                        "Other",
+}
+
+// Other LDAP constants
+const (
+	LDAPBindAuthSimple = 0
+	LDAPBindAuthSASL   = 3
+)
+
+type LDAPResultCode uint8
+
+type Attribute struct {
+	attrType string
+	attrVals []string
+}
+type AddRequest struct {
+	dn         string
+	attributes []Attribute
+}
+type DeleteRequest struct {
+	dn string
+}
+type ModifyDNRequest struct {
+	dn           string
+	newrdn       string
+	deleteoldrdn bool
+	newSuperior  string
+}
+type AttributeValueAssertion struct {
+	attributeDesc  string
+	assertionValue string
+}
+type CompareRequest struct {
+	dn  string
+	ava []AttributeValueAssertion
+}
+type ExtendedRequest struct {
+	requestName  string
+	requestValue string
+}
+
+// Adds descriptions to an LDAP Response packet for debugging
+func addLDAPDescriptions(packet *ber.Packet) (err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			err = NewError(ErrorDebugging, errors.New("ldap: cannot process packet to add descriptions"))
+		}
+	}()
+	packet.Description = "LDAP Response"
+	packet.Children[0].Description = "Message ID"
+
+	application := packet.Children[1].Tag
+	packet.Children[1].Description = ApplicationMap[application]
+
+	switch application {
+	case ApplicationBindRequest:
+		addRequestDescriptions(packet)
+	case ApplicationBindResponse:
+		addDefaultLDAPResponseDescriptions(packet)
+	case ApplicationUnbindRequest:
+		addRequestDescriptions(packet)
+	case ApplicationSearchRequest:
+		addRequestDescriptions(packet)
+	case ApplicationSearchResultEntry:
+		packet.Children[1].Children[0].Description = "Object Name"
+		packet.Children[1].Children[1].Description = "Attributes"
+		for _, child := range packet.Children[1].Children[1].Children {
+			child.Description = "Attribute"
+			child.Children[0].Description = "Attribute Name"
+			child.Children[1].Description = "Attribute Values"
+			for _, grandchild := range child.Children[1].Children {
+				grandchild.Description = "Attribute Value"
+			}
+		}
+		if len(packet.Children) == 3 {
+			addControlDescriptions(packet.Children[2])
+		}
+	case ApplicationSearchResultDone:
+		addDefaultLDAPResponseDescriptions(packet)
+	case ApplicationModifyRequest:
+		addRequestDescriptions(packet)
+	case ApplicationModifyResponse:
+	case ApplicationAddRequest:
+		addRequestDescriptions(packet)
+	case ApplicationAddResponse:
+	case ApplicationDelRequest:
+		addRequestDescriptions(packet)
+	case ApplicationDelResponse:
+	case ApplicationModifyDNRequest:
+		addRequestDescriptions(packet)
+	case ApplicationModifyDNResponse:
+	case ApplicationCompareRequest:
+		addRequestDescriptions(packet)
+	case ApplicationCompareResponse:
+	case ApplicationAbandonRequest:
+		addRequestDescriptions(packet)
+	case ApplicationSearchResultReference:
+	case ApplicationExtendedRequest:
+		addRequestDescriptions(packet)
+	case ApplicationExtendedResponse:
+	}
+
+	return nil
+}
+
+func addControlDescriptions(packet *ber.Packet) {
+	packet.Description = "Controls"
+	for _, child := range packet.Children {
+		child.Description = "Control"
+		child.Children[0].Description = "Control Type (" + ControlTypeMap[child.Children[0].Value.(string)] + ")"
+		value := child.Children[1]
+		if len(child.Children) == 3 {
+			child.Children[1].Description = "Criticality"
+			value = child.Children[2]
+		}
+		value.Description = "Control Value"
+
+		switch child.Children[0].Value.(string) {
+		case ControlTypePaging:
+			value.Description += " (Paging)"
+			if value.Value != nil {
+				valueChildren := ber.DecodePacket(value.Data.Bytes())
+				value.Data.Truncate(0)
+				value.Value = nil
+				valueChildren.Children[1].Value = valueChildren.Children[1].Data.Bytes()
+				value.AppendChild(valueChildren)
+			}
+			value.Children[0].Description = "Real Search Control Value"
+			value.Children[0].Children[0].Description = "Paging Size"
+			value.Children[0].Children[1].Description = "Cookie"
+		}
+	}
+}
+
+func addRequestDescriptions(packet *ber.Packet) {
+	packet.Description = "LDAP Request"
+	packet.Children[0].Description = "Message ID"
+	packet.Children[1].Description = ApplicationMap[packet.Children[1].Tag]
+	if len(packet.Children) == 3 {
+		addControlDescriptions(packet.Children[2])
+	}
+}
+
+func addDefaultLDAPResponseDescriptions(packet *ber.Packet) {
+	resultCode := packet.Children[1].Children[0].Value.(uint64)
+	packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[LDAPResultCode(resultCode)] + ")"
+	packet.Children[1].Children[1].Description = "Matched DN"
+	packet.Children[1].Children[2].Description = "Error Message"
+	if len(packet.Children[1].Children) > 3 {
+		packet.Children[1].Children[3].Description = "Referral"
+	}
+	if len(packet.Children) == 3 {
+		addControlDescriptions(packet.Children[2])
+	}
+}
+
+func DebugBinaryFile(fileName string) error {
+	file, err := ioutil.ReadFile(fileName)
+	if err != nil {
+		return NewError(ErrorDebugging, err)
+	}
+	ber.PrintBytes(file, "")
+	packet := ber.DecodePacket(file)
+	addLDAPDescriptions(packet)
+	ber.PrintPacket(packet)
+
+	return nil
+}
+
+type Error struct {
+	Err        error
+	ResultCode LDAPResultCode
+}
+
+func (e *Error) Error() string {
+	return fmt.Sprintf("LDAP Result Code %d %q: %s", e.ResultCode, LDAPResultCodeMap[e.ResultCode], e.Err.Error())
+}
+
+func NewError(resultCode LDAPResultCode, err error) error {
+	return &Error{ResultCode: resultCode, Err: err}
+}
+
+func getLDAPResultCode(packet *ber.Packet) (code LDAPResultCode, description string) {
+	if len(packet.Children) >= 2 {
+		response := packet.Children[1]
+		if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) == 3 {
+			return LDAPResultCode(response.Children[0].Value.(uint64)), response.Children[2].Value.(string)
+		}
+	}
+
+	return ErrorNetwork, "Invalid packet format"
+}

+ 123 - 0
ldap/ldap_test.go

@@ -0,0 +1,123 @@
+package ldap
+
+import (
+	"fmt"
+	"testing"
+)
+
+var ldapServer = "ldap.itd.umich.edu"
+var ldapPort = uint16(389)
+var baseDN = "dc=umich,dc=edu"
+var filter = []string{
+	"(cn=cis-fac)",
+	"(&(objectclass=rfc822mailgroup)(cn=*Computer*))",
+	"(&(objectclass=rfc822mailgroup)(cn=*Mathematics*))"}
+var attributes = []string{
+	"cn",
+	"description"}
+
+func TestConnect(t *testing.T) {
+	fmt.Printf("TestConnect: starting...\n")
+	l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
+	if err != nil {
+		t.Errorf(err.Error())
+		return
+	}
+	defer l.Close()
+	fmt.Printf("TestConnect: finished...\n")
+}
+
+func TestSearch(t *testing.T) {
+	fmt.Printf("TestSearch: starting...\n")
+	l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
+	if err != nil {
+		t.Errorf(err.Error())
+		return
+	}
+	defer l.Close()
+
+	searchRequest := NewSearchRequest(
+		baseDN,
+		ScopeWholeSubtree, DerefAlways, 0, 0, false,
+		filter[0],
+		attributes,
+		nil)
+
+	sr, err := l.Search(searchRequest)
+	if err != nil {
+		t.Errorf(err.Error())
+		return
+	}
+
+	fmt.Printf("TestSearch: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries))
+}
+
+func TestSearchWithPaging(t *testing.T) {
+	fmt.Printf("TestSearchWithPaging: starting...\n")
+	l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
+	if err != nil {
+		t.Errorf(err.Error())
+		return
+	}
+	defer l.Close()
+
+	err = l.Bind("", "")
+	if err != nil {
+		t.Errorf(err.Error())
+		return
+	}
+
+	searchRequest := NewSearchRequest(
+		baseDN,
+		ScopeWholeSubtree, DerefAlways, 0, 0, false,
+		filter[1],
+		attributes,
+		nil)
+	sr, err := l.SearchWithPaging(searchRequest, 5)
+	if err != nil {
+		t.Errorf(err.Error())
+		return
+	}
+
+	fmt.Printf("TestSearchWithPaging: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries))
+}
+
+func testMultiGoroutineSearch(t *testing.T, l *Conn, results chan *SearchResult, i int) {
+	searchRequest := NewSearchRequest(
+		baseDN,
+		ScopeWholeSubtree, DerefAlways, 0, 0, false,
+		filter[i],
+		attributes,
+		nil)
+	sr, err := l.Search(searchRequest)
+	if err != nil {
+		t.Errorf(err.Error())
+		results <- nil
+		return
+	}
+	results <- sr
+}
+
+func TestMultiGoroutineSearch(t *testing.T) {
+	fmt.Printf("TestMultiGoroutineSearch: starting...\n")
+	l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
+	if err != nil {
+		t.Errorf(err.Error())
+		return
+	}
+	defer l.Close()
+
+	results := make([]chan *SearchResult, len(filter))
+	for i := range filter {
+		results[i] = make(chan *SearchResult)
+		go testMultiGoroutineSearch(t, l, results[i], i)
+	}
+	for i := range filter {
+		sr := <-results[i]
+		if sr == nil {
+			t.Errorf("Did not receive results from goroutine for %q", filter[i])
+		} else {
+			fmt.Printf("TestMultiGoroutineSearch(%d): %s -> num of entries = %d\n", i, filter[i], len(sr.Entries))
+		}
+	}
+}

+ 1 - 0
ldap/light-ldap.go

@@ -0,0 +1 @@
+package ldap

+ 162 - 0
ldap/modify.go

@@ -0,0 +1,162 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+//
+// File contains Modify functionality
+//
+// https://tools.ietf.org/html/rfc4511
+//
+// ModifyRequest ::= [APPLICATION 6] SEQUENCE {
+//      object          LDAPDN,
+//      changes         SEQUENCE OF change SEQUENCE {
+//           operation       ENUMERATED {
+//                add     (0),
+//                delete  (1),
+//                replace (2),
+//                ...  },
+//           modification    PartialAttribute } }
+//
+// PartialAttribute ::= SEQUENCE {
+//      type       AttributeDescription,
+//      vals       SET OF value AttributeValue }
+//
+// AttributeDescription ::= LDAPString
+//                         -- Constrained to <attributedescription>
+//                         -- [RFC4512]
+//
+// AttributeValue ::= OCTET STRING
+//
+
+package ldap
+
+import (
+	"errors"
+	"log"
+
+	"github.com/nmcclain/asn1-ber"
+)
+
+const (
+	AddAttribute     = 0
+	DeleteAttribute  = 1
+	ReplaceAttribute = 2
+)
+
+var LDAPModifyAttributeMap = map[uint64]string{
+	AddAttribute:     "Add",
+	DeleteAttribute:  "Delete",
+	ReplaceAttribute: "Replace",
+}
+
+type PartialAttribute struct {
+	AttrType string
+	AttrVals []string
+}
+
+func (p *PartialAttribute) encode() *ber.Packet {
+	seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "PartialAttribute")
+	seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, p.AttrType, "Type"))
+	set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue")
+	for _, value := range p.AttrVals {
+		set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals"))
+	}
+	seq.AppendChild(set)
+	return seq
+}
+
+type ModifyRequest struct {
+	Dn                string
+	AddAttributes     []PartialAttribute
+	DeleteAttributes  []PartialAttribute
+	ReplaceAttributes []PartialAttribute
+}
+
+func (m *ModifyRequest) Add(attrType string, attrVals []string) {
+	m.AddAttributes = append(m.AddAttributes, PartialAttribute{AttrType: attrType, AttrVals: attrVals})
+}
+
+func (m *ModifyRequest) Delete(attrType string, attrVals []string) {
+	m.DeleteAttributes = append(m.DeleteAttributes, PartialAttribute{AttrType: attrType, AttrVals: attrVals})
+}
+
+func (m *ModifyRequest) Replace(attrType string, attrVals []string) {
+	m.ReplaceAttributes = append(m.ReplaceAttributes, PartialAttribute{AttrType: attrType, AttrVals: attrVals})
+}
+
+func (m ModifyRequest) encode() *ber.Packet {
+	request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request")
+	request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.Dn, "DN"))
+	changes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Changes")
+	for _, attribute := range m.AddAttributes {
+		change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change")
+		change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(AddAttribute), "Operation"))
+		change.AppendChild(attribute.encode())
+		changes.AppendChild(change)
+	}
+	for _, attribute := range m.DeleteAttributes {
+		change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change")
+		change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(DeleteAttribute), "Operation"))
+		change.AppendChild(attribute.encode())
+		changes.AppendChild(change)
+	}
+	for _, attribute := range m.ReplaceAttributes {
+		change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change")
+		change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ReplaceAttribute), "Operation"))
+		change.AppendChild(attribute.encode())
+		changes.AppendChild(change)
+	}
+	request.AppendChild(changes)
+	return request
+}
+
+func NewModifyRequest(
+	dn string,
+) *ModifyRequest {
+	return &ModifyRequest{
+		Dn: dn,
+	}
+}
+
+func (l *Conn) Modify(modifyRequest *ModifyRequest) error {
+	messageID := l.nextMessageID()
+	packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
+	packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
+	packet.AppendChild(modifyRequest.encode())
+
+	l.Debug.PrintPacket(packet)
+
+	channel, err := l.sendMessage(packet)
+	if err != nil {
+		return err
+	}
+	if channel == nil {
+		return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
+	}
+	defer l.finishMessage(messageID)
+
+	l.Debug.Printf("%d: waiting for response", messageID)
+	packet = <-channel
+	l.Debug.Printf("%d: got response %p", messageID, packet)
+	if packet == nil {
+		return NewError(ErrorNetwork, errors.New("ldap: could not retrieve message"))
+	}
+
+	if l.Debug {
+		if err := addLDAPDescriptions(packet); err != nil {
+			return err
+		}
+		ber.PrintPacket(packet)
+	}
+
+	if packet.Children[1].Tag == ApplicationModifyResponse {
+		resultCode, resultDescription := getLDAPResultCode(packet)
+		if resultCode != 0 {
+			return NewError(resultCode, errors.New(resultDescription))
+		}
+	} else {
+		log.Printf("Unexpected Response: %d", packet.Children[1].Tag)
+	}
+
+	l.Debug.Printf("%d: returning", messageID)
+	return nil
+}

+ 350 - 0
ldap/search.go

@@ -0,0 +1,350 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+//
+// File contains Search functionality
+//
+// https://tools.ietf.org/html/rfc4511
+//
+//         SearchRequest ::= [APPLICATION 3] SEQUENCE {
+//              baseObject      LDAPDN,
+//              scope           ENUMERATED {
+//                   baseObject              (0),
+//                   singleLevel             (1),
+//                   wholeSubtree            (2),
+//                   ...  },
+//              derefAliases    ENUMERATED {
+//                   neverDerefAliases       (0),
+//                   derefInSearching        (1),
+//                   derefFindingBaseObj     (2),
+//                   derefAlways             (3) },
+//              sizeLimit       INTEGER (0 ..  maxInt),
+//              timeLimit       INTEGER (0 ..  maxInt),
+//              typesOnly       BOOLEAN,
+//              filter          Filter,
+//              attributes      AttributeSelection }
+//
+//         AttributeSelection ::= SEQUENCE OF selector LDAPString
+//                         -- The LDAPString is constrained to
+//                         -- <attributeSelector> in Section 4.5.1.8
+//
+//         Filter ::= CHOICE {
+//              and             [0] SET SIZE (1..MAX) OF filter Filter,
+//              or              [1] SET SIZE (1..MAX) OF filter Filter,
+//              not             [2] Filter,
+//              equalityMatch   [3] AttributeValueAssertion,
+//              substrings      [4] SubstringFilter,
+//              greaterOrEqual  [5] AttributeValueAssertion,
+//              lessOrEqual     [6] AttributeValueAssertion,
+//              present         [7] AttributeDescription,
+//              approxMatch     [8] AttributeValueAssertion,
+//              extensibleMatch [9] MatchingRuleAssertion,
+//              ...  }
+//
+//         SubstringFilter ::= SEQUENCE {
+//              type           AttributeDescription,
+//              substrings     SEQUENCE SIZE (1..MAX) OF substring CHOICE {
+//                   initial [0] AssertionValue,  -- can occur at most once
+//                   any     [1] AssertionValue,
+//                   final   [2] AssertionValue } -- can occur at most once
+//              }
+//
+//         MatchingRuleAssertion ::= SEQUENCE {
+//              matchingRule    [1] MatchingRuleId OPTIONAL,
+//              type            [2] AttributeDescription OPTIONAL,
+//              matchValue      [3] AssertionValue,
+//              dnAttributes    [4] BOOLEAN DEFAULT FALSE }
+//
+//
+
+package ldap
+
+import (
+	"errors"
+	"fmt"
+	"strings"
+
+	"github.com/nmcclain/asn1-ber"
+)
+
+const (
+	ScopeBaseObject   = 0
+	ScopeSingleLevel  = 1
+	ScopeWholeSubtree = 2
+)
+
+var ScopeMap = map[int]string{
+	ScopeBaseObject:   "Base Object",
+	ScopeSingleLevel:  "Single Level",
+	ScopeWholeSubtree: "Whole Subtree",
+}
+
+const (
+	NeverDerefAliases   = 0
+	DerefInSearching    = 1
+	DerefFindingBaseObj = 2
+	DerefAlways         = 3
+)
+
+var DerefMap = map[int]string{
+	NeverDerefAliases:   "NeverDerefAliases",
+	DerefInSearching:    "DerefInSearching",
+	DerefFindingBaseObj: "DerefFindingBaseObj",
+	DerefAlways:         "DerefAlways",
+}
+
+type Entry struct {
+	DN         string
+	Attributes []*EntryAttribute
+}
+
+func (e *Entry) GetAttributeValues(attribute string) []string {
+	for _, attr := range e.Attributes {
+		if attr.Name == attribute {
+			return attr.Values
+		}
+	}
+	return []string{}
+}
+
+func (e *Entry) GetAttributeValue(attribute string) string {
+	values := e.GetAttributeValues(attribute)
+	if len(values) == 0 {
+		return ""
+	}
+	return values[0]
+}
+
+func (e *Entry) Print() {
+	fmt.Printf("DN: %s\n", e.DN)
+	for _, attr := range e.Attributes {
+		attr.Print()
+	}
+}
+
+func (e *Entry) PrettyPrint(indent int) {
+	fmt.Printf("%sDN: %s\n", strings.Repeat(" ", indent), e.DN)
+	for _, attr := range e.Attributes {
+		attr.PrettyPrint(indent + 2)
+	}
+}
+
+type EntryAttribute struct {
+	Name   string
+	Values []string
+}
+
+func (e *EntryAttribute) Print() {
+	fmt.Printf("%s: %s\n", e.Name, e.Values)
+}
+
+func (e *EntryAttribute) PrettyPrint(indent int) {
+	fmt.Printf("%s%s: %s\n", strings.Repeat(" ", indent), e.Name, e.Values)
+}
+
+type SearchResult struct {
+	Entries   []*Entry
+	Referrals []string
+	Controls  []Control
+}
+
+func (s *SearchResult) Print() {
+	for _, entry := range s.Entries {
+		entry.Print()
+	}
+}
+
+func (s *SearchResult) PrettyPrint(indent int) {
+	for _, entry := range s.Entries {
+		entry.PrettyPrint(indent)
+	}
+}
+
+type SearchRequest struct {
+	BaseDN       string
+	Scope        int
+	DerefAliases int
+	SizeLimit    int
+	TimeLimit    int
+	TypesOnly    bool
+	Filter       string
+	Attributes   []string
+	Controls     []Control
+}
+
+func (s *SearchRequest) encode() (*ber.Packet, error) {
+	request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request")
+	request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, s.BaseDN, "Base DN"))
+	request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(s.Scope), "Scope"))
+	request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(s.DerefAliases), "Deref Aliases"))
+	request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(s.SizeLimit), "Size Limit"))
+	request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(s.TimeLimit), "Time Limit"))
+	request.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, s.TypesOnly, "Types Only"))
+	// compile and encode filter
+	filterPacket, err := CompileFilter(s.Filter)
+	if err != nil {
+		return nil, err
+	}
+	request.AppendChild(filterPacket)
+	// encode attributes
+	attributesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes")
+	for _, attribute := range s.Attributes {
+		attributesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
+	}
+	request.AppendChild(attributesPacket)
+	return request, nil
+}
+
+func NewSearchRequest(
+	BaseDN string,
+	Scope, DerefAliases, SizeLimit, TimeLimit int,
+	TypesOnly bool,
+	Filter string,
+	Attributes []string,
+	Controls []Control,
+) *SearchRequest {
+	return &SearchRequest{
+		BaseDN:       BaseDN,
+		Scope:        Scope,
+		DerefAliases: DerefAliases,
+		SizeLimit:    SizeLimit,
+		TimeLimit:    TimeLimit,
+		TypesOnly:    TypesOnly,
+		Filter:       Filter,
+		Attributes:   Attributes,
+		Controls:     Controls,
+	}
+}
+
+func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) {
+	if searchRequest.Controls == nil {
+		searchRequest.Controls = make([]Control, 0)
+	}
+
+	pagingControl := NewControlPaging(pagingSize)
+	searchRequest.Controls = append(searchRequest.Controls, pagingControl)
+	searchResult := new(SearchResult)
+	for {
+		result, err := l.Search(searchRequest)
+		l.Debug.Printf("Looking for Paging Control...")
+		if err != nil {
+			return searchResult, err
+		}
+		if result == nil {
+			return searchResult, NewError(ErrorNetwork, errors.New("ldap: packet not received"))
+		}
+
+		for _, entry := range result.Entries {
+			searchResult.Entries = append(searchResult.Entries, entry)
+		}
+		for _, referral := range result.Referrals {
+			searchResult.Referrals = append(searchResult.Referrals, referral)
+		}
+		for _, control := range result.Controls {
+			searchResult.Controls = append(searchResult.Controls, control)
+		}
+
+		l.Debug.Printf("Looking for Paging Control...")
+		pagingResult := FindControl(result.Controls, ControlTypePaging)
+		if pagingResult == nil {
+			pagingControl = nil
+			l.Debug.Printf("Could not find paging control.  Breaking...")
+			break
+		}
+
+		cookie := pagingResult.(*ControlPaging).Cookie
+		if len(cookie) == 0 {
+			pagingControl = nil
+			l.Debug.Printf("Could not find cookie.  Breaking...")
+			break
+		}
+		pagingControl.SetCookie(cookie)
+	}
+
+	if pagingControl != nil {
+		l.Debug.Printf("Abandoning Paging...")
+		pagingControl.PagingSize = 0
+		l.Search(searchRequest)
+	}
+
+	return searchResult, nil
+}
+
+func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
+	messageID := l.nextMessageID()
+	packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
+	packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
+	// encode search request
+	encodedSearchRequest, err := searchRequest.encode()
+	if err != nil {
+		return nil, err
+	}
+	packet.AppendChild(encodedSearchRequest)
+	// encode search controls
+	if searchRequest.Controls != nil {
+		packet.AppendChild(encodeControls(searchRequest.Controls))
+	}
+
+	l.Debug.PrintPacket(packet)
+
+	channel, err := l.sendMessage(packet)
+	if err != nil {
+		return nil, err
+	}
+	if channel == nil {
+		return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message"))
+	}
+	defer l.finishMessage(messageID)
+
+	result := &SearchResult{
+		Entries:   make([]*Entry, 0),
+		Referrals: make([]string, 0),
+		Controls:  make([]Control, 0)}
+
+	foundSearchResultDone := false
+	for !foundSearchResultDone {
+		l.Debug.Printf("%d: waiting for response", messageID)
+		packet = <-channel
+		l.Debug.Printf("%d: got response %p", messageID, packet)
+		if packet == nil {
+			return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve message"))
+		}
+
+		if l.Debug {
+			if err := addLDAPDescriptions(packet); err != nil {
+				return nil, err
+			}
+			ber.PrintPacket(packet)
+		}
+
+		switch packet.Children[1].Tag {
+		case 4:
+			entry := new(Entry)
+			entry.DN = packet.Children[1].Children[0].Value.(string)
+			for _, child := range packet.Children[1].Children[1].Children {
+				attr := new(EntryAttribute)
+				attr.Name = child.Children[0].Value.(string)
+				for _, value := range child.Children[1].Children {
+					attr.Values = append(attr.Values, value.Value.(string))
+				}
+				entry.Attributes = append(entry.Attributes, attr)
+			}
+			result.Entries = append(result.Entries, entry)
+		case 5:
+			resultCode, resultDescription := getLDAPResultCode(packet)
+			if resultCode != 0 {
+				return result, NewError(resultCode, errors.New(resultDescription))
+			}
+			if len(packet.Children) == 3 {
+				for _, child := range packet.Children[2].Children {
+					result.Controls = append(result.Controls, DecodeControl(child))
+				}
+			}
+			foundSearchResultDone = true
+		case 19:
+			result.Referrals = append(result.Referrals, packet.Children[1].Children[0].Value.(string))
+		}
+	}
+	l.Debug.Printf("%d: returning", messageID)
+	return result, nil
+}

+ 475 - 0
ldap/server.go

@@ -0,0 +1,475 @@
+package ldap
+
+import (
+	"crypto/tls"
+	"io"
+	"log"
+	"net"
+	"strings"
+	"sync"
+
+	"github.com/nmcclain/asn1-ber"
+)
+
+type Binder interface {
+	Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error)
+}
+type Searcher interface {
+	Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error)
+}
+type Adder interface {
+	Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Modifier interface {
+	Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Deleter interface {
+	Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error)
+}
+type ModifyDNr interface {
+	ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Comparer interface {
+	Compare(boundDN string, req CompareRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Abandoner interface {
+	Abandon(boundDN string, conn net.Conn) error
+}
+type Extender interface {
+	Extended(boundDN string, req ExtendedRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Unbinder interface {
+	Unbind(boundDN string, conn net.Conn) (LDAPResultCode, error)
+}
+type Closer interface {
+	Close(boundDN string, conn net.Conn) error
+}
+
+//
+type Server struct {
+	BindFns     map[string]Binder
+	SearchFns   map[string]Searcher
+	AddFns      map[string]Adder
+	ModifyFns   map[string]Modifier
+	DeleteFns   map[string]Deleter
+	ModifyDNFns map[string]ModifyDNr
+	CompareFns  map[string]Comparer
+	AbandonFns  map[string]Abandoner
+	ExtendedFns map[string]Extender
+	UnbindFns   map[string]Unbinder
+	CloseFns    map[string]Closer
+	Quit        chan bool
+	EnforceLDAP bool
+	Stats       *Stats
+}
+
+type Stats struct {
+	Conns      int
+	Binds      int
+	Unbinds    int
+	Searches   int
+	statsMutex sync.Mutex
+}
+
+type ServerSearchResult struct {
+	Entries    []*Entry
+	Referrals  []string
+	Controls   []Control
+	ResultCode LDAPResultCode
+}
+
+//
+func NewServer() *Server {
+	s := new(Server)
+	s.Quit = make(chan bool)
+
+	d := defaultHandler{}
+	s.BindFns = make(map[string]Binder)
+	s.SearchFns = make(map[string]Searcher)
+	s.AddFns = make(map[string]Adder)
+	s.ModifyFns = make(map[string]Modifier)
+	s.DeleteFns = make(map[string]Deleter)
+	s.ModifyDNFns = make(map[string]ModifyDNr)
+	s.CompareFns = make(map[string]Comparer)
+	s.AbandonFns = make(map[string]Abandoner)
+	s.ExtendedFns = make(map[string]Extender)
+	s.UnbindFns = make(map[string]Unbinder)
+	s.CloseFns = make(map[string]Closer)
+	s.BindFunc("", d)
+	s.SearchFunc("", d)
+	s.AddFunc("", d)
+	s.ModifyFunc("", d)
+	s.DeleteFunc("", d)
+	s.ModifyDNFunc("", d)
+	s.CompareFunc("", d)
+	s.AbandonFunc("", d)
+	s.ExtendedFunc("", d)
+	s.UnbindFunc("", d)
+	s.CloseFunc("", d)
+	s.Stats = nil
+	return s
+}
+func (server *Server) BindFunc(baseDN string, f Binder) {
+	server.BindFns[baseDN] = f
+}
+func (server *Server) SearchFunc(baseDN string, f Searcher) {
+	server.SearchFns[baseDN] = f
+}
+func (server *Server) AddFunc(baseDN string, f Adder) {
+	server.AddFns[baseDN] = f
+}
+func (server *Server) ModifyFunc(baseDN string, f Modifier) {
+	server.ModifyFns[baseDN] = f
+}
+func (server *Server) DeleteFunc(baseDN string, f Deleter) {
+	server.DeleteFns[baseDN] = f
+}
+func (server *Server) ModifyDNFunc(baseDN string, f ModifyDNr) {
+	server.ModifyDNFns[baseDN] = f
+}
+func (server *Server) CompareFunc(baseDN string, f Comparer) {
+	server.CompareFns[baseDN] = f
+}
+func (server *Server) AbandonFunc(baseDN string, f Abandoner) {
+	server.AbandonFns[baseDN] = f
+}
+func (server *Server) ExtendedFunc(baseDN string, f Extender) {
+	server.ExtendedFns[baseDN] = f
+}
+func (server *Server) UnbindFunc(baseDN string, f Unbinder) {
+	server.UnbindFns[baseDN] = f
+}
+func (server *Server) CloseFunc(baseDN string, f Closer) {
+	server.CloseFns[baseDN] = f
+}
+func (server *Server) QuitChannel(quit chan bool) {
+	server.Quit = quit
+}
+
+func (server *Server) ListenAndServeTLS(listenString string, certFile string, keyFile string) error {
+	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
+	if err != nil {
+		return err
+	}
+	tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}}
+	tlsConfig.ServerName = "localhost"
+	ln, err := tls.Listen("tcp", listenString, &tlsConfig)
+	if err != nil {
+		return err
+	}
+	err = server.Serve(ln)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func (server *Server) SetStats(enable bool) {
+	if enable {
+		server.Stats = &Stats{}
+	} else {
+		server.Stats = nil
+	}
+}
+
+func (server *Server) GetStats() Stats {
+	defer func() {
+		server.Stats.statsMutex.Unlock()
+	}()
+	server.Stats.statsMutex.Lock()
+	return *server.Stats
+}
+
+func (server *Server) ListenAndServe(listenString string) error {
+	ln, err := net.Listen("tcp", listenString)
+	if err != nil {
+		return err
+	}
+	err = server.Serve(ln)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func (server *Server) Serve(ln net.Listener) error {
+	newConn := make(chan net.Conn)
+	go func() {
+		for {
+			conn, err := ln.Accept()
+			if err != nil {
+				if !strings.HasSuffix(err.Error(), "use of closed network connection") {
+					log.Printf("Error accepting network connection: %s", err.Error())
+				}
+				break
+			}
+			newConn <- conn
+		}
+	}()
+
+listener:
+	for {
+		select {
+		case c := <-newConn:
+			server.Stats.countConns(1)
+			go server.handleConnection(c)
+		case <-server.Quit:
+			ln.Close()
+			break listener
+		}
+	}
+	return nil
+}
+
+//
+func (server *Server) handleConnection(conn net.Conn) {
+	boundDN := "" // "" == anonymous
+
+handler:
+	for {
+		// read incoming LDAP packet
+		packet, err := ber.ReadPacket(conn)
+		log.Println(packet)
+		if err == io.EOF { // Client closed connection
+			break
+		} else if err != nil {
+			log.Printf("handleConnection ber.ReadPacket ERROR: %s", err.Error())
+			break
+		}
+
+		// sanity check this packet
+		if len(packet.Children) < 2 {
+			log.Print("len(packet.Children) < 2")
+			break
+		}
+		// check the message ID and ClassType
+		messageID, ok := packet.Children[0].Value.(uint64)
+		if !ok {
+			log.Print("malformed messageID")
+			break
+		}
+		req := packet.Children[1]
+		if req.ClassType != ber.ClassApplication {
+			log.Print("req.ClassType != ber.ClassApplication")
+			break
+		}
+		// handle controls if present
+		controls := []Control{}
+		if len(packet.Children) > 2 {
+			for _, child := range packet.Children[2].Children {
+				controls = append(controls, DecodeControl(child))
+			}
+		}
+
+		//log.Printf("DEBUG: handling operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
+		//ber.PrintPacket(packet) // DEBUG
+
+		// dispatch the LDAP operation
+		switch req.Tag { // ldap op code
+		default:
+			responsePacket := encodeLDAPResponse(messageID, ApplicationAddResponse, LDAPResultOperationsError, "Unsupported operation: add")
+			if err = sendPacket(conn, responsePacket); err != nil {
+				log.Printf("sendPacket error %s", err.Error())
+			}
+			log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
+			break handler
+
+		case ApplicationBindRequest:
+			server.Stats.countBinds(1)
+			ldapResultCode := HandleBindRequest(req, server.BindFns, conn)
+			if ldapResultCode == LDAPResultSuccess {
+				boundDN, ok = req.Children[1].Value.(string)
+				if !ok {
+					log.Printf("Malformed Bind DN")
+					break handler
+				}
+			}
+			responsePacket := encodeBindResponse(messageID, ldapResultCode)
+			if err = sendPacket(conn, responsePacket); err != nil {
+				log.Printf("sendPacket error %s", err.Error())
+				break handler
+			}
+		case ApplicationSearchRequest:
+			server.Stats.countSearches(1)
+			if err := HandleSearchRequest(req, &controls, messageID, boundDN, server, conn); err != nil {
+				log.Printf("handleSearchRequest error %s", err.Error()) // TODO: make this more testable/better err handling - stop using log, stop using breaks?
+				e := err.(*Error)
+				if err = sendPacket(conn, encodeSearchDone(messageID, e.ResultCode)); err != nil {
+					log.Printf("sendPacket error %s", err.Error())
+					break handler
+				}
+				break handler
+			} else {
+				if err = sendPacket(conn, encodeSearchDone(messageID, LDAPResultSuccess)); err != nil {
+					log.Printf("sendPacket error %s", err.Error())
+					break handler
+				}
+			}
+		case ApplicationUnbindRequest:
+			server.Stats.countUnbinds(1)
+			break handler // simply disconnect
+		case ApplicationExtendedRequest:
+			ldapResultCode := HandleExtendedRequest(req, boundDN, server.ExtendedFns, conn)
+			responsePacket := encodeLDAPResponse(messageID, ApplicationExtendedResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+			if err = sendPacket(conn, responsePacket); err != nil {
+				log.Printf("sendPacket error %s", err.Error())
+				break handler
+			}
+		case ApplicationAbandonRequest:
+			HandleAbandonRequest(req, boundDN, server.AbandonFns, conn)
+			break handler
+
+		case ApplicationAddRequest:
+			ldapResultCode := HandleAddRequest(req, boundDN, server.AddFns, conn)
+			responsePacket := encodeLDAPResponse(messageID, ApplicationAddResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+			if err = sendPacket(conn, responsePacket); err != nil {
+				log.Printf("sendPacket error %s", err.Error())
+				break handler
+			}
+		case ApplicationModifyRequest:
+			ldapResultCode := HandleModifyRequest(req, boundDN, server.ModifyFns, conn)
+			responsePacket := encodeLDAPResponse(messageID, ApplicationModifyResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+			if err = sendPacket(conn, responsePacket); err != nil {
+				log.Printf("sendPacket error %s", err.Error())
+				break handler
+			}
+		case ApplicationDelRequest:
+			ldapResultCode := HandleDeleteRequest(req, boundDN, server.DeleteFns, conn)
+			responsePacket := encodeLDAPResponse(messageID, ApplicationDelResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+			if err = sendPacket(conn, responsePacket); err != nil {
+				log.Printf("sendPacket error %s", err.Error())
+				break handler
+			}
+		case ApplicationModifyDNRequest:
+			ldapResultCode := HandleModifyDNRequest(req, boundDN, server.ModifyDNFns, conn)
+			responsePacket := encodeLDAPResponse(messageID, ApplicationModifyDNResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+			if err = sendPacket(conn, responsePacket); err != nil {
+				log.Printf("sendPacket error %s", err.Error())
+				break handler
+			}
+		case ApplicationCompareRequest:
+			ldapResultCode := HandleCompareRequest(req, boundDN, server.CompareFns, conn)
+			responsePacket := encodeLDAPResponse(messageID, ApplicationCompareResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+			if err = sendPacket(conn, responsePacket); err != nil {
+				log.Printf("sendPacket error %s", err.Error())
+				break handler
+			}
+		}
+	}
+
+	for _, c := range server.CloseFns {
+		c.Close(boundDN, conn)
+	}
+
+	conn.Close()
+}
+
+//
+func sendPacket(conn net.Conn, packet *ber.Packet) error {
+	_, err := conn.Write(packet.Bytes())
+	if err != nil {
+		log.Printf("Error Sending Message: %s", err.Error())
+		return err
+	}
+	return nil
+}
+
+//
+func routeFunc(dn string, funcNames []string) string {
+	bestPick := ""
+	for _, fn := range funcNames {
+		if strings.HasSuffix(dn, fn) {
+			l := len(strings.Split(bestPick, ","))
+			if bestPick == "" {
+				l = 0
+			}
+			if len(strings.Split(fn, ",")) > l {
+				bestPick = fn
+			}
+		}
+	}
+	return bestPick
+}
+
+//
+func encodeLDAPResponse(messageID uint64, responseType uint8, ldapResultCode LDAPResultCode, message string) *ber.Packet {
+	responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+	responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+	reponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, responseType, nil, ApplicationMap[responseType])
+	reponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: "))
+	reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
+	reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, message, "errorMessage: "))
+	responsePacket.AppendChild(reponse)
+	return responsePacket
+}
+
+//
+type defaultHandler struct {
+}
+
+func (h defaultHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+	return LDAPResultInvalidCredentials, nil
+}
+func (h defaultHandler) Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+	return ServerSearchResult{make([]*Entry, 0), []string{}, []Control{}, LDAPResultSuccess}, nil
+}
+func (h defaultHandler) Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) {
+	return LDAPResultInsufficientAccessRights, nil
+}
+func (h defaultHandler) Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error) {
+	return LDAPResultInsufficientAccessRights, nil
+}
+func (h defaultHandler) Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error) {
+	return LDAPResultInsufficientAccessRights, nil
+}
+func (h defaultHandler) ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error) {
+	return LDAPResultInsufficientAccessRights, nil
+}
+func (h defaultHandler) Compare(boundDN string, req CompareRequest, conn net.Conn) (LDAPResultCode, error) {
+	return LDAPResultInsufficientAccessRights, nil
+}
+func (h defaultHandler) Abandon(boundDN string, conn net.Conn) error {
+	return nil
+}
+func (h defaultHandler) Extended(boundDN string, req ExtendedRequest, conn net.Conn) (LDAPResultCode, error) {
+	return LDAPResultProtocolError, nil
+}
+func (h defaultHandler) Unbind(boundDN string, conn net.Conn) (LDAPResultCode, error) {
+	return LDAPResultSuccess, nil
+}
+func (h defaultHandler) Close(boundDN string, conn net.Conn) error {
+	conn.Close()
+	return nil
+}
+
+//
+func (stats *Stats) countConns(delta int) {
+	if stats != nil {
+		stats.statsMutex.Lock()
+		stats.Conns += delta
+		stats.statsMutex.Unlock()
+	}
+}
+func (stats *Stats) countBinds(delta int) {
+	if stats != nil {
+		stats.statsMutex.Lock()
+		stats.Binds += delta
+		stats.statsMutex.Unlock()
+	}
+}
+func (stats *Stats) countUnbinds(delta int) {
+	if stats != nil {
+		stats.statsMutex.Lock()
+		stats.Unbinds += delta
+		stats.statsMutex.Unlock()
+	}
+}
+func (stats *Stats) countSearches(delta int) {
+	if stats != nil {
+		stats.statsMutex.Lock()
+		stats.Searches += delta
+		stats.statsMutex.Unlock()
+	}
+}
+
+//

+ 73 - 0
ldap/server_bind.go

@@ -0,0 +1,73 @@
+package ldap
+
+import (
+	"github.com/nmcclain/asn1-ber"
+	"log"
+	"net"
+)
+
+func HandleBindRequest(req *ber.Packet, fns map[string]Binder, conn net.Conn) (resultCode LDAPResultCode) {
+	defer func() {
+		if r := recover(); r != nil {
+			resultCode = LDAPResultOperationsError
+		}
+	}()
+
+	// we only support ldapv3
+	ldapVersion, ok := req.Children[0].Value.(uint64)
+	if !ok {
+		return LDAPResultProtocolError
+	}
+	if ldapVersion != 3 {
+		log.Printf("Unsupported LDAP version: %d", ldapVersion)
+		return LDAPResultInappropriateAuthentication
+	}
+
+	// auth types
+	bindDN, ok := req.Children[1].Value.(string)
+	if !ok {
+		return LDAPResultProtocolError
+	}
+	bindAuth := req.Children[2]
+	switch bindAuth.Tag {
+	default:
+		log.Print("Unknown LDAP authentication method")
+		return LDAPResultInappropriateAuthentication
+	case LDAPBindAuthSimple:
+		if len(req.Children) == 3 {
+			fnNames := []string{}
+			for k := range fns {
+				fnNames = append(fnNames, k)
+			}
+			fn := routeFunc(bindDN, fnNames)
+			resultCode, err := fns[fn].Bind(bindDN, bindAuth.Data.String(), conn)
+			if err != nil {
+				log.Printf("BindFn Error %s", err.Error())
+				return LDAPResultOperationsError
+			}
+			return resultCode
+		} else {
+			log.Print("Simple bind request has wrong # children.  len(req.Children) != 3")
+			return LDAPResultInappropriateAuthentication
+		}
+	case LDAPBindAuthSASL:
+		log.Print("SASL authentication is not supported")
+		return LDAPResultInappropriateAuthentication
+	}
+	return LDAPResultOperationsError
+}
+
+func encodeBindResponse(messageID uint64, ldapResultCode LDAPResultCode) *ber.Packet {
+	responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+	responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+
+	bindReponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response")
+	bindReponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: "))
+	bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
+	bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: "))
+
+	responsePacket.AppendChild(bindReponse)
+
+	// ber.PrintPacket(responsePacket)
+	return responsePacket
+}

+ 232 - 0
ldap/server_modify.go

@@ -0,0 +1,232 @@
+package ldap
+
+import (
+	"log"
+	"net"
+
+	"github.com/nmcclain/asn1-ber"
+)
+
+func HandleAddRequest(req *ber.Packet, boundDN string, fns map[string]Adder, conn net.Conn) (resultCode LDAPResultCode) {
+	if len(req.Children) != 2 {
+		return LDAPResultProtocolError
+	}
+	var ok bool
+	addReq := AddRequest{}
+	addReq.dn, ok = req.Children[0].Value.(string)
+	if !ok {
+		return LDAPResultProtocolError
+	}
+	addReq.attributes = []Attribute{}
+	for _, attr := range req.Children[1].Children {
+		if len(attr.Children) != 2 {
+			return LDAPResultProtocolError
+		}
+
+		a := Attribute{}
+		a.attrType, ok = attr.Children[0].Value.(string)
+		if !ok {
+			return LDAPResultProtocolError
+		}
+		a.attrVals = []string{}
+		for _, val := range attr.Children[1].Children {
+			v, ok := val.Value.(string)
+			if !ok {
+				return LDAPResultProtocolError
+			}
+			a.attrVals = append(a.attrVals, v)
+		}
+		addReq.attributes = append(addReq.attributes, a)
+	}
+	fnNames := []string{}
+	for k := range fns {
+		fnNames = append(fnNames, k)
+	}
+	fn := routeFunc(boundDN, fnNames)
+	resultCode, err := fns[fn].Add(boundDN, addReq, conn)
+	if err != nil {
+		log.Printf("AddFn Error %s", err.Error())
+		return LDAPResultOperationsError
+	}
+	return resultCode
+}
+
+func HandleDeleteRequest(req *ber.Packet, boundDN string, fns map[string]Deleter, conn net.Conn) (resultCode LDAPResultCode) {
+	deleteDN := ber.DecodeString(req.Data.Bytes())
+	fnNames := []string{}
+	for k := range fns {
+		fnNames = append(fnNames, k)
+	}
+	fn := routeFunc(boundDN, fnNames)
+	resultCode, err := fns[fn].Delete(boundDN, deleteDN, conn)
+	if err != nil {
+		log.Printf("DeleteFn Error %s", err.Error())
+		return LDAPResultOperationsError
+	}
+	return resultCode
+}
+
+func HandleModifyRequest(req *ber.Packet, boundDN string, fns map[string]Modifier, conn net.Conn) (resultCode LDAPResultCode) {
+	if len(req.Children) != 2 {
+		return LDAPResultProtocolError
+	}
+	var ok bool
+	modReq := ModifyRequest{}
+	modReq.Dn, ok = req.Children[0].Value.(string)
+	if !ok {
+		return LDAPResultProtocolError
+	}
+	for _, change := range req.Children[1].Children {
+		if len(change.Children) != 2 {
+			return LDAPResultProtocolError
+		}
+		attr := PartialAttribute{}
+		attrs := change.Children[1].Children
+		if len(attrs) != 2 {
+			return LDAPResultProtocolError
+		}
+		attr.AttrType, ok = attrs[0].Value.(string)
+		if !ok {
+			return LDAPResultProtocolError
+		}
+		for _, val := range attrs[1].Children {
+			v, ok := val.Value.(string)
+			if !ok {
+				return LDAPResultProtocolError
+			}
+			attr.AttrVals = append(attr.AttrVals, v)
+		}
+		op, ok := change.Children[0].Value.(uint64)
+		if !ok {
+			return LDAPResultProtocolError
+		}
+		switch op {
+		default:
+			log.Printf("Unrecognized Modify attribute %d", op)
+			return LDAPResultProtocolError
+		case AddAttribute:
+			modReq.Add(attr.AttrType, attr.AttrVals)
+		case DeleteAttribute:
+			modReq.Delete(attr.AttrType, attr.AttrVals)
+		case ReplaceAttribute:
+			modReq.Replace(attr.AttrType, attr.AttrVals)
+		}
+	}
+	fnNames := []string{}
+	for k := range fns {
+		fnNames = append(fnNames, k)
+	}
+	fn := routeFunc(boundDN, fnNames)
+	resultCode, err := fns[fn].Modify(boundDN, modReq, conn)
+	if err != nil {
+		log.Printf("ModifyFn Error %s", err.Error())
+		return LDAPResultOperationsError
+	}
+	return resultCode
+}
+
+func HandleCompareRequest(req *ber.Packet, boundDN string, fns map[string]Comparer, conn net.Conn) (resultCode LDAPResultCode) {
+	if len(req.Children) != 2 {
+		return LDAPResultProtocolError
+	}
+	var ok bool
+	compReq := CompareRequest{}
+	compReq.dn, ok = req.Children[0].Value.(string)
+	if !ok {
+		return LDAPResultProtocolError
+	}
+	ava := req.Children[1]
+	if len(ava.Children) != 2 {
+		return LDAPResultProtocolError
+	}
+	attr, ok := ava.Children[0].Value.(string)
+	if !ok {
+		return LDAPResultProtocolError
+	}
+	val, ok := ava.Children[1].Value.(string)
+	if !ok {
+		return LDAPResultProtocolError
+	}
+	compReq.ava = []AttributeValueAssertion{AttributeValueAssertion{attr, val}}
+	fnNames := []string{}
+	for k := range fns {
+		fnNames = append(fnNames, k)
+	}
+	fn := routeFunc(boundDN, fnNames)
+	resultCode, err := fns[fn].Compare(boundDN, compReq, conn)
+	if err != nil {
+		log.Printf("CompareFn Error %s", err.Error())
+		return LDAPResultOperationsError
+	}
+	return resultCode
+}
+
+func HandleExtendedRequest(req *ber.Packet, boundDN string, fns map[string]Extender, conn net.Conn) (resultCode LDAPResultCode) {
+	if len(req.Children) != 1 && len(req.Children) != 2 {
+		return LDAPResultProtocolError
+	}
+	name := ber.DecodeString(req.Children[0].Data.Bytes())
+	var val string
+	if len(req.Children) == 2 {
+		val = ber.DecodeString(req.Children[1].Data.Bytes())
+	}
+	extReq := ExtendedRequest{name, val}
+	fnNames := []string{}
+	for k := range fns {
+		fnNames = append(fnNames, k)
+	}
+	fn := routeFunc(boundDN, fnNames)
+	resultCode, err := fns[fn].Extended(boundDN, extReq, conn)
+	if err != nil {
+		log.Printf("ExtendedFn Error %s", err.Error())
+		return LDAPResultOperationsError
+	}
+	return resultCode
+}
+
+func HandleAbandonRequest(req *ber.Packet, boundDN string, fns map[string]Abandoner, conn net.Conn) error {
+	fnNames := []string{}
+	for k := range fns {
+		fnNames = append(fnNames, k)
+	}
+	fn := routeFunc(boundDN, fnNames)
+	err := fns[fn].Abandon(boundDN, conn)
+	return err
+}
+
+func HandleModifyDNRequest(req *ber.Packet, boundDN string, fns map[string]ModifyDNr, conn net.Conn) (resultCode LDAPResultCode) {
+	if len(req.Children) != 3 && len(req.Children) != 4 {
+		return LDAPResultProtocolError
+	}
+	var ok bool
+	mdnReq := ModifyDNRequest{}
+	mdnReq.dn, ok = req.Children[0].Value.(string)
+	if !ok {
+		return LDAPResultProtocolError
+	}
+	mdnReq.newrdn, ok = req.Children[1].Value.(string)
+	if !ok {
+		return LDAPResultProtocolError
+	}
+	mdnReq.deleteoldrdn, ok = req.Children[2].Value.(bool)
+	if !ok {
+		return LDAPResultProtocolError
+	}
+	if len(req.Children) == 4 {
+		mdnReq.newSuperior, ok = req.Children[3].Value.(string)
+		if !ok {
+			return LDAPResultProtocolError
+		}
+	}
+	fnNames := []string{}
+	for k := range fns {
+		fnNames = append(fnNames, k)
+	}
+	fn := routeFunc(boundDN, fnNames)
+	resultCode, err := fns[fn].ModifyDN(boundDN, mdnReq, conn)
+	if err != nil {
+		log.Printf("ModifyDN Error %s", err.Error())
+		return LDAPResultOperationsError
+	}
+	return resultCode
+}

+ 191 - 0
ldap/server_modify_test.go

@@ -0,0 +1,191 @@
+package ldap
+
+import (
+	"net"
+	"os/exec"
+	"strings"
+	"testing"
+	"time"
+)
+
+//
+func TestAdd(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.QuitChannel(quit)
+		s.BindFunc("", modifyTestHandler{})
+		s.AddFunc("", modifyTestHandler{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+	go func() {
+		cmd := exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add.ldif")
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "modify complete") {
+			t.Errorf("ldapadd failed: %v", string(out))
+		}
+		cmd = exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add2.ldif")
+		out, _ = cmd.CombinedOutput()
+		if !strings.Contains(string(out), "ldap_add: Insufficient access") {
+			t.Errorf("ldapadd should have failed: %v", string(out))
+		}
+		if strings.Contains(string(out), "modify complete") {
+			t.Errorf("ldapadd should have failed: %v", string(out))
+		}
+		done <- true
+	}()
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapadd command timed out")
+	}
+	quit <- true
+}
+
+//
+func TestDelete(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.QuitChannel(quit)
+		s.BindFunc("", modifyTestHandler{})
+		s.DeleteFunc("", modifyTestHandler{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+	go func() {
+		cmd := exec.Command("ldapdelete", "-v", "-H", ldapURL, "-x", "cn=Delete Me,dc=example,dc=com")
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "Delete Result: Success (0)") || !strings.Contains(string(out), "Additional info: Success") {
+			t.Errorf("ldapdelete failed: %v", string(out))
+		}
+		cmd = exec.Command("ldapdelete", "-v", "-H", ldapURL, "-x", "cn=Bob,dc=example,dc=com")
+		out, _ = cmd.CombinedOutput()
+		if strings.Contains(string(out), "Success") || !strings.Contains(string(out), "ldap_delete: Insufficient access") {
+			t.Errorf("ldapdelete should have failed: %v", string(out))
+		}
+		done <- true
+	}()
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapdelete command timed out")
+	}
+	quit <- true
+}
+
+func TestModify(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.QuitChannel(quit)
+		s.BindFunc("", modifyTestHandler{})
+		s.ModifyFunc("", modifyTestHandler{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+	go func() {
+		cmd := exec.Command("ldapmodify", "-v", "-H", ldapURL, "-x", "-f", "tests/modify.ldif")
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "modify complete") {
+			t.Errorf("ldapmodify failed: %v", string(out))
+		}
+		cmd = exec.Command("ldapmodify", "-v", "-H", ldapURL, "-x", "-f", "tests/modify2.ldif")
+		out, _ = cmd.CombinedOutput()
+		if !strings.Contains(string(out), "ldap_modify: Insufficient access") || strings.Contains(string(out), "modify complete") {
+			t.Errorf("ldapmodify should have failed: %v", string(out))
+		}
+		done <- true
+	}()
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapadd command timed out")
+	}
+	quit <- true
+}
+
+/*
+func TestModifyDN(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.QuitChannel(quit)
+		s.BindFunc("", modifyTestHandler{})
+		s.AddFunc("", modifyTestHandler{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+	go func() {
+		cmd := exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add.ldif")
+		//ldapmodrdn -H ldap://localhost:3389 -x "uid=babs,dc=example,dc=com" "uid=babsy,dc=example,dc=com"
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "modify complete") {
+			t.Errorf("ldapadd failed: %v", string(out))
+		}
+		cmd = exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add2.ldif")
+		out, _ = cmd.CombinedOutput()
+		if !strings.Contains(string(out), "ldap_add: Insufficient access") {
+			t.Errorf("ldapadd should have failed: %v", string(out))
+		}
+		if strings.Contains(string(out), "modify complete") {
+			t.Errorf("ldapadd should have failed: %v", string(out))
+		}
+		done <- true
+	}()
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapadd command timed out")
+	}
+	quit <- true
+}
+*/
+
+//
+type modifyTestHandler struct {
+}
+
+func (h modifyTestHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+	if bindDN == "" && bindSimplePw == "" {
+		return LDAPResultSuccess, nil
+	}
+	return LDAPResultInvalidCredentials, nil
+}
+func (h modifyTestHandler) Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) {
+	// only succeed on expected contents of add.ldif:
+	if len(req.attributes) == 5 && req.dn == "cn=Barbara Jensen,dc=example,dc=com" &&
+		req.attributes[2].attrType == "sn" && len(req.attributes[2].attrVals) == 1 &&
+		req.attributes[2].attrVals[0] == "Jensen" {
+		return LDAPResultSuccess, nil
+	}
+	return LDAPResultInsufficientAccessRights, nil
+}
+func (h modifyTestHandler) Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error) {
+	// only succeed on expected deleteDN
+	if deleteDN == "cn=Delete Me,dc=example,dc=com" {
+		return LDAPResultSuccess, nil
+	}
+	return LDAPResultInsufficientAccessRights, nil
+}
+func (h modifyTestHandler) Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error) {
+	// only succeed on expected contents of modify.ldif:
+	if req.Dn == "cn=testy,dc=example,dc=com" && len(req.AddAttributes) == 1 &&
+		len(req.DeleteAttributes) == 3 && len(req.ReplaceAttributes) == 2 &&
+		req.DeleteAttributes[2].AttrType == "details" && len(req.DeleteAttributes[2].AttrVals) == 0 {
+		return LDAPResultSuccess, nil
+	}
+	return LDAPResultInsufficientAccessRights, nil
+}
+func (h modifyTestHandler) ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error) {
+	return LDAPResultInsufficientAccessRights, nil
+}

+ 218 - 0
ldap/server_search.go

@@ -0,0 +1,218 @@
+package ldap
+
+import (
+	"errors"
+	"fmt"
+	"net"
+	"strings"
+
+	ber "github.com/nmcclain/asn1-ber"
+)
+
+func HandleSearchRequest(req *ber.Packet, controls *[]Control, messageID uint64, boundDN string, server *Server, conn net.Conn) (resultErr error) {
+	defer func() {
+		if r := recover(); r != nil {
+			resultErr = NewError(LDAPResultOperationsError, fmt.Errorf("Search function panic: %s", r))
+		}
+	}()
+
+	searchReq, err := parseSearchRequest(boundDN, req, controls)
+	if err != nil {
+		return NewError(LDAPResultOperationsError, err)
+	}
+
+	filterPacket, err := CompileFilter(searchReq.Filter)
+	if err != nil {
+		return NewError(LDAPResultOperationsError, err)
+	}
+
+	fnNames := []string{}
+	for k := range server.SearchFns {
+		fnNames = append(fnNames, k)
+	}
+	fn := routeFunc(searchReq.BaseDN, fnNames)
+	searchResp, err := server.SearchFns[fn].Search(boundDN, searchReq, conn)
+	if err != nil {
+		return NewError(searchResp.ResultCode, err)
+	}
+
+	if server.EnforceLDAP {
+		if searchReq.DerefAliases != NeverDerefAliases { // [-a {never|always|search|find}
+			// TODO: Server DerefAliases not supported: RFC4511 4.5.1.3
+		}
+		if searchReq.TimeLimit > 0 {
+			// TODO: Server TimeLimit not implemented
+		}
+	}
+
+	i := 0
+	for _, entry := range searchResp.Entries {
+		if server.EnforceLDAP {
+			// filter
+			keep, resultCode := ServerApplyFilter(filterPacket, entry)
+			if resultCode != LDAPResultSuccess {
+				return NewError(resultCode, errors.New("ServerApplyFilter error"))
+			}
+			if !keep {
+				continue
+			}
+
+			// constrained search scope
+			switch searchReq.Scope {
+			case ScopeWholeSubtree: // The scope is constrained to the entry named by baseObject and to all its subordinates.
+			case ScopeBaseObject: // The scope is constrained to the entry named by baseObject.
+				if entry.DN != searchReq.BaseDN {
+					continue
+				}
+			case ScopeSingleLevel: // The scope is constrained to the immediate subordinates of the entry named by baseObject.
+				parts := strings.Split(entry.DN, ",")
+				if len(parts) < 2 && entry.DN != searchReq.BaseDN {
+					continue
+				}
+				if dn := strings.Join(parts[1:], ","); dn != searchReq.BaseDN {
+					continue
+				}
+			}
+
+			// attributes
+			if len(searchReq.Attributes) > 1 || (len(searchReq.Attributes) == 1 && len(searchReq.Attributes[0]) > 0) {
+				entry, err = filterAttributes(entry, searchReq.Attributes)
+				if err != nil {
+					return NewError(LDAPResultOperationsError, err)
+				}
+			}
+
+			// size limit
+			if searchReq.SizeLimit > 0 && i >= searchReq.SizeLimit {
+				break
+			}
+			i++
+		}
+
+		// respond
+		responsePacket := encodeSearchResponse(messageID, searchReq, entry)
+		if err = sendPacket(conn, responsePacket); err != nil {
+			return NewError(LDAPResultOperationsError, err)
+		}
+	}
+	return nil
+}
+
+/////////////////////////
+func parseSearchRequest(boundDN string, req *ber.Packet, controls *[]Control) (SearchRequest, error) {
+	if len(req.Children) != 8 {
+		return SearchRequest{}, NewError(LDAPResultOperationsError, errors.New("Bad search request"))
+	}
+
+	// Parse the request
+	baseObject, ok := req.Children[0].Value.(string)
+	if !ok {
+		return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+	}
+	s, ok := req.Children[1].Value.(uint64)
+	if !ok {
+		return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+	}
+	scope := int(s)
+	d, ok := req.Children[2].Value.(uint64)
+	if !ok {
+		return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+	}
+	derefAliases := int(d)
+	s, ok = req.Children[3].Value.(uint64)
+	if !ok {
+		return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+	}
+	sizeLimit := int(s)
+	t, ok := req.Children[4].Value.(uint64)
+	if !ok {
+		return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+	}
+	timeLimit := int(t)
+	typesOnly := false
+	if req.Children[5].Value != nil {
+		typesOnly, ok = req.Children[5].Value.(bool)
+		if !ok {
+			return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+		}
+	}
+	filter, err := DecompileFilter(req.Children[6])
+	if err != nil {
+		return SearchRequest{}, err
+	}
+	attributes := []string{}
+	for _, attr := range req.Children[7].Children {
+		a, ok := attr.Value.(string)
+		if !ok {
+			return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+		}
+		attributes = append(attributes, a)
+	}
+	searchReq := SearchRequest{baseObject, scope,
+		derefAliases, sizeLimit, timeLimit,
+		typesOnly, filter, attributes, *controls}
+
+	return searchReq, nil
+}
+
+/////////////////////////
+func filterAttributes(entry *Entry, attributes []string) (*Entry, error) {
+	// only return requested attributes
+	newAttributes := []*EntryAttribute{}
+
+	for _, attr := range entry.Attributes {
+		for _, requested := range attributes {
+			if requested == "*" || strings.ToLower(attr.Name) == strings.ToLower(requested) {
+				newAttributes = append(newAttributes, attr)
+			}
+		}
+	}
+	entry.Attributes = newAttributes
+
+	return entry, nil
+}
+
+/////////////////////////
+func encodeSearchResponse(messageID uint64, req SearchRequest, res *Entry) *ber.Packet {
+	responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+	responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+
+	searchEntry := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultEntry, nil, "Search Result Entry")
+	searchEntry.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, res.DN, "Object Name"))
+
+	attrs := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes:")
+	for _, attribute := range res.Attributes {
+		attrs.AppendChild(encodeSearchAttribute(attribute.Name, attribute.Values))
+	}
+
+	searchEntry.AppendChild(attrs)
+	responsePacket.AppendChild(searchEntry)
+
+	return responsePacket
+}
+
+func encodeSearchAttribute(name string, values []string) *ber.Packet {
+	packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute")
+	packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, name, "Attribute Name"))
+
+	valuesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "Attribute Values")
+	for _, value := range values {
+		valuesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Attribute Value"))
+	}
+
+	packet.AppendChild(valuesPacket)
+
+	return packet
+}
+
+func encodeSearchDone(messageID uint64, ldapResultCode LDAPResultCode) *ber.Packet {
+	responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+	responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+	donePacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultDone, nil, "Search result done")
+	donePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: "))
+	donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
+	donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: "))
+	responsePacket.AppendChild(donePacket)
+
+	return responsePacket
+}

+ 505 - 0
ldap/server_search_test.go

@@ -0,0 +1,505 @@
+package ldap
+
+import (
+	"os/exec"
+	"strings"
+	"testing"
+	"time"
+)
+
+//
+func TestSearchSimpleOK(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.QuitChannel(quit)
+		s.SearchFunc("", searchSimple{})
+		s.BindFunc("", bindSimple{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	serverBaseDN := "o=testers,c=test"
+
+	go func() {
+		cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test")
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		if !strings.Contains(string(out), "uidNumber: 5000") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		if !strings.Contains(string(out), "numResponses: 4") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapsearch command timed out")
+	}
+	quit <- true
+}
+
+func TestSearchSizelimit(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.EnforceLDAP = true
+		s.QuitChannel(quit)
+		s.SearchFunc("", searchSimple{})
+		s.BindFunc("", bindSimple{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	go func() {
+		cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test") // no limit for this test
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		if !strings.Contains(string(out), "numEntries: 3") {
+			t.Errorf("ldapsearch sizelimit unlimited failed - not enough entries: %v", string(out))
+		}
+
+		cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "9") // effectively no limit for this test
+		out, _ = cmd.CombinedOutput()
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		if !strings.Contains(string(out), "numEntries: 3") {
+			t.Errorf("ldapsearch sizelimit 9 failed - not enough entries: %v", string(out))
+		}
+
+		cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "2")
+		out, _ = cmd.CombinedOutput()
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		if !strings.Contains(string(out), "numEntries: 2") {
+			t.Errorf("ldapsearch sizelimit 2 failed - too many entries: %v", string(out))
+		}
+
+		cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "1")
+		out, _ = cmd.CombinedOutput()
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		if !strings.Contains(string(out), "numEntries: 1") {
+			t.Errorf("ldapsearch sizelimit 1 failed - too many entries: %v", string(out))
+		}
+
+		cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "0")
+		out, _ = cmd.CombinedOutput()
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		if !strings.Contains(string(out), "numEntries: 3") {
+			t.Errorf("ldapsearch sizelimit 0 failed - wrong number of entries: %v", string(out))
+		}
+
+		cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "1", "(uid=trent)")
+		out, _ = cmd.CombinedOutput()
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		if !strings.Contains(string(out), "numEntries: 1") {
+			t.Errorf("ldapsearch sizelimit 1 with filter failed - wrong number of entries: %v", string(out))
+		}
+
+		cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "0", "(uid=trent)")
+		out, _ = cmd.CombinedOutput()
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		if !strings.Contains(string(out), "numEntries: 1") {
+			t.Errorf("ldapsearch sizelimit 0 with filter failed - wrong number of entries: %v", string(out))
+		}
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapsearch command timed out")
+	}
+	quit <- true
+}
+
+/////////////////////////
+func TestBindSearchMulti(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.QuitChannel(quit)
+		s.BindFunc("", bindSimple{})
+		s.BindFunc("c=testz", bindSimple2{})
+		s.SearchFunc("", searchSimple{})
+		s.SearchFunc("c=testz", searchSimple2{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	go func() {
+		cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test",
+			"-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "cn=ned")
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("error routing default bind/search functions: %v", string(out))
+		}
+		if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") {
+			t.Errorf("search default routing failed: %v", string(out))
+		}
+		cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=testz",
+			"-D", "cn=testy,o=testers,c=testz", "-w", "ZLike2test", "cn=hamburger")
+		out, _ = cmd.CombinedOutput()
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("error routing custom bind/search functions: %v", string(out))
+		}
+		if !strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") {
+			t.Errorf("search custom routing failed: %v", string(out))
+		}
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapsearch command timed out")
+	}
+
+	quit <- true
+}
+
+/////////////////////////
+func TestSearchPanic(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.QuitChannel(quit)
+		s.SearchFunc("", searchPanic{})
+		s.BindFunc("", bindAnonOK{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	go func() {
+		cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test")
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "result: 1 Operations error") {
+			t.Errorf("ldapsearch should have returned operations error due to panic: %v", string(out))
+		}
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapsearch command timed out")
+	}
+	quit <- true
+}
+
+/////////////////////////
+type compileSearchFilterTest struct {
+	name         string
+	filterStr    string
+	numResponses string
+}
+
+var searchFilterTestFilters = []compileSearchFilterTest{
+	compileSearchFilterTest{name: "equalityOk", filterStr: "(uid=ned)", numResponses: "2"},
+	compileSearchFilterTest{name: "equalityNo", filterStr: "(uid=foo)", numResponses: "1"},
+	compileSearchFilterTest{name: "equalityOk", filterStr: "(objectclass=posixaccount)", numResponses: "4"},
+	compileSearchFilterTest{name: "presentEmptyOk", filterStr: "", numResponses: "4"},
+	compileSearchFilterTest{name: "presentOk", filterStr: "(objectclass=*)", numResponses: "4"},
+	compileSearchFilterTest{name: "presentOk", filterStr: "(description=*)", numResponses: "3"},
+	compileSearchFilterTest{name: "presentNo", filterStr: "(foo=*)", numResponses: "1"},
+	compileSearchFilterTest{name: "andOk", filterStr: "(&(uid=ned)(objectclass=posixaccount))", numResponses: "2"},
+	compileSearchFilterTest{name: "andNo", filterStr: "(&(uid=ned)(objectclass=posixgroup))", numResponses: "1"},
+	compileSearchFilterTest{name: "andNo", filterStr: "(&(uid=ned)(uid=trent))", numResponses: "1"},
+	compileSearchFilterTest{name: "orOk", filterStr: "(|(uid=ned)(uid=trent))", numResponses: "3"},
+	compileSearchFilterTest{name: "orOk", filterStr: "(|(uid=ned)(objectclass=posixaccount))", numResponses: "4"},
+	compileSearchFilterTest{name: "orNo", filterStr: "(|(uid=foo)(objectclass=foo))", numResponses: "1"},
+	compileSearchFilterTest{name: "andOrOk", filterStr: "(&(|(uid=ned)(uid=trent))(objectclass=posixaccount))", numResponses: "3"},
+	compileSearchFilterTest{name: "notOk", filterStr: "(!(uid=ned))", numResponses: "3"},
+	compileSearchFilterTest{name: "notOk", filterStr: "(!(uid=foo))", numResponses: "4"},
+	compileSearchFilterTest{name: "notAndOrOk", filterStr: "(&(|(uid=ned)(uid=trent))(!(objectclass=posixgroup)))", numResponses: "3"},
+	/*
+		compileSearchFilterTest{filterStr: "(sn=Mill*)", filterType: FilterSubstrings},
+		compileSearchFilterTest{filterStr: "(sn=*Mill)", filterType: FilterSubstrings},
+		compileSearchFilterTest{filterStr: "(sn=*Mill*)", filterType: FilterSubstrings},
+		compileSearchFilterTest{filterStr: "(sn>=Miller)", filterType: FilterGreaterOrEqual},
+		compileSearchFilterTest{filterStr: "(sn<=Miller)", filterType: FilterLessOrEqual},
+		compileSearchFilterTest{filterStr: "(sn~=Miller)", filterType: FilterApproxMatch},
+	*/
+}
+
+/////////////////////////
+func TestSearchFiltering(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.EnforceLDAP = true
+		s.QuitChannel(quit)
+		s.SearchFunc("", searchSimple{})
+		s.BindFunc("", bindSimple{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	for _, i := range searchFilterTestFilters {
+		t.Log(i.name)
+
+		go func() {
+			cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+				"-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", i.filterStr)
+			out, _ := cmd.CombinedOutput()
+			if !strings.Contains(string(out), "numResponses: "+i.numResponses) {
+				t.Errorf("ldapsearch failed - expected numResponses==%s: %v", i.numResponses, string(out))
+			}
+			done <- true
+		}()
+
+		select {
+		case <-done:
+		case <-time.After(timeout):
+			t.Errorf("ldapsearch command timed out")
+		}
+	}
+	quit <- true
+}
+
+/////////////////////////
+func TestSearchAttributes(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.EnforceLDAP = true
+		s.QuitChannel(quit)
+		s.SearchFunc("", searchSimple{})
+		s.BindFunc("", bindSimple{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	go func() {
+		filterString := ""
+		cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", filterString, "cn")
+		out, _ := cmd.CombinedOutput()
+
+		if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") {
+			t.Errorf("ldapsearch failed - missing requested DN attribute: %v", string(out))
+		}
+		if !strings.Contains(string(out), "cn: ned") {
+			t.Errorf("ldapsearch failed - missing requested CN attribute: %v", string(out))
+		}
+		if strings.Contains(string(out), "uidNumber") {
+			t.Errorf("ldapsearch failed - uidNumber attr should not be displayed: %v", string(out))
+		}
+		if strings.Contains(string(out), "accountstatus") {
+			t.Errorf("ldapsearch failed - accountstatus attr should not be displayed: %v", string(out))
+		}
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapsearch command timed out")
+	}
+	quit <- true
+}
+
+func TestSearchAllUserAttributes(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.EnforceLDAP = true
+		s.QuitChannel(quit)
+		s.SearchFunc("", searchSimple{})
+		s.BindFunc("", bindSimple{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	go func() {
+		filterString := ""
+		cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", filterString, "*")
+		out, _ := cmd.CombinedOutput()
+
+		if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") {
+			t.Errorf("ldapsearch failed - missing requested DN attribute: %v", string(out))
+		}
+		if !strings.Contains(string(out), "cn: ned") {
+			t.Errorf("ldapsearch failed - missing requested CN attribute: %v", string(out))
+		}
+		if !strings.Contains(string(out), "uidNumber") {
+			t.Errorf("ldapsearch failed - missing requested uidNumber attribute: %v", string(out))
+		}
+		if !strings.Contains(string(out), "accountstatus") {
+			t.Errorf("ldapsearch failed - missing requested accountstatus attribute: %v", string(out))
+		}
+		if !strings.Contains(string(out), "o: ate") {
+			t.Errorf("ldapsearch failed - missing requested o attribute: %v", string(out))
+		}
+		if !strings.Contains(string(out), "description") {
+			t.Errorf("ldapsearch failed - missing requested description attribute: %v", string(out))
+		}
+		if !strings.Contains(string(out), "objectclass") {
+			t.Errorf("ldapsearch failed - missing requested objectclass attribute: %v", string(out))
+		}
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapsearch command timed out")
+	}
+	quit <- true
+}
+
+/////////////////////////
+func TestSearchScope(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.EnforceLDAP = true
+		s.QuitChannel(quit)
+		s.SearchFunc("", searchSimple{})
+		s.BindFunc("", bindSimple{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	go func() {
+		cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", "c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "sub", "cn=trent")
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
+			t.Errorf("ldapsearch 'sub' scope failed - didn't find expected DN: %v", string(out))
+		}
+
+		cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", "o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "one", "cn=trent")
+		out, _ = cmd.CombinedOutput()
+		if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
+			t.Errorf("ldapsearch 'one' scope failed - didn't find expected DN: %v", string(out))
+		}
+		cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", "c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "one", "cn=trent")
+		out, _ = cmd.CombinedOutput()
+		if strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
+			t.Errorf("ldapsearch 'one' scope failed - found unexpected DN: %v", string(out))
+		}
+
+		cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", "cn=trent,o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "base", "cn=trent")
+		out, _ = cmd.CombinedOutput()
+		if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
+			t.Errorf("ldapsearch 'base' scope failed - didn't find expected DN: %v", string(out))
+		}
+		cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", "o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "base", "cn=trent")
+		out, _ = cmd.CombinedOutput()
+		if strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
+			t.Errorf("ldapsearch 'base' scope failed - found unexpected DN: %v", string(out))
+		}
+
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapsearch command timed out")
+	}
+	quit <- true
+}
+
+func TestSearchControls(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.QuitChannel(quit)
+		s.SearchFunc("", searchControls{})
+		s.BindFunc("", bindSimple{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	serverBaseDN := "o=testers,c=test"
+
+	go func() {
+		cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-e", "1.2.3.4.5")
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") {
+			t.Errorf("ldapsearch with control failed: %v", string(out))
+		}
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("ldapsearch with control failed: %v", string(out))
+		}
+		if !strings.Contains(string(out), "numResponses: 2") {
+			t.Errorf("ldapsearch with control failed: %v", string(out))
+		}
+
+		cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test")
+		out, _ = cmd.CombinedOutput()
+		if strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") {
+			t.Errorf("ldapsearch without control failed: %v", string(out))
+		}
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("ldapsearch without control failed: %v", string(out))
+		}
+		if !strings.Contains(string(out), "numResponses: 1") {
+			t.Errorf("ldapsearch without control failed: %v", string(out))
+		}
+
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapsearch command timed out")
+	}
+	quit <- true
+}

+ 410 - 0
ldap/server_test.go

@@ -0,0 +1,410 @@
+package ldap
+
+import (
+	"bytes"
+	"log"
+	"net"
+	"os/exec"
+	"strings"
+	"testing"
+	"time"
+)
+
+var listenString = "localhost:3389"
+var ldapURL = "ldap://" + listenString
+var timeout = 400 * time.Millisecond
+var serverBaseDN = "o=testers,c=test"
+
+/////////////////////////
+func TestBindAnonOK(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.QuitChannel(quit)
+		s.BindFunc("", bindAnonOK{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	go func() {
+		cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test")
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapsearch command timed out")
+	}
+	quit <- true
+}
+
+/////////////////////////
+func TestBindAnonFail(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.QuitChannel(quit)
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	time.Sleep(timeout)
+	go func() {
+		cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test")
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "ldap_bind: Invalid credentials (49)") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapsearch command timed out")
+	}
+	time.Sleep(timeout)
+	quit <- true
+}
+
+/////////////////////////
+func TestBindSimpleOK(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.QuitChannel(quit)
+		s.SearchFunc("", searchSimple{})
+		s.BindFunc("", bindSimple{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	serverBaseDN := "o=testers,c=test"
+
+	go func() {
+		cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test")
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapsearch command timed out")
+	}
+	quit <- true
+}
+
+/////////////////////////
+func TestBindSimpleFailBadPw(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.QuitChannel(quit)
+		s.BindFunc("", bindSimple{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	serverBaseDN := "o=testers,c=test"
+
+	go func() {
+		cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "BADPassword")
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "ldap_bind: Invalid credentials (49)") {
+			t.Errorf("ldapsearch succeeded - should have failed: %v", string(out))
+		}
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapsearch command timed out")
+	}
+	quit <- true
+}
+
+/////////////////////////
+func TestBindSimpleFailBadDn(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.QuitChannel(quit)
+		s.BindFunc("", bindSimple{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	serverBaseDN := "o=testers,c=test"
+
+	go func() {
+		cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+			"-b", serverBaseDN, "-D", "cn=testoy,"+serverBaseDN, "-w", "iLike2test")
+		out, _ := cmd.CombinedOutput()
+		if string(out) != "ldap_bind: Invalid credentials (49)\n" {
+			t.Errorf("ldapsearch succeeded - should have failed: %v", string(out))
+		}
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapsearch command timed out")
+	}
+	quit <- true
+}
+
+/////////////////////////
+func TestBindSSL(t *testing.T) {
+	ldapURLSSL := "ldaps://" + listenString
+	longerTimeout := 300 * time.Millisecond
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.QuitChannel(quit)
+		s.BindFunc("", bindAnonOK{})
+		if err := s.ListenAndServeTLS(listenString, "tests/cert_DONOTUSE.pem", "tests/key_DONOTUSE.pem"); err != nil {
+			t.Errorf("s.ListenAndServeTLS failed: %s", err.Error())
+		}
+	}()
+
+	go func() {
+		time.Sleep(longerTimeout * 2)
+		cmd := exec.Command("ldapsearch", "-H", ldapURLSSL, "-x", "-b", "o=testers,c=test")
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(longerTimeout * 2):
+		t.Errorf("ldapsearch command timed out")
+	}
+	quit <- true
+}
+
+/////////////////////////
+func TestBindPanic(t *testing.T) {
+	quit := make(chan bool)
+	done := make(chan bool)
+	go func() {
+		s := NewServer()
+		s.QuitChannel(quit)
+		s.BindFunc("", bindPanic{})
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	go func() {
+		cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test")
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "ldap_bind: Operations error") {
+			t.Errorf("ldapsearch should have returned operations error due to panic: %v", string(out))
+		}
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapsearch command timed out")
+	}
+	quit <- true
+}
+
+/////////////////////////
+type testStatsWriter struct {
+	buffer *bytes.Buffer
+}
+
+func (tsw testStatsWriter) Write(buf []byte) (int, error) {
+	tsw.buffer.Write(buf)
+	return len(buf), nil
+}
+
+func TestSearchStats(t *testing.T) {
+	w := testStatsWriter{&bytes.Buffer{}}
+	log.SetOutput(w)
+
+	quit := make(chan bool)
+	done := make(chan bool)
+	s := NewServer()
+
+	go func() {
+		s.QuitChannel(quit)
+		s.SearchFunc("", searchSimple{})
+		s.BindFunc("", bindAnonOK{})
+		s.SetStats(true)
+		if err := s.ListenAndServe(listenString); err != nil {
+			t.Errorf("s.ListenAndServe failed: %s", err.Error())
+		}
+	}()
+
+	go func() {
+		cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test")
+		out, _ := cmd.CombinedOutput()
+		if !strings.Contains(string(out), "result: 0 Success") {
+			t.Errorf("ldapsearch failed: %v", string(out))
+		}
+		done <- true
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(timeout):
+		t.Errorf("ldapsearch command timed out")
+	}
+
+	stats := s.GetStats()
+	log.Println(stats)
+	if stats.Conns != 1 || stats.Binds != 1 {
+		t.Errorf("Stats data missing or incorrect: %v", w.buffer.String())
+	}
+	quit <- true
+}
+
+/////////////////////////
+type bindAnonOK struct {
+}
+
+func (b bindAnonOK) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+	if bindDN == "" && bindSimplePw == "" {
+		return LDAPResultSuccess, nil
+	}
+	return LDAPResultInvalidCredentials, nil
+}
+
+type bindSimple struct {
+}
+
+func (b bindSimple) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+	if bindDN == "cn=testy,o=testers,c=test" && bindSimplePw == "iLike2test" {
+		return LDAPResultSuccess, nil
+	}
+	return LDAPResultInvalidCredentials, nil
+}
+
+type bindSimple2 struct {
+}
+
+func (b bindSimple2) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+	if bindDN == "cn=testy,o=testers,c=testz" && bindSimplePw == "ZLike2test" {
+		return LDAPResultSuccess, nil
+	}
+	return LDAPResultInvalidCredentials, nil
+}
+
+type bindPanic struct {
+}
+
+func (b bindPanic) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+	panic("test panic at the disco")
+	return LDAPResultInvalidCredentials, nil
+}
+
+type searchSimple struct {
+}
+
+func (s searchSimple) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+	entries := []*Entry{
+		&Entry{"cn=ned,o=testers,c=test", []*EntryAttribute{
+			&EntryAttribute{"cn", []string{"ned"}},
+			&EntryAttribute{"o", []string{"ate"}},
+			&EntryAttribute{"uidNumber", []string{"5000"}},
+			&EntryAttribute{"accountstatus", []string{"active"}},
+			&EntryAttribute{"uid", []string{"ned"}},
+			&EntryAttribute{"description", []string{"ned via sa"}},
+			&EntryAttribute{"objectclass", []string{"posixaccount"}},
+		}},
+		&Entry{"cn=trent,o=testers,c=test", []*EntryAttribute{
+			&EntryAttribute{"cn", []string{"trent"}},
+			&EntryAttribute{"o", []string{"ate"}},
+			&EntryAttribute{"uidNumber", []string{"5005"}},
+			&EntryAttribute{"accountstatus", []string{"active"}},
+			&EntryAttribute{"uid", []string{"trent"}},
+			&EntryAttribute{"description", []string{"trent via sa"}},
+			&EntryAttribute{"objectclass", []string{"posixaccount"}},
+		}},
+		&Entry{"cn=randy,o=testers,c=test", []*EntryAttribute{
+			&EntryAttribute{"cn", []string{"randy"}},
+			&EntryAttribute{"o", []string{"ate"}},
+			&EntryAttribute{"uidNumber", []string{"5555"}},
+			&EntryAttribute{"accountstatus", []string{"active"}},
+			&EntryAttribute{"uid", []string{"randy"}},
+			&EntryAttribute{"objectclass", []string{"posixaccount"}},
+		}},
+	}
+	return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil
+}
+
+type searchSimple2 struct {
+}
+
+func (s searchSimple2) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+	entries := []*Entry{
+		&Entry{"cn=hamburger,o=testers,c=testz", []*EntryAttribute{
+			&EntryAttribute{"cn", []string{"hamburger"}},
+			&EntryAttribute{"o", []string{"testers"}},
+			&EntryAttribute{"uidNumber", []string{"5000"}},
+			&EntryAttribute{"accountstatus", []string{"active"}},
+			&EntryAttribute{"uid", []string{"hamburger"}},
+			&EntryAttribute{"objectclass", []string{"posixaccount"}},
+		}},
+	}
+	return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil
+}
+
+type searchPanic struct {
+}
+
+func (s searchPanic) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+	entries := []*Entry{}
+	panic("this is a test panic")
+	return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil
+}
+
+type searchControls struct {
+}
+
+func (s searchControls) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+	entries := []*Entry{}
+	if len(searchReq.Controls) == 1 && searchReq.Controls[0].GetControlType() == "1.2.3.4.5" {
+		newEntry := &Entry{"cn=hamburger,o=testers,c=testz", []*EntryAttribute{
+			&EntryAttribute{"cn", []string{"hamburger"}},
+			&EntryAttribute{"o", []string{"testers"}},
+			&EntryAttribute{"uidNumber", []string{"5000"}},
+			&EntryAttribute{"accountstatus", []string{"active"}},
+			&EntryAttribute{"uid", []string{"hamburger"}},
+			&EntryAttribute{"objectclass", []string{"posixaccount"}},
+		}}
+		entries = append(entries, newEntry)
+	}
+	return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil
+}

+ 65 - 0
utils/auth/ldap_auth.go

@@ -0,0 +1,65 @@
+package auth
+
+import (
+	"git.qianqiusoft.com/qianqiusoft/light-apiengine/ldap"
+	"git.qianqiusoft.com/qianqiusoft/light-apiengine/entitys"
+	"git.qianqiusoft.com/qianqiusoft/light-apiengine/logs"
+	"net"
+	"fmt"
+)
+type LdapAuth struct {
+	IAuth
+}
+
+func (this *LdapAuth)Login(c *entitys.CtrlContext) {
+}
+
+func (this *LdapAuth)Logout(c *entitys.CtrlContext){
+}
+
+func (this* LdapAuth)Init(){
+
+	s := ldap.NewServer()
+
+	// register Bind and Search function handlers
+	handler := ldapHandler{}
+	s.BindFunc("", handler)
+	s.SearchFunc("", handler)
+
+	// start the server
+	listen := "0.0.0.0:389"
+	logs.Info("Starting example LDAP server on %s", listen)
+	if err := s.ListenAndServe(listen); err != nil {
+		logs.Error("LDAP Server Failed: %s", err.Error())
+	}
+}
+
+type ldapHandler struct {
+}
+
+///////////// Allow anonymous binds only
+func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (ldap.LDAPResultCode, error) {
+
+	fmt.Println(bindDN)
+	fmt.Println(bindSimplePw)
+	/*if bindDN == "" && bindSimplePw == "" {
+		return ldap.LDAPResultSuccess, nil
+	}*/
+	return ldap.LDAPResultSuccess, nil
+}
+
+///////////// Return some hardcoded search results - we'll respond to any baseDN for testing
+func (h ldapHandler) Search(boundDN string, searchReq ldap.SearchRequest, conn net.Conn) (ldap.ServerSearchResult, error) {
+	fmt.Print("%s,search......%s", boundDN, searchReq)
+	entries := []*ldap.Entry{
+		&ldap.Entry{"cn=ned," + searchReq.BaseDN, []*ldap.EntryAttribute{
+			&ldap.EntryAttribute{"cn", []string{"ned"}},
+			&ldap.EntryAttribute{"uidNumber", []string{"5000"}},
+			&ldap.EntryAttribute{"accountStatus", []string{"active"}},
+			&ldap.EntryAttribute{"uid", []string{"ned"}},
+			&ldap.EntryAttribute{"description", []string{"ned"}},
+			&ldap.EntryAttribute{"objectClass", []string{"posixAccount"}},
+		}},
+	}
+	return ldap.ServerSearchResult{entries, []string{}, []ldap.Control{}, ldap.LDAPResultSuccess}, nil
+}