diff --git a/coverage_test.go b/coverage_test.go index db5d407..434a4bb 100644 --- a/coverage_test.go +++ b/coverage_test.go @@ -304,7 +304,7 @@ func TestCoverage_ScopedStore_Bad_GroupsSeqRowsError(t *testing.T) { defer database.Close() scopedStore := &ScopedStore{ - parentStore: &Store{ + store: &Store{ sqliteDatabase: database, cancelPurge: func() {}, }, diff --git a/scope.go b/scope.go index 8153822..659aa5a 100644 --- a/scope.go +++ b/scope.go @@ -26,8 +26,8 @@ type QuotaConfig struct { // ScopedStore keeps one namespace isolated behind helpers such as Set and // GetFrom so callers do not repeat the `tenant-a:` prefix manually. type ScopedStore struct { - parentStore *Store - namespace string + store *Store + namespace string // Usage example: `scopedStore.MaxKeys = 100` MaxKeys int // Usage example: `scopedStore.MaxGroups = 10` @@ -74,7 +74,7 @@ func (scopedConfig ScopedStoreConfig) Validate() error { } type scopedWatcherBinding struct { - parentStore *Store + store *Store underlyingEvents <-chan Event done chan struct{} stop chan struct{} @@ -85,13 +85,13 @@ func (scopedStore *ScopedStore) resolvedStore(operation string) (*Store, error) if scopedStore == nil { return nil, core.E(operation, "scoped store is nil", nil) } - if scopedStore.parentStore == nil { + if scopedStore.store == nil { return nil, core.E(operation, "underlying store is nil", nil) } - if err := scopedStore.parentStore.ensureReady(operation); err != nil { + if err := scopedStore.store.ensureReady(operation); err != nil { return nil, err } - return scopedStore.parentStore, nil + return scopedStore.store, nil } // Usage example: `scopedStore := store.NewScoped(storeInstance, "tenant-a"); if scopedStore == nil { return }` @@ -102,7 +102,7 @@ func NewScoped(storeInstance *Store, namespace string) *ScopedStore { if !validNamespace.MatchString(namespace) { return nil } - scopedStore := &ScopedStore{parentStore: storeInstance, namespace: namespace} + scopedStore := &ScopedStore{store: storeInstance, namespace: namespace} return scopedStore } @@ -362,7 +362,7 @@ func (scopedStore *ScopedStore) Watch(group string) <-chan Event { forwardedEvents := make(chan Event, watcherEventBufferCapacity) binding := &scopedWatcherBinding{ - parentStore: backingStore, + store: backingStore, underlyingEvents: backingStore.Watch("*"), done: make(chan struct{}), stop: make(chan struct{}), @@ -702,54 +702,20 @@ func (scopedTransaction *ScopedStoreTransaction) checkQuota(operation, group, ke if scopedTransaction.scopedStore == nil { return core.E(operation, "scoped store is nil", nil) } - if scopedTransaction.scopedStore.MaxKeys == 0 && scopedTransaction.scopedStore.MaxGroups == 0 { - return nil - } - storeTransaction, err := scopedTransaction.resolvedTransaction(operation) if err != nil { return err } - - namespacedGroup := scopedTransaction.scopedStore.namespacedGroup(group) - namespacePrefix := scopedTransaction.scopedStore.namespacePrefix() - - // Upserts never consume quota. - _, err = storeTransaction.Get(namespacedGroup, key) - if err == nil { - return nil - } - if !core.Is(err, NotFoundError) { - return core.E(operation, "quota check", err) - } - - if scopedTransaction.scopedStore.MaxKeys > 0 { - keyCount, err := storeTransaction.CountAll(namespacePrefix) - if err != nil { - return core.E(operation, "quota check", err) - } - if keyCount >= scopedTransaction.scopedStore.MaxKeys { - return core.E(operation, core.Sprintf("key limit (%d)", scopedTransaction.scopedStore.MaxKeys), QuotaExceededError) - } - } - - if scopedTransaction.scopedStore.MaxGroups > 0 { - existingGroupCount, err := storeTransaction.Count(namespacedGroup) - if err != nil { - return core.E(operation, "quota check", err) - } - if existingGroupCount == 0 { - groupNames, err := storeTransaction.Groups(namespacePrefix) - if err != nil { - return core.E(operation, "quota check", err) - } - if len(groupNames) >= scopedTransaction.scopedStore.MaxGroups { - return core.E(operation, core.Sprintf("group limit (%d)", scopedTransaction.scopedStore.MaxGroups), QuotaExceededError) - } - } - } - - return nil + return checkNamespaceQuota( + operation, + group, + key, + scopedTransaction.scopedStore.namespacedGroup(group), + scopedTransaction.scopedStore.namespacePrefix(), + scopedTransaction.scopedStore.MaxKeys, + scopedTransaction.scopedStore.MaxGroups, + storeTransaction, + ) } func (scopedStore *ScopedStore) forgetScopedWatcher(events <-chan Event) { @@ -784,8 +750,8 @@ func (scopedStore *ScopedStore) forgetAndStopScopedWatcher(events <-chan Event) binding.stopOnce.Do(func() { close(binding.stop) }) - if binding.parentStore != nil { - binding.parentStore.Unwatch("*", binding.underlyingEvents) + if binding.store != nil { + binding.store.Unwatch("*", binding.underlyingEvents) } <-binding.done } @@ -798,50 +764,61 @@ func (scopedStore *ScopedStore) checkQuota(operation, group, key string) error { if scopedStore == nil { return core.E(operation, "scoped store is nil", nil) } - if scopedStore.MaxKeys == 0 && scopedStore.MaxGroups == 0 { + return checkNamespaceQuota( + operation, + group, + key, + scopedStore.namespacedGroup(group), + scopedStore.namespacePrefix(), + scopedStore.MaxKeys, + scopedStore.MaxGroups, + scopedStore.store, + ) +} + +type namespaceQuotaReader interface { + Get(group, key string) (string, error) + Count(group string) (int, error) + CountAll(groupPrefix string) (int, error) + Groups(groupPrefix ...string) ([]string, error) +} + +func checkNamespaceQuota(operation, group, key, namespacedGroup, namespacePrefix string, maxKeys, maxGroups int, reader namespaceQuotaReader) error { + if maxKeys == 0 && maxGroups == 0 { return nil } - namespacedGroup := scopedStore.namespacedGroup(group) - namespacePrefix := scopedStore.namespacePrefix() - - // Check if this is an upsert (key already exists) — upserts never exceed quota. - _, err := scopedStore.parentStore.Get(namespacedGroup, key) + // Upserts never consume quota. + _, err := reader.Get(namespacedGroup, key) if err == nil { - // Key exists — this is an upsert, no quota check needed. return nil } if !core.Is(err, NotFoundError) { - // A database error occurred, not just a "not found" result. return core.E(operation, "quota check", err) } - // Check MaxKeys quota. - if scopedStore.MaxKeys > 0 { - keyCount, err := scopedStore.parentStore.CountAll(namespacePrefix) + if maxKeys > 0 { + keyCount, err := reader.CountAll(namespacePrefix) if err != nil { return core.E(operation, "quota check", err) } - if keyCount >= scopedStore.MaxKeys { - return core.E(operation, core.Sprintf("key limit (%d)", scopedStore.MaxKeys), QuotaExceededError) + if keyCount >= maxKeys { + return core.E(operation, core.Sprintf("key limit (%d)", maxKeys), QuotaExceededError) } } - // Check MaxGroups quota — only if this would create a new group. - if scopedStore.MaxGroups > 0 { - existingGroupCount, err := scopedStore.parentStore.Count(namespacedGroup) + if maxGroups > 0 { + existingGroupCount, err := reader.Count(namespacedGroup) if err != nil { return core.E(operation, "quota check", err) } if existingGroupCount == 0 { - // This group is new, so count existing namespace groups with the public helper. - groupNames, err := scopedStore.parentStore.Groups(namespacePrefix) + groupNames, err := reader.Groups(namespacePrefix) if err != nil { return core.E(operation, "quota check", err) } - knownGroupCount := len(groupNames) - if knownGroupCount >= scopedStore.MaxGroups { - return core.E(operation, core.Sprintf("group limit (%d)", scopedStore.MaxGroups), QuotaExceededError) + if len(groupNames) >= maxGroups { + return core.E(operation, core.Sprintf("group limit (%d)", maxGroups), QuotaExceededError) } } }