diff --git a/README.md b/README.md index 7660e53..f78416f 100644 --- a/README.md +++ b/README.md @@ -94,4 +94,4 @@ go build ./... ## Licence -European Union Public Licence 1.2 — see [LICENCE](LICENCE) for details. +European Union Public Licence 1.2 — see [LICENCE.md](LICENCE.md) for details. diff --git a/events.go b/events.go index 00068d5..00bc135 100644 --- a/events.go +++ b/events.go @@ -76,7 +76,7 @@ func (storeInstance *Store) Watch(group string) <-chan Event { storeInstance.lifecycleLock.Lock() defer storeInstance.lifecycleLock.Unlock() - if storeInstance.isClosed { + if storeInstance.isClosed || storeInstance.isClosing { return closedEventChannel() } @@ -97,7 +97,7 @@ func (storeInstance *Store) Unwatch(group string, events <-chan Event) { } storeInstance.lifecycleLock.Lock() - closed := storeInstance.isClosed + closed := storeInstance.isClosed || storeInstance.isClosing storeInstance.lifecycleLock.Unlock() if closed { return @@ -146,7 +146,7 @@ func (storeInstance *Store) OnChange(callback func(Event)) func() { storeInstance.lifecycleLock.Lock() defer storeInstance.lifecycleLock.Unlock() - if storeInstance.isClosed { + if storeInstance.isClosed || storeInstance.isClosing { return func() {} } @@ -188,7 +188,7 @@ func (storeInstance *Store) notify(event Event) { } storeInstance.lifecycleLock.Lock() - if storeInstance.isClosed { + if storeInstance.isClosed || storeInstance.isClosing { storeInstance.lifecycleLock.Unlock() return } @@ -210,7 +210,7 @@ func (storeInstance *Store) notify(event Event) { storeInstance.watcherLock.RUnlock() storeInstance.lifecycleLock.Lock() - if storeInstance.isClosed { + if storeInstance.isClosed || storeInstance.isClosing { storeInstance.lifecycleLock.Unlock() return } diff --git a/import.go b/import.go index af18e41..8edb98b 100644 --- a/import.go +++ b/import.go @@ -4,6 +4,7 @@ package store import ( "bufio" + "database/sql" "io" "io/fs" @@ -13,6 +14,27 @@ import ( // localFs provides unrestricted filesystem access for import operations. var localFs = (&core.Fs{}).New("/") +type duckDBImportSession interface { + exec(query string, args ...any) error + queryRowScan(query string, dest any, args ...any) error +} + +type duckDBImportTransaction struct { + transaction *sql.Tx +} + +func (session duckDBImportTransaction) exec(query string, args ...any) error { + _, err := session.transaction.Exec(query, args...) + if err != nil { + return core.E("store.duckDBImportTransaction.Exec", "execute query", err) + } + return nil +} + +func (session duckDBImportTransaction) queryRowScan(query string, dest any, args ...any) error { + return session.transaction.QueryRow(query, args...).Scan(dest) +} + // ScpFunc is a callback for executing SCP file transfers. // The function receives remote source and local destination paths. // @@ -77,6 +99,10 @@ type ImportConfig struct { // // err := store.ImportAll(db, store.ImportConfig{DataDir: "/Volumes/Data/lem"}, os.Stdout) func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error { + if db == nil || db.Conn() == nil { + return core.E("store.ImportAll", "database is nil", nil) + } + m3Host := cfg.M3Host if m3Host == "" { m3Host = "m3" @@ -93,14 +119,26 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error { core.Print(w, " WARNING: could not pull golden set from M3: %v", err) } } + transaction, err := db.Conn().Begin() + if err != nil { + return core.E("store.ImportAll", "begin import transaction", err) + } + committed := false + defer func() { + if !committed { + _ = transaction.Rollback() + } + }() + importSession := duckDBImportTransaction{transaction: transaction} + if isFile(goldenPath) { - if err := db.Exec("DROP TABLE IF EXISTS golden_set"); err != nil { + if err := importSession.exec("DROP TABLE IF EXISTS golden_set"); err != nil { return core.E("store.ImportAll", "drop golden_set", err) } - err := db.Exec(core.Sprintf(` - CREATE TABLE golden_set AS - SELECT - idx::INT AS idx, + err := importSession.exec(core.Sprintf(` + CREATE TABLE golden_set AS + SELECT + idx::INT AS idx, seed_id::VARCHAR AS seed_id, domain::VARCHAR AS domain, voice::VARCHAR AS voice, @@ -115,7 +153,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error { return core.E("store.ImportAll", "import golden_set", err) } else { var n int - if err := db.QueryRowScan("SELECT count(*) FROM golden_set", &n); err != nil { + if err := importSession.queryRowScan("SELECT count(*) FROM golden_set", &n); err != nil { return core.E("store.ImportAll", "count golden_set", err) } totals["golden_set"] = n @@ -160,13 +198,13 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error { } } - if err := db.Exec("DROP TABLE IF EXISTS training_examples"); err != nil { + if err := importSession.exec("DROP TABLE IF EXISTS training_examples"); err != nil { return core.E("store.ImportAll", "drop training_examples", err) } - if err := db.Exec(` - CREATE TABLE training_examples ( - source VARCHAR, - split VARCHAR, + if err := importSession.exec(` + CREATE TABLE training_examples ( + source VARCHAR, + split VARCHAR, prompt TEXT, response TEXT, num_turns INT, @@ -192,7 +230,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error { split = "test" } - n, err := importTrainingFile(db, local, td.name, split) + n, err := importTrainingFile(importSession, local, td.name, split) if err != nil { return core.E("store.ImportAll", core.Sprintf("import training file %s", local), err) } @@ -224,13 +262,13 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error { } } - if err := db.Exec("DROP TABLE IF EXISTS benchmark_results"); err != nil { + if err := importSession.exec("DROP TABLE IF EXISTS benchmark_results"); err != nil { return core.E("store.ImportAll", "drop benchmark_results", err) } - if err := db.Exec(` - CREATE TABLE benchmark_results ( - source VARCHAR, id VARCHAR, benchmark VARCHAR, model VARCHAR, - prompt TEXT, response TEXT, elapsed_seconds DOUBLE, domain VARCHAR + if err := importSession.exec(` + CREATE TABLE benchmark_results ( + source VARCHAR, id VARCHAR, benchmark VARCHAR, model VARCHAR, + prompt TEXT, response TEXT, elapsed_seconds DOUBLE, domain VARCHAR ) `); err != nil { return core.E("store.ImportAll", "create benchmark_results", err) @@ -241,7 +279,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error { resultDir := core.JoinPath(benchLocal, subdir) matches := core.PathGlob(core.JoinPath(resultDir, "*.jsonl")) for _, jf := range matches { - n, err := importBenchmarkFile(db, jf, subdir) + n, err := importBenchmarkFile(importSession, jf, subdir) if err != nil { return core.E("store.ImportAll", core.Sprintf("import benchmark file %s", jf), err) } @@ -259,7 +297,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error { } } if isFile(local) { - n, err := importBenchmarkFile(db, local, "benchmark") + n, err := importBenchmarkFile(importSession, local, "benchmark") if err != nil { return core.E("store.ImportAll", core.Sprintf("import benchmark file %s", local), err) } @@ -270,13 +308,13 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error { core.Print(w, " benchmark_results: %d rows", benchTotal) // ── 4. Benchmark questions ── - if err := db.Exec("DROP TABLE IF EXISTS benchmark_questions"); err != nil { + if err := importSession.exec("DROP TABLE IF EXISTS benchmark_questions"); err != nil { return core.E("store.ImportAll", "drop benchmark_questions", err) } - if err := db.Exec(` - CREATE TABLE benchmark_questions ( - benchmark VARCHAR, id VARCHAR, question TEXT, - best_answer TEXT, correct_answers TEXT, incorrect_answers TEXT, category VARCHAR + if err := importSession.exec(` + CREATE TABLE benchmark_questions ( + benchmark VARCHAR, id VARCHAR, question TEXT, + best_answer TEXT, correct_answers TEXT, incorrect_answers TEXT, category VARCHAR ) `); err != nil { return core.E("store.ImportAll", "create benchmark_questions", err) @@ -286,7 +324,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error { for _, bname := range []string{"truthfulqa", "gsm8k", "do_not_answer", "toxigen"} { local := core.JoinPath(benchLocal, bname+".jsonl") if isFile(local) { - n, err := importBenchmarkQuestions(db, local, bname) + n, err := importBenchmarkQuestions(importSession, local, bname) if err != nil { return core.E("store.ImportAll", core.Sprintf("import benchmark questions %s", local), err) } @@ -297,13 +335,13 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error { core.Print(w, " benchmark_questions: %d rows", benchQTotal) // ── 5. Seeds ── - if err := db.Exec("DROP TABLE IF EXISTS seeds"); err != nil { + if err := importSession.exec("DROP TABLE IF EXISTS seeds"); err != nil { return core.E("store.ImportAll", "drop seeds", err) } - if err := db.Exec(` - CREATE TABLE seeds ( - source_file VARCHAR, region VARCHAR, seed_id VARCHAR, domain VARCHAR, prompt TEXT - ) + if err := importSession.exec(` + CREATE TABLE seeds ( + source_file VARCHAR, region VARCHAR, seed_id VARCHAR, domain VARCHAR, prompt TEXT + ) `); err != nil { return core.E("store.ImportAll", "create seeds", err) } @@ -314,7 +352,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error { if !isDir(seedDir) { continue } - n, err := importSeeds(db, seedDir) + n, err := importSeeds(importSession, seedDir) if err != nil { return core.E("store.ImportAll", core.Sprintf("import seeds %s", seedDir), err) } @@ -323,6 +361,11 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error { totals["seeds"] = seedTotal core.Print(w, " seeds: %d rows", seedTotal) + if err := transaction.Commit(); err != nil { + return core.E("store.ImportAll", "commit import transaction", err) + } + committed = true + // ── Summary ── grandTotal := 0 core.Print(w, "\n%s", repeat("=", 50)) @@ -339,7 +382,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error { return nil } -func importTrainingFile(db *DuckDB, path, source, split string) (int, error) { +func importTrainingFile(db duckDBImportSession, path, source, split string) (int, error) { r := localFs.Open(path) if !r.OK { return 0, core.E("store.importTrainingFile", core.Sprintf("open %s", path), r.Value.(error)) @@ -351,12 +394,15 @@ func importTrainingFile(db *DuckDB, path, source, split string) (int, error) { scanner := bufio.NewScanner(f) scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + lineNumber := 0 for scanner.Scan() { + lineNumber++ var rec struct { Messages []ChatMessage `json:"messages"` } if r := core.JSONUnmarshal(scanner.Bytes(), &rec); !r.OK { - continue + parseErr, _ := r.Value.(error) + return count, core.E("store.importTrainingFile", core.Sprintf("parse %s line %d", path, lineNumber), parseErr) } prompt := "" @@ -375,7 +421,7 @@ func importTrainingFile(db *DuckDB, path, source, split string) (int, error) { } msgsJSON := core.JSONMarshalString(rec.Messages) - if err := db.Exec(`INSERT INTO training_examples VALUES (?, ?, ?, ?, ?, ?, ?)`, + if err := db.exec(`INSERT INTO training_examples VALUES (?, ?, ?, ?, ?, ?, ?)`, source, split, prompt, response, assistantCount, msgsJSON, len(response)); err != nil { return count, core.E("store.importTrainingFile", "insert training example", err) } @@ -387,7 +433,7 @@ func importTrainingFile(db *DuckDB, path, source, split string) (int, error) { return count, nil } -func importBenchmarkFile(db *DuckDB, path, source string) (int, error) { +func importBenchmarkFile(db duckDBImportSession, path, source string) (int, error) { r := localFs.Open(path) if !r.OK { return 0, core.E("store.importBenchmarkFile", core.Sprintf("open %s", path), r.Value.(error)) @@ -399,13 +445,16 @@ func importBenchmarkFile(db *DuckDB, path, source string) (int, error) { scanner := bufio.NewScanner(f) scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + lineNumber := 0 for scanner.Scan() { + lineNumber++ var rec map[string]any if r := core.JSONUnmarshal(scanner.Bytes(), &rec); !r.OK { - continue + parseErr, _ := r.Value.(error) + return count, core.E("store.importBenchmarkFile", core.Sprintf("parse %s line %d", path, lineNumber), parseErr) } - if err := db.Exec(`INSERT INTO benchmark_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + if err := db.exec(`INSERT INTO benchmark_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, source, core.Sprint(rec["id"]), strOrEmpty(rec, "benchmark"), @@ -425,7 +474,7 @@ func importBenchmarkFile(db *DuckDB, path, source string) (int, error) { return count, nil } -func importBenchmarkQuestions(db *DuckDB, path, benchmark string) (int, error) { +func importBenchmarkQuestions(db duckDBImportSession, path, benchmark string) (int, error) { r := localFs.Open(path) if !r.OK { return 0, core.E("store.importBenchmarkQuestions", core.Sprintf("open %s", path), r.Value.(error)) @@ -437,16 +486,19 @@ func importBenchmarkQuestions(db *DuckDB, path, benchmark string) (int, error) { scanner := bufio.NewScanner(f) scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + lineNumber := 0 for scanner.Scan() { + lineNumber++ var rec map[string]any if r := core.JSONUnmarshal(scanner.Bytes(), &rec); !r.OK { - continue + parseErr, _ := r.Value.(error) + return count, core.E("store.importBenchmarkQuestions", core.Sprintf("parse %s line %d", path, lineNumber), parseErr) } correctJSON := core.JSONMarshalString(rec["correct_answers"]) incorrectJSON := core.JSONMarshalString(rec["incorrect_answers"]) - if err := db.Exec(`INSERT INTO benchmark_questions VALUES (?, ?, ?, ?, ?, ?, ?)`, + if err := db.exec(`INSERT INTO benchmark_questions VALUES (?, ?, ?, ?, ?, ?, ?)`, benchmark, core.Sprint(rec["id"]), strOrEmpty(rec, "question"), @@ -465,15 +517,11 @@ func importBenchmarkQuestions(db *DuckDB, path, benchmark string) (int, error) { return count, nil } -func importSeeds(db *DuckDB, seedDir string) (int, error) { +func importSeeds(db duckDBImportSession, seedDir string) (int, error) { count := 0 - var firstErr error - walkDir(seedDir, func(path string) { - if firstErr != nil { - return - } + if err := walkDir(seedDir, func(path string) error { if !core.HasSuffix(path, ".json") { - return + return nil } rel := core.TrimPrefix(path, seedDir+"/") @@ -481,8 +529,7 @@ func importSeeds(db *DuckDB, seedDir string) (int, error) { readResult := localFs.Read(path) if !readResult.OK { - firstErr = core.E("store.importSeeds", core.Sprintf("read seed file %s", rel), readResult.Value.(error)) - return + return core.E("store.importSeeds", core.Sprintf("read seed file %s", rel), readResult.Value.(error)) } data := []byte(readResult.Value.(string)) @@ -491,8 +538,7 @@ func importSeeds(db *DuckDB, seedDir string) (int, error) { var raw any if r := core.JSONUnmarshal(data, &raw); !r.OK { err, _ := r.Value.(error) - firstErr = core.E("store.importSeeds", core.Sprintf("parse seed file %s", rel), err) - return + return core.E("store.importSeeds", core.Sprintf("parse seed file %s", rel), err) } switch v := raw.(type) { @@ -516,50 +562,53 @@ func importSeeds(db *DuckDB, seedDir string) (int, error) { if prompt == "" { prompt = strOrEmpty(seed, "question") } - if err := db.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`, + if err := db.exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`, rel, region, strOrEmpty(seed, "seed_id"), strOrEmpty(seed, "domain"), prompt, ); err != nil { - firstErr = core.E("store.importSeeds", "insert seed prompt", err) - return + return core.E("store.importSeeds", "insert seed prompt", err) } count++ case string: - if err := db.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`, + if err := db.exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`, rel, region, "", "", seed); err != nil { - firstErr = core.E("store.importSeeds", "insert seed string", err) - return + return core.E("store.importSeeds", "insert seed string", err) } count++ } } - }) - if firstErr != nil { - return count, firstErr + return nil + }); err != nil { + return count, err } return count, nil } // walkDir recursively visits all regular files under root, calling fn for each. -func walkDir(root string, fn func(path string)) { +func walkDir(root string, fn func(path string) error) error { r := localFs.List(root) if !r.OK { - return + return core.E("store.walkDir", core.Sprintf("list %s", root), r.Value.(error)) } entries, ok := r.Value.([]fs.DirEntry) if !ok { - return + return core.E("store.walkDir", core.Sprintf("list %s returned invalid entries", root), nil) } for _, entry := range entries { full := core.JoinPath(root, entry.Name()) if entry.IsDir() { - walkDir(full, fn) + if err := walkDir(full, fn); err != nil { + return err + } } else { - fn(full) + if err := fn(full); err != nil { + return err + } } } + return nil } // strOrEmpty extracts a string value from a map, returning an empty string if diff --git a/import_test.go b/import_test.go new file mode 100644 index 0000000..1690604 --- /dev/null +++ b/import_test.go @@ -0,0 +1,70 @@ +package store + +import ( + "testing" + + core "dappco.re/go/core" +) + +type importSessionStub struct { + inserts int +} + +func (session *importSessionStub) exec(string, ...any) error { + session.inserts++ + return nil +} + +func (session *importSessionStub) queryRowScan(string, any, ...any) error { + return nil +} + +func TestImport_ImportTrainingFile_Bad_MalformedJSONL(t *testing.T) { + path := testPath(t, "training.jsonl") + requireCoreWriteBytes(t, path, []byte("{\"messages\":[]}\n{broken\n")) + session := &importSessionStub{} + + count, err := importTrainingFile(session, path, "training", "train") + + assertError(t, err) + assertContainsString(t, err.Error(), "line 2") + assertEqual(t, 1, count) + assertEqual(t, 1, session.inserts) +} + +func TestImport_ImportBenchmarkFile_Bad_MalformedJSONL(t *testing.T) { + path := testPath(t, "benchmark.jsonl") + requireCoreWriteBytes(t, path, []byte("{\"id\":\"row-1\"}\n{broken\n")) + session := &importSessionStub{} + + count, err := importBenchmarkFile(session, path, "benchmark") + + assertError(t, err) + assertContainsString(t, err.Error(), "line 2") + assertEqual(t, 1, count) + assertEqual(t, 1, session.inserts) +} + +func TestImport_ImportBenchmarkQuestions_Bad_MalformedJSONL(t *testing.T) { + path := testPath(t, "questions.jsonl") + requireCoreWriteBytes(t, path, []byte("{\"id\":\"q-1\"}\n{broken\n")) + session := &importSessionStub{} + + count, err := importBenchmarkQuestions(session, path, "truthfulqa") + + assertError(t, err) + assertContainsString(t, err.Error(), "line 2") + assertEqual(t, 1, count) + assertEqual(t, 1, session.inserts) +} + +func TestImport_ImportSeeds_Bad_WalkFailure(t *testing.T) { + session := &importSessionStub{} + + count, err := importSeeds(session, core.JoinPath(t.TempDir(), "missing-seeds")) + + assertError(t, err) + assertContainsString(t, err.Error(), "store.walkDir") + assertEqual(t, 0, count) + assertEqual(t, 0, session.inserts) +} diff --git a/medium.go b/medium.go index 532d3ff..6f308dc 100644 --- a/medium.go +++ b/medium.go @@ -243,7 +243,10 @@ func importCSV(workspace *Workspace, kind, content string) error { } func exportJSON(workspace *Workspace, medium Medium, path string) error { - summary := workspace.Aggregate() + summary, err := workspace.aggregateFields() + if err != nil { + return core.E("store.Export", "aggregate workspace", err) + } content := core.JSONMarshalString(summary) if err := medium.Write(path, content); err != nil { return core.E("store.Export", "write json", err) diff --git a/medium_test.go b/medium_test.go index d8c444c..500c372 100644 --- a/medium_test.go +++ b/medium_test.go @@ -114,6 +114,19 @@ func (medium *renameFailMedium) Rename(string, string) error { return core.E("renameFailMedium.Rename", "forced rename failure", nil) } +type writeFailOnceMedium struct { + *memoryMedium + failures int +} + +func (medium *writeFailOnceMedium) Write(path, content string) error { + if medium.failures > 0 { + medium.failures-- + return core.E("writeFailOnceMedium.Write", "forced write failure", nil) + } + return medium.memoryMedium.Write(path, content) +} + func (medium *memoryMedium) List(path string) ([]fs.DirEntry, error) { return nil, nil } func (medium *memoryMedium) Stat(path string) (fs.FileInfo, error) { @@ -462,6 +475,30 @@ func TestMedium_Export_Bad_NilArguments(t *testing.T) { assertError(t, Export(workspace, medium, "")) } +func TestMedium_Export_Bad_JSONPropagatesWorkspaceFailure(t *testing.T) { + useWorkspaceStateDirectory(t) + + storeInstance, err := New(":memory:") + assertNoError(t, err) + defer func() { _ = storeInstance.Close() }() + + workspace, err := storeInstance.NewWorkspace("medium-export-json-closed") + assertNoError(t, err) + assertNoError(t, workspace.Put("like", map[string]any{"user": "@alice"})) + assertNoError(t, workspace.Close()) + + medium := newMemoryMedium() + assertNoError(t, medium.Write("report.json", `{"previous":true}`)) + + err = Export(workspace, medium, "report.json") + + assertError(t, err) + assertContainsString(t, err.Error(), "aggregate workspace") + content, readErr := medium.Read("report.json") + assertNoError(t, readErr) + assertEqual(t, `{"previous":true}`, content) +} + func TestMedium_Compact_Good_MediumRoutesArchive(t *testing.T) { useWorkspaceStateDirectory(t) useArchiveOutputDirectory(t) diff --git a/publish.go b/publish.go index fe8ec6d..799a4c4 100644 --- a/publish.go +++ b/publish.go @@ -97,11 +97,11 @@ func Publish(cfg PublishConfig, w io.Writer) error { return core.E("store.Publish", "HuggingFace token required (--token, HF_TOKEN env, or ~/.huggingface/token)", nil) } - files, err := collectUploadFiles(cfg.InputDir) + files, hasSplit, err := collectUploadFiles(cfg.InputDir) if err != nil { return err } - if len(files) == 0 { + if !hasSplit { return core.E("store.Publish", core.Sprintf("no Parquet files found in %s", cfg.InputDir), nil) } @@ -150,21 +150,33 @@ func resolveHFToken(explicit string) string { if env := core.Env("HF_TOKEN"); env != "" { return env } - home := core.Env("HOME") - if home == "" { - return "" + // Core populates DIR_HOME via os.UserHomeDir while this package keeps the + // repository-wide ban on direct os imports. + homes := []string{core.Env("DIR_HOME")} + if homeEnv := core.Env("HOME"); homeEnv != "" && homeEnv != homes[0] { + homes = append(homes, homeEnv) } - r := localFs.Read(core.JoinPath(home, ".huggingface", "token")) - if !r.OK { - return "" + for _, home := range homes { + if home == "" { + continue + } + r := localFs.Read(core.JoinPath(home, ".huggingface", "token")) + if !r.OK { + continue + } + token := core.Trim(r.Value.(string)) + if token != "" { + return token + } } - return core.Trim(r.Value.(string)) + return "" } // collectUploadFiles finds Parquet split files and an optional dataset card. -func collectUploadFiles(inputDir string) ([]uploadEntry, error) { +func collectUploadFiles(inputDir string) ([]uploadEntry, bool, error) { splits := []string{"train", "valid", "test"} var files []uploadEntry + hasSplit := false for _, split := range splits { path := core.JoinPath(inputDir, split+".parquet") @@ -172,6 +184,7 @@ func collectUploadFiles(inputDir string) ([]uploadEntry, error) { continue } files = append(files, uploadEntry{path, core.Sprintf("data/%s.parquet", split)}) + hasSplit = true } // Check for dataset card in parent directory. @@ -180,7 +193,7 @@ func collectUploadFiles(inputDir string) ([]uploadEntry, error) { files = append(files, uploadEntry{cardPath, "README.md"}) } - return files, nil + return files, hasSplit, nil } func ensureHFDatasetRepo(ctx context.Context, token, repoID string, public bool) error { diff --git a/publish_test.go b/publish_test.go index 3c0191a..f38b307 100644 --- a/publish_test.go +++ b/publish_test.go @@ -16,6 +16,18 @@ func TestPublish_Publish_Bad_EmptyRepository(t *testing.T) { assertContainsString(t, err.Error(), "repository is required") } +func TestPublish_Publish_Bad_DatasetCardWithoutParquetSplit(t *testing.T) { + inputDir := core.JoinPath(t.TempDir(), "data") + requireCoreOK(t, testFilesystem().EnsureDir(inputDir)) + requireCoreWriteBytes(t, core.JoinPath(inputDir, "..", "dataset_card.md"), []byte("# Dataset\n")) + + var output bytes.Buffer + err := Publish(PublishConfig{InputDir: inputDir, Repo: "snider/lem-training", DryRun: true}, &output) + + assertError(t, err) + assertContainsString(t, err.Error(), "no Parquet files found") +} + func TestPublish_ResolveHFToken_Good_UserHomeFallback(t *testing.T) { homeDirectory := t.TempDir() t.Setenv("HF_TOKEN", "") diff --git a/scope.go b/scope.go index ceac10c..61f00ff 100644 --- a/scope.go +++ b/scope.go @@ -397,15 +397,11 @@ func (scopedStore *ScopedStore) PurgeExpired() (int64, error) { } cutoffUnixMilli := time.Now().UnixMilli() - expiredEntries, err := listExpiredEntriesMatchingGroupPrefix(scopedStore.store.sqliteDatabase, scopedStore.namespacePrefix(), cutoffUnixMilli) - if err != nil { - return 0, core.E("store.ScopedStore.PurgeExpired", "list expired rows", err) - } - - removedRows, err := purgeExpiredMatchingGroupPrefix(scopedStore.store.sqliteDatabase, scopedStore.namespacePrefix(), cutoffUnixMilli) + expiredEntries, err := deleteExpiredEntriesMatchingGroupPrefix(scopedStore.store.sqliteDatabase, scopedStore.namespacePrefix(), cutoffUnixMilli) if err != nil { return 0, core.E("store.ScopedStore.PurgeExpired", "delete expired rows", err) } + removedRows := int64(len(expiredEntries)) if removedRows > 0 { for _, expiredEntry := range expiredEntries { scopedStore.store.notify(Event{ @@ -822,15 +818,11 @@ func (scopedStoreTransaction *ScopedStoreTransaction) PurgeExpired() (int64, err } cutoffUnixMilli := time.Now().UnixMilli() - expiredEntries, err := listExpiredEntriesMatchingGroupPrefix(scopedStoreTransaction.storeTransaction.sqliteTransaction, scopedStoreTransaction.scopedStore.namespacePrefix(), cutoffUnixMilli) - if err != nil { - return 0, core.E("store.ScopedStoreTransaction.PurgeExpired", "list expired rows", err) - } - - removedRows, err := purgeExpiredMatchingGroupPrefix(scopedStoreTransaction.storeTransaction.sqliteTransaction, scopedStoreTransaction.scopedStore.namespacePrefix(), cutoffUnixMilli) + expiredEntries, err := deleteExpiredEntriesMatchingGroupPrefix(scopedStoreTransaction.storeTransaction.sqliteTransaction, scopedStoreTransaction.scopedStore.namespacePrefix(), cutoffUnixMilli) if err != nil { return 0, core.E("store.ScopedStoreTransaction.PurgeExpired", "delete expired rows", err) } + removedRows := int64(len(expiredEntries)) if removedRows > 0 { for _, expiredEntry := range expiredEntries { scopedStoreTransaction.storeTransaction.recordEvent(Event{ diff --git a/store.go b/store.go index 05cdb12..ab9f316 100644 --- a/store.go +++ b/store.go @@ -154,7 +154,9 @@ type Store struct { journalConfiguration JournalConfiguration medium Medium lifecycleLock sync.Mutex + closeLock sync.Mutex isClosed bool + isClosing bool // Event dispatch state. watchers map[string][]chan Event @@ -182,7 +184,7 @@ func (storeInstance *Store) ensureReady(operation string) error { } storeInstance.lifecycleLock.Lock() - closed := storeInstance.isClosed + closed := storeInstance.isClosed || storeInstance.isClosing storeInstance.lifecycleLock.Unlock() if closed { return core.E(operation, "store is closed", nil) @@ -423,12 +425,15 @@ func (storeInstance *Store) Close() error { return nil } + storeInstance.closeLock.Lock() + defer storeInstance.closeLock.Unlock() + storeInstance.lifecycleLock.Lock() if storeInstance.isClosed { storeInstance.lifecycleLock.Unlock() return nil } - storeInstance.isClosed = true + storeInstance.isClosing = true storeInstance.lifecycleLock.Unlock() if storeInstance.cancelPurge != nil { @@ -470,6 +475,7 @@ func (storeInstance *Store) Close() error { storeInstance.sqliteDatabase = storeInstance.db } if storeInstance.sqliteDatabase == nil { + storeInstance.markClosed() return orphanCleanupErr } if err := storeInstance.sqliteDatabase.Close(); err != nil { @@ -478,12 +484,20 @@ func (storeInstance *Store) Close() error { if err := storeInstance.syncMediumBackedDatabase(); err != nil { return core.E("store.Close", "sync medium-backed database", err) } + storeInstance.markClosed() if orphanCleanupErr != nil { return core.E("store.Close", "close orphan workspaces", orphanCleanupErr) } return orphanCleanupErr } +func (storeInstance *Store) markClosed() { + storeInstance.lifecycleLock.Lock() + storeInstance.isClosed = true + storeInstance.isClosing = false + storeInstance.lifecycleLock.Unlock() +} + func (storeInstance *Store) syncMediumBackedDatabase() error { if storeInstance == nil || !storeInstance.mediumBacked || storeInstance.medium == nil { return nil @@ -965,15 +979,11 @@ func (storeInstance *Store) PurgeExpired() (int64, error) { } cutoffUnixMilli := time.Now().UnixMilli() - expiredEntries, err := listExpiredEntriesMatchingGroupPrefix(storeInstance.sqliteDatabase, "", cutoffUnixMilli) - if err != nil { - return 0, core.E("store.PurgeExpired", "list expired rows", err) - } - - removedRows, err := purgeExpiredMatchingGroupPrefix(storeInstance.sqliteDatabase, "", cutoffUnixMilli) + expiredEntries, err := deleteExpiredEntriesMatchingGroupPrefix(storeInstance.sqliteDatabase, "", cutoffUnixMilli) if err != nil { return 0, core.E("store.PurgeExpired", "delete expired rows", err) } + removedRows := int64(len(expiredEntries)) if removedRows > 0 { for _, expiredEntry := range expiredEntries { storeInstance.notify(Event{ @@ -1062,19 +1072,19 @@ type expiredEntryRef struct { key string } -func listExpiredEntriesMatchingGroupPrefix(database schemaDatabase, groupPrefix string, cutoffUnixMilli int64) ([]expiredEntryRef, error) { +func deleteExpiredEntriesMatchingGroupPrefix(database schemaDatabase, groupPrefix string, cutoffUnixMilli int64) ([]expiredEntryRef, error) { var ( rows *sql.Rows err error ) if groupPrefix == "" { rows, err = database.Query( - "SELECT "+entryGroupColumn+", "+entryKeyColumn+" FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ? ORDER BY "+entryGroupColumn+", "+entryKeyColumn, + "DELETE FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ? RETURNING "+entryGroupColumn+", "+entryKeyColumn, cutoffUnixMilli, ) } else { rows, err = database.Query( - "SELECT "+entryGroupColumn+", "+entryKeyColumn+" FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ? AND "+entryGroupColumn+" LIKE ? ESCAPE '^' ORDER BY "+entryGroupColumn+", "+entryKeyColumn, + "DELETE FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ? AND "+entryGroupColumn+" LIKE ? ESCAPE '^' RETURNING "+entryGroupColumn+", "+entryKeyColumn, cutoffUnixMilli, escapeLike(groupPrefix)+"%", ) } @@ -1097,35 +1107,6 @@ func listExpiredEntriesMatchingGroupPrefix(database schemaDatabase, groupPrefix return expiredEntries, nil } -// purgeExpiredMatchingGroupPrefix deletes expired rows globally when -// groupPrefix is empty, otherwise only rows whose group starts with the given -// prefix. -func purgeExpiredMatchingGroupPrefix(database schemaDatabase, groupPrefix string, cutoffUnixMilli int64) (int64, error) { - var ( - deleteResult sql.Result - err error - ) - if groupPrefix == "" { - deleteResult, err = database.Exec( - "DELETE FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ?", - cutoffUnixMilli, - ) - } else { - deleteResult, err = database.Exec( - "DELETE FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ? AND "+entryGroupColumn+" LIKE ? ESCAPE '^'", - cutoffUnixMilli, escapeLike(groupPrefix)+"%", - ) - } - if err != nil { - return 0, err - } - removedRows, rowsAffectedErr := deleteResult.RowsAffected() - if rowsAffectedErr != nil { - return 0, rowsAffectedErr - } - return removedRows, nil -} - type schemaDatabase interface { Exec(query string, args ...any) (sql.Result, error) QueryRow(query string, args ...any) *sql.Row diff --git a/store_test.go b/store_test.go index 614c79d..7bf116a 100644 --- a/store_test.go +++ b/store_test.go @@ -1082,6 +1082,27 @@ func TestStore_Close_Bad_DriverCloseError(t *testing.T) { assertContainsString(t, err.Error(), "store.Close") } +func TestStore_Close_Bad_MediumSyncFailureRetryable(t *testing.T) { + useWorkspaceStateDirectory(t) + + medium := &writeFailOnceMedium{memoryMedium: newMemoryMedium(), failures: 1} + storeInstance, err := New("retryable-close.db", WithMedium(medium)) + assertNoError(t, err) + assertNoError(t, storeInstance.Set("g", "k", "v")) + + err = storeInstance.Close() + assertError(t, err) + assertContainsString(t, err.Error(), "sync medium-backed database") + assertFalse(t, storeInstance.IsClosed()) + + _, err = storeInstance.Get("g", "k") + assertError(t, err) + + assertNoError(t, storeInstance.Close()) + assertTrue(t, storeInstance.IsClosed()) + assertTrue(t, medium.Exists("retryable-close.db")) +} + // --------------------------------------------------------------------------- // Test helpers // --------------------------------------------------------------------------- @@ -1600,6 +1621,32 @@ func TestStore_PurgeExpired_Good(t *testing.T) { assertEqualf(t, 1, count, "only non-expiring key should remain") } +func TestStore_PurgeExpired_Good_NotifiesDeletedRows(t *testing.T) { + storeInstance, _ := New(":memory:") + defer func() { _ = storeInstance.Close() }() + + assertNoError(t, storeInstance.SetWithTTL("g", "expired", "1", 1*time.Millisecond)) + assertNoError(t, storeInstance.SetWithTTL("g", "live", "2", time.Hour)) + time.Sleep(5 * time.Millisecond) + + events := storeInstance.Watch("*") + defer storeInstance.Unwatch("*", events) + + removed, err := storeInstance.PurgeExpired() + assertNoError(t, err) + assertEqual(t, int64(1), removed) + + event := <-events + assertEqual(t, EventDelete, event.Type) + assertEqual(t, "g", event.Group) + assertEqual(t, "expired", event.Key) + select { + case extraEvent := <-events: + t.Fatalf("unexpected extra purge event: %#v", extraEvent) + default: + } +} + func TestStore_PurgeExpired_Good_NoneExpired(t *testing.T) { storeInstance, _ := New(":memory:") defer func() { _ = storeInstance.Close() }() diff --git a/transaction.go b/transaction.go index 8f52178..2bfef73 100644 --- a/transaction.go +++ b/transaction.go @@ -512,15 +512,11 @@ func (storeTransaction *StoreTransaction) PurgeExpired() (int64, error) { } cutoffUnixMilli := time.Now().UnixMilli() - expiredEntries, err := listExpiredEntriesMatchingGroupPrefix(storeTransaction.sqliteTransaction, "", cutoffUnixMilli) - if err != nil { - return 0, core.E("store.Transaction.PurgeExpired", "list expired rows", err) - } - - removedRows, err := purgeExpiredMatchingGroupPrefix(storeTransaction.sqliteTransaction, "", cutoffUnixMilli) + expiredEntries, err := deleteExpiredEntriesMatchingGroupPrefix(storeTransaction.sqliteTransaction, "", cutoffUnixMilli) if err != nil { return 0, core.E("store.Transaction.PurgeExpired", "delete expired rows", err) } + removedRows := int64(len(expiredEntries)) if removedRows > 0 { for _, expiredEntry := range expiredEntries { storeTransaction.recordEvent(Event{