// Copyright 2015 CoreOS, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package types import ( "reflect" "sort" "sync" ) type Set interface { Add(string) Remove(string) Contains(string) bool Equals(Set) bool Length() int Values() []string Copy() Set Sub(Set) Set } func NewUnsafeSet(values ...string) *unsafeSet { set := &unsafeSet{make(map[string]struct{})} for _, v := range values { set.Add(v) } return set } func NewThreadsafeSet(values ...string) *tsafeSet { us := NewUnsafeSet(values...) return &tsafeSet{us, sync.RWMutex{}} } type unsafeSet struct { d map[string]struct{} } // Add adds a new value to the set (no-op if the value is already present) func (us *unsafeSet) Add(value string) { us.d[value] = struct{}{} } // Remove removes the given value from the set func (us *unsafeSet) Remove(value string) { delete(us.d, value) } // Contains returns whether the set contains the given value func (us *unsafeSet) Contains(value string) (exists bool) { _, exists = us.d[value] return } // ContainsAll returns whether the set contains all given values func (us *unsafeSet) ContainsAll(values []string) bool { for _, s := range values { if !us.Contains(s) { return false } } return true } // Equals returns whether the contents of two sets are identical func (us *unsafeSet) Equals(other Set) bool { v1 := sort.StringSlice(us.Values()) v2 := sort.StringSlice(other.Values()) v1.Sort() v2.Sort() return reflect.DeepEqual(v1, v2) } // Length returns the number of elements in the set func (us *unsafeSet) Length() int { return len(us.d) } // Values returns the values of the Set in an unspecified order. func (us *unsafeSet) Values() (values []string) { values = make([]string, 0) for val, _ := range us.d { values = append(values, val) } return } // Copy creates a new Set containing the values of the first func (us *unsafeSet) Copy() Set { cp := NewUnsafeSet() for val, _ := range us.d { cp.Add(val) } return cp } // Sub removes all elements in other from the set func (us *unsafeSet) Sub(other Set) Set { oValues := other.Values() result := us.Copy().(*unsafeSet) for _, val := range oValues { if _, ok := result.d[val]; !ok { continue } delete(result.d, val) } return result } type tsafeSet struct { us *unsafeSet m sync.RWMutex } func (ts *tsafeSet) Add(value string) { ts.m.Lock() defer ts.m.Unlock() ts.us.Add(value) } func (ts *tsafeSet) Remove(value string) { ts.m.Lock() defer ts.m.Unlock() ts.us.Remove(value) } func (ts *tsafeSet) Contains(value string) (exists bool) { ts.m.RLock() defer ts.m.RUnlock() return ts.us.Contains(value) } func (ts *tsafeSet) Equals(other Set) bool { ts.m.RLock() defer ts.m.RUnlock() return ts.us.Equals(other) } func (ts *tsafeSet) Length() int { ts.m.RLock() defer ts.m.RUnlock() return ts.us.Length() } func (ts *tsafeSet) Values() (values []string) { ts.m.RLock() defer ts.m.RUnlock() return ts.us.Values() } func (ts *tsafeSet) Copy() Set { ts.m.RLock() defer ts.m.RUnlock() usResult := ts.us.Copy().(*unsafeSet) return &tsafeSet{usResult, sync.RWMutex{}} } func (ts *tsafeSet) Sub(other Set) Set { ts.m.RLock() defer ts.m.RUnlock() usResult := ts.us.Sub(other).(*unsafeSet) return &tsafeSet{usResult, sync.RWMutex{}} }