diff --git a/events.go b/events.go index b054c3f..6cdac54 100644 --- a/events.go +++ b/events.go @@ -54,6 +54,7 @@ type Event struct { // the exact callback later. type changeCallbackRegistration struct { registrationID uint64 + group string callback func(Event) } @@ -116,14 +117,52 @@ func (storeInstance *Store) Unwatch(group string, events <-chan Event) { } // OnChange registers a synchronous mutation callback. -// Usage example: `events := make(chan store.Event, 1); unregister := storeInstance.OnChange(func(event store.Event) { events <- event }); defer unregister()` -func (storeInstance *Store) OnChange(callback func(Event)) func() { +// Usage example: `unregister := storeInstance.OnChange(func(event store.Event) { fmt.Println(event.Group, event.Key, event.Value) })` +// Usage example: `unregister := storeInstance.OnChange("config", func(key, value string) { fmt.Println(key, value) })` +func (storeInstance *Store) OnChange(arguments ...any) func() { + if len(arguments) == 0 { + return func() {} + } + + var ( + callbackGroup string + callback func(Event) + ) + + switch len(arguments) { + case 1: + switch typedCallback := arguments[0].(type) { + case func(Event): + callback = typedCallback + default: + return func() {} + } + case 2: + groupName, ok := arguments[0].(string) + if !ok { + return func() {} + } + callbackGroup = groupName + switch typedCallback := arguments[1].(type) { + case func(Event): + callback = typedCallback + case func(string, string): + callback = func(event Event) { + typedCallback(event.Key, event.Value) + } + default: + return func() {} + } + default: + return func() {} + } + if callback == nil { return func() {} } registrationID := atomic.AddUint64(&storeInstance.nextCallbackRegistrationID, 1) - callbackRegistration := changeCallbackRegistration{registrationID: registrationID, callback: callback} + callbackRegistration := changeCallbackRegistration{registrationID: registrationID, group: callbackGroup, callback: callback} storeInstance.callbacksLock.Lock() storeInstance.callbacks = append(storeInstance.callbacks, callbackRegistration) @@ -172,6 +211,9 @@ func (storeInstance *Store) notify(event Event) { storeInstance.callbacksLock.RUnlock() for _, callback := range callbacks { + if callback.group != "" && callback.group != "*" && callback.group != event.Group { + continue + } callback.callback(event) } } diff --git a/events_test.go b/events_test.go index b1f491c..1767ca7 100644 --- a/events_test.go +++ b/events_test.go @@ -148,6 +148,22 @@ func TestEvents_OnChange_Good_Fires(t *testing.T) { assert.Equal(t, EventDelete, events[1].Type) } +func TestEvents_OnChange_Good_GroupFilteredCallback(t *testing.T) { + storeInstance, _ := New(":memory:") + defer storeInstance.Close() + + var seen []string + unregister := storeInstance.OnChange("config", func(key, value string) { + seen = append(seen, key+"="+value) + }) + defer unregister() + + require.NoError(t, storeInstance.Set("config", "theme", "dark")) + require.NoError(t, storeInstance.Set("other", "theme", "light")) + + assert.Equal(t, []string{"theme=dark"}, seen) +} + func TestEvents_Watch_Good_BufferDrops(t *testing.T) { storeInstance, _ := New(":memory:") defer storeInstance.Close() diff --git a/scope.go b/scope.go index 5ef1e30..c9f3bef 100644 --- a/scope.go +++ b/scope.go @@ -74,10 +74,6 @@ func (scopedStore *ScopedStore) namespacePrefix() string { return scopedStore.namespace + ":" } -func (scopedStore *ScopedStore) defaultGroup() string { - return defaultScopedGroupName -} - func (scopedStore *ScopedStore) trimNamespacePrefix(groupName string) string { return core.TrimPrefix(groupName, scopedStore.namespacePrefix()) } @@ -91,21 +87,45 @@ func (scopedStore *ScopedStore) Namespace() string { // Usage example: `colourValue, err := scopedStore.Get("colour")` // Usage example: `colourValue, err := scopedStore.Get("config", "colour")` func (scopedStore *ScopedStore) Get(arguments ...string) (string, error) { - group, key, err := scopedStore.getArguments(arguments) - if err != nil { - return "", err + switch len(arguments) { + case 1: + return scopedStore.GetFrom(defaultScopedGroupName, arguments[0]) + case 2: + return scopedStore.GetFrom(arguments[0], arguments[1]) + default: + return "", core.E( + "store.ScopedStore.Get", + core.Sprintf("expected 1 or 2 arguments; got %d", len(arguments)), + nil, + ) } +} + +// Usage example: `colourValue, err := scopedStore.GetFrom("config", "colour")` +func (scopedStore *ScopedStore) GetFrom(group, key string) (string, error) { return scopedStore.storeInstance.Get(scopedStore.namespacedGroup(group), key) } // Usage example: `if err := scopedStore.Set("colour", "blue"); err != nil { return }` // Usage example: `if err := scopedStore.Set("config", "colour", "blue"); err != nil { return }` func (scopedStore *ScopedStore) Set(arguments ...string) error { - group, key, value, err := scopedStore.setArguments(arguments) - if err != nil { - return err + switch len(arguments) { + case 2: + return scopedStore.SetIn(defaultScopedGroupName, arguments[0], arguments[1]) + case 3: + return scopedStore.SetIn(arguments[0], arguments[1], arguments[2]) + default: + return core.E( + "store.ScopedStore.Set", + core.Sprintf("expected 2 or 3 arguments; got %d", len(arguments)), + nil, + ) } - if err := scopedStore.checkQuota("store.ScopedStore.Set", group, key); err != nil { +} + +// Usage example: `if err := scopedStore.SetIn("config", "colour", "blue"); err != nil { return }` +func (scopedStore *ScopedStore) SetIn(group, key, value string) error { + if err := scopedStore.checkQuota("store.ScopedStore.SetIn", group, key); err != nil { return err } return scopedStore.storeInstance.Set(scopedStore.namespacedGroup(group), key, value) @@ -139,6 +159,11 @@ func (scopedStore *ScopedStore) All(group string) iter.Seq2[KeyValue, error] { return scopedStore.storeInstance.All(scopedStore.namespacedGroup(group)) } +// Usage example: `for entry, err := range scopedStore.AllSeq("config") { if err != nil { break }; fmt.Println(entry.Key, entry.Value) }` +func (scopedStore *ScopedStore) AllSeq(group string) iter.Seq2[KeyValue, error] { + return scopedStore.storeInstance.AllSeq(scopedStore.namespacedGroup(group)) +} + // Usage example: `keyCount, err := scopedStore.Count("config")` func (scopedStore *ScopedStore) Count(group string) (int, error) { return scopedStore.storeInstance.Count(scopedStore.namespacedGroup(group)) @@ -263,33 +288,3 @@ func (scopedStore *ScopedStore) checkQuota(operation, group, key string) error { return nil } - -func (scopedStore *ScopedStore) getArguments(arguments []string) (string, string, error) { - switch len(arguments) { - case 1: - return scopedStore.defaultGroup(), arguments[0], nil - case 2: - return arguments[0], arguments[1], nil - default: - return "", "", core.E( - "store.ScopedStore.Get", - core.Sprintf("expected 1 or 2 arguments; got %d", len(arguments)), - nil, - ) - } -} - -func (scopedStore *ScopedStore) setArguments(arguments []string) (string, string, string, error) { - switch len(arguments) { - case 2: - return scopedStore.defaultGroup(), arguments[0], arguments[1], nil - case 3: - return arguments[0], arguments[1], arguments[2], nil - default: - return "", "", "", core.E( - "store.ScopedStore.Set", - core.Sprintf("expected 2 or 3 arguments; got %d", len(arguments)), - nil, - ) - } -} diff --git a/scope_test.go b/scope_test.go index e3a8bfc..40b2a81 100644 --- a/scope_test.go +++ b/scope_test.go @@ -134,6 +134,35 @@ func TestScope_ScopedStore_Good_DefaultGroupHelpers(t *testing.T) { assert.Equal(t, "dark", rawValue) } +func TestScope_ScopedStore_Good_SetInAndGetFrom(t *testing.T) { + storeInstance, _ := New(":memory:") + defer storeInstance.Close() + + scopedStore := mustScoped(t, storeInstance, "tenant-a") + require.NoError(t, scopedStore.SetIn("config", "colour", "blue")) + + value, err := scopedStore.GetFrom("config", "colour") + require.NoError(t, err) + assert.Equal(t, "blue", value) +} + +func TestScope_ScopedStore_Good_AllSeq(t *testing.T) { + storeInstance, _ := New(":memory:") + defer storeInstance.Close() + + scopedStore := mustScoped(t, storeInstance, "tenant-a") + require.NoError(t, scopedStore.Set("items", "first", "1")) + require.NoError(t, scopedStore.Set("items", "second", "2")) + + var keys []string + for entry, err := range scopedStore.AllSeq("items") { + require.NoError(t, err) + keys = append(keys, entry.Key) + } + + assert.ElementsMatch(t, []string{"first", "second"}, keys) +} + func TestScope_ScopedStore_Good_PrefixedInUnderlyingStore(t *testing.T) { storeInstance, _ := New(":memory:") defer storeInstance.Close() diff --git a/store.go b/store.go index 9a57f73..00a4aa4 100644 --- a/store.go +++ b/store.go @@ -245,8 +245,8 @@ func (storeInstance *Store) GetAll(group string) (map[string]string, error) { return entriesByKey, nil } -// Usage example: `for entry, err := range storeInstance.All("config") { if err != nil { break }; fmt.Println(entry.Key, entry.Value) }` -func (storeInstance *Store) All(group string) iter.Seq2[KeyValue, error] { +// Usage example: `for entry, err := range storeInstance.AllSeq("config") { if err != nil { break }; fmt.Println(entry.Key, entry.Value) }` +func (storeInstance *Store) AllSeq(group string) iter.Seq2[KeyValue, error] { return func(yield func(KeyValue, error) bool) { rows, err := storeInstance.database.Query( "SELECT "+entryKeyColumn+", "+entryValueColumn+" FROM "+entriesTableName+" WHERE "+entryGroupColumn+" = ? AND (expires_at IS NULL OR expires_at > ?) ORDER BY "+entryKeyColumn, @@ -276,6 +276,11 @@ func (storeInstance *Store) All(group string) iter.Seq2[KeyValue, error] { } } +// Usage example: `for entry, err := range storeInstance.All("config") { if err != nil { break }; fmt.Println(entry.Key, entry.Value) }` +func (storeInstance *Store) All(group string) iter.Seq2[KeyValue, error] { + return storeInstance.AllSeq(group) +} + // Usage example: `parts, err := storeInstance.GetSplit("config", "hosts", ","); if err != nil { return }; for part := range parts { fmt.Println(part) }` func (storeInstance *Store) GetSplit(group, key, separator string) (iter.Seq[string], error) { value, err := storeInstance.Get(group, key) diff --git a/store_test.go b/store_test.go index 03b0c9f..f3eb9a3 100644 --- a/store_test.go +++ b/store_test.go @@ -382,6 +382,23 @@ func TestStore_All_Good_SortedByKey(t *testing.T) { assert.Equal(t, []string{"alpha", "bravo", "charlie"}, keys) } +func TestStore_AllSeq_Good_SortedByKey(t *testing.T) { + storeInstance, _ := New(":memory:") + defer storeInstance.Close() + + require.NoError(t, storeInstance.Set("g", "charlie", "3")) + require.NoError(t, storeInstance.Set("g", "alpha", "1")) + require.NoError(t, storeInstance.Set("g", "bravo", "2")) + + var keys []string + for entry, err := range storeInstance.AllSeq("g") { + require.NoError(t, err) + keys = append(keys, entry.Key) + } + + assert.Equal(t, []string{"alpha", "bravo", "charlie"}, keys) +} + func TestStore_All_Bad_ClosedStore(t *testing.T) { storeInstance, _ := New(":memory:") storeInstance.Close()