filter.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ldap
  5. import (
  6. "errors"
  7. "fmt"
  8. "github.com/nmcclain/asn1-ber"
  9. "strings"
  10. )
  11. const (
  12. FilterAnd = 0
  13. FilterOr = 1
  14. FilterNot = 2
  15. FilterEqualityMatch = 3
  16. FilterSubstrings = 4
  17. FilterGreaterOrEqual = 5
  18. FilterLessOrEqual = 6
  19. FilterPresent = 7
  20. FilterApproxMatch = 8
  21. FilterExtensibleMatch = 9
  22. )
  23. var FilterMap = map[uint8]string{
  24. FilterAnd: "And",
  25. FilterOr: "Or",
  26. FilterNot: "Not",
  27. FilterEqualityMatch: "Equality Match",
  28. FilterSubstrings: "Substrings",
  29. FilterGreaterOrEqual: "Greater Or Equal",
  30. FilterLessOrEqual: "Less Or Equal",
  31. FilterPresent: "Present",
  32. FilterApproxMatch: "Approx Match",
  33. FilterExtensibleMatch: "Extensible Match",
  34. }
  35. const (
  36. FilterSubstringsInitial = 0
  37. FilterSubstringsAny = 1
  38. FilterSubstringsFinal = 2
  39. )
  40. func CompileFilter(filter string) (*ber.Packet, error) {
  41. if len(filter) == 0 || filter[0] != '(' {
  42. return nil, NewError(ErrorFilterCompile, errors.New("ldap: filter does not start with an '('"))
  43. }
  44. packet, pos, err := compileFilter(filter, 1)
  45. if err != nil {
  46. return nil, err
  47. }
  48. if pos != len(filter) {
  49. return nil, NewError(ErrorFilterCompile, errors.New("ldap: finished compiling filter with extra at end: "+fmt.Sprint(filter[pos:])))
  50. }
  51. return packet, nil
  52. }
  53. func DecompileFilter(packet *ber.Packet) (ret string, err error) {
  54. defer func() {
  55. if r := recover(); r != nil {
  56. err = NewError(ErrorFilterDecompile, errors.New("ldap: error decompiling filter"))
  57. }
  58. }()
  59. ret = "("
  60. err = nil
  61. childStr := ""
  62. switch packet.Tag {
  63. case FilterAnd:
  64. ret += "&"
  65. for _, child := range packet.Children {
  66. childStr, err = DecompileFilter(child)
  67. if err != nil {
  68. return
  69. }
  70. ret += childStr
  71. }
  72. case FilterOr:
  73. ret += "|"
  74. for _, child := range packet.Children {
  75. childStr, err = DecompileFilter(child)
  76. if err != nil {
  77. return
  78. }
  79. ret += childStr
  80. }
  81. case FilterNot:
  82. ret += "!"
  83. childStr, err = DecompileFilter(packet.Children[0])
  84. if err != nil {
  85. return
  86. }
  87. ret += childStr
  88. case FilterSubstrings:
  89. ret += ber.DecodeString(packet.Children[0].Data.Bytes())
  90. ret += "="
  91. switch packet.Children[1].Children[0].Tag {
  92. case FilterSubstringsInitial:
  93. ret += ber.DecodeString(packet.Children[1].Children[0].Data.Bytes()) + "*"
  94. case FilterSubstringsAny:
  95. ret += "*" + ber.DecodeString(packet.Children[1].Children[0].Data.Bytes()) + "*"
  96. case FilterSubstringsFinal:
  97. ret += "*" + ber.DecodeString(packet.Children[1].Children[0].Data.Bytes())
  98. }
  99. case FilterEqualityMatch:
  100. ret += ber.DecodeString(packet.Children[0].Data.Bytes())
  101. ret += "="
  102. ret += ber.DecodeString(packet.Children[1].Data.Bytes())
  103. case FilterGreaterOrEqual:
  104. ret += ber.DecodeString(packet.Children[0].Data.Bytes())
  105. ret += ">="
  106. ret += ber.DecodeString(packet.Children[1].Data.Bytes())
  107. case FilterLessOrEqual:
  108. ret += ber.DecodeString(packet.Children[0].Data.Bytes())
  109. ret += "<="
  110. ret += ber.DecodeString(packet.Children[1].Data.Bytes())
  111. case FilterPresent:
  112. ret += ber.DecodeString(packet.Data.Bytes())
  113. ret += "=*"
  114. case FilterApproxMatch:
  115. ret += ber.DecodeString(packet.Children[0].Data.Bytes())
  116. ret += "~="
  117. ret += ber.DecodeString(packet.Children[1].Data.Bytes())
  118. }
  119. ret += ")"
  120. return
  121. }
  122. func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, error) {
  123. for pos < len(filter) && filter[pos] == '(' {
  124. child, newPos, err := compileFilter(filter, pos+1)
  125. if err != nil {
  126. return pos, err
  127. }
  128. pos = newPos
  129. parent.AppendChild(child)
  130. }
  131. if pos == len(filter) {
  132. return pos, NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
  133. }
  134. return pos + 1, nil
  135. }
  136. func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
  137. var packet *ber.Packet
  138. var err error
  139. defer func() {
  140. if r := recover(); r != nil {
  141. err = NewError(ErrorFilterCompile, errors.New("ldap: error compiling filter"))
  142. }
  143. }()
  144. newPos := pos
  145. switch filter[pos] {
  146. case '(':
  147. packet, newPos, err = compileFilter(filter, pos+1)
  148. newPos++
  149. return packet, newPos, err
  150. case '&':
  151. packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd])
  152. newPos, err = compileFilterSet(filter, pos+1, packet)
  153. return packet, newPos, err
  154. case '|':
  155. packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr])
  156. newPos, err = compileFilterSet(filter, pos+1, packet)
  157. return packet, newPos, err
  158. case '!':
  159. packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot])
  160. var child *ber.Packet
  161. child, newPos, err = compileFilter(filter, pos+1)
  162. packet.AppendChild(child)
  163. return packet, newPos, err
  164. default:
  165. attribute := ""
  166. condition := ""
  167. for newPos < len(filter) && filter[newPos] != ')' {
  168. switch {
  169. case packet != nil:
  170. condition += fmt.Sprintf("%c", filter[newPos])
  171. case filter[newPos] == '=':
  172. packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch])
  173. case filter[newPos] == '>' && filter[newPos+1] == '=':
  174. packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual])
  175. newPos++
  176. case filter[newPos] == '<' && filter[newPos+1] == '=':
  177. packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual])
  178. newPos++
  179. case filter[newPos] == '~' && filter[newPos+1] == '=':
  180. packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterLessOrEqual])
  181. newPos++
  182. case packet == nil:
  183. attribute += fmt.Sprintf("%c", filter[newPos])
  184. }
  185. newPos++
  186. }
  187. if newPos == len(filter) {
  188. err = NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
  189. return packet, newPos, err
  190. }
  191. if packet == nil {
  192. err = NewError(ErrorFilterCompile, errors.New("ldap: error parsing filter"))
  193. return packet, newPos, err
  194. }
  195. // Handle FilterEqualityMatch as a separate case (is primitive, not constructed like the other filters)
  196. if packet.Tag == FilterEqualityMatch && condition == "*" {
  197. packet.TagType = ber.TypePrimitive
  198. packet.Tag = FilterPresent
  199. packet.Description = FilterMap[packet.Tag]
  200. packet.Data.WriteString(attribute)
  201. return packet, newPos + 1, nil
  202. }
  203. packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
  204. switch {
  205. case packet.Tag == FilterEqualityMatch && condition[0] == '*' && condition[len(condition)-1] == '*':
  206. // Any
  207. packet.Tag = FilterSubstrings
  208. packet.Description = FilterMap[packet.Tag]
  209. seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
  210. seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsAny, condition[1:len(condition)-1], "Any Substring"))
  211. packet.AppendChild(seq)
  212. case packet.Tag == FilterEqualityMatch && condition[0] == '*':
  213. // Final
  214. packet.Tag = FilterSubstrings
  215. packet.Description = FilterMap[packet.Tag]
  216. seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
  217. seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsFinal, condition[1:], "Final Substring"))
  218. packet.AppendChild(seq)
  219. case packet.Tag == FilterEqualityMatch && condition[len(condition)-1] == '*':
  220. // Initial
  221. packet.Tag = FilterSubstrings
  222. packet.Description = FilterMap[packet.Tag]
  223. seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
  224. seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsInitial, condition[:len(condition)-1], "Initial Substring"))
  225. packet.AppendChild(seq)
  226. default:
  227. packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, condition, "Condition"))
  228. }
  229. newPos++
  230. return packet, newPos, err
  231. }
  232. }
  233. func ServerApplyFilter(f *ber.Packet, entry *Entry) (bool, LDAPResultCode) {
  234. switch FilterMap[f.Tag] {
  235. default:
  236. //log.Fatalf("Unknown LDAP filter code: %d", f.Tag)
  237. return false, LDAPResultOperationsError
  238. case "Equality Match":
  239. if len(f.Children) != 2 {
  240. return false, LDAPResultOperationsError
  241. }
  242. attribute := f.Children[0].Value.(string)
  243. value := f.Children[1].Value.(string)
  244. for _, a := range entry.Attributes {
  245. if strings.ToLower(a.Name) == strings.ToLower(attribute) {
  246. for _, v := range a.Values {
  247. if strings.ToLower(v) == strings.ToLower(value) {
  248. return true, LDAPResultSuccess
  249. }
  250. }
  251. }
  252. }
  253. case "Present":
  254. for _, a := range entry.Attributes {
  255. if strings.ToLower(a.Name) == strings.ToLower(f.Data.String()) {
  256. return true, LDAPResultSuccess
  257. }
  258. }
  259. case "And":
  260. for _, child := range f.Children {
  261. ok, exitCode := ServerApplyFilter(child, entry)
  262. if exitCode != LDAPResultSuccess {
  263. return false, exitCode
  264. }
  265. if !ok {
  266. return false, LDAPResultSuccess
  267. }
  268. }
  269. return true, LDAPResultSuccess
  270. case "Or":
  271. anyOk := false
  272. for _, child := range f.Children {
  273. ok, exitCode := ServerApplyFilter(child, entry)
  274. if exitCode != LDAPResultSuccess {
  275. return false, exitCode
  276. } else if ok {
  277. anyOk = true
  278. }
  279. }
  280. if anyOk {
  281. return true, LDAPResultSuccess
  282. }
  283. case "Not":
  284. if len(f.Children) != 1 {
  285. return false, LDAPResultOperationsError
  286. }
  287. ok, exitCode := ServerApplyFilter(f.Children[0], entry)
  288. if exitCode != LDAPResultSuccess {
  289. return false, exitCode
  290. } else if !ok {
  291. return true, LDAPResultSuccess
  292. }
  293. case "Substrings":
  294. if len(f.Children) != 2 {
  295. return false, LDAPResultOperationsError
  296. }
  297. attribute := f.Children[0].Value.(string)
  298. bytes := f.Children[1].Children[0].Data.Bytes()
  299. value := string(bytes[:])
  300. for _, a := range entry.Attributes {
  301. if strings.ToLower(a.Name) == strings.ToLower(attribute) {
  302. for _, v := range a.Values {
  303. switch f.Children[1].Children[0].Tag {
  304. case FilterSubstringsInitial:
  305. if strings.HasPrefix(v, value) {
  306. return true, LDAPResultSuccess
  307. }
  308. case FilterSubstringsAny:
  309. if strings.Contains(v, value) {
  310. return true, LDAPResultSuccess
  311. }
  312. case FilterSubstringsFinal:
  313. if strings.HasSuffix(v, value) {
  314. return true, LDAPResultSuccess
  315. }
  316. }
  317. }
  318. }
  319. }
  320. case "FilterGreaterOrEqual": // TODO
  321. return false, LDAPResultOperationsError
  322. case "FilterLessOrEqual": // TODO
  323. return false, LDAPResultOperationsError
  324. case "FilterApproxMatch": // TODO
  325. return false, LDAPResultOperationsError
  326. case "FilterExtensibleMatch": // TODO
  327. return false, LDAPResultOperationsError
  328. }
  329. return false, LDAPResultSuccess
  330. }
  331. func GetFilterObjectClass(filter string) (string, error) {
  332. f, err := CompileFilter(filter)
  333. if err != nil {
  334. return "", err
  335. }
  336. return parseFilterObjectClass(f)
  337. }
  338. func parseFilterObjectClass(f *ber.Packet) (string, error) {
  339. objectClass := ""
  340. switch FilterMap[f.Tag] {
  341. case "Equality Match":
  342. if len(f.Children) != 2 {
  343. return "", errors.New("Equality match must have only two children")
  344. }
  345. attribute := strings.ToLower(f.Children[0].Value.(string))
  346. value := f.Children[1].Value.(string)
  347. if attribute == "objectclass" {
  348. objectClass = strings.ToLower(value)
  349. }
  350. case "And":
  351. for _, child := range f.Children {
  352. subType, err := parseFilterObjectClass(child)
  353. if err != nil {
  354. return "", err
  355. }
  356. if len(subType) > 0 {
  357. objectClass = subType
  358. }
  359. }
  360. case "Or":
  361. for _, child := range f.Children {
  362. subType, err := parseFilterObjectClass(child)
  363. if err != nil {
  364. return "", err
  365. }
  366. if len(subType) > 0 {
  367. objectClass = subType
  368. }
  369. }
  370. case "Not":
  371. if len(f.Children) != 1 {
  372. return "", errors.New("Not filter must have only one child")
  373. }
  374. subType, err := parseFilterObjectClass(f.Children[0])
  375. if err != nil {
  376. return "", err
  377. }
  378. if len(subType) > 0 {
  379. objectClass = subType
  380. }
  381. }
  382. return strings.ToLower(objectClass), nil
  383. }