diff --git a/scope.go b/scope.go index 1763b68..f20ceec 100644 --- a/scope.go +++ b/scope.go @@ -543,51 +543,17 @@ func (scopedStoreTransaction *ScopedStoreTransaction) PurgeExpired() (int64, err } func (scopedStoreTransaction *ScopedStoreTransaction) checkQuota(operation, group, key string) error { - if scopedStoreTransaction.scopedStore.MaxKeys == 0 && scopedStoreTransaction.scopedStore.MaxGroups == 0 { - return nil - } - - namespacedGroup := scopedStoreTransaction.scopedStore.namespacedGroup(group) - namespacePrefix := scopedStoreTransaction.scopedStore.namespacePrefix() - - exists, err := liveEntryExists(scopedStoreTransaction.storeTransaction.sqliteTransaction, namespacedGroup, key) - if err != nil { - return core.E(operation, "quota check", err) - } - if exists { - return nil - } - - if scopedStoreTransaction.scopedStore.MaxKeys > 0 { - keyCount, err := scopedStoreTransaction.storeTransaction.CountAll(namespacePrefix) - if err != nil { - return core.E(operation, "quota check", err) - } - if keyCount >= scopedStoreTransaction.scopedStore.MaxKeys { - return core.E(operation, core.Sprintf("key limit (%d)", scopedStoreTransaction.scopedStore.MaxKeys), QuotaExceededError) - } - } - - if scopedStoreTransaction.scopedStore.MaxGroups > 0 { - existingGroupCount, err := scopedStoreTransaction.storeTransaction.Count(namespacedGroup) - if err != nil { - return core.E(operation, "quota check", err) - } - if existingGroupCount == 0 { - knownGroupCount := 0 - for _, iterationErr := range scopedStoreTransaction.storeTransaction.GroupsSeq(namespacePrefix) { - if iterationErr != nil { - return core.E(operation, "quota check", iterationErr) - } - knownGroupCount++ - } - if knownGroupCount >= scopedStoreTransaction.scopedStore.MaxGroups { - return core.E(operation, core.Sprintf("group limit (%d)", scopedStoreTransaction.scopedStore.MaxGroups), QuotaExceededError) - } - } - } - - return nil + return enforceQuota( + operation, + group, + key, + scopedStoreTransaction.scopedStore.namespacePrefix(), + scopedStoreTransaction.scopedStore.namespacedGroup(group), + scopedStoreTransaction.scopedStore.MaxKeys, + scopedStoreTransaction.scopedStore.MaxGroups, + scopedStoreTransaction.storeTransaction.sqliteTransaction, + scopedStoreTransaction.storeTransaction, + ) } // checkQuota("store.ScopedStore.Set", "config", "colour") returns nil when the @@ -595,51 +561,70 @@ func (scopedStoreTransaction *ScopedStoreTransaction) checkQuota(operation, grou // group would exceed the configured limit. Existing keys are treated as // upserts and do not consume quota. func (scopedStore *ScopedStore) checkQuota(operation, group, key string) error { - if scopedStore.MaxKeys == 0 && scopedStore.MaxGroups == 0 { + return enforceQuota( + operation, + group, + key, + scopedStore.namespacePrefix(), + scopedStore.namespacedGroup(group), + scopedStore.MaxKeys, + scopedStore.MaxGroups, + scopedStore.storeInstance.sqliteDatabase, + scopedStore.storeInstance, + ) +} + +type quotaCounter interface { + CountAll(groupPrefix string) (int, error) + Count(group string) (int, error) + GroupsSeq(groupPrefix ...string) iter.Seq2[string, error] +} + +func enforceQuota( + operation, group, key, namespacePrefix, namespacedGroup string, + maxKeys, maxGroups int, + queryable keyExistenceQuery, + counter quotaCounter, +) error { + if maxKeys == 0 && maxGroups == 0 { return nil } - namespacedGroup := scopedStore.namespacedGroup(group) - namespacePrefix := scopedStore.namespacePrefix() - - exists, err := liveEntryExists(scopedStore.storeInstance.sqliteDatabase, namespacedGroup, key) + exists, err := liveEntryExists(queryable, namespacedGroup, key) if err != nil { // A database error occurred, not just a "not found" result. return core.E(operation, "quota check", err) } if exists { - // Key exists — this is an upsert, no quota check needed. + // Key exists - this is an upsert, no quota check needed. return nil } - // Check MaxKeys quota. - if scopedStore.MaxKeys > 0 { - keyCount, err := scopedStore.storeInstance.CountAll(namespacePrefix) + if maxKeys > 0 { + keyCount, err := counter.CountAll(namespacePrefix) if err != nil { return core.E(operation, "quota check", err) } - if keyCount >= scopedStore.MaxKeys { - return core.E(operation, core.Sprintf("key limit (%d)", scopedStore.MaxKeys), QuotaExceededError) + if keyCount >= maxKeys { + return core.E(operation, core.Sprintf("key limit (%d)", maxKeys), QuotaExceededError) } } - // Check MaxGroups quota — only if this would create a new group. - if scopedStore.MaxGroups > 0 { - existingGroupCount, err := scopedStore.storeInstance.Count(namespacedGroup) + if maxGroups > 0 { + existingGroupCount, err := counter.Count(namespacedGroup) if err != nil { return core.E(operation, "quota check", err) } if existingGroupCount == 0 { - // This group is new — check if adding it would exceed the group limit. knownGroupCount := 0 - for _, iterationErr := range scopedStore.storeInstance.GroupsSeq(namespacePrefix) { + for _, iterationErr := range counter.GroupsSeq(namespacePrefix) { if iterationErr != nil { return core.E(operation, "quota check", iterationErr) } knownGroupCount++ } - if knownGroupCount >= scopedStore.MaxGroups { - return core.E(operation, core.Sprintf("group limit (%d)", scopedStore.MaxGroups), QuotaExceededError) + if knownGroupCount >= maxGroups { + return core.E(operation, core.Sprintf("group limit (%d)", maxGroups), QuotaExceededError) } } }