diff --git a/scope.go b/scope.go index aec2142..eaf3b36 100644 --- a/scope.go +++ b/scope.go @@ -4,6 +4,7 @@ import ( "database/sql" "iter" "regexp" + "sync" "time" core "dappco.re/go/core" @@ -69,6 +70,15 @@ type ScopedStore struct { MaxKeys int // Usage example: `scopedStore.MaxGroups = 10` MaxGroups int + + watcherLock sync.Mutex + watcherBridges map[uintptr]scopedWatcherBridge +} + +type scopedWatcherBridge struct { + sourceGroup string + sourceEvents <-chan Event + done chan struct{} } // NewScoped validates a namespace and prefixes groups with namespace + ":". @@ -80,7 +90,11 @@ func NewScoped(storeInstance *Store, namespace string) (*ScopedStore, error) { if !validNamespace.MatchString(namespace) { return nil, core.E("store.NewScoped", core.Sprintf("namespace %q is invalid; use names like %q or %q", namespace, "tenant-a", "tenant-42"), nil) } - scopedStore := &ScopedStore{store: storeInstance, namespace: namespace} + scopedStore := &ScopedStore{ + store: storeInstance, + namespace: namespace, + watcherBridges: make(map[uintptr]scopedWatcherBridge), + } return scopedStore, nil } @@ -271,6 +285,114 @@ func (scopedStore *ScopedStore) PurgeExpired() (int64, error) { return removedRows, nil } +// Usage example: `events := scopedStore.Watch("config")` +// Usage example: `events := scopedStore.Watch("*")` +// The returned events always use namespace-local group names, so a write to +// `tenant-a:config` is delivered as `config`. +func (scopedStore *ScopedStore) Watch(group string) <-chan Event { + if scopedStore == nil || scopedStore.store == nil { + return closedEventChannel() + } + + sourceGroup := scopedStore.namespacedGroup(group) + if group == "*" { + sourceGroup = "*" + } + + sourceEvents := scopedStore.store.Watch(sourceGroup) + localEvents := make(chan Event, watcherEventBufferCapacity) + done := make(chan struct{}) + localEventsPointer := channelPointer(localEvents) + + scopedStore.watcherLock.Lock() + if scopedStore.watcherBridges == nil { + scopedStore.watcherBridges = make(map[uintptr]scopedWatcherBridge) + } + scopedStore.watcherBridges[localEventsPointer] = scopedWatcherBridge{ + sourceGroup: sourceGroup, + sourceEvents: sourceEvents, + done: done, + } + scopedStore.watcherLock.Unlock() + + go func() { + defer close(localEvents) + defer scopedStore.removeWatcherBridge(localEventsPointer) + + for { + select { + case <-done: + return + case event, ok := <-sourceEvents: + if !ok { + return + } + + localEvent, allowed := scopedStore.localiseWatchedEvent(event) + if !allowed { + continue + } + + select { + case localEvents <- localEvent: + default: + } + } + } + }() + + return localEvents +} + +// Usage example: `events := scopedStore.Watch("config"); scopedStore.Unwatch("config", events)` +// Usage example: `events := scopedStore.Watch("*"); scopedStore.Unwatch("*", events)` +func (scopedStore *ScopedStore) Unwatch(group string, events <-chan Event) { + if scopedStore == nil || events == nil { + return + } + + scopedStore.watcherLock.Lock() + watcherBridge, ok := scopedStore.watcherBridges[channelPointer(events)] + if ok { + delete(scopedStore.watcherBridges, channelPointer(events)) + } + scopedStore.watcherLock.Unlock() + + if !ok { + return + } + + close(watcherBridge.done) + scopedStore.store.Unwatch(watcherBridge.sourceGroup, watcherBridge.sourceEvents) +} + +func (scopedStore *ScopedStore) removeWatcherBridge(pointer uintptr) { + if scopedStore == nil { + return + } + + scopedStore.watcherLock.Lock() + delete(scopedStore.watcherBridges, pointer) + scopedStore.watcherLock.Unlock() +} + +func (scopedStore *ScopedStore) localiseWatchedEvent(event Event) (Event, bool) { + if scopedStore == nil { + return Event{}, false + } + + namespacePrefix := scopedStore.namespacePrefix() + if event.Group == "*" { + return event, true + } + if !core.HasPrefix(event.Group, namespacePrefix) { + return Event{}, false + } + + event.Group = core.TrimPrefix(event.Group, namespacePrefix) + return event, true +} + // Usage example: `unregister := scopedStore.OnChange(func(event store.Event) { fmt.Println(event.Group, event.Key, event.Value) })` // The callback receives the namespace-local group name, so a write to // `tenant-a:config` is reported as `config`. diff --git a/scope_test.go b/scope_test.go index 9c52c83..526d657 100644 --- a/scope_test.go +++ b/scope_test.go @@ -314,6 +314,93 @@ func TestScope_ScopedStore_Good_OnChange_NamespaceLocal(t *testing.T) { assert.Equal(t, "", events[1].Value) } +func TestScope_ScopedStore_Good_Watch_NamespaceLocal(t *testing.T) { + storeInstance, _ := New(":memory:") + defer storeInstance.Close() + + scopedStore, _ := NewScoped(storeInstance, "tenant-a") + otherScopedStore, _ := NewScoped(storeInstance, "tenant-b") + + events := scopedStore.Watch("config") + defer scopedStore.Unwatch("config", events) + + require.NoError(t, scopedStore.SetIn("config", "colour", "blue")) + require.NoError(t, otherScopedStore.SetIn("config", "colour", "red")) + + select { + case event, ok := <-events: + require.True(t, ok) + assert.Equal(t, EventSet, event.Type) + assert.Equal(t, "config", event.Group) + assert.Equal(t, "colour", event.Key) + assert.Equal(t, "blue", event.Value) + case <-time.After(time.Second): + t.Fatal("timed out waiting for scoped watch event") + } + + select { + case event := <-events: + t.Fatalf("unexpected event from another namespace: %#v", event) + case <-time.After(50 * time.Millisecond): + } +} + +func TestScope_ScopedStore_Good_Watch_All_NamespaceLocal(t *testing.T) { + storeInstance, _ := New(":memory:") + defer storeInstance.Close() + + scopedStore, _ := NewScoped(storeInstance, "tenant-a") + otherScopedStore, _ := NewScoped(storeInstance, "tenant-b") + + events := scopedStore.Watch("*") + defer scopedStore.Unwatch("*", events) + + require.NoError(t, scopedStore.SetIn("config", "colour", "blue")) + require.NoError(t, scopedStore.SetIn("cache", "page", "home")) + require.NoError(t, otherScopedStore.SetIn("config", "colour", "red")) + + select { + case event, ok := <-events: + require.True(t, ok) + assert.Equal(t, "config", event.Group) + assert.Equal(t, "colour", event.Key) + case <-time.After(time.Second): + t.Fatal("timed out waiting for first wildcard scoped watch event") + } + + select { + case event, ok := <-events: + require.True(t, ok) + assert.Equal(t, "cache", event.Group) + assert.Equal(t, "page", event.Key) + case <-time.After(time.Second): + t.Fatal("timed out waiting for second wildcard scoped watch event") + } + + select { + case event := <-events: + t.Fatalf("unexpected wildcard event from another namespace: %#v", event) + case <-time.After(50 * time.Millisecond): + } +} + +func TestScope_ScopedStore_Good_Unwatch_ClosesLocalChannel(t *testing.T) { + storeInstance, _ := New(":memory:") + defer storeInstance.Close() + + scopedStore, _ := NewScoped(storeInstance, "tenant-a") + + events := scopedStore.Watch("config") + scopedStore.Unwatch("config", events) + + select { + case _, ok := <-events: + assert.False(t, ok) + case <-time.After(time.Second): + t.Fatal("timed out waiting for scoped watch channel to close") + } +} + func TestScope_ScopedStore_Good_GetAll(t *testing.T) { storeInstance, _ := New(":memory:") defer storeInstance.Close()