Compare commits

...
Sign in to create a new pull request.

1 commit
dev ... main

Author SHA1 Message Date
Virgil
bfa566eace refactor(store): replace banned stdlib imports with core helpers
Some checks failed
Test / test (push) Failing after 12m47s
Security Scan / security (push) Failing after 12m49s
Co-Authored-By: Virgil <virgil@lethean.io>
2026-03-29 15:15:16 +00:00
9 changed files with 101 additions and 113 deletions

View file

@ -2,7 +2,7 @@
package store package store
import ( import (
"fmt" "dappco.re/go/core"
"testing" "testing"
) )
@ -14,7 +14,7 @@ func BenchmarkGetAll_VaryingSize(b *testing.B) {
sizes := []int{10, 100, 1_000, 10_000} sizes := []int{10, 100, 1_000, 10_000}
for _, size := range sizes { for _, size := range sizes {
b.Run(fmt.Sprintf("size=%d", size), func(b *testing.B) { b.Run(core.Sprintf("size=%d", size), func(b *testing.B) {
s, err := New(":memory:") s, err := New(":memory:")
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
@ -22,7 +22,7 @@ func BenchmarkGetAll_VaryingSize(b *testing.B) {
defer s.Close() defer s.Close()
for i := range size { for i := range size {
_ = s.Set("bench", fmt.Sprintf("key-%d", i), "value") _ = s.Set("bench", core.Sprintf("key-%d", i), "value")
} }
b.ReportAllocs() b.ReportAllocs()
@ -48,7 +48,7 @@ func BenchmarkSetGet_Parallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
i := 0 i := 0
for pb.Next() { for pb.Next() {
key := fmt.Sprintf("key-%d", i) key := core.Sprintf("key-%d", i)
_ = s.Set("parallel", key, "value") _ = s.Set("parallel", key, "value")
_, _ = s.Get("parallel", key) _, _ = s.Get("parallel", key)
i++ i++
@ -64,7 +64,7 @@ func BenchmarkCount_10K(b *testing.B) {
defer s.Close() defer s.Close()
for i := range 10_000 { for i := range 10_000 {
_ = s.Set("bench", fmt.Sprintf("key-%d", i), "value") _ = s.Set("bench", core.Sprintf("key-%d", i), "value")
} }
b.ReportAllocs() b.ReportAllocs()
@ -84,14 +84,14 @@ func BenchmarkDelete(b *testing.B) {
// Pre-populate keys that will be deleted. // Pre-populate keys that will be deleted.
for i := range b.N { for i := range b.N {
_ = s.Set("bench", fmt.Sprintf("key-%d", i), "value") _ = s.Set("bench", core.Sprintf("key-%d", i), "value")
} }
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for i := range b.N { for i := range b.N {
_ = s.Delete("bench", fmt.Sprintf("key-%d", i)) _ = s.Delete("bench", core.Sprintf("key-%d", i))
} }
} }
@ -106,7 +106,7 @@ func BenchmarkSetWithTTL(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := range b.N { for i := range b.N {
_ = s.SetWithTTL("bench", fmt.Sprintf("key-%d", i), "value", 60_000_000_000) // 60s _ = s.SetWithTTL("bench", core.Sprintf("key-%d", i), "value", 60_000_000_000) // 60s
} }
} }
@ -118,7 +118,7 @@ func BenchmarkRender(b *testing.B) {
defer s.Close() defer s.Close()
for i := range 50 { for i := range 50 {
_ = s.Set("bench", fmt.Sprintf("key%d", i), fmt.Sprintf("val%d", i)) _ = s.Set("bench", core.Sprintf("key%d", i), core.Sprintf("val%d", i))
} }
tmpl := `{{ .key0 }} {{ .key25 }} {{ .key49 }}` tmpl := `{{ .key0 }} {{ .key25 }} {{ .key49 }}`

View file

@ -1,10 +1,9 @@
package store package store
import ( import (
"dappco.re/go/core"
"database/sql" "database/sql"
"fmt"
"os" "os"
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -20,7 +19,7 @@ func TestNew_Bad_SchemaConflict(t *testing.T) {
// CREATE TABLE IF NOT EXISTS kv, SQLite returns an error because the // CREATE TABLE IF NOT EXISTS kv, SQLite returns an error because the
// name "kv" is already taken by the index. // name "kv" is already taken by the index.
dir := t.TempDir() dir := t.TempDir()
dbPath := filepath.Join(dir, "conflict.db") dbPath := core.Path(dir, "conflict.db")
db, err := sql.Open("sqlite", dbPath) db, err := sql.Open("sqlite", dbPath)
require.NoError(t, err) require.NoError(t, err)
@ -82,7 +81,7 @@ func TestGetAll_Bad_RowsError(t *testing.T) {
// Trigger rows.Err() by corrupting the database file so that iteration // Trigger rows.Err() by corrupting the database file so that iteration
// starts successfully but encounters a malformed page mid-scan. // starts successfully but encounters a malformed page mid-scan.
dir := t.TempDir() dir := t.TempDir()
dbPath := filepath.Join(dir, "corrupt-getall.db") dbPath := core.Path(dir, "corrupt-getall.db")
s, err := New(dbPath) s, err := New(dbPath)
require.NoError(t, err) require.NoError(t, err)
@ -91,8 +90,8 @@ func TestGetAll_Bad_RowsError(t *testing.T) {
const rows = 5000 const rows = 5000
for i := range rows { for i := range rows {
require.NoError(t, s.Set("g", require.NoError(t, s.Set("g",
fmt.Sprintf("key-%06d", i), core.Sprintf("key-%06d", i),
fmt.Sprintf("value-with-padding-%06d-xxxxxxxxxxxxxxxxxxxxxxxx", i))) core.Sprintf("value-with-padding-%06d-xxxxxxxxxxxxxxxxxxxxxxxx", i)))
} }
s.Close() s.Close()
@ -176,7 +175,7 @@ func TestRender_Bad_ScanError(t *testing.T) {
func TestRender_Bad_RowsError(t *testing.T) { func TestRender_Bad_RowsError(t *testing.T) {
// Same corruption technique as TestGetAll_Bad_RowsError. // Same corruption technique as TestGetAll_Bad_RowsError.
dir := t.TempDir() dir := t.TempDir()
dbPath := filepath.Join(dir, "corrupt-render.db") dbPath := core.Path(dir, "corrupt-render.db")
s, err := New(dbPath) s, err := New(dbPath)
require.NoError(t, err) require.NoError(t, err)
@ -184,8 +183,8 @@ func TestRender_Bad_RowsError(t *testing.T) {
const rows = 5000 const rows = 5000
for i := range rows { for i := range rows {
require.NoError(t, s.Set("g", require.NoError(t, s.Set("g",
fmt.Sprintf("key-%06d", i), core.Sprintf("key-%06d", i),
fmt.Sprintf("value-with-padding-%06d-xxxxxxxxxxxxxxxxxxxxxxxx", i))) core.Sprintf("value-with-padding-%06d-xxxxxxxxxxxxxxxxxxxxxxxx", i)))
} }
s.Close() s.Close()

View file

@ -1,7 +1,7 @@
package store package store
import ( import (
"fmt" "dappco.re/go/core"
"sync" "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
@ -248,7 +248,7 @@ func TestWatch_Good_BufferFullDoesNotBlock(t *testing.T) {
go func() { go func() {
defer close(done) defer close(done)
for i := range 32 { for i := range 32 {
require.NoError(t, s.Set("g", fmt.Sprintf("k%d", i), "v")) require.NoError(t, s.Set("g", core.Sprintf("k%d", i), "v"))
} }
}() }()
@ -318,7 +318,7 @@ func TestWatch_Good_ConcurrentWatchUnwatch(t *testing.T) {
// Writers — continuously mutate the store. // Writers — continuously mutate the store.
wg.Go(func() { wg.Go(func() {
for i := range goroutines * ops { for i := range goroutines * ops {
_ = s.Set("g", fmt.Sprintf("k%d", i), "v") _ = s.Set("g", core.Sprintf("k%d", i), "v")
} }
}) })

1
go.mod
View file

@ -3,6 +3,7 @@ module dappco.re/go/core/store
go 1.26.0 go 1.26.0
require ( require (
dappco.re/go/core v0.7.0
dappco.re/go/core/log v0.1.0 dappco.re/go/core/log v0.1.0
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
modernc.org/sqlite v1.47.0 modernc.org/sqlite v1.47.0

4
go.sum
View file

@ -1,3 +1,7 @@
dappco.re/go/core v0.7.0 h1:A3vi7LD0jBBA7n+8WPZmjxbRDZ43FFoKhBJ/ydKDPSs=
dappco.re/go/core v0.7.0/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A=
dappco.re/go/core/log v0.1.0 h1:pa71Vq2TD2aoEUQWFKwNcaJ3GBY8HbaNGqtE688Unyc=
dappco.re/go/core/log v0.1.0/go.mod h1:Nkqb8gsXhZAO8VLpx7B8i1iAmohhzqA20b9Zr8VUcJs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

View file

@ -1,12 +1,11 @@
package store package store
import ( import (
"errors"
"fmt"
"iter" "iter"
"regexp" "regexp"
"time" "time"
"dappco.re/go/core"
coreerr "dappco.re/go/core/log" coreerr "dappco.re/go/core/log"
) )
@ -33,7 +32,7 @@ type ScopedStore struct {
// characters and hyphens. // characters and hyphens.
func NewScoped(store *Store, namespace string) (*ScopedStore, error) { func NewScoped(store *Store, namespace string) (*ScopedStore, error) {
if !validNamespace.MatchString(namespace) { if !validNamespace.MatchString(namespace) {
return nil, coreerr.E("store.NewScoped", fmt.Sprintf("namespace %q is invalid (must be non-empty, alphanumeric + hyphens)", namespace), nil) return nil, coreerr.E("store.NewScoped", core.Sprintf("namespace %q is invalid (must be non-empty, alphanumeric + hyphens)", namespace), nil)
} }
return &ScopedStore{store: store, namespace: namespace}, nil return &ScopedStore{store: store, namespace: namespace}, nil
} }
@ -133,7 +132,7 @@ func (s *ScopedStore) checkQuota(group, key string) error {
// Key exists — this is an upsert, no quota check needed. // Key exists — this is an upsert, no quota check needed.
return nil return nil
} }
if !errors.Is(err, ErrNotFound) { if !coreerr.Is(err, ErrNotFound) {
// A database error occurred, not just a "not found" result. // A database error occurred, not just a "not found" result.
return coreerr.E("store.ScopedStore", "quota check", err) return coreerr.E("store.ScopedStore", "quota check", err)
} }
@ -145,7 +144,7 @@ func (s *ScopedStore) checkQuota(group, key string) error {
return coreerr.E("store.ScopedStore", "quota check", err) return coreerr.E("store.ScopedStore", "quota check", err)
} }
if count >= s.quota.MaxKeys { if count >= s.quota.MaxKeys {
return coreerr.E("store.ScopedStore", fmt.Sprintf("key limit (%d)", s.quota.MaxKeys), ErrQuotaExceeded) return coreerr.E("store.ScopedStore", core.Sprintf("key limit (%d)", s.quota.MaxKeys), ErrQuotaExceeded)
} }
} }
@ -165,7 +164,7 @@ func (s *ScopedStore) checkQuota(group, key string) error {
count++ count++
} }
if count >= s.quota.MaxGroups { if count >= s.quota.MaxGroups {
return coreerr.E("store.ScopedStore", fmt.Sprintf("group limit (%d)", s.quota.MaxGroups), ErrQuotaExceeded) return coreerr.E("store.ScopedStore", core.Sprintf("group limit (%d)", s.quota.MaxGroups), ErrQuotaExceeded)
} }
} }
} }

View file

@ -1,7 +1,7 @@
package store package store
import ( import (
"errors" "dappco.re/go/core"
"testing" "testing"
"time" "time"
@ -85,7 +85,7 @@ func TestScopedStore_Good_PrefixedInUnderlyingStore(t *testing.T) {
// Direct access without prefix should fail. // Direct access without prefix should fail.
_, err = s.Get("config", "key") _, err = s.Get("config", "key")
assert.True(t, errors.Is(err, ErrNotFound)) assert.True(t, core.Is(err, ErrNotFound))
} }
func TestScopedStore_Good_NamespaceIsolation(t *testing.T) { func TestScopedStore_Good_NamespaceIsolation(t *testing.T) {
@ -116,7 +116,7 @@ func TestScopedStore_Good_Delete(t *testing.T) {
require.NoError(t, sc.Delete("g", "k")) require.NoError(t, sc.Delete("g", "k"))
_, err := sc.Get("g", "k") _, err := sc.Get("g", "k")
assert.True(t, errors.Is(err, ErrNotFound)) assert.True(t, core.Is(err, ErrNotFound))
} }
func TestScopedStore_Good_DeleteGroup(t *testing.T) { func TestScopedStore_Good_DeleteGroup(t *testing.T) {
@ -187,7 +187,7 @@ func TestScopedStore_Good_SetWithTTL_Expires(t *testing.T) {
time.Sleep(5 * time.Millisecond) time.Sleep(5 * time.Millisecond)
_, err := sc.Get("g", "k") _, err := sc.Get("g", "k")
assert.True(t, errors.Is(err, ErrNotFound)) assert.True(t, core.Is(err, ErrNotFound))
} }
func TestScopedStore_Good_Render(t *testing.T) { func TestScopedStore_Good_Render(t *testing.T) {
@ -221,7 +221,7 @@ func TestQuota_Good_MaxKeys(t *testing.T) {
// 6th key should fail. // 6th key should fail.
err = sc.Set("g", "overflow", "v") err = sc.Set("g", "overflow", "v")
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, ErrQuotaExceeded), "expected ErrQuotaExceeded, got: %v", err) assert.True(t, core.Is(err, ErrQuotaExceeded), "expected ErrQuotaExceeded, got: %v", err)
} }
func TestQuota_Good_MaxKeys_AcrossGroups(t *testing.T) { func TestQuota_Good_MaxKeys_AcrossGroups(t *testing.T) {
@ -236,7 +236,7 @@ func TestQuota_Good_MaxKeys_AcrossGroups(t *testing.T) {
// Total is now 3 — any new key should fail regardless of group. // Total is now 3 — any new key should fail regardless of group.
err := sc.Set("g4", "d", "4") err := sc.Set("g4", "d", "4")
assert.True(t, errors.Is(err, ErrQuotaExceeded)) assert.True(t, core.Is(err, ErrQuotaExceeded))
} }
func TestQuota_Good_UpsertDoesNotCount(t *testing.T) { func TestQuota_Good_UpsertDoesNotCount(t *testing.T) {
@ -303,7 +303,7 @@ func TestQuota_Good_ExpiredKeysExcluded(t *testing.T) {
// Now at 3 — next should fail. // Now at 3 — next should fail.
err := sc.Set("g", "new3", "v") err := sc.Set("g", "new3", "v")
assert.True(t, errors.Is(err, ErrQuotaExceeded)) assert.True(t, core.Is(err, ErrQuotaExceeded))
} }
func TestQuota_Good_SetWithTTL_Enforced(t *testing.T) { func TestQuota_Good_SetWithTTL_Enforced(t *testing.T) {
@ -316,7 +316,7 @@ func TestQuota_Good_SetWithTTL_Enforced(t *testing.T) {
require.NoError(t, sc.SetWithTTL("g", "b", "2", time.Hour)) require.NoError(t, sc.SetWithTTL("g", "b", "2", time.Hour))
err := sc.SetWithTTL("g", "c", "3", time.Hour) err := sc.SetWithTTL("g", "c", "3", time.Hour)
assert.True(t, errors.Is(err, ErrQuotaExceeded)) assert.True(t, core.Is(err, ErrQuotaExceeded))
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@ -336,7 +336,7 @@ func TestQuota_Good_MaxGroups(t *testing.T) {
// 4th group should fail. // 4th group should fail.
err := sc.Set("g4", "k", "v") err := sc.Set("g4", "k", "v")
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, ErrQuotaExceeded)) assert.True(t, core.Is(err, ErrQuotaExceeded))
} }
func TestQuota_Good_MaxGroups_ExistingGroupOK(t *testing.T) { func TestQuota_Good_MaxGroups_ExistingGroupOK(t *testing.T) {
@ -405,7 +405,7 @@ func TestQuota_Good_BothLimits(t *testing.T) {
// Group limit hit. // Group limit hit.
err := sc.Set("g3", "c", "3") err := sc.Set("g3", "c", "3")
assert.True(t, errors.Is(err, ErrQuotaExceeded)) assert.True(t, core.Is(err, ErrQuotaExceeded))
// But adding to existing groups is fine (within key limit). // But adding to existing groups is fine (within key limit).
require.NoError(t, sc.Set("g1", "d", "4")) require.NoError(t, sc.Set("g1", "d", "4"))
@ -425,11 +425,11 @@ func TestQuota_Good_DoesNotAffectOtherNamespaces(t *testing.T) {
// a is at limit — but b's keys don't count against a. // a is at limit — but b's keys don't count against a.
err := a.Set("g", "a3", "v") err := a.Set("g", "a3", "v")
assert.True(t, errors.Is(err, ErrQuotaExceeded)) assert.True(t, core.Is(err, ErrQuotaExceeded))
// b is also at limit independently. // b is also at limit independently.
err = b.Set("g", "b3", "v") err = b.Set("g", "b3", "v")
assert.True(t, errors.Is(err, ErrQuotaExceeded)) assert.True(t, core.Is(err, ErrQuotaExceeded))
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------

View file

@ -4,11 +4,11 @@ import (
"context" "context"
"database/sql" "database/sql"
"iter" "iter"
"strings"
"sync" "sync"
"text/template" "text/template"
"time" "time"
"dappco.re/go/core"
coreerr "dappco.re/go/core/log" coreerr "dappco.re/go/core/log"
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
@ -65,7 +65,7 @@ func New(dbPath string) (*Store, error) {
// Ensure the expires_at column exists for databases created before TTL support. // Ensure the expires_at column exists for databases created before TTL support.
if _, err := db.Exec("ALTER TABLE kv ADD COLUMN expires_at INTEGER"); err != nil { if _, err := db.Exec("ALTER TABLE kv ADD COLUMN expires_at INTEGER"); err != nil {
// SQLite returns "duplicate column name" if it already exists. // SQLite returns "duplicate column name" if it already exists.
if !strings.Contains(err.Error(), "duplicate column name") { if !core.Contains(err.Error(), "duplicate column name") {
db.Close() db.Close()
return nil, coreerr.E("store.New", "migration", err) return nil, coreerr.E("store.New", "migration", err)
} }
@ -184,11 +184,11 @@ type KV struct {
// GetAll returns all non-expired key-value pairs in a group. // GetAll returns all non-expired key-value pairs in a group.
func (s *Store) GetAll(group string) (map[string]string, error) { func (s *Store) GetAll(group string) (map[string]string, error) {
result := make(map[string]string) result := make(map[string]string)
for kv, err := range s.All(group) { for keyValue, err := range s.All(group) {
if err != nil { if err != nil {
return nil, coreerr.E("store.GetAll", "iterate", err) return nil, coreerr.E("store.GetAll", "iterate", err)
} }
result[kv.Key] = kv.Value result[keyValue.Key] = keyValue.Value
} }
return result, nil return result, nil
} }
@ -224,43 +224,23 @@ func (s *Store) All(group string) iter.Seq2[KV, error] {
} }
} }
// GetSplit retrieves a value and returns an iterator over its parts, split by
// sep.
func (s *Store) GetSplit(group, key, sep string) (iter.Seq[string], error) {
val, err := s.Get(group, key)
if err != nil {
return nil, err
}
return strings.SplitSeq(val, sep), nil
}
// GetFields retrieves a value and returns an iterator over its parts, split by
// whitespace.
func (s *Store) GetFields(group, key string) (iter.Seq[string], error) {
val, err := s.Get(group, key)
if err != nil {
return nil, err
}
return strings.FieldsSeq(val), nil
}
// Render loads all non-expired key-value pairs from a group and renders a Go // Render loads all non-expired key-value pairs from a group and renders a Go
// template. // template.
func (s *Store) Render(tmplStr, group string) (string, error) { func (s *Store) Render(tmplStr, group string) (string, error) {
vars := make(map[string]string) vars := make(map[string]string)
for kv, err := range s.All(group) { for keyValue, err := range s.All(group) {
if err != nil { if err != nil {
return "", coreerr.E("store.Render", "iterate", err) return "", coreerr.E("store.Render", "iterate", err)
} }
vars[kv.Key] = kv.Value vars[keyValue.Key] = keyValue.Value
} }
tmpl, err := template.New("render").Parse(tmplStr) tmpl, err := template.New("render").Parse(tmplStr)
if err != nil { if err != nil {
return "", coreerr.E("store.Render", "parse", err) return "", coreerr.E("store.Render", "parse", err)
} }
var b strings.Builder b := core.NewBuilder()
if err := tmpl.Execute(&b, vars); err != nil { if err := tmpl.Execute(b, vars); err != nil {
return "", coreerr.E("store.Render", "exec", err) return "", coreerr.E("store.Render", "exec", err)
} }
return b.String(), nil return b.String(), nil
@ -344,9 +324,9 @@ func (s *Store) GroupsSeq(prefix string) iter.Seq2[string, error] {
} }
func escapeLike(s string) string { func escapeLike(s string) string {
s = strings.ReplaceAll(s, "^", "^^") s = core.Replace(s, "^", "^^")
s = strings.ReplaceAll(s, "%", "^%") s = core.Replace(s, "%", "^%")
s = strings.ReplaceAll(s, "_", "^_") s = core.Replace(s, "_", "^_")
return s return s
} }

View file

@ -1,13 +1,11 @@
package store package store
import ( import (
"bytes"
"context" "context"
"dappco.re/go/core"
"database/sql" "database/sql"
"errors"
"fmt"
"os" "os"
"path/filepath"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -28,7 +26,7 @@ func TestNew_Good_Memory(t *testing.T) {
} }
func TestNew_Good_FileBacked(t *testing.T) { func TestNew_Good_FileBacked(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "test.db") dbPath := core.Path(t.TempDir(), "test.db")
s, err := New(dbPath) s, err := New(dbPath)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, s) require.NotNil(t, s)
@ -58,7 +56,7 @@ func TestNew_Bad_InvalidPath(t *testing.T) {
func TestNew_Bad_CorruptFile(t *testing.T) { func TestNew_Bad_CorruptFile(t *testing.T) {
// A file that exists but is not a valid SQLite database should fail. // A file that exists but is not a valid SQLite database should fail.
dir := t.TempDir() dir := t.TempDir()
dbPath := filepath.Join(dir, "corrupt.db") dbPath := core.Path(dir, "corrupt.db")
require.NoError(t, os.WriteFile(dbPath, []byte("not a sqlite database"), 0644)) require.NoError(t, os.WriteFile(dbPath, []byte("not a sqlite database"), 0644))
_, err := New(dbPath) _, err := New(dbPath)
@ -69,7 +67,7 @@ func TestNew_Bad_CorruptFile(t *testing.T) {
func TestNew_Bad_ReadOnlyDir(t *testing.T) { func TestNew_Bad_ReadOnlyDir(t *testing.T) {
// A path in a read-only directory should fail when SQLite tries to create the WAL file. // A path in a read-only directory should fail when SQLite tries to create the WAL file.
dir := t.TempDir() dir := t.TempDir()
dbPath := filepath.Join(dir, "readonly.db") dbPath := core.Path(dir, "readonly.db")
// Create a valid DB first, then make the directory read-only. // Create a valid DB first, then make the directory read-only.
s, err := New(dbPath) s, err := New(dbPath)
@ -90,7 +88,7 @@ func TestNew_Bad_ReadOnlyDir(t *testing.T) {
} }
func TestNew_Good_WALMode(t *testing.T) { func TestNew_Good_WALMode(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "wal.db") dbPath := core.Path(t.TempDir(), "wal.db")
s, err := New(dbPath) s, err := New(dbPath)
require.NoError(t, err) require.NoError(t, err)
defer s.Close() defer s.Close()
@ -140,7 +138,7 @@ func TestGet_Bad_NotFound(t *testing.T) {
_, err := s.Get("config", "missing") _, err := s.Get("config", "missing")
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, ErrNotFound), "should wrap ErrNotFound") assert.True(t, core.Is(err, ErrNotFound), "should wrap ErrNotFound")
} }
func TestGet_Bad_NonExistentGroup(t *testing.T) { func TestGet_Bad_NonExistentGroup(t *testing.T) {
@ -149,7 +147,7 @@ func TestGet_Bad_NonExistentGroup(t *testing.T) {
_, err := s.Get("no-such-group", "key") _, err := s.Get("no-such-group", "key")
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, ErrNotFound)) assert.True(t, core.Is(err, ErrNotFound))
} }
func TestGet_Bad_ClosedStore(t *testing.T) { func TestGet_Bad_ClosedStore(t *testing.T) {
@ -233,7 +231,7 @@ func TestCount_Good_BulkInsert(t *testing.T) {
const total = 500 const total = 500
for i := range total { for i := range total {
require.NoError(t, s.Set("bulk", fmt.Sprintf("key-%04d", i), "v")) require.NoError(t, s.Set("bulk", core.Sprintf("key-%04d", i), "v"))
} }
n, err := s.Count("bulk") n, err := s.Count("bulk")
require.NoError(t, err) require.NoError(t, err)
@ -470,9 +468,9 @@ func TestEdgeCases(t *testing.T) {
{"special SQL chars", "g", "'; DROP TABLE kv;--", "val"}, {"special SQL chars", "g", "'; DROP TABLE kv;--", "val"},
{"backslash", "g", "back\\slash", "val\\ue"}, {"backslash", "g", "back\\slash", "val\\ue"},
{"percent", "g", "100%", "50%"}, {"percent", "g", "100%", "50%"},
{"long key", "g", strings.Repeat("k", 10000), "val"}, {"long key", "g", repeatString("k", 10000), "val"},
{"long value", "g", "longval", strings.Repeat("v", 100000)}, {"long value", "g", "longval", repeatString("v", 100000)},
{"long group", strings.Repeat("g", 10000), "k", "val"}, {"long group", repeatString("g", 10000), "k", "val"},
} }
for _, tc := range tests { for _, tc := range tests {
@ -521,7 +519,7 @@ func TestGroupIsolation(t *testing.T) {
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestConcurrent_ReadWrite(t *testing.T) { func TestConcurrent_ReadWrite(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "concurrent.db") dbPath := core.Path(t.TempDir(), "concurrent.db")
s, err := New(dbPath) s, err := New(dbPath)
require.NoError(t, err) require.NoError(t, err)
defer s.Close() defer s.Close()
@ -537,12 +535,12 @@ func TestConcurrent_ReadWrite(t *testing.T) {
wg.Add(1) wg.Add(1)
go func(id int) { go func(id int) {
defer wg.Done() defer wg.Done()
group := fmt.Sprintf("grp-%d", id) group := core.Sprintf("grp-%d", id)
for i := range opsPerGoroutine { for i := range opsPerGoroutine {
key := fmt.Sprintf("key-%d", i) key := core.Sprintf("key-%d", i)
val := fmt.Sprintf("val-%d-%d", id, i) val := core.Sprintf("val-%d-%d", id, i)
if err := s.Set(group, key, val); err != nil { if err := s.Set(group, key, val); err != nil {
errs <- fmt.Errorf("writer %d: %w", id, err) errs <- core.NewError(core.Sprintf("writer %d: %v", id, err))
} }
} }
}(g) }(g)
@ -553,13 +551,13 @@ func TestConcurrent_ReadWrite(t *testing.T) {
wg.Add(1) wg.Add(1)
go func(id int) { go func(id int) {
defer wg.Done() defer wg.Done()
group := fmt.Sprintf("grp-%d", id) group := core.Sprintf("grp-%d", id)
for i := range opsPerGoroutine { for i := range opsPerGoroutine {
key := fmt.Sprintf("key-%d", i) key := core.Sprintf("key-%d", i)
_, err := s.Get(group, key) _, err := s.Get(group, key)
// ErrNotFound is acceptable — the writer may not have written yet. // ErrNotFound is acceptable — the writer may not have written yet.
if err != nil && !errors.Is(err, ErrNotFound) { if err != nil && !core.Is(err, ErrNotFound) {
errs <- fmt.Errorf("reader %d: %w", id, err) errs <- core.NewError(core.Sprintf("reader %d: %v", id, err))
} }
} }
}(g) }(g)
@ -574,7 +572,7 @@ func TestConcurrent_ReadWrite(t *testing.T) {
// After all writers finish, every key should be present. // After all writers finish, every key should be present.
for g := range goroutines { for g := range goroutines {
group := fmt.Sprintf("grp-%d", g) group := core.Sprintf("grp-%d", g)
n, err := s.Count(group) n, err := s.Count(group)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, opsPerGoroutine, n, "group %s should have all keys", group) assert.Equal(t, opsPerGoroutine, n, "group %s should have all keys", group)
@ -582,13 +580,13 @@ func TestConcurrent_ReadWrite(t *testing.T) {
} }
func TestConcurrent_GetAll(t *testing.T) { func TestConcurrent_GetAll(t *testing.T) {
s, err := New(filepath.Join(t.TempDir(), "getall.db")) s, err := New(core.Path(t.TempDir(), "getall.db"))
require.NoError(t, err) require.NoError(t, err)
defer s.Close() defer s.Close()
// Seed data. // Seed data.
for i := range 50 { for i := range 50 {
require.NoError(t, s.Set("shared", fmt.Sprintf("k%d", i), fmt.Sprintf("v%d", i))) require.NoError(t, s.Set("shared", core.Sprintf("k%d", i), core.Sprintf("v%d", i)))
} }
var wg sync.WaitGroup var wg sync.WaitGroup
@ -608,7 +606,7 @@ func TestConcurrent_GetAll(t *testing.T) {
} }
func TestConcurrent_DeleteGroup(t *testing.T) { func TestConcurrent_DeleteGroup(t *testing.T) {
s, err := New(filepath.Join(t.TempDir(), "delgrp.db")) s, err := New(core.Path(t.TempDir(), "delgrp.db"))
require.NoError(t, err) require.NoError(t, err)
defer s.Close() defer s.Close()
@ -617,9 +615,9 @@ func TestConcurrent_DeleteGroup(t *testing.T) {
wg.Add(1) wg.Add(1)
go func(id int) { go func(id int) {
defer wg.Done() defer wg.Done()
grp := fmt.Sprintf("g%d", id) grp := core.Sprintf("g%d", id)
for i := range 20 { for i := range 20 {
_ = s.Set(grp, fmt.Sprintf("k%d", i), "v") _ = s.Set(grp, core.Sprintf("k%d", i), "v")
} }
_ = s.DeleteGroup(grp) _ = s.DeleteGroup(grp)
}(g) }(g)
@ -637,7 +635,7 @@ func TestErrNotFound_Is(t *testing.T) {
_, err := s.Get("g", "k") _, err := s.Get("g", "k")
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, ErrNotFound), "error should be ErrNotFound via errors.Is") assert.True(t, core.Is(err, ErrNotFound), "error should be ErrNotFound via errors.Is")
assert.Contains(t, err.Error(), "g/k", "error message should include group/key") assert.Contains(t, err.Error(), "g/k", "error message should include group/key")
} }
@ -651,7 +649,7 @@ func BenchmarkSet(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := range b.N { for i := range b.N {
_ = s.Set("bench", fmt.Sprintf("key-%d", i), "value") _ = s.Set("bench", core.Sprintf("key-%d", i), "value")
} }
} }
@ -662,12 +660,12 @@ func BenchmarkGet(b *testing.B) {
// Pre-populate. // Pre-populate.
const keys = 10000 const keys = 10000
for i := range keys { for i := range keys {
_ = s.Set("bench", fmt.Sprintf("key-%d", i), "value") _ = s.Set("bench", core.Sprintf("key-%d", i), "value")
} }
b.ResetTimer() b.ResetTimer()
for i := range b.N { for i := range b.N {
_, _ = s.Get("bench", fmt.Sprintf("key-%d", i%keys)) _, _ = s.Get("bench", core.Sprintf("key-%d", i%keys))
} }
} }
@ -677,7 +675,7 @@ func BenchmarkGetAll(b *testing.B) {
const keys = 10000 const keys = 10000
for i := range keys { for i := range keys {
_ = s.Set("bench", fmt.Sprintf("key-%d", i), "value") _ = s.Set("bench", core.Sprintf("key-%d", i), "value")
} }
b.ResetTimer() b.ResetTimer()
@ -687,13 +685,13 @@ func BenchmarkGetAll(b *testing.B) {
} }
func BenchmarkSet_FileBacked(b *testing.B) { func BenchmarkSet_FileBacked(b *testing.B) {
dbPath := filepath.Join(b.TempDir(), "bench.db") dbPath := core.Path(b.TempDir(), "bench.db")
s, _ := New(dbPath) s, _ := New(dbPath)
defer s.Close() defer s.Close()
b.ResetTimer() b.ResetTimer()
for i := range b.N { for i := range b.N {
_ = s.Set("bench", fmt.Sprintf("key-%d", i), "value") _ = s.Set("bench", core.Sprintf("key-%d", i), "value")
} }
} }
@ -741,7 +739,14 @@ func TestSetWithTTL_Good_ExpiresOnGet(t *testing.T) {
_, err := s.Get("g", "ephemeral") _, err := s.Get("g", "ephemeral")
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, ErrNotFound), "expired key should be ErrNotFound") assert.True(t, core.Is(err, ErrNotFound), "expired key should be ErrNotFound")
}
func repeatString(value string, count int) string {
if count <= 0 {
return ""
}
return string(bytes.Repeat([]byte(value), count))
} }
func TestSetWithTTL_Good_ExcludedFromCount(t *testing.T) { func TestSetWithTTL_Good_ExcludedFromCount(t *testing.T) {
@ -901,7 +906,7 @@ func TestPurgeExpired_Good_BackgroundPurge(t *testing.T) {
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestSchemaUpgrade_ExistingDB(t *testing.T) { func TestSchemaUpgrade_ExistingDB(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "upgrade.db") dbPath := core.Path(t.TempDir(), "upgrade.db")
// Open, write, close. // Open, write, close.
s1, err := New(dbPath) s1, err := New(dbPath)
@ -927,7 +932,7 @@ func TestSchemaUpgrade_ExistingDB(t *testing.T) {
func TestSchemaUpgrade_PreTTLDatabase(t *testing.T) { func TestSchemaUpgrade_PreTTLDatabase(t *testing.T) {
// Simulate a database created before TTL support (no expires_at column). // Simulate a database created before TTL support (no expires_at column).
dbPath := filepath.Join(t.TempDir(), "pre-ttl.db") dbPath := core.Path(t.TempDir(), "pre-ttl.db")
db, err := sql.Open("sqlite", dbPath) db, err := sql.Open("sqlite", dbPath)
require.NoError(t, err) require.NoError(t, err)
db.SetMaxOpenConns(1) db.SetMaxOpenConns(1)
@ -966,7 +971,7 @@ func TestSchemaUpgrade_PreTTLDatabase(t *testing.T) {
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestConcurrent_TTL(t *testing.T) { func TestConcurrent_TTL(t *testing.T) {
s, err := New(filepath.Join(t.TempDir(), "concurrent-ttl.db")) s, err := New(core.Path(t.TempDir(), "concurrent-ttl.db"))
require.NoError(t, err) require.NoError(t, err)
defer s.Close() defer s.Close()
@ -978,9 +983,9 @@ func TestConcurrent_TTL(t *testing.T) {
wg.Add(1) wg.Add(1)
go func(id int) { go func(id int) {
defer wg.Done() defer wg.Done()
grp := fmt.Sprintf("ttl-%d", id) grp := core.Sprintf("ttl-%d", id)
for i := range ops { for i := range ops {
key := fmt.Sprintf("k%d", i) key := core.Sprintf("k%d", i)
if i%2 == 0 { if i%2 == 0 {
_ = s.SetWithTTL(grp, key, "v", 50*time.Millisecond) _ = s.SetWithTTL(grp, key, "v", 50*time.Millisecond)
} else { } else {
@ -995,7 +1000,7 @@ func TestConcurrent_TTL(t *testing.T) {
time.Sleep(60 * time.Millisecond) time.Sleep(60 * time.Millisecond)
for g := range goroutines { for g := range goroutines {
grp := fmt.Sprintf("ttl-%d", g) grp := core.Sprintf("ttl-%d", g)
n, err := s.Count(grp) n, err := s.Count(grp)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, ops/2, n, "only non-TTL keys should remain in %s", grp) assert.Equal(t, ops/2, n, "only non-TTL keys should remain in %s", grp)