server_search.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. package ldap
  2. import (
  3. "errors"
  4. "fmt"
  5. "net"
  6. "strings"
  7. ber "github.com/nmcclain/asn1-ber"
  8. )
  9. func HandleSearchRequest(req *ber.Packet, controls *[]Control, messageID uint64, boundDN string, server *Server, conn net.Conn) (resultErr error) {
  10. defer func() {
  11. if r := recover(); r != nil {
  12. resultErr = NewError(LDAPResultOperationsError, fmt.Errorf("Search function panic: %s", r))
  13. }
  14. }()
  15. searchReq, err := parseSearchRequest(boundDN, req, controls)
  16. if err != nil {
  17. return NewError(LDAPResultOperationsError, err)
  18. }
  19. filterPacket, err := CompileFilter(searchReq.Filter)
  20. if err != nil {
  21. return NewError(LDAPResultOperationsError, err)
  22. }
  23. fnNames := []string{}
  24. for k := range server.SearchFns {
  25. fnNames = append(fnNames, k)
  26. }
  27. fn := routeFunc(searchReq.BaseDN, fnNames)
  28. searchResp, err := server.SearchFns[fn].Search(boundDN, searchReq, conn)
  29. if err != nil {
  30. return NewError(searchResp.ResultCode, err)
  31. }
  32. if server.EnforceLDAP {
  33. if searchReq.DerefAliases != NeverDerefAliases { // [-a {never|always|search|find}
  34. // TODO: Server DerefAliases not supported: RFC4511 4.5.1.3
  35. }
  36. if searchReq.TimeLimit > 0 {
  37. // TODO: Server TimeLimit not implemented
  38. }
  39. }
  40. i := 0
  41. for _, entry := range searchResp.Entries {
  42. if server.EnforceLDAP {
  43. // filter
  44. keep, resultCode := ServerApplyFilter(filterPacket, entry)
  45. if resultCode != LDAPResultSuccess {
  46. return NewError(resultCode, errors.New("ServerApplyFilter error"))
  47. }
  48. if !keep {
  49. continue
  50. }
  51. // constrained search scope
  52. switch searchReq.Scope {
  53. case ScopeWholeSubtree: // The scope is constrained to the entry named by baseObject and to all its subordinates.
  54. case ScopeBaseObject: // The scope is constrained to the entry named by baseObject.
  55. if entry.DN != searchReq.BaseDN {
  56. continue
  57. }
  58. case ScopeSingleLevel: // The scope is constrained to the immediate subordinates of the entry named by baseObject.
  59. parts := strings.Split(entry.DN, ",")
  60. if len(parts) < 2 && entry.DN != searchReq.BaseDN {
  61. continue
  62. }
  63. if dn := strings.Join(parts[1:], ","); dn != searchReq.BaseDN {
  64. continue
  65. }
  66. }
  67. // attributes
  68. if len(searchReq.Attributes) > 1 || (len(searchReq.Attributes) == 1 && len(searchReq.Attributes[0]) > 0) {
  69. entry, err = filterAttributes(entry, searchReq.Attributes)
  70. if err != nil {
  71. return NewError(LDAPResultOperationsError, err)
  72. }
  73. }
  74. // size limit
  75. if searchReq.SizeLimit > 0 && i >= searchReq.SizeLimit {
  76. break
  77. }
  78. i++
  79. }
  80. // respond
  81. responsePacket := encodeSearchResponse(messageID, searchReq, entry)
  82. if err = sendPacket(conn, responsePacket); err != nil {
  83. return NewError(LDAPResultOperationsError, err)
  84. }
  85. }
  86. return nil
  87. }
  88. /////////////////////////
  89. func parseSearchRequest(boundDN string, req *ber.Packet, controls *[]Control) (SearchRequest, error) {
  90. if len(req.Children) != 8 {
  91. return SearchRequest{}, NewError(LDAPResultOperationsError, errors.New("Bad search request"))
  92. }
  93. // Parse the request
  94. baseObject, ok := req.Children[0].Value.(string)
  95. if !ok {
  96. return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
  97. }
  98. s, ok := req.Children[1].Value.(uint64)
  99. if !ok {
  100. return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
  101. }
  102. scope := int(s)
  103. d, ok := req.Children[2].Value.(uint64)
  104. if !ok {
  105. return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
  106. }
  107. derefAliases := int(d)
  108. s, ok = req.Children[3].Value.(uint64)
  109. if !ok {
  110. return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
  111. }
  112. sizeLimit := int(s)
  113. t, ok := req.Children[4].Value.(uint64)
  114. if !ok {
  115. return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
  116. }
  117. timeLimit := int(t)
  118. typesOnly := false
  119. if req.Children[5].Value != nil {
  120. typesOnly, ok = req.Children[5].Value.(bool)
  121. if !ok {
  122. return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
  123. }
  124. }
  125. filter, err := DecompileFilter(req.Children[6])
  126. if err != nil {
  127. return SearchRequest{}, err
  128. }
  129. attributes := []string{}
  130. for _, attr := range req.Children[7].Children {
  131. a, ok := attr.Value.(string)
  132. if !ok {
  133. return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
  134. }
  135. attributes = append(attributes, a)
  136. }
  137. searchReq := SearchRequest{baseObject, scope,
  138. derefAliases, sizeLimit, timeLimit,
  139. typesOnly, filter, attributes, *controls}
  140. return searchReq, nil
  141. }
  142. /////////////////////////
  143. func filterAttributes(entry *Entry, attributes []string) (*Entry, error) {
  144. // only return requested attributes
  145. newAttributes := []*EntryAttribute{}
  146. for _, attr := range entry.Attributes {
  147. for _, requested := range attributes {
  148. if requested == "*" || strings.ToLower(attr.Name) == strings.ToLower(requested) {
  149. newAttributes = append(newAttributes, attr)
  150. }
  151. }
  152. }
  153. entry.Attributes = newAttributes
  154. return entry, nil
  155. }
  156. /////////////////////////
  157. func encodeSearchResponse(messageID uint64, req SearchRequest, res *Entry) *ber.Packet {
  158. responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
  159. responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
  160. searchEntry := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultEntry, nil, "Search Result Entry")
  161. searchEntry.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, res.DN, "Object Name"))
  162. attrs := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes:")
  163. for _, attribute := range res.Attributes {
  164. attrs.AppendChild(encodeSearchAttribute(attribute.Name, attribute.Values))
  165. }
  166. searchEntry.AppendChild(attrs)
  167. responsePacket.AppendChild(searchEntry)
  168. return responsePacket
  169. }
  170. func encodeSearchAttribute(name string, values []string) *ber.Packet {
  171. packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute")
  172. packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, name, "Attribute Name"))
  173. valuesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "Attribute Values")
  174. for _, value := range values {
  175. valuesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Attribute Value"))
  176. }
  177. packet.AppendChild(valuesPacket)
  178. return packet
  179. }
  180. func encodeSearchDone(messageID uint64, ldapResultCode LDAPResultCode) *ber.Packet {
  181. responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
  182. responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
  183. donePacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultDone, nil, "Search result done")
  184. donePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: "))
  185. donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
  186. donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: "))
  187. responsePacket.AppendChild(donePacket)
  188. return responsePacket
  189. }