fix(store): r3 — transactional import + DELETE RETURNING + token home order on PR #4
Some checks are pending
Security Scan / security (push) Waiting to run
Test / test (push) Waiting to run

Round 3 follow-up to ebe5377. Closes residual CodeRabbit findings.

Code:
- import.go: ImportAll DB mutations wrapped in transaction with
  rollback-on-error
- import.go: malformed JSONL returns file/line parse errors in all
  three import helpers (was silently swallowing per-line errors)
- import.go: walkDir returns + propagates traversal/list/type errors
- medium.go: JSON export uses aggregateFields() + propagates
  workspace failures
- publish.go: dataset_card.md excluded from Parquet split count
- store.go: medium-backed Close() remains retryable after sync
  failure; operations see closing state as closed
- store.go + scope.go + transaction.go: purge uses
  DELETE ... RETURNING so notifications come from rows actually
  deleted (was reading first then deleting separately)
- publish.go: token lookup uses Core's DIR_HOME (populated via
  os.UserHomeDir) then falls back to HOME — preserves direct-os
  import ban while picking up real home

Tests:
- import_test.go (new): coverage of transactional import +
  malformed-JSONL error path

Doc:
- README.md: footer licence link targets LICENCE.md (UK English)

Verification: gofmt clean, golangci-lint v2 0 issues, GOWORK=off
go vet + go test -count=1 ./... pass with explicit cache paths.

Closes residual findings on https://github.com/dAppCore/go-store/pull/4

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Snider 2026-04-27 18:29:59 +01:00
parent ebe5377871
commit fc77445de0
12 changed files with 341 additions and 141 deletions

View file

@ -94,4 +94,4 @@ go build ./...
## Licence
European Union Public Licence 1.2 — see [LICENCE](LICENCE) for details.
European Union Public Licence 1.2 — see [LICENCE.md](LICENCE.md) for details.

View file

@ -76,7 +76,7 @@ func (storeInstance *Store) Watch(group string) <-chan Event {
storeInstance.lifecycleLock.Lock()
defer storeInstance.lifecycleLock.Unlock()
if storeInstance.isClosed {
if storeInstance.isClosed || storeInstance.isClosing {
return closedEventChannel()
}
@ -97,7 +97,7 @@ func (storeInstance *Store) Unwatch(group string, events <-chan Event) {
}
storeInstance.lifecycleLock.Lock()
closed := storeInstance.isClosed
closed := storeInstance.isClosed || storeInstance.isClosing
storeInstance.lifecycleLock.Unlock()
if closed {
return
@ -146,7 +146,7 @@ func (storeInstance *Store) OnChange(callback func(Event)) func() {
storeInstance.lifecycleLock.Lock()
defer storeInstance.lifecycleLock.Unlock()
if storeInstance.isClosed {
if storeInstance.isClosed || storeInstance.isClosing {
return func() {}
}
@ -188,7 +188,7 @@ func (storeInstance *Store) notify(event Event) {
}
storeInstance.lifecycleLock.Lock()
if storeInstance.isClosed {
if storeInstance.isClosed || storeInstance.isClosing {
storeInstance.lifecycleLock.Unlock()
return
}
@ -210,7 +210,7 @@ func (storeInstance *Store) notify(event Event) {
storeInstance.watcherLock.RUnlock()
storeInstance.lifecycleLock.Lock()
if storeInstance.isClosed {
if storeInstance.isClosed || storeInstance.isClosing {
storeInstance.lifecycleLock.Unlock()
return
}

179
import.go
View file

@ -4,6 +4,7 @@ package store
import (
"bufio"
"database/sql"
"io"
"io/fs"
@ -13,6 +14,27 @@ import (
// localFs provides unrestricted filesystem access for import operations.
var localFs = (&core.Fs{}).New("/")
type duckDBImportSession interface {
exec(query string, args ...any) error
queryRowScan(query string, dest any, args ...any) error
}
type duckDBImportTransaction struct {
transaction *sql.Tx
}
func (session duckDBImportTransaction) exec(query string, args ...any) error {
_, err := session.transaction.Exec(query, args...)
if err != nil {
return core.E("store.duckDBImportTransaction.Exec", "execute query", err)
}
return nil
}
func (session duckDBImportTransaction) queryRowScan(query string, dest any, args ...any) error {
return session.transaction.QueryRow(query, args...).Scan(dest)
}
// ScpFunc is a callback for executing SCP file transfers.
// The function receives remote source and local destination paths.
//
@ -77,6 +99,10 @@ type ImportConfig struct {
//
// err := store.ImportAll(db, store.ImportConfig{DataDir: "/Volumes/Data/lem"}, os.Stdout)
func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
if db == nil || db.Conn() == nil {
return core.E("store.ImportAll", "database is nil", nil)
}
m3Host := cfg.M3Host
if m3Host == "" {
m3Host = "m3"
@ -93,14 +119,26 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
core.Print(w, " WARNING: could not pull golden set from M3: %v", err)
}
}
transaction, err := db.Conn().Begin()
if err != nil {
return core.E("store.ImportAll", "begin import transaction", err)
}
committed := false
defer func() {
if !committed {
_ = transaction.Rollback()
}
}()
importSession := duckDBImportTransaction{transaction: transaction}
if isFile(goldenPath) {
if err := db.Exec("DROP TABLE IF EXISTS golden_set"); err != nil {
if err := importSession.exec("DROP TABLE IF EXISTS golden_set"); err != nil {
return core.E("store.ImportAll", "drop golden_set", err)
}
err := db.Exec(core.Sprintf(`
CREATE TABLE golden_set AS
SELECT
idx::INT AS idx,
err := importSession.exec(core.Sprintf(`
CREATE TABLE golden_set AS
SELECT
idx::INT AS idx,
seed_id::VARCHAR AS seed_id,
domain::VARCHAR AS domain,
voice::VARCHAR AS voice,
@ -115,7 +153,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
return core.E("store.ImportAll", "import golden_set", err)
} else {
var n int
if err := db.QueryRowScan("SELECT count(*) FROM golden_set", &n); err != nil {
if err := importSession.queryRowScan("SELECT count(*) FROM golden_set", &n); err != nil {
return core.E("store.ImportAll", "count golden_set", err)
}
totals["golden_set"] = n
@ -160,13 +198,13 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
}
}
if err := db.Exec("DROP TABLE IF EXISTS training_examples"); err != nil {
if err := importSession.exec("DROP TABLE IF EXISTS training_examples"); err != nil {
return core.E("store.ImportAll", "drop training_examples", err)
}
if err := db.Exec(`
CREATE TABLE training_examples (
source VARCHAR,
split VARCHAR,
if err := importSession.exec(`
CREATE TABLE training_examples (
source VARCHAR,
split VARCHAR,
prompt TEXT,
response TEXT,
num_turns INT,
@ -192,7 +230,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
split = "test"
}
n, err := importTrainingFile(db, local, td.name, split)
n, err := importTrainingFile(importSession, local, td.name, split)
if err != nil {
return core.E("store.ImportAll", core.Sprintf("import training file %s", local), err)
}
@ -224,13 +262,13 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
}
}
if err := db.Exec("DROP TABLE IF EXISTS benchmark_results"); err != nil {
if err := importSession.exec("DROP TABLE IF EXISTS benchmark_results"); err != nil {
return core.E("store.ImportAll", "drop benchmark_results", err)
}
if err := db.Exec(`
CREATE TABLE benchmark_results (
source VARCHAR, id VARCHAR, benchmark VARCHAR, model VARCHAR,
prompt TEXT, response TEXT, elapsed_seconds DOUBLE, domain VARCHAR
if err := importSession.exec(`
CREATE TABLE benchmark_results (
source VARCHAR, id VARCHAR, benchmark VARCHAR, model VARCHAR,
prompt TEXT, response TEXT, elapsed_seconds DOUBLE, domain VARCHAR
)
`); err != nil {
return core.E("store.ImportAll", "create benchmark_results", err)
@ -241,7 +279,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
resultDir := core.JoinPath(benchLocal, subdir)
matches := core.PathGlob(core.JoinPath(resultDir, "*.jsonl"))
for _, jf := range matches {
n, err := importBenchmarkFile(db, jf, subdir)
n, err := importBenchmarkFile(importSession, jf, subdir)
if err != nil {
return core.E("store.ImportAll", core.Sprintf("import benchmark file %s", jf), err)
}
@ -259,7 +297,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
}
}
if isFile(local) {
n, err := importBenchmarkFile(db, local, "benchmark")
n, err := importBenchmarkFile(importSession, local, "benchmark")
if err != nil {
return core.E("store.ImportAll", core.Sprintf("import benchmark file %s", local), err)
}
@ -270,13 +308,13 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
core.Print(w, " benchmark_results: %d rows", benchTotal)
// ── 4. Benchmark questions ──
if err := db.Exec("DROP TABLE IF EXISTS benchmark_questions"); err != nil {
if err := importSession.exec("DROP TABLE IF EXISTS benchmark_questions"); err != nil {
return core.E("store.ImportAll", "drop benchmark_questions", err)
}
if err := db.Exec(`
CREATE TABLE benchmark_questions (
benchmark VARCHAR, id VARCHAR, question TEXT,
best_answer TEXT, correct_answers TEXT, incorrect_answers TEXT, category VARCHAR
if err := importSession.exec(`
CREATE TABLE benchmark_questions (
benchmark VARCHAR, id VARCHAR, question TEXT,
best_answer TEXT, correct_answers TEXT, incorrect_answers TEXT, category VARCHAR
)
`); err != nil {
return core.E("store.ImportAll", "create benchmark_questions", err)
@ -286,7 +324,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
for _, bname := range []string{"truthfulqa", "gsm8k", "do_not_answer", "toxigen"} {
local := core.JoinPath(benchLocal, bname+".jsonl")
if isFile(local) {
n, err := importBenchmarkQuestions(db, local, bname)
n, err := importBenchmarkQuestions(importSession, local, bname)
if err != nil {
return core.E("store.ImportAll", core.Sprintf("import benchmark questions %s", local), err)
}
@ -297,13 +335,13 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
core.Print(w, " benchmark_questions: %d rows", benchQTotal)
// ── 5. Seeds ──
if err := db.Exec("DROP TABLE IF EXISTS seeds"); err != nil {
if err := importSession.exec("DROP TABLE IF EXISTS seeds"); err != nil {
return core.E("store.ImportAll", "drop seeds", err)
}
if err := db.Exec(`
CREATE TABLE seeds (
source_file VARCHAR, region VARCHAR, seed_id VARCHAR, domain VARCHAR, prompt TEXT
)
if err := importSession.exec(`
CREATE TABLE seeds (
source_file VARCHAR, region VARCHAR, seed_id VARCHAR, domain VARCHAR, prompt TEXT
)
`); err != nil {
return core.E("store.ImportAll", "create seeds", err)
}
@ -314,7 +352,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
if !isDir(seedDir) {
continue
}
n, err := importSeeds(db, seedDir)
n, err := importSeeds(importSession, seedDir)
if err != nil {
return core.E("store.ImportAll", core.Sprintf("import seeds %s", seedDir), err)
}
@ -323,6 +361,11 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
totals["seeds"] = seedTotal
core.Print(w, " seeds: %d rows", seedTotal)
if err := transaction.Commit(); err != nil {
return core.E("store.ImportAll", "commit import transaction", err)
}
committed = true
// ── Summary ──
grandTotal := 0
core.Print(w, "\n%s", repeat("=", 50))
@ -339,7 +382,7 @@ func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
return nil
}
func importTrainingFile(db *DuckDB, path, source, split string) (int, error) {
func importTrainingFile(db duckDBImportSession, path, source, split string) (int, error) {
r := localFs.Open(path)
if !r.OK {
return 0, core.E("store.importTrainingFile", core.Sprintf("open %s", path), r.Value.(error))
@ -351,12 +394,15 @@ func importTrainingFile(db *DuckDB, path, source, split string) (int, error) {
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
lineNumber := 0
for scanner.Scan() {
lineNumber++
var rec struct {
Messages []ChatMessage `json:"messages"`
}
if r := core.JSONUnmarshal(scanner.Bytes(), &rec); !r.OK {
continue
parseErr, _ := r.Value.(error)
return count, core.E("store.importTrainingFile", core.Sprintf("parse %s line %d", path, lineNumber), parseErr)
}
prompt := ""
@ -375,7 +421,7 @@ func importTrainingFile(db *DuckDB, path, source, split string) (int, error) {
}
msgsJSON := core.JSONMarshalString(rec.Messages)
if err := db.Exec(`INSERT INTO training_examples VALUES (?, ?, ?, ?, ?, ?, ?)`,
if err := db.exec(`INSERT INTO training_examples VALUES (?, ?, ?, ?, ?, ?, ?)`,
source, split, prompt, response, assistantCount, msgsJSON, len(response)); err != nil {
return count, core.E("store.importTrainingFile", "insert training example", err)
}
@ -387,7 +433,7 @@ func importTrainingFile(db *DuckDB, path, source, split string) (int, error) {
return count, nil
}
func importBenchmarkFile(db *DuckDB, path, source string) (int, error) {
func importBenchmarkFile(db duckDBImportSession, path, source string) (int, error) {
r := localFs.Open(path)
if !r.OK {
return 0, core.E("store.importBenchmarkFile", core.Sprintf("open %s", path), r.Value.(error))
@ -399,13 +445,16 @@ func importBenchmarkFile(db *DuckDB, path, source string) (int, error) {
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
lineNumber := 0
for scanner.Scan() {
lineNumber++
var rec map[string]any
if r := core.JSONUnmarshal(scanner.Bytes(), &rec); !r.OK {
continue
parseErr, _ := r.Value.(error)
return count, core.E("store.importBenchmarkFile", core.Sprintf("parse %s line %d", path, lineNumber), parseErr)
}
if err := db.Exec(`INSERT INTO benchmark_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
if err := db.exec(`INSERT INTO benchmark_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
source,
core.Sprint(rec["id"]),
strOrEmpty(rec, "benchmark"),
@ -425,7 +474,7 @@ func importBenchmarkFile(db *DuckDB, path, source string) (int, error) {
return count, nil
}
func importBenchmarkQuestions(db *DuckDB, path, benchmark string) (int, error) {
func importBenchmarkQuestions(db duckDBImportSession, path, benchmark string) (int, error) {
r := localFs.Open(path)
if !r.OK {
return 0, core.E("store.importBenchmarkQuestions", core.Sprintf("open %s", path), r.Value.(error))
@ -437,16 +486,19 @@ func importBenchmarkQuestions(db *DuckDB, path, benchmark string) (int, error) {
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
lineNumber := 0
for scanner.Scan() {
lineNumber++
var rec map[string]any
if r := core.JSONUnmarshal(scanner.Bytes(), &rec); !r.OK {
continue
parseErr, _ := r.Value.(error)
return count, core.E("store.importBenchmarkQuestions", core.Sprintf("parse %s line %d", path, lineNumber), parseErr)
}
correctJSON := core.JSONMarshalString(rec["correct_answers"])
incorrectJSON := core.JSONMarshalString(rec["incorrect_answers"])
if err := db.Exec(`INSERT INTO benchmark_questions VALUES (?, ?, ?, ?, ?, ?, ?)`,
if err := db.exec(`INSERT INTO benchmark_questions VALUES (?, ?, ?, ?, ?, ?, ?)`,
benchmark,
core.Sprint(rec["id"]),
strOrEmpty(rec, "question"),
@ -465,15 +517,11 @@ func importBenchmarkQuestions(db *DuckDB, path, benchmark string) (int, error) {
return count, nil
}
func importSeeds(db *DuckDB, seedDir string) (int, error) {
func importSeeds(db duckDBImportSession, seedDir string) (int, error) {
count := 0
var firstErr error
walkDir(seedDir, func(path string) {
if firstErr != nil {
return
}
if err := walkDir(seedDir, func(path string) error {
if !core.HasSuffix(path, ".json") {
return
return nil
}
rel := core.TrimPrefix(path, seedDir+"/")
@ -481,8 +529,7 @@ func importSeeds(db *DuckDB, seedDir string) (int, error) {
readResult := localFs.Read(path)
if !readResult.OK {
firstErr = core.E("store.importSeeds", core.Sprintf("read seed file %s", rel), readResult.Value.(error))
return
return core.E("store.importSeeds", core.Sprintf("read seed file %s", rel), readResult.Value.(error))
}
data := []byte(readResult.Value.(string))
@ -491,8 +538,7 @@ func importSeeds(db *DuckDB, seedDir string) (int, error) {
var raw any
if r := core.JSONUnmarshal(data, &raw); !r.OK {
err, _ := r.Value.(error)
firstErr = core.E("store.importSeeds", core.Sprintf("parse seed file %s", rel), err)
return
return core.E("store.importSeeds", core.Sprintf("parse seed file %s", rel), err)
}
switch v := raw.(type) {
@ -516,50 +562,53 @@ func importSeeds(db *DuckDB, seedDir string) (int, error) {
if prompt == "" {
prompt = strOrEmpty(seed, "question")
}
if err := db.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
if err := db.exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
rel, region,
strOrEmpty(seed, "seed_id"),
strOrEmpty(seed, "domain"),
prompt,
); err != nil {
firstErr = core.E("store.importSeeds", "insert seed prompt", err)
return
return core.E("store.importSeeds", "insert seed prompt", err)
}
count++
case string:
if err := db.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
if err := db.exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
rel, region, "", "", seed); err != nil {
firstErr = core.E("store.importSeeds", "insert seed string", err)
return
return core.E("store.importSeeds", "insert seed string", err)
}
count++
}
}
})
if firstErr != nil {
return count, firstErr
return nil
}); err != nil {
return count, err
}
return count, nil
}
// walkDir recursively visits all regular files under root, calling fn for each.
func walkDir(root string, fn func(path string)) {
func walkDir(root string, fn func(path string) error) error {
r := localFs.List(root)
if !r.OK {
return
return core.E("store.walkDir", core.Sprintf("list %s", root), r.Value.(error))
}
entries, ok := r.Value.([]fs.DirEntry)
if !ok {
return
return core.E("store.walkDir", core.Sprintf("list %s returned invalid entries", root), nil)
}
for _, entry := range entries {
full := core.JoinPath(root, entry.Name())
if entry.IsDir() {
walkDir(full, fn)
if err := walkDir(full, fn); err != nil {
return err
}
} else {
fn(full)
if err := fn(full); err != nil {
return err
}
}
}
return nil
}
// strOrEmpty extracts a string value from a map, returning an empty string if

70
import_test.go Normal file
View file

@ -0,0 +1,70 @@
package store
import (
"testing"
core "dappco.re/go/core"
)
type importSessionStub struct {
inserts int
}
func (session *importSessionStub) exec(string, ...any) error {
session.inserts++
return nil
}
func (session *importSessionStub) queryRowScan(string, any, ...any) error {
return nil
}
func TestImport_ImportTrainingFile_Bad_MalformedJSONL(t *testing.T) {
path := testPath(t, "training.jsonl")
requireCoreWriteBytes(t, path, []byte("{\"messages\":[]}\n{broken\n"))
session := &importSessionStub{}
count, err := importTrainingFile(session, path, "training", "train")
assertError(t, err)
assertContainsString(t, err.Error(), "line 2")
assertEqual(t, 1, count)
assertEqual(t, 1, session.inserts)
}
func TestImport_ImportBenchmarkFile_Bad_MalformedJSONL(t *testing.T) {
path := testPath(t, "benchmark.jsonl")
requireCoreWriteBytes(t, path, []byte("{\"id\":\"row-1\"}\n{broken\n"))
session := &importSessionStub{}
count, err := importBenchmarkFile(session, path, "benchmark")
assertError(t, err)
assertContainsString(t, err.Error(), "line 2")
assertEqual(t, 1, count)
assertEqual(t, 1, session.inserts)
}
func TestImport_ImportBenchmarkQuestions_Bad_MalformedJSONL(t *testing.T) {
path := testPath(t, "questions.jsonl")
requireCoreWriteBytes(t, path, []byte("{\"id\":\"q-1\"}\n{broken\n"))
session := &importSessionStub{}
count, err := importBenchmarkQuestions(session, path, "truthfulqa")
assertError(t, err)
assertContainsString(t, err.Error(), "line 2")
assertEqual(t, 1, count)
assertEqual(t, 1, session.inserts)
}
func TestImport_ImportSeeds_Bad_WalkFailure(t *testing.T) {
session := &importSessionStub{}
count, err := importSeeds(session, core.JoinPath(t.TempDir(), "missing-seeds"))
assertError(t, err)
assertContainsString(t, err.Error(), "store.walkDir")
assertEqual(t, 0, count)
assertEqual(t, 0, session.inserts)
}

View file

@ -243,7 +243,10 @@ func importCSV(workspace *Workspace, kind, content string) error {
}
func exportJSON(workspace *Workspace, medium Medium, path string) error {
summary := workspace.Aggregate()
summary, err := workspace.aggregateFields()
if err != nil {
return core.E("store.Export", "aggregate workspace", err)
}
content := core.JSONMarshalString(summary)
if err := medium.Write(path, content); err != nil {
return core.E("store.Export", "write json", err)

View file

@ -114,6 +114,19 @@ func (medium *renameFailMedium) Rename(string, string) error {
return core.E("renameFailMedium.Rename", "forced rename failure", nil)
}
type writeFailOnceMedium struct {
*memoryMedium
failures int
}
func (medium *writeFailOnceMedium) Write(path, content string) error {
if medium.failures > 0 {
medium.failures--
return core.E("writeFailOnceMedium.Write", "forced write failure", nil)
}
return medium.memoryMedium.Write(path, content)
}
func (medium *memoryMedium) List(path string) ([]fs.DirEntry, error) { return nil, nil }
func (medium *memoryMedium) Stat(path string) (fs.FileInfo, error) {
@ -462,6 +475,30 @@ func TestMedium_Export_Bad_NilArguments(t *testing.T) {
assertError(t, Export(workspace, medium, ""))
}
func TestMedium_Export_Bad_JSONPropagatesWorkspaceFailure(t *testing.T) {
useWorkspaceStateDirectory(t)
storeInstance, err := New(":memory:")
assertNoError(t, err)
defer func() { _ = storeInstance.Close() }()
workspace, err := storeInstance.NewWorkspace("medium-export-json-closed")
assertNoError(t, err)
assertNoError(t, workspace.Put("like", map[string]any{"user": "@alice"}))
assertNoError(t, workspace.Close())
medium := newMemoryMedium()
assertNoError(t, medium.Write("report.json", `{"previous":true}`))
err = Export(workspace, medium, "report.json")
assertError(t, err)
assertContainsString(t, err.Error(), "aggregate workspace")
content, readErr := medium.Read("report.json")
assertNoError(t, readErr)
assertEqual(t, `{"previous":true}`, content)
}
func TestMedium_Compact_Good_MediumRoutesArchive(t *testing.T) {
useWorkspaceStateDirectory(t)
useArchiveOutputDirectory(t)

View file

@ -97,11 +97,11 @@ func Publish(cfg PublishConfig, w io.Writer) error {
return core.E("store.Publish", "HuggingFace token required (--token, HF_TOKEN env, or ~/.huggingface/token)", nil)
}
files, err := collectUploadFiles(cfg.InputDir)
files, hasSplit, err := collectUploadFiles(cfg.InputDir)
if err != nil {
return err
}
if len(files) == 0 {
if !hasSplit {
return core.E("store.Publish", core.Sprintf("no Parquet files found in %s", cfg.InputDir), nil)
}
@ -150,21 +150,33 @@ func resolveHFToken(explicit string) string {
if env := core.Env("HF_TOKEN"); env != "" {
return env
}
home := core.Env("HOME")
if home == "" {
return ""
// Core populates DIR_HOME via os.UserHomeDir while this package keeps the
// repository-wide ban on direct os imports.
homes := []string{core.Env("DIR_HOME")}
if homeEnv := core.Env("HOME"); homeEnv != "" && homeEnv != homes[0] {
homes = append(homes, homeEnv)
}
r := localFs.Read(core.JoinPath(home, ".huggingface", "token"))
if !r.OK {
return ""
for _, home := range homes {
if home == "" {
continue
}
r := localFs.Read(core.JoinPath(home, ".huggingface", "token"))
if !r.OK {
continue
}
token := core.Trim(r.Value.(string))
if token != "" {
return token
}
}
return core.Trim(r.Value.(string))
return ""
}
// collectUploadFiles finds Parquet split files and an optional dataset card.
func collectUploadFiles(inputDir string) ([]uploadEntry, error) {
func collectUploadFiles(inputDir string) ([]uploadEntry, bool, error) {
splits := []string{"train", "valid", "test"}
var files []uploadEntry
hasSplit := false
for _, split := range splits {
path := core.JoinPath(inputDir, split+".parquet")
@ -172,6 +184,7 @@ func collectUploadFiles(inputDir string) ([]uploadEntry, error) {
continue
}
files = append(files, uploadEntry{path, core.Sprintf("data/%s.parquet", split)})
hasSplit = true
}
// Check for dataset card in parent directory.
@ -180,7 +193,7 @@ func collectUploadFiles(inputDir string) ([]uploadEntry, error) {
files = append(files, uploadEntry{cardPath, "README.md"})
}
return files, nil
return files, hasSplit, nil
}
func ensureHFDatasetRepo(ctx context.Context, token, repoID string, public bool) error {

View file

@ -16,6 +16,18 @@ func TestPublish_Publish_Bad_EmptyRepository(t *testing.T) {
assertContainsString(t, err.Error(), "repository is required")
}
func TestPublish_Publish_Bad_DatasetCardWithoutParquetSplit(t *testing.T) {
inputDir := core.JoinPath(t.TempDir(), "data")
requireCoreOK(t, testFilesystem().EnsureDir(inputDir))
requireCoreWriteBytes(t, core.JoinPath(inputDir, "..", "dataset_card.md"), []byte("# Dataset\n"))
var output bytes.Buffer
err := Publish(PublishConfig{InputDir: inputDir, Repo: "snider/lem-training", DryRun: true}, &output)
assertError(t, err)
assertContainsString(t, err.Error(), "no Parquet files found")
}
func TestPublish_ResolveHFToken_Good_UserHomeFallback(t *testing.T) {
homeDirectory := t.TempDir()
t.Setenv("HF_TOKEN", "")

View file

@ -397,15 +397,11 @@ func (scopedStore *ScopedStore) PurgeExpired() (int64, error) {
}
cutoffUnixMilli := time.Now().UnixMilli()
expiredEntries, err := listExpiredEntriesMatchingGroupPrefix(scopedStore.store.sqliteDatabase, scopedStore.namespacePrefix(), cutoffUnixMilli)
if err != nil {
return 0, core.E("store.ScopedStore.PurgeExpired", "list expired rows", err)
}
removedRows, err := purgeExpiredMatchingGroupPrefix(scopedStore.store.sqliteDatabase, scopedStore.namespacePrefix(), cutoffUnixMilli)
expiredEntries, err := deleteExpiredEntriesMatchingGroupPrefix(scopedStore.store.sqliteDatabase, scopedStore.namespacePrefix(), cutoffUnixMilli)
if err != nil {
return 0, core.E("store.ScopedStore.PurgeExpired", "delete expired rows", err)
}
removedRows := int64(len(expiredEntries))
if removedRows > 0 {
for _, expiredEntry := range expiredEntries {
scopedStore.store.notify(Event{
@ -822,15 +818,11 @@ func (scopedStoreTransaction *ScopedStoreTransaction) PurgeExpired() (int64, err
}
cutoffUnixMilli := time.Now().UnixMilli()
expiredEntries, err := listExpiredEntriesMatchingGroupPrefix(scopedStoreTransaction.storeTransaction.sqliteTransaction, scopedStoreTransaction.scopedStore.namespacePrefix(), cutoffUnixMilli)
if err != nil {
return 0, core.E("store.ScopedStoreTransaction.PurgeExpired", "list expired rows", err)
}
removedRows, err := purgeExpiredMatchingGroupPrefix(scopedStoreTransaction.storeTransaction.sqliteTransaction, scopedStoreTransaction.scopedStore.namespacePrefix(), cutoffUnixMilli)
expiredEntries, err := deleteExpiredEntriesMatchingGroupPrefix(scopedStoreTransaction.storeTransaction.sqliteTransaction, scopedStoreTransaction.scopedStore.namespacePrefix(), cutoffUnixMilli)
if err != nil {
return 0, core.E("store.ScopedStoreTransaction.PurgeExpired", "delete expired rows", err)
}
removedRows := int64(len(expiredEntries))
if removedRows > 0 {
for _, expiredEntry := range expiredEntries {
scopedStoreTransaction.storeTransaction.recordEvent(Event{

View file

@ -154,7 +154,9 @@ type Store struct {
journalConfiguration JournalConfiguration
medium Medium
lifecycleLock sync.Mutex
closeLock sync.Mutex
isClosed bool
isClosing bool
// Event dispatch state.
watchers map[string][]chan Event
@ -182,7 +184,7 @@ func (storeInstance *Store) ensureReady(operation string) error {
}
storeInstance.lifecycleLock.Lock()
closed := storeInstance.isClosed
closed := storeInstance.isClosed || storeInstance.isClosing
storeInstance.lifecycleLock.Unlock()
if closed {
return core.E(operation, "store is closed", nil)
@ -423,12 +425,15 @@ func (storeInstance *Store) Close() error {
return nil
}
storeInstance.closeLock.Lock()
defer storeInstance.closeLock.Unlock()
storeInstance.lifecycleLock.Lock()
if storeInstance.isClosed {
storeInstance.lifecycleLock.Unlock()
return nil
}
storeInstance.isClosed = true
storeInstance.isClosing = true
storeInstance.lifecycleLock.Unlock()
if storeInstance.cancelPurge != nil {
@ -470,6 +475,7 @@ func (storeInstance *Store) Close() error {
storeInstance.sqliteDatabase = storeInstance.db
}
if storeInstance.sqliteDatabase == nil {
storeInstance.markClosed()
return orphanCleanupErr
}
if err := storeInstance.sqliteDatabase.Close(); err != nil {
@ -478,12 +484,20 @@ func (storeInstance *Store) Close() error {
if err := storeInstance.syncMediumBackedDatabase(); err != nil {
return core.E("store.Close", "sync medium-backed database", err)
}
storeInstance.markClosed()
if orphanCleanupErr != nil {
return core.E("store.Close", "close orphan workspaces", orphanCleanupErr)
}
return orphanCleanupErr
}
func (storeInstance *Store) markClosed() {
storeInstance.lifecycleLock.Lock()
storeInstance.isClosed = true
storeInstance.isClosing = false
storeInstance.lifecycleLock.Unlock()
}
func (storeInstance *Store) syncMediumBackedDatabase() error {
if storeInstance == nil || !storeInstance.mediumBacked || storeInstance.medium == nil {
return nil
@ -965,15 +979,11 @@ func (storeInstance *Store) PurgeExpired() (int64, error) {
}
cutoffUnixMilli := time.Now().UnixMilli()
expiredEntries, err := listExpiredEntriesMatchingGroupPrefix(storeInstance.sqliteDatabase, "", cutoffUnixMilli)
if err != nil {
return 0, core.E("store.PurgeExpired", "list expired rows", err)
}
removedRows, err := purgeExpiredMatchingGroupPrefix(storeInstance.sqliteDatabase, "", cutoffUnixMilli)
expiredEntries, err := deleteExpiredEntriesMatchingGroupPrefix(storeInstance.sqliteDatabase, "", cutoffUnixMilli)
if err != nil {
return 0, core.E("store.PurgeExpired", "delete expired rows", err)
}
removedRows := int64(len(expiredEntries))
if removedRows > 0 {
for _, expiredEntry := range expiredEntries {
storeInstance.notify(Event{
@ -1062,19 +1072,19 @@ type expiredEntryRef struct {
key string
}
func listExpiredEntriesMatchingGroupPrefix(database schemaDatabase, groupPrefix string, cutoffUnixMilli int64) ([]expiredEntryRef, error) {
func deleteExpiredEntriesMatchingGroupPrefix(database schemaDatabase, groupPrefix string, cutoffUnixMilli int64) ([]expiredEntryRef, error) {
var (
rows *sql.Rows
err error
)
if groupPrefix == "" {
rows, err = database.Query(
"SELECT "+entryGroupColumn+", "+entryKeyColumn+" FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ? ORDER BY "+entryGroupColumn+", "+entryKeyColumn,
"DELETE FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ? RETURNING "+entryGroupColumn+", "+entryKeyColumn,
cutoffUnixMilli,
)
} else {
rows, err = database.Query(
"SELECT "+entryGroupColumn+", "+entryKeyColumn+" FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ? AND "+entryGroupColumn+" LIKE ? ESCAPE '^' ORDER BY "+entryGroupColumn+", "+entryKeyColumn,
"DELETE FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ? AND "+entryGroupColumn+" LIKE ? ESCAPE '^' RETURNING "+entryGroupColumn+", "+entryKeyColumn,
cutoffUnixMilli, escapeLike(groupPrefix)+"%",
)
}
@ -1097,35 +1107,6 @@ func listExpiredEntriesMatchingGroupPrefix(database schemaDatabase, groupPrefix
return expiredEntries, nil
}
// purgeExpiredMatchingGroupPrefix deletes expired rows globally when
// groupPrefix is empty, otherwise only rows whose group starts with the given
// prefix.
func purgeExpiredMatchingGroupPrefix(database schemaDatabase, groupPrefix string, cutoffUnixMilli int64) (int64, error) {
var (
deleteResult sql.Result
err error
)
if groupPrefix == "" {
deleteResult, err = database.Exec(
"DELETE FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ?",
cutoffUnixMilli,
)
} else {
deleteResult, err = database.Exec(
"DELETE FROM "+entriesTableName+" WHERE expires_at IS NOT NULL AND expires_at <= ? AND "+entryGroupColumn+" LIKE ? ESCAPE '^'",
cutoffUnixMilli, escapeLike(groupPrefix)+"%",
)
}
if err != nil {
return 0, err
}
removedRows, rowsAffectedErr := deleteResult.RowsAffected()
if rowsAffectedErr != nil {
return 0, rowsAffectedErr
}
return removedRows, nil
}
type schemaDatabase interface {
Exec(query string, args ...any) (sql.Result, error)
QueryRow(query string, args ...any) *sql.Row

View file

@ -1082,6 +1082,27 @@ func TestStore_Close_Bad_DriverCloseError(t *testing.T) {
assertContainsString(t, err.Error(), "store.Close")
}
func TestStore_Close_Bad_MediumSyncFailureRetryable(t *testing.T) {
useWorkspaceStateDirectory(t)
medium := &writeFailOnceMedium{memoryMedium: newMemoryMedium(), failures: 1}
storeInstance, err := New("retryable-close.db", WithMedium(medium))
assertNoError(t, err)
assertNoError(t, storeInstance.Set("g", "k", "v"))
err = storeInstance.Close()
assertError(t, err)
assertContainsString(t, err.Error(), "sync medium-backed database")
assertFalse(t, storeInstance.IsClosed())
_, err = storeInstance.Get("g", "k")
assertError(t, err)
assertNoError(t, storeInstance.Close())
assertTrue(t, storeInstance.IsClosed())
assertTrue(t, medium.Exists("retryable-close.db"))
}
// ---------------------------------------------------------------------------
// Test helpers
// ---------------------------------------------------------------------------
@ -1600,6 +1621,32 @@ func TestStore_PurgeExpired_Good(t *testing.T) {
assertEqualf(t, 1, count, "only non-expiring key should remain")
}
func TestStore_PurgeExpired_Good_NotifiesDeletedRows(t *testing.T) {
storeInstance, _ := New(":memory:")
defer func() { _ = storeInstance.Close() }()
assertNoError(t, storeInstance.SetWithTTL("g", "expired", "1", 1*time.Millisecond))
assertNoError(t, storeInstance.SetWithTTL("g", "live", "2", time.Hour))
time.Sleep(5 * time.Millisecond)
events := storeInstance.Watch("*")
defer storeInstance.Unwatch("*", events)
removed, err := storeInstance.PurgeExpired()
assertNoError(t, err)
assertEqual(t, int64(1), removed)
event := <-events
assertEqual(t, EventDelete, event.Type)
assertEqual(t, "g", event.Group)
assertEqual(t, "expired", event.Key)
select {
case extraEvent := <-events:
t.Fatalf("unexpected extra purge event: %#v", extraEvent)
default:
}
}
func TestStore_PurgeExpired_Good_NoneExpired(t *testing.T) {
storeInstance, _ := New(":memory:")
defer func() { _ = storeInstance.Close() }()

View file

@ -512,15 +512,11 @@ func (storeTransaction *StoreTransaction) PurgeExpired() (int64, error) {
}
cutoffUnixMilli := time.Now().UnixMilli()
expiredEntries, err := listExpiredEntriesMatchingGroupPrefix(storeTransaction.sqliteTransaction, "", cutoffUnixMilli)
if err != nil {
return 0, core.E("store.Transaction.PurgeExpired", "list expired rows", err)
}
removedRows, err := purgeExpiredMatchingGroupPrefix(storeTransaction.sqliteTransaction, "", cutoffUnixMilli)
expiredEntries, err := deleteExpiredEntriesMatchingGroupPrefix(storeTransaction.sqliteTransaction, "", cutoffUnixMilli)
if err != nil {
return 0, core.E("store.Transaction.PurgeExpired", "delete expired rows", err)
}
removedRows := int64(len(expiredEntries))
if removedRows > 0 {
for _, expiredEntry := range expiredEntries {
storeTransaction.recordEvent(Event{