diff --git a/scope.go b/scope.go index 524f038..7cafbfa 100644 --- a/scope.go +++ b/scope.go @@ -497,6 +497,9 @@ func (scopedTransaction *ScopedStoreTransaction) SetIn(group, key, value string) if err != nil { return err } + if err := scopedTransaction.checkQuota("store.ScopedStoreTransaction.SetIn", group, key); err != nil { + return err + } return storeTransaction.Set(scopedTransaction.scopedStore.namespacedGroup(group), key, value) } @@ -506,6 +509,9 @@ func (scopedTransaction *ScopedStoreTransaction) SetWithTTL(group, key, value st if err != nil { return err } + if err := scopedTransaction.checkQuota("store.ScopedStoreTransaction.SetWithTTL", group, key); err != nil { + return err + } return storeTransaction.SetWithTTL(scopedTransaction.scopedStore.namespacedGroup(group), key, value, timeToLive) } @@ -669,6 +675,66 @@ func (scopedTransaction *ScopedStoreTransaction) GetFields(group, key string) (i return storeTransaction.GetFields(scopedTransaction.scopedStore.namespacedGroup(group), key) } +// checkQuota("store.ScopedStoreTransaction.SetIn", "config", "colour") uses +// the transaction's own read state so staged writes inside the same +// transaction count towards the namespace limits. +func (scopedTransaction *ScopedStoreTransaction) checkQuota(operation, group, key string) error { + if scopedTransaction == nil { + return core.E(operation, "scoped transaction is nil", nil) + } + if scopedTransaction.scopedStore == nil { + return core.E(operation, "scoped store is nil", nil) + } + if scopedTransaction.scopedStore.MaxKeys == 0 && scopedTransaction.scopedStore.MaxGroups == 0 { + return nil + } + + storeTransaction, err := scopedTransaction.resolvedTransaction(operation) + if err != nil { + return err + } + + namespacedGroup := scopedTransaction.scopedStore.namespacedGroup(group) + namespacePrefix := scopedTransaction.scopedStore.namespacePrefix() + + // Upserts never consume quota. + _, err = storeTransaction.Get(namespacedGroup, key) + if err == nil { + return nil + } + if !core.Is(err, NotFoundError) { + return core.E(operation, "quota check", err) + } + + if scopedTransaction.scopedStore.MaxKeys > 0 { + keyCount, err := storeTransaction.CountAll(namespacePrefix) + if err != nil { + return core.E(operation, "quota check", err) + } + if keyCount >= scopedTransaction.scopedStore.MaxKeys { + return core.E(operation, core.Sprintf("key limit (%d)", scopedTransaction.scopedStore.MaxKeys), QuotaExceededError) + } + } + + if scopedTransaction.scopedStore.MaxGroups > 0 { + existingGroupCount, err := storeTransaction.Count(namespacedGroup) + if err != nil { + return core.E(operation, "quota check", err) + } + if existingGroupCount == 0 { + groupNames, err := storeTransaction.Groups(namespacePrefix) + if err != nil { + return core.E(operation, "quota check", err) + } + if len(groupNames) >= scopedTransaction.scopedStore.MaxGroups { + return core.E(operation, core.Sprintf("group limit (%d)", scopedTransaction.scopedStore.MaxGroups), QuotaExceededError) + } + } + } + + return nil +} + func (scopedStore *ScopedStore) forgetScopedWatcher(events <-chan Event) { if scopedStore == nil || events == nil { return diff --git a/scope_test.go b/scope_test.go index c0a6fc3..1dcd642 100644 --- a/scope_test.go +++ b/scope_test.go @@ -637,6 +637,42 @@ func TestScope_ScopedStoreTransaction_Good_PrefixesAndReadsPendingWrites(t *test assert.Equal(t, "tenant-a:config", received[1].Group) } +func TestScope_Quota_Good_TransactionEnforcesMaxKeys(t *testing.T) { + storeInstance, _ := New(":memory:") + defer storeInstance.Close() + + scopedStore, err := NewScopedWithQuota(storeInstance, "tenant-a", QuotaConfig{MaxKeys: 1}) + require.NoError(t, err) + + err = scopedStore.Transaction(func(transaction *ScopedStoreTransaction) error { + require.NoError(t, transaction.SetIn("config", "colour", "blue")) + return transaction.SetIn("config", "language", "en-GB") + }) + require.Error(t, err) + assert.True(t, core.Is(err, QuotaExceededError)) + + _, err = scopedStore.GetFrom("config", "colour") + assert.ErrorIs(t, err, NotFoundError) +} + +func TestScope_Quota_Good_TransactionEnforcesMaxGroups(t *testing.T) { + storeInstance, _ := New(":memory:") + defer storeInstance.Close() + + scopedStore, err := NewScopedWithQuota(storeInstance, "tenant-a", QuotaConfig{MaxGroups: 1}) + require.NoError(t, err) + + err = scopedStore.Transaction(func(transaction *ScopedStoreTransaction) error { + require.NoError(t, transaction.SetIn("config", "colour", "blue")) + return transaction.SetWithTTL("preferences", "language", "en-GB", time.Hour) + }) + require.Error(t, err) + assert.True(t, core.Is(err, QuotaExceededError)) + + _, err = scopedStore.GetFrom("config", "colour") + assert.ErrorIs(t, err, NotFoundError) +} + // --------------------------------------------------------------------------- // Quota enforcement — MaxKeys // ---------------------------------------------------------------------------