fix(store): r3 — transactional import + DELETE RETURNING + token home order on PR #4
Round 3 follow-up to ebe5377. Closes residual CodeRabbit findings.
Code:
- import.go: ImportAll DB mutations wrapped in transaction with
rollback-on-error
- import.go: malformed JSONL returns file/line parse errors in all
three import helpers (was silently swallowing per-line errors)
- import.go: walkDir returns + propagates traversal/list/type errors
- medium.go: JSON export uses aggregateFields() + propagates
workspace failures
- publish.go: dataset_card.md excluded from Parquet split count
- store.go: medium-backed Close() remains retryable after sync
failure; operations see closing state as closed
- store.go + scope.go + transaction.go: purge uses
DELETE ... RETURNING so notifications come from rows actually
deleted (was reading first then deleting separately)
- publish.go: token lookup uses Core's DIR_HOME (populated via
os.UserHomeDir) then falls back to HOME — preserves direct-os
import ban while picking up real home
Tests:
- import_test.go (new): coverage of transactional import +
malformed-JSONL error path
Doc:
- README.md: footer licence link targets LICENCE.md (UK English)
Verification: gofmt clean, golangci-lint v2 0 issues, GOWORK=off
go vet + go test -count=1 ./... pass with explicit cache paths.
Closes residual findings on https://github.com/dAppCore/go-store/pull/4
Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
parent
ebe5377871
commit
fc77445de0
12 changed files with 341 additions and 141 deletions
|
|
@ -94,4 +94,4 @@ go build ./...
|
|||
|
||||
## Licence
|
||||
|
||||
European Union Public Licence 1.2 — see [LICENCE](LICENCE) for details.
|
||||
European Union Public Licence 1.2 — see [LICENCE.md](LICENCE.md) for details.
|
||||
|
|
|
|||
10
events.go
10
events.go
|
|
@ -76,7 +76,7 @@ func (storeInstance *Store) Watch(group string) <-chan Event {
|
|||
|
||||
storeInstance.lifecycleLock.Lock()
|
||||
defer storeInstance.lifecycleLock.Unlock()
|
||||
if storeInstance.isClosed {
|
||||
if storeInstance.isClosed || storeInstance.isClosing {
|
||||
return closedEventChannel()
|
||||
}
|
||||
|
||||
|
|
@ -97,7 +97,7 @@ func (storeInstance *Store) Unwatch(group string, events <-chan Event) {
|
|||
}
|
||||
|
||||
storeInstance.lifecycleLock.Lock()
|
||||
closed := storeInstance.isClosed
|
||||
closed := storeInstance.isClosed || storeInstance.isClosing
|
||||
storeInstance.lifecycleLock.Unlock()
|
||||
if closed {
|
||||
return
|
||||
|
|
@ -146,7 +146,7 @@ func (storeInstance *Store) OnChange(callback func(Event)) func() {
|
|||
|
||||
storeInstance.lifecycleLock.Lock()
|
||||
defer storeInstance.lifecycleLock.Unlock()
|
||||
if storeInstance.isClosed {
|
||||
if storeInstance.isClosed || storeInstance.isClosing {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
|
|
@ -188,7 +188,7 @@ func (storeInstance *Store) notify(event Event) {
|
|||
}
|
||||
|
||||
storeInstance.lifecycleLock.Lock()
|
||||
if storeInstance.isClosed {
|
||||
if storeInstance.isClosed || storeInstance.isClosing {
|
||||
storeInstance.lifecycleLock.Unlock()
|
||||
return
|
||||
}
|
||||
|
|
@ -210,7 +210,7 @@ func (storeInstance *Store) notify(event Event) {
|
|||
storeInstance.watcherLock.RUnlock()
|
||||
|
||||
storeInstance.lifecycleLock.Lock()
|
||||
if storeInstance.isClosed {
|
||||
if storeInstance.isClosed || storeInstance.isClosing {
|
||||
storeInstance.lifecycleLock.Unlock()
|
||||
return
|
||||
}
|
||||
|
|
|
|||
179
import.go
179
import.go
|
|
@ -4,6 +4,7 @@ package store
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"database/sql"
|
||||
"io"
|
||||
"io/fs"
|
||||
|
||||
|
|
@ -13,6 +14,27 @@ import (
|
|||
// localFs provides unrestricted filesystem access for import operations.
|
||||
var localFs = (&core.Fs{}).New("/")
|
||||
|
||||
type duckDBImportSession interface {
|
||||
exec(query string, args ...any) error
|
||||
queryRowScan(query string, dest any, args ...any) error
|
||||
}
|
||||
|
||||
type duckDBImportTransaction struct {
|
||||
transaction *sql.Tx
|
||||
}
|
||||
|
||||
func (session duckDBImportTransaction) exec(query string, args ...any) error {
|
||||
_, err := session.transaction.Exec(query, args...)
|
||||
if err != nil {
|
||||
return core.E("store.duckDBImportTransaction.Exec", "execute query", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (session duckDBImportTransaction) queryRowScan(query string, dest any, args ...any) error {
|
||||
return session.transaction.QueryRow(query, args...).Scan(dest)
|
||||
}
|
||||
|
||||
// ScpFunc is a callback for executing SCP file transfers.
|
||||
// The function receives remote source and local destination paths.
|
||||
//
|
||||
|
|
@ -77,6 +99,10 @@ type ImportConfig struct {
|
|||
//
|
||||
// err := store.ImportAll(db, store.ImportConfig{DataDir: "/Volumes/Data/lem"}, os.Stdout)
|
||||
func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
|
||||
if db == nil || db.Conn() == nil {
|
||||
return core.E("store.ImportAll", "database is nil", nil)
|
||||
}
|
||||
|
||||
m3Host := cfg.M3Host
|
||||
if m3Host == "" {
|
||||
m3Host = "m3"
|
||||
|
|
@ -93,14 +119,26 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
|
|||
core.Print(w, " WARNING: could not pull golden set from M3: %v", err)
|
||||
}
|
||||
}
|
||||
transaction, err := db.Conn().Begin()
|
||||
if err != nil {
|
||||
return core.E("store.ImportAll", "begin import transaction", err)
|
||||
}
|
||||
committed := false
|
||||
defer func() {
|
||||
if !committed {
|
||||
_ = transaction.Rollback()
|
||||
}
|
||||
}()
|
||||
importSession := duckDBImportTransaction{transaction: transaction}
|
||||
|
||||
if isFile(goldenPath) {
|
||||
if err := db.Exec("DROP TABLE IF EXISTS golden_set"); err != nil {
|
||||
if err := importSession.exec("DROP TABLE IF EXISTS golden_set"); err != nil {
|
||||
return core.E("store.ImportAll", "drop golden_set", err)
|
||||
}
|
||||
err := db.Exec(core.Sprintf(`
|
||||
CREATE TABLE golden_set AS
|
||||
SELECT
|
||||
idx::INT AS idx,
|
||||
err := importSession.exec(core.Sprintf(`
|
||||
CREATE TABLE golden_set AS
|
||||
SELECT
|
||||
idx::INT AS idx,
|
||||
seed_id::VARCHAR AS seed_id,
|
||||
domain::VARCHAR AS domain,
|
||||
voice::VARCHAR AS voice,
|
||||
|
|
@ -115,7 +153,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
|
|||
return core.E("store.ImportAll", "import golden_set", err)
|
||||
} else {
|
||||
var n int
|
||||
if err := db.QueryRowScan("SELECT count(*) FROM golden_set", &n); err != nil {
|
||||
if err := importSession.queryRowScan("SELECT count(*) FROM golden_set", &n); err != nil {
|
||||
return core.E("store.ImportAll", "count golden_set", err)
|
||||
}
|
||||
totals["golden_set"] = n
|
||||
|
|
@ -160,13 +198,13 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
|
|||
}
|
||||
}
|
||||
|
||||
if err := db.Exec("DROP TABLE IF EXISTS training_examples"); err != nil {
|
||||
if err := importSession.exec("DROP TABLE IF EXISTS training_examples"); err != nil {
|
||||
return core.E("store.ImportAll", "drop training_examples", err)
|
||||
}
|
||||
if err := db.Exec(`
|
||||
CREATE TABLE training_examples (
|
||||
source VARCHAR,
|
||||
split VARCHAR,
|
||||
if err := importSession.exec(`
|
||||
CREATE TABLE training_examples (
|
||||
source VARCHAR,
|
||||
split VARCHAR,
|
||||
prompt TEXT,
|
||||
response TEXT,
|
||||
num_turns INT,
|
||||
|
|
@ -192,7 +230,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
|
|||
split = "test"
|
||||
}
|
||||
|
||||
n, err := importTrainingFile(db, local, td.name, split)
|
||||
n, err := importTrainingFile(importSession, local, td.name, split)
|
||||
if err != nil {
|
||||
return core.E("store.ImportAll", core.Sprintf("import training file %s", local), err)
|
||||
}
|
||||
|
|
@ -224,13 +262,13 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
|
|||
}
|
||||
}
|
||||
|
||||
if err := db.Exec("DROP TABLE IF EXISTS benchmark_results"); err != nil {
|
||||
if err := importSession.exec("DROP TABLE IF EXISTS benchmark_results"); err != nil {
|
||||
return core.E("store.ImportAll", "drop benchmark_results", err)
|
||||
}
|
||||
if err := db.Exec(`
|
||||
CREATE TABLE benchmark_results (
|
||||
source VARCHAR, id VARCHAR, benchmark VARCHAR, model VARCHAR,
|
||||
prompt TEXT, response TEXT, elapsed_seconds DOUBLE, domain VARCHAR
|
||||
if err := importSession.exec(`
|
||||
CREATE TABLE benchmark_results (
|
||||
source VARCHAR, id VARCHAR, benchmark VARCHAR, model VARCHAR,
|
||||
prompt TEXT, response TEXT, elapsed_seconds DOUBLE, domain VARCHAR
|
||||
)
|
||||
`); err != nil {
|
||||
return core.E("store.ImportAll", "create benchmark_results", err)
|
||||
|
|
@ -241,7 +279,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
|
|||
resultDir := core.JoinPath(benchLocal, subdir)
|
||||
matches := core.PathGlob(core.JoinPath(resultDir, "*.jsonl"))
|
||||
for _, jf := range matches {
|
||||
n, err := importBenchmarkFile(db, jf, subdir)
|
||||
n, err := importBenchmarkFile(importSession, jf, subdir)
|
||||
if err != nil {
|
||||
return core.E("store.ImportAll", core.Sprintf("import benchmark file %s", jf), err)
|
||||
}
|
||||
|
|
@ -259,7 +297,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
|
|||
}
|
||||
}
|
||||
if isFile(local) {
|
||||
n, err := importBenchmarkFile(db, local, "benchmark")
|
||||
n, err := importBenchmarkFile(importSession, local, "benchmark")
|
||||
if err != nil {
|
||||
return core.E("store.ImportAll", core.Sprintf("import benchmark file %s", local), err)
|
||||
}
|
||||
|
|
@ -270,13 +308,13 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
|
|||
core.Print(w, " benchmark_results: %d rows", benchTotal)
|
||||
|
||||
// ── 4. Benchmark questions ──
|
||||
if err := db.Exec("DROP TABLE IF EXISTS benchmark_questions"); err != nil {
|
||||
if err := importSession.exec("DROP TABLE IF EXISTS benchmark_questions"); err != nil {
|
||||
return core.E("store.ImportAll", "drop benchmark_questions", err)
|
||||
}
|
||||
if err := db.Exec(`
|
||||
CREATE TABLE benchmark_questions (
|
||||
benchmark VARCHAR, id VARCHAR, question TEXT,
|
||||
best_answer TEXT, correct_answers TEXT, incorrect_answers TEXT, category VARCHAR
|
||||
if err := importSession.exec(`
|
||||
CREATE TABLE benchmark_questions (
|
||||
benchmark VARCHAR, id VARCHAR, question TEXT,
|
||||
best_answer TEXT, correct_answers TEXT, incorrect_answers TEXT, category VARCHAR
|
||||
)
|
||||
`); err != nil {
|
||||
return core.E("store.ImportAll", "create benchmark_questions", err)
|
||||
|
|
@ -286,7 +324,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
|
|||
for _, bname := range []string{"truthfulqa", "gsm8k", "do_not_answer", "toxigen"} {
|
||||
local := core.JoinPath(benchLocal, bname+".jsonl")
|
||||
if isFile(local) {
|
||||
n, err := importBenchmarkQuestions(db, local, bname)
|
||||
n, err := importBenchmarkQuestions(importSession, local, bname)
|
||||
if err != nil {
|
||||
return core.E("store.ImportAll", core.Sprintf("import benchmark questions %s", local), err)
|
||||
}
|
||||
|
|
@ -297,13 +335,13 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
|
|||
core.Print(w, " benchmark_questions: %d rows", benchQTotal)
|
||||
|
||||
// ── 5. Seeds ──
|
||||
if err := db.Exec("DROP TABLE IF EXISTS seeds"); err != nil {
|
||||
if err := importSession.exec("DROP TABLE IF EXISTS seeds"); err != nil {
|
||||
return core.E("store.ImportAll", "drop seeds", err)
|
||||
}
|
||||
if err := db.Exec(`
|
||||
CREATE TABLE seeds (
|
||||
source_file VARCHAR, region VARCHAR, seed_id VARCHAR, domain VARCHAR, prompt TEXT
|
||||
)
|
||||
if err := importSession.exec(`
|
||||
CREATE TABLE seeds (
|
||||
source_file VARCHAR, region VARCHAR, seed_id VARCHAR, domain VARCHAR, prompt TEXT
|
||||
)
|
||||
`); err != nil {
|
||||
return core.E("store.ImportAll", "create seeds", err)
|
||||
}
|
||||
|
|
@ -314,7 +352,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
|
|||
if !isDir(seedDir) {
|
||||
continue
|
||||
}
|
||||
n, err := importSeeds(db, seedDir)
|
||||
n, err := importSeeds(importSession, seedDir)
|
||||
if err != nil {
|
||||
return core.E("store.ImportAll", core.Sprintf("import seeds %s", seedDir), err)
|
||||
}
|
||||
|
|
@ -323,6 +361,11 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
|
|||
totals["seeds"] = seedTotal
|
||||
core.Print(w, " seeds: %d rows", seedTotal)
|
||||
|
||||
if err := transaction.Commit(); err != nil {
|
||||
return core.E("store.ImportAll", "commit import transaction", err)
|
||||
}
|
||||
committed = true
|
||||
|
||||
// ── Summary ──
|
||||
grandTotal := 0
|
||||
core.Print(w, "\n%s", repeat("=", 50))
|
||||
|
|
@ -339,7 +382,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func importTrainingFile(db *DuckDB, path, source, split string) (int, error) {
|
||||
func importTrainingFile(db duckDBImportSession, path, source, split string) (int, error) {
|
||||
r := localFs.Open(path)
|
||||
if !r.OK {
|
||||
return 0, core.E("store.importTrainingFile", core.Sprintf("open %s", path), r.Value.(error))
|
||||
|
|
@ -351,12 +394,15 @@ func importTrainingFile(db *DuckDB, path, source, split string) (int, error) {
|
|||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
lineNumber := 0
|
||||
for scanner.Scan() {
|
||||
lineNumber++
|
||||
var rec struct {
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
}
|
||||
if r := core.JSONUnmarshal(scanner.Bytes(), &rec); !r.OK {
|
||||
continue
|
||||
parseErr, _ := r.Value.(error)
|
||||
return count, core.E("store.importTrainingFile", core.Sprintf("parse %s line %d", path, lineNumber), parseErr)
|
||||
}
|
||||
|
||||
prompt := ""
|
||||
|
|
@ -375,7 +421,7 @@ func importTrainingFile(db *DuckDB, path, source, split string) (int, error) {
|
|||
}
|
||||
|
||||
msgsJSON := core.JSONMarshalString(rec.Messages)
|
||||
if err := db.Exec(`INSERT INTO training_examples VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
if err := db.exec(`INSERT INTO training_examples VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
source, split, prompt, response, assistantCount, msgsJSON, len(response)); err != nil {
|
||||
return count, core.E("store.importTrainingFile", "insert training example", err)
|
||||
}
|
||||
|
|
@ -387,7 +433,7 @@ func importTrainingFile(db *DuckDB, path, source, split string) (int, error) {
|
|||
return count, nil
|
||||
}
|
||||
|
||||
func importBenchmarkFile(db *DuckDB, path, source string) (int, error) {
|
||||
func importBenchmarkFile(db duckDBImportSession, path, source string) (int, error) {
|
||||
r := localFs.Open(path)
|
||||
if !r.OK {
|
||||
return 0, core.E("store.importBenchmarkFile", core.Sprintf("open %s", path), r.Value.(error))
|
||||
|
|
@ -399,13 +445,16 @@ func importBenchmarkFile(db *DuckDB, path, source string) (int, error) {
|
|||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
lineNumber := 0
|
||||
for scanner.Scan() {
|
||||
lineNumber++
|
||||
var rec map[string]any
|
||||
if r := core.JSONUnmarshal(scanner.Bytes(), &rec); !r.OK {
|
||||
continue
|
||||
parseErr, _ := r.Value.(error)
|
||||
return count, core.E("store.importBenchmarkFile", core.Sprintf("parse %s line %d", path, lineNumber), parseErr)
|
||||
}
|
||||
|
||||
if err := db.Exec(`INSERT INTO benchmark_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
if err := db.exec(`INSERT INTO benchmark_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
source,
|
||||
core.Sprint(rec["id"]),
|
||||
strOrEmpty(rec, "benchmark"),
|
||||
|
|
@ -425,7 +474,7 @@ func importBenchmarkFile(db *DuckDB, path, source string) (int, error) {
|
|||
return count, nil
|
||||
}
|
||||
|
||||
func importBenchmarkQuestions(db *DuckDB, path, benchmark string) (int, error) {
|
||||
func importBenchmarkQuestions(db duckDBImportSession, path, benchmark string) (int, error) {
|
||||
r := localFs.Open(path)
|
||||
if !r.OK {
|
||||
return 0, core.E("store.importBenchmarkQuestions", core.Sprintf("open %s", path), r.Value.(error))
|
||||
|
|
@ -437,16 +486,19 @@ func importBenchmarkQuestions(db *DuckDB, path, benchmark string) (int, error) {
|
|||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
lineNumber := 0
|
||||
for scanner.Scan() {
|
||||
lineNumber++
|
||||
var rec map[string]any
|
||||
if r := core.JSONUnmarshal(scanner.Bytes(), &rec); !r.OK {
|
||||
continue
|
||||
parseErr, _ := r.Value.(error)
|
||||
return count, core.E("store.importBenchmarkQuestions", core.Sprintf("parse %s line %d", path, lineNumber), parseErr)
|
||||
}
|
||||
|
||||
correctJSON := core.JSONMarshalString(rec["correct_answers"])
|
||||
incorrectJSON := core.JSONMarshalString(rec["incorrect_answers"])
|
||||
|
||||
if err := db.Exec(`INSERT INTO benchmark_questions VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
if err := db.exec(`INSERT INTO benchmark_questions VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
benchmark,
|
||||
core.Sprint(rec["id"]),
|
||||
strOrEmpty(rec, "question"),
|
||||
|
|
@ -465,15 +517,11 @@ func importBenchmarkQuestions(db *DuckDB, path, benchmark string) (int, error) {
|
|||
return count, nil
|
||||
}
|
||||
|
||||
func importSeeds(db *DuckDB, seedDir string) (int, error) {
|
||||
func importSeeds(db duckDBImportSession, seedDir string) (int, error) {
|
||||
count := 0
|
||||
var firstErr error
|
||||
walkDir(seedDir, func(path string) {
|
||||
if firstErr != nil {
|
||||
return
|
||||
}
|
||||
if err := walkDir(seedDir, func(path string) error {
|
||||
if !core.HasSuffix(path, ".json") {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
rel := core.TrimPrefix(path, seedDir+"/")
|
||||
|
|
@ -481,8 +529,7 @@ func importSeeds(db *DuckDB, seedDir string) (int, error) {
|
|||
|
||||
readResult := localFs.Read(path)
|
||||
if !readResult.OK {
|
||||
firstErr = core.E("store.importSeeds", core.Sprintf("read seed file %s", rel), readResult.Value.(error))
|
||||
return
|
||||
return core.E("store.importSeeds", core.Sprintf("read seed file %s", rel), readResult.Value.(error))
|
||||
}
|
||||
data := []byte(readResult.Value.(string))
|
||||
|
||||
|
|
@ -491,8 +538,7 @@ func importSeeds(db *DuckDB, seedDir string) (int, error) {
|
|||
var raw any
|
||||
if r := core.JSONUnmarshal(data, &raw); !r.OK {
|
||||
err, _ := r.Value.(error)
|
||||
firstErr = core.E("store.importSeeds", core.Sprintf("parse seed file %s", rel), err)
|
||||
return
|
||||
return core.E("store.importSeeds", core.Sprintf("parse seed file %s", rel), err)
|
||||
}
|
||||
|
||||
switch v := raw.(type) {
|
||||
|
|
@ -516,50 +562,53 @@ func importSeeds(db *DuckDB, seedDir string) (int, error) {
|
|||
if prompt == "" {
|
||||
prompt = strOrEmpty(seed, "question")
|
||||
}
|
||||
if err := db.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
|
||||
if err := db.exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
|
||||
rel, region,
|
||||
strOrEmpty(seed, "seed_id"),
|
||||
strOrEmpty(seed, "domain"),
|
||||
prompt,
|
||||
); err != nil {
|
||||
firstErr = core.E("store.importSeeds", "insert seed prompt", err)
|
||||
return
|
||||
return core.E("store.importSeeds", "insert seed prompt", err)
|
||||
}
|
||||
count++
|
||||
case string:
|
||||
if err := db.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
|
||||
if err := db.exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
|
||||
rel, region, "", "", seed); err != nil {
|
||||
firstErr = core.E("store.importSeeds", "insert seed string", err)
|
||||
return
|
||||
return core.E("store.importSeeds", "insert seed string", err)
|
||||
}
|
||||
count++
|
||||
}
|
||||
}
|
||||
})
|
||||
if firstErr != nil {
|
||||
return count, firstErr
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// walkDir recursively visits all regular files under root, calling fn for each.
|
||||
func walkDir(root string, fn func(path string)) {
|
||||
func walkDir(root string, fn func(path string) error) error {
|
||||
r := localFs.List(root)
|
||||
if !r.OK {
|
||||
return
|
||||
return core.E("store.walkDir", core.Sprintf("list %s", root), r.Value.(error))
|
||||
}
|
||||
entries, ok := r.Value.([]fs.DirEntry)
|
||||
if !ok {
|
||||
return
|
||||
return core.E("store.walkDir", core.Sprintf("list %s returned invalid entries", root), nil)
|
||||
}
|
||||
for _, entry := range entries {
|
||||
full := core.JoinPath(root, entry.Name())
|
||||
if entry.IsDir() {
|
||||
walkDir(full, fn)
|
||||
if err := walkDir(full, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
fn(full)
|
||||
if err := fn(full); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// strOrEmpty extracts a string value from a map, returning an empty string if
|
||||
|
|
|
|||
70
import_test.go
Normal file
70
import_test.go
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
core "dappco.re/go/core"
|
||||
)
|
||||
|
||||
type importSessionStub struct {
|
||||
inserts int
|
||||
}
|
||||
|
||||
func (session *importSessionStub) exec(string, ...any) error {
|
||||
session.inserts++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (session *importSessionStub) queryRowScan(string, any, ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestImport_ImportTrainingFile_Bad_MalformedJSONL(t *testing.T) {
|
||||
path := testPath(t, "training.jsonl")
|
||||
requireCoreWriteBytes(t, path, []byte("{\"messages\":[]}\n{broken\n"))
|
||||
session := &importSessionStub{}
|
||||
|
||||
count, err := importTrainingFile(session, path, "training", "train")
|
||||
|
||||
assertError(t, err)
|
||||
assertContainsString(t, err.Error(), "line 2")
|
||||
assertEqual(t, 1, count)
|
||||
assertEqual(t, 1, session.inserts)
|
||||
}
|
||||
|
||||
func TestImport_ImportBenchmarkFile_Bad_MalformedJSONL(t *testing.T) {
|
||||
path := testPath(t, "benchmark.jsonl")
|
||||
requireCoreWriteBytes(t, path, []byte("{\"id\":\"row-1\"}\n{broken\n"))
|
||||
session := &importSessionStub{}
|
||||
|
||||
count, err := importBenchmarkFile(session, path, "benchmark")
|
||||
|
||||
assertError(t, err)
|
||||
assertContainsString(t, err.Error(), "line 2")
|
||||
assertEqual(t, 1, count)
|
||||
assertEqual(t, 1, session.inserts)
|
||||
}
|
||||
|
||||
func TestImport_ImportBenchmarkQuestions_Bad_MalformedJSONL(t *testing.T) {
|
||||
path := testPath(t, "questions.jsonl")
|
||||
requireCoreWriteBytes(t, path, []byte("{\"id\":\"q-1\"}\n{broken\n"))
|
||||
session := &importSessionStub{}
|
||||
|
||||
count, err := importBenchmarkQuestions(session, path, "truthfulqa")
|
||||
|
||||
assertError(t, err)
|
||||
assertContainsString(t, err.Error(), "line 2")
|
||||
assertEqual(t, 1, count)
|
||||
assertEqual(t, 1, session.inserts)
|
||||
}
|
||||
|
||||
func TestImport_ImportSeeds_Bad_WalkFailure(t *testing.T) {
|
||||
session := &importSessionStub{}
|
||||
|
||||
count, err := importSeeds(session, core.JoinPath(t.TempDir(), "missing-seeds"))
|
||||
|
||||
assertError(t, err)
|
||||
assertContainsString(t, err.Error(), "store.walkDir")
|
||||
assertEqual(t, 0, count)
|
||||
assertEqual(t, 0, session.inserts)
|
||||
}
|
||||
|
|
@ -243,7 +243,10 @@ func importCSV(workspace *Workspace, kind, content string) error {
|
|||
}
|
||||
|
||||
func exportJSON(workspace *Workspace, medium Medium, path string) error {
|
||||
summary := workspace.Aggregate()
|
||||
summary, err := workspace.aggregateFields()
|
||||
if err != nil {
|
||||
return core.E("store.Export", "aggregate workspace", err)
|
||||
}
|
||||
content := core.JSONMarshalString(summary)
|
||||
if err := medium.Write(path, content); err != nil {
|
||||
return core.E("store.Export", "write json", err)
|
||||
|
|
|
|||
|
|
@ -114,6 +114,19 @@ func (medium *renameFailMedium) Rename(string, string) error {
|
|||
return core.E("renameFailMedium.Rename", "forced rename failure", nil)
|
||||
}
|
||||
|
||||
type writeFailOnceMedium struct {
|
||||
*memoryMedium
|
||||
failures int
|
||||
}
|
||||
|
||||
func (medium *writeFailOnceMedium) Write(path, content string) error {
|
||||
if medium.failures > 0 {
|
||||
medium.failures--
|
||||
return core.E("writeFailOnceMedium.Write", "forced write failure", nil)
|
||||
}
|
||||
return medium.memoryMedium.Write(path, content)
|
||||
}
|
||||
|
||||
func (medium *memoryMedium) List(path string) ([]fs.DirEntry, error) { return nil, nil }
|
||||
|
||||
func (medium *memoryMedium) Stat(path string) (fs.FileInfo, error) {
|
||||
|
|
@ -462,6 +475,30 @@ func TestMedium_Export_Bad_NilArguments(t *testing.T) {
|
|||
assertError(t, Export(workspace, medium, ""))
|
||||
}
|
||||
|
||||
func TestMedium_Export_Bad_JSONPropagatesWorkspaceFailure(t *testing.T) {
|
||||
useWorkspaceStateDirectory(t)
|
||||
|
||||
storeInstance, err := New(":memory:")
|
||||
assertNoError(t, err)
|
||||
defer func() { _ = storeInstance.Close() }()
|
||||
|
||||
workspace, err := storeInstance.NewWorkspace("medium-export-json-closed")
|
||||
assertNoError(t, err)
|
||||
assertNoError(t, workspace.Put("like", map[string]any{"user": "@alice"}))
|
||||
assertNoError(t, workspace.Close())
|
||||
|
||||
medium := newMemoryMedium()
|
||||
assertNoError(t, medium.Write("report.json", `{"previous":true}`))
|
||||
|
||||
err = Export(workspace, medium, "report.json")
|
||||
|
||||
assertError(t, err)
|
||||
assertContainsString(t, err.Error(), "aggregate workspace")
|
||||
content, readErr := medium.Read("report.json")
|
||||
assertNoError(t, readErr)
|
||||
assertEqual(t, `{"previous":true}`, content)
|
||||
}
|
||||
|
||||
func TestMedium_Compact_Good_MediumRoutesArchive(t *testing.T) {
|
||||
useWorkspaceStateDirectory(t)
|
||||
useArchiveOutputDirectory(t)
|
||||
|
|
|
|||
35
publish.go
35
publish.go
|
|
@ -97,11 +97,11 @@ func Publish(cfg PublishConfig, w io.Writer) error {
|
|||
return core.E("store.Publish", "HuggingFace token required (--token, HF_TOKEN env, or ~/.huggingface/token)", nil)
|
||||
}
|
||||
|
||||
files, err := collectUploadFiles(cfg.InputDir)
|
||||
files, hasSplit, err := collectUploadFiles(cfg.InputDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(files) == 0 {
|
||||
if !hasSplit {
|
||||
return core.E("store.Publish", core.Sprintf("no Parquet files found in %s", cfg.InputDir), nil)
|
||||
}
|
||||
|
||||
|
|
@ -150,21 +150,33 @@ func resolveHFToken(explicit string) string {
|
|||
if env := core.Env("HF_TOKEN"); env != "" {
|
||||
return env
|
||||
}
|
||||
home := core.Env("HOME")
|
||||
if home == "" {
|
||||
return ""
|
||||
// Core populates DIR_HOME via os.UserHomeDir while this package keeps the
|
||||
// repository-wide ban on direct os imports.
|
||||
homes := []string{core.Env("DIR_HOME")}
|
||||
if homeEnv := core.Env("HOME"); homeEnv != "" && homeEnv != homes[0] {
|
||||
homes = append(homes, homeEnv)
|
||||
}
|
||||
r := localFs.Read(core.JoinPath(home, ".huggingface", "token"))
|
||||
if !r.OK {
|
||||
return ""
|
||||
for _, home := range homes {
|
||||
if home == "" {
|
||||
continue
|
||||
}
|
||||
r := localFs.Read(core.JoinPath(home, ".huggingface", "token"))
|
||||
if !r.OK {
|
||||
continue
|
||||
}
|
||||
token := core.Trim(r.Value.(string))
|
||||
if token != "" {
|
||||
return token
|
||||
}
|
||||
}
|
||||
return core.Trim(r.Value.(string))
|
||||
return ""
|
||||
}
|
||||
|
||||
// collectUploadFiles finds Parquet split files and an optional dataset card.
|
||||
func collectUploadFiles(inputDir string) ([]uploadEntry, error) {
|
||||
func collectUploadFiles(inputDir string) ([]uploadEntry, bool, error) {
|
||||
splits := []string{"train", "valid", "test"}
|
||||
var files []uploadEntry
|
||||
hasSplit := false
|
||||
|
||||
for _, split := range splits {
|
||||
path := core.JoinPath(inputDir, split+".parquet")
|
||||
|
|
@ -172,6 +184,7 @@ func collectUploadFiles(inputDir string) ([]uploadEntry, error) {
|
|||
continue
|
||||
}
|
||||
files = append(files, uploadEntry{path, core.Sprintf("data/%s.parquet", split)})
|
||||
hasSplit = true
|
||||
}
|
||||
|
||||
// Check for dataset card in parent directory.
|
||||
|
|
@ -180,7 +193,7 @@ func collectUploadFiles(inputDir string) ([]uploadEntry, error) {
|
|||
files = append(files, uploadEntry{cardPath, "README.md"})
|
||||
}
|
||||
|
||||
return files, nil
|
||||
return files, hasSplit, nil
|
||||
}
|
||||
|
||||
func ensureHFDatasetRepo(ctx context.Context, token, repoID string, public bool) error {
|
||||
|
|
|
|||
|
|
@ -16,6 +16,18 @@ func TestPublish_Publish_Bad_EmptyRepository(t *testing.T) {
|
|||
assertContainsString(t, err.Error(), "repository is required")
|
||||
}
|
||||
|
||||
func TestPublish_Publish_Bad_DatasetCardWithoutParquetSplit(t *testing.T) {
|
||||
inputDir := core.JoinPath(t.TempDir(), "data")
|
||||
requireCoreOK(t, testFilesystem().EnsureDir(inputDir))
|
||||
requireCoreWriteBytes(t, core.JoinPath(inputDir, "..", "dataset_card.md"), []byte("# Dataset\n"))
|
||||
|
||||
var output bytes.Buffer
|
||||
err := Publish(PublishConfig{InputDir: inputDir, Repo: "snider/lem-training", DryRun: true}, &output)
|
||||
|
||||
assertError(t, err)
|
||||
assertContainsString(t, err.Error(), "no Parquet files found")
|
||||
}
|
||||
|
||||
func TestPublish_ResolveHFToken_Good_UserHomeFallback(t *testing.T) {
|
||||
homeDirectory := t.TempDir()
|
||||
t.Setenv("HF_TOKEN", "")
|
||||
|
|
|
|||
16
scope.go
16
scope.go
|
|
@ -397,15 +397,11 @@ func (scopedStore *ScopedStore) PurgeExpired() (int64, error) {
|
|||
}
|
||||
|
||||
cutoffUnixMilli := time.Now().UnixMilli()
|
||||
expiredEntries, err := listExpiredEntriesMatchingGroupPrefix(scopedStore.store.sqliteDatabase, scopedStore.namespacePrefix(), cutoffUnixMilli)
|
||||
if err != nil {
|
||||
return 0, core.E("store.ScopedStore.PurgeExpired", "list expired rows", err)
|
||||
}
|
||||
|
||||
removedRows, err := purgeExpiredMatchingGroupPrefix(scopedStore.store.sqliteDatabase, scopedStore.namespacePrefix(), cutoffUnixMilli)
|
||||
expiredEntries, err := deleteExpiredEntriesMatchingGroupPrefix(scopedStore.store.sqliteDatabase, scopedStore.namespacePrefix(), cutoffUnixMilli)
|
||||
if err != nil {
|
||||
return 0, core.E("store.ScopedStore.PurgeExpired", "delete expired rows", err)
|
||||
}
|
||||
removedRows := int64(len(expiredEntries))
|
||||
if removedRows > 0 {
|
||||
for _, expiredEntry := range expiredEntries {
|
||||
scopedStore.store.notify(Event{
|
||||
|
|
@ -822,15 +818,11 @@ func (scopedStoreTransaction *ScopedStoreTransaction) PurgeExpired() (int64, err
|
|||
}
|
||||
|
||||
cutoffUnixMilli := time.Now().UnixMilli()
|
||||
expiredEntries, err := listExpiredEntriesMatchingGroupPrefix(scopedStoreTransaction.storeTransaction.sqliteTransaction, scopedStoreTransaction.scopedStore.namespacePrefix(), cutoffUnixMilli)
|
||||
if err != nil {
|
||||
return 0, core.E("store.ScopedStoreTransaction.PurgeExpired", "list expired rows", err)
|
||||
}
|
||||
|
||||
removedRows, err := purgeExpiredMatchingGroupPrefix(scopedStoreTransaction.storeTransaction.sqliteTransaction, scopedStoreTransaction.scopedStore.namespacePrefix(), cutoffUnixMilli)
|
||||
expiredEntries, err := deleteExpiredEntriesMatchingGroupPrefix(scopedStoreTransaction.storeTransaction.sqliteTransaction, scopedStoreTransaction.scopedStore.namespacePrefix(), cutoffUnixMilli)
|
||||
if err != nil {
|
||||
return 0, core.E("store.ScopedStoreTransaction.PurgeExpired", "delete expired rows", err)
|
||||
}
|
||||
removedRows := int64(len(expiredEntries))
|
||||
if removedRows > 0 {
|
||||
for _, expiredEntry := range expiredEntries {
|
||||
scopedStoreTransaction.storeTransaction.recordEvent(Event{
|
||||
|
|
|
|||
61
store.go
61
store.go
|
|
@ -154,7 +154,9 @@ type Store struct {
|
|||
journalConfiguration JournalConfiguration
|
||||
medium Medium
|
||||
lifecycleLock sync.Mutex
|
||||
closeLock sync.Mutex
|
||||
isClosed bool
|
||||
isClosing bool
|
||||
|
||||
// Event dispatch state.
|
||||
watchers map[string][]chan Event
|
||||
|
|
@ -182,7 +184,7 @@ func (storeInstance *Store) ensureReady(operation string) error {
|
|||
}
|
||||
|
||||
storeInstance.lifecycleLock.Lock()
|
||||
closed := storeInstance.isClosed
|
||||
closed := storeInstance.isClosed || storeInstance.isClosing
|
||||
storeInstance.lifecycleLock.Unlock()
|
||||
if closed {
|
||||
return core.E(operation, "store is closed", nil)
|
||||
|
|
@ -423,12 +425,15 @@ func (storeInstance *Store) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
storeInstance.closeLock.Lock()
|
||||
defer storeInstance.closeLock.Unlock()
|
||||
|
||||
storeInstance.lifecycleLock.Lock()
|
||||
if storeInstance.isClosed {
|
||||
storeInstance.lifecycleLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
storeInstance.isClosed = true
|
||||
storeInstance.isClosing = true
|
||||
storeInstance.lifecycleLock.Unlock()
|
||||
|
||||
if storeInstance.cancelPurge != nil {
|
||||
|
|
@ -470,6 +475,7 @@ func (storeInstance *Store) Close() error {
|
|||
storeInstance.sqliteDatabase = storeInstance.db
|
||||
}
|
||||
if storeInstance.sqliteDatabase == nil {
|
||||
storeInstance.markClosed()
|
||||
return orphanCleanupErr
|
||||
}
|
||||
if err := storeInstance.sqliteDatabase.Close(); err != nil {
|
||||
|
|
@ -478,12 +484,20 @@ func (storeInstance *Store) Close() error {
|
|||
if err := storeInstance.syncMediumBackedDatabase(); err != nil {
|
||||
return core.E("store.Close", "sync medium-backed database", err)
|
||||
}
|
||||
storeInstance.markClosed()
|
||||
if orphanCleanupErr != nil {
|
||||
return core.E("store.Close", "close orphan workspaces", orphanCleanupErr)
|
||||
}
|
||||
return orphanCleanupErr
|
||||
}
|
||||
|
||||
func (storeInstance *Store) markClosed() {
|
||||
storeInstance.lifecycleLock.Lock()
|
||||
storeInstance.isClosed = true
|
||||
storeInstance.isClosing = false
|
||||
storeInstance.lifecycleLock.Unlock()
|
||||
}
|
||||
|
||||
func (storeInstance *Store) syncMediumBackedDatabase() error {
|
||||
if storeInstance == nil || !storeInstance.mediumBacked || storeInstance.medium == nil {
|
||||
return nil
|
||||
|
|
@ -965,15 +979,11 @@ func (storeInstance *Store) PurgeExpired() (int64, error) {
|
|||
}
|
||||
|
||||
cutoffUnixMilli := time.Now().UnixMilli()
|
||||
expiredEntries, err := listExpiredEntriesMatchingGroupPrefix(storeInstance.sqliteDatabase, "", cutoffUnixMilli)
|
||||
if err != nil {
|
||||
return 0, core.E("store.PurgeExpired", "list expired rows", err)
|
||||
}
|
||||
|
||||
removedRows, err := purgeExpiredMatchingGroupPrefix(storeInstance.sqliteDatabase, "", cutoffUnixMilli)
|
||||
expiredEntries, err := deleteExpiredEntriesMatchingGroupPrefix(storeInstance.sqliteDatabase, "", cutoffUnixMilli)
|
||||
if err != nil {
|
||||
return 0, core.E("store.PurgeExpired", "delete expired rows", err)
|
||||
}
|
||||
removedRows := int64(len(expiredEntries))
|
||||
if removedRows > 0 {
|
||||
for _, expiredEntry := range expiredEntries {
|
||||
storeInstance.notify(Event{
|
||||
|
|
@ -1062,19 +1072,19 @@ type expiredEntryRef struct {
|
|||
key string
|
||||
}
|
||||
|
||||
func listExpiredEntriesMatchingGroupPrefix(database schemaDatabase, groupPrefix string, cutoffUnixMilli int64) ([]expiredEntryRef, error) {
|
||||
func deleteExpiredEntriesMatchingGroupPrefix(database schemaDatabase, groupPrefix string, cutoffUnixMilli int64) ([]expiredEntryRef, error) {
|
||||
var (
|
||||
rows *sql.Rows
|
||||
err error
|
||||
)
|
||||
if groupPrefix == "" {
|
||||
rows, err = database.Query(
|
||||
"SELECT "+entryGroupColumn+", "+entryKeyColumn+" FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ? ORDER BY "+entryGroupColumn+", "+entryKeyColumn,
|
||||
"DELETE FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ? RETURNING "+entryGroupColumn+", "+entryKeyColumn,
|
||||
cutoffUnixMilli,
|
||||
)
|
||||
} else {
|
||||
rows, err = database.Query(
|
||||
"SELECT "+entryGroupColumn+", "+entryKeyColumn+" FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ? AND "+entryGroupColumn+" LIKE ? ESCAPE '^' ORDER BY "+entryGroupColumn+", "+entryKeyColumn,
|
||||
"DELETE FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ? AND "+entryGroupColumn+" LIKE ? ESCAPE '^' RETURNING "+entryGroupColumn+", "+entryKeyColumn,
|
||||
cutoffUnixMilli, escapeLike(groupPrefix)+"%",
|
||||
)
|
||||
}
|
||||
|
|
@ -1097,35 +1107,6 @@ func listExpiredEntriesMatchingGroupPrefix(database schemaDatabase, groupPrefix
|
|||
return expiredEntries, nil
|
||||
}
|
||||
|
||||
// purgeExpiredMatchingGroupPrefix deletes expired rows globally when
|
||||
// groupPrefix is empty, otherwise only rows whose group starts with the given
|
||||
// prefix.
|
||||
func purgeExpiredMatchingGroupPrefix(database schemaDatabase, groupPrefix string, cutoffUnixMilli int64) (int64, error) {
|
||||
var (
|
||||
deleteResult sql.Result
|
||||
err error
|
||||
)
|
||||
if groupPrefix == "" {
|
||||
deleteResult, err = database.Exec(
|
||||
"DELETE FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ?",
|
||||
cutoffUnixMilli,
|
||||
)
|
||||
} else {
|
||||
deleteResult, err = database.Exec(
|
||||
"DELETE FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ? AND "+entryGroupColumn+" LIKE ? ESCAPE '^'",
|
||||
cutoffUnixMilli, escapeLike(groupPrefix)+"%",
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
removedRows, rowsAffectedErr := deleteResult.RowsAffected()
|
||||
if rowsAffectedErr != nil {
|
||||
return 0, rowsAffectedErr
|
||||
}
|
||||
return removedRows, nil
|
||||
}
|
||||
|
||||
type schemaDatabase interface {
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
QueryRow(query string, args ...any) *sql.Row
|
||||
|
|
|
|||
|
|
@ -1082,6 +1082,27 @@ func TestStore_Close_Bad_DriverCloseError(t *testing.T) {
|
|||
assertContainsString(t, err.Error(), "store.Close")
|
||||
}
|
||||
|
||||
func TestStore_Close_Bad_MediumSyncFailureRetryable(t *testing.T) {
|
||||
useWorkspaceStateDirectory(t)
|
||||
|
||||
medium := &writeFailOnceMedium{memoryMedium: newMemoryMedium(), failures: 1}
|
||||
storeInstance, err := New("retryable-close.db", WithMedium(medium))
|
||||
assertNoError(t, err)
|
||||
assertNoError(t, storeInstance.Set("g", "k", "v"))
|
||||
|
||||
err = storeInstance.Close()
|
||||
assertError(t, err)
|
||||
assertContainsString(t, err.Error(), "sync medium-backed database")
|
||||
assertFalse(t, storeInstance.IsClosed())
|
||||
|
||||
_, err = storeInstance.Get("g", "k")
|
||||
assertError(t, err)
|
||||
|
||||
assertNoError(t, storeInstance.Close())
|
||||
assertTrue(t, storeInstance.IsClosed())
|
||||
assertTrue(t, medium.Exists("retryable-close.db"))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
@ -1600,6 +1621,32 @@ func TestStore_PurgeExpired_Good(t *testing.T) {
|
|||
assertEqualf(t, 1, count, "only non-expiring key should remain")
|
||||
}
|
||||
|
||||
func TestStore_PurgeExpired_Good_NotifiesDeletedRows(t *testing.T) {
|
||||
storeInstance, _ := New(":memory:")
|
||||
defer func() { _ = storeInstance.Close() }()
|
||||
|
||||
assertNoError(t, storeInstance.SetWithTTL("g", "expired", "1", 1*time.Millisecond))
|
||||
assertNoError(t, storeInstance.SetWithTTL("g", "live", "2", time.Hour))
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
events := storeInstance.Watch("*")
|
||||
defer storeInstance.Unwatch("*", events)
|
||||
|
||||
removed, err := storeInstance.PurgeExpired()
|
||||
assertNoError(t, err)
|
||||
assertEqual(t, int64(1), removed)
|
||||
|
||||
event := <-events
|
||||
assertEqual(t, EventDelete, event.Type)
|
||||
assertEqual(t, "g", event.Group)
|
||||
assertEqual(t, "expired", event.Key)
|
||||
select {
|
||||
case extraEvent := <-events:
|
||||
t.Fatalf("unexpected extra purge event: %#v", extraEvent)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_PurgeExpired_Good_NoneExpired(t *testing.T) {
|
||||
storeInstance, _ := New(":memory:")
|
||||
defer func() { _ = storeInstance.Close() }()
|
||||
|
|
|
|||
|
|
@ -512,15 +512,11 @@ func (storeTransaction *StoreTransaction) PurgeExpired() (int64, error) {
|
|||
}
|
||||
|
||||
cutoffUnixMilli := time.Now().UnixMilli()
|
||||
expiredEntries, err := listExpiredEntriesMatchingGroupPrefix(storeTransaction.sqliteTransaction, "", cutoffUnixMilli)
|
||||
if err != nil {
|
||||
return 0, core.E("store.Transaction.PurgeExpired", "list expired rows", err)
|
||||
}
|
||||
|
||||
removedRows, err := purgeExpiredMatchingGroupPrefix(storeTransaction.sqliteTransaction, "", cutoffUnixMilli)
|
||||
expiredEntries, err := deleteExpiredEntriesMatchingGroupPrefix(storeTransaction.sqliteTransaction, "", cutoffUnixMilli)
|
||||
if err != nil {
|
||||
return 0, core.E("store.Transaction.PurgeExpired", "delete expired rows", err)
|
||||
}
|
||||
removedRows := int64(len(expiredEntries))
|
||||
if removedRows > 0 {
|
||||
for _, expiredEntry := range expiredEntries {
|
||||
storeTransaction.recordEvent(Event{
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue