filters.go 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. package gocql
  2. import "fmt"
  3. // HostFilter interface is used when a host is discovered via server sent events.
  4. type HostFilter interface {
  5. // Called when a new host is discovered, returning true will cause the host
  6. // to be added to the pools.
  7. Accept(host *HostInfo) bool
  8. }
  9. // HostFilterFunc converts a func(host HostInfo) bool into a HostFilter
  10. type HostFilterFunc func(host *HostInfo) bool
  11. func (fn HostFilterFunc) Accept(host *HostInfo) bool {
  12. return fn(host)
  13. }
  14. // AcceptAllFilter will accept all hosts
  15. func AcceptAllFilter() HostFilter {
  16. return HostFilterFunc(func(host *HostInfo) bool {
  17. return true
  18. })
  19. }
  20. func DenyAllFilter() HostFilter {
  21. return HostFilterFunc(func(host *HostInfo) bool {
  22. return false
  23. })
  24. }
  25. // DataCentreHostFilter filters all hosts such that they are in the same data centre
  26. // as the supplied data centre.
  27. func DataCentreHostFilter(dataCentre string) HostFilter {
  28. return HostFilterFunc(func(host *HostInfo) bool {
  29. return host.DataCenter() == dataCentre
  30. })
  31. }
  32. // WhiteListHostFilter filters incoming hosts by checking that their address is
  33. // in the initial hosts whitelist.
  34. func WhiteListHostFilter(hosts ...string) HostFilter {
  35. hostInfos, err := addrsToHosts(hosts, 9042)
  36. if err != nil {
  37. // dont want to panic here, but rather not break the API
  38. panic(fmt.Errorf("unable to lookup host info from address: %v", err))
  39. }
  40. m := make(map[string]bool, len(hostInfos))
  41. for _, host := range hostInfos {
  42. m[host.ConnectAddress().String()] = true
  43. }
  44. return HostFilterFunc(func(host *HostInfo) bool {
  45. return m[host.ConnectAddress().String()]
  46. })
  47. }