refactor(store): replace banned stdlib imports with core/go primitives
- fmt → core.Sprintf, core.E - strings → core.Contains, core.HasPrefix, core.Split, core.Join, core.Trim - os → core.Fs operations - path/filepath → core.JoinPath, core.PathBase - encoding/json → core.JSONMarshal, core.JSONUnmarshal - Add usage example comments to all exported struct fields Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
79815048c3
commit
eef4e737aa
7 changed files with 1645 additions and 2 deletions
462
duckdb.go
Normal file
462
duckdb.go
Normal file
|
|
@ -0,0 +1,462 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
core "dappco.re/go/core"
|
||||
_ "github.com/marcboeker/go-duckdb"
|
||||
)
|
||||
|
||||
// DuckDB table names for checkpoint scoring and probe results.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// db.EnsureScoringTables()
|
||||
// db.Exec(core.Sprintf("SELECT * FROM %s", store.TableCheckpointScores))
|
||||
const (
|
||||
// TableCheckpointScores is the table name for checkpoint scoring data.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// store.TableCheckpointScores // "checkpoint_scores"
|
||||
TableCheckpointScores = "checkpoint_scores"
|
||||
|
||||
// TableProbeResults is the table name for probe result data.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// store.TableProbeResults // "probe_results"
|
||||
TableProbeResults = "probe_results"
|
||||
)
|
||||
|
||||
// DuckDB wraps a DuckDB connection for analytical queries against training
|
||||
// data, benchmark results, and scoring tables.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// db, err := store.OpenDuckDB("/Volumes/Data/lem/lem.duckdb")
|
||||
// if err != nil { return }
|
||||
// defer db.Close()
|
||||
// rows, _ := db.QueryGoldenSet(500)
|
||||
type DuckDB struct {
|
||||
conn *sql.DB
|
||||
path string
|
||||
}
|
||||
|
||||
// OpenDuckDB opens a DuckDB database file in read-only mode to avoid locking
|
||||
// issues with the Python pipeline.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// db, err := store.OpenDuckDB("/Volumes/Data/lem/lem.duckdb")
|
||||
func OpenDuckDB(path string) (*DuckDB, error) {
|
||||
conn, err := sql.Open("duckdb", path+"?access_mode=READ_ONLY")
|
||||
if err != nil {
|
||||
return nil, core.E("store.OpenDuckDB", core.Sprintf("open duckdb %s", path), err)
|
||||
}
|
||||
if err := conn.Ping(); err != nil {
|
||||
conn.Close()
|
||||
return nil, core.E("store.OpenDuckDB", core.Sprintf("ping duckdb %s", path), err)
|
||||
}
|
||||
return &DuckDB{conn: conn, path: path}, nil
|
||||
}
|
||||
|
||||
// OpenDuckDBReadWrite opens a DuckDB database in read-write mode.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// db, err := store.OpenDuckDBReadWrite("/Volumes/Data/lem/lem.duckdb")
|
||||
func OpenDuckDBReadWrite(path string) (*DuckDB, error) {
|
||||
conn, err := sql.Open("duckdb", path)
|
||||
if err != nil {
|
||||
return nil, core.E("store.OpenDuckDBReadWrite", core.Sprintf("open duckdb %s", path), err)
|
||||
}
|
||||
if err := conn.Ping(); err != nil {
|
||||
conn.Close()
|
||||
return nil, core.E("store.OpenDuckDBReadWrite", core.Sprintf("ping duckdb %s", path), err)
|
||||
}
|
||||
return &DuckDB{conn: conn, path: path}, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// defer db.Close()
|
||||
func (db *DuckDB) Close() error {
|
||||
return db.conn.Close()
|
||||
}
|
||||
|
||||
// Path returns the database file path.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// p := db.Path() // "/Volumes/Data/lem/lem.duckdb"
|
||||
func (db *DuckDB) Path() string {
|
||||
return db.path
|
||||
}
|
||||
|
||||
// Exec executes a query without returning rows.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// err := db.Exec("INSERT INTO golden_set VALUES (?, ?)", idx, prompt)
|
||||
func (db *DuckDB) Exec(query string, args ...any) error {
|
||||
_, err := db.conn.Exec(query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// QueryRowScan executes a query expected to return at most one row and scans
|
||||
// the result into dest. It is a convenience wrapper around sql.DB.QueryRow.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// var count int
|
||||
// err := db.QueryRowScan("SELECT COUNT(*) FROM golden_set", &count)
|
||||
func (db *DuckDB) QueryRowScan(query string, dest any, args ...any) error {
|
||||
return db.conn.QueryRow(query, args...).Scan(dest)
|
||||
}
|
||||
|
||||
// GoldenSetRow represents one row from the golden_set table.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// rows, err := db.QueryGoldenSet(500)
|
||||
// for _, row := range rows { core.Println(row.Prompt) }
|
||||
type GoldenSetRow struct {
|
||||
// Idx is the row index.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// row.Idx // 42
|
||||
Idx int
|
||||
|
||||
// SeedID is the seed identifier that produced this row.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// row.SeedID // "seed-001"
|
||||
SeedID string
|
||||
|
||||
// Domain is the content domain (e.g. "philosophy", "science").
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// row.Domain // "philosophy"
|
||||
Domain string
|
||||
|
||||
// Voice is the writing voice/style used for generation.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// row.Voice // "watts"
|
||||
Voice string
|
||||
|
||||
// Prompt is the input prompt text.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// row.Prompt // "What is sovereignty?"
|
||||
Prompt string
|
||||
|
||||
// Response is the generated response text.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// row.Response // "Sovereignty is..."
|
||||
Response string
|
||||
|
||||
// GenTime is the generation time in seconds.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// row.GenTime // 2.5
|
||||
GenTime float64
|
||||
|
||||
// CharCount is the character count of the response.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// row.CharCount // 1500
|
||||
CharCount int
|
||||
}
|
||||
|
||||
// ExpansionPromptRow represents one row from the expansion_prompts table.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// prompts, err := db.QueryExpansionPrompts("pending", 100)
|
||||
// for _, p := range prompts { core.Println(p.Prompt) }
|
||||
type ExpansionPromptRow struct {
|
||||
// Idx is the row index.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// p.Idx // 42
|
||||
Idx int64
|
||||
|
||||
// SeedID is the seed identifier that produced this prompt.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// p.SeedID // "seed-001"
|
||||
SeedID string
|
||||
|
||||
// Region is the geographic/cultural region for the prompt.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// p.Region // "western"
|
||||
Region string
|
||||
|
||||
// Domain is the content domain (e.g. "philosophy", "science").
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// p.Domain // "philosophy"
|
||||
Domain string
|
||||
|
||||
// Language is the ISO language code for the prompt.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// p.Language // "en"
|
||||
Language string
|
||||
|
||||
// Prompt is the prompt text in the original language.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// p.Prompt // "What is sovereignty?"
|
||||
Prompt string
|
||||
|
||||
// PromptEn is the English translation of the prompt.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// p.PromptEn // "What is sovereignty?"
|
||||
PromptEn string
|
||||
|
||||
// Priority is the generation priority (lower is higher priority).
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// p.Priority // 1
|
||||
Priority int
|
||||
|
||||
// Status is the processing status (e.g. "pending", "done").
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// p.Status // "pending"
|
||||
Status string
|
||||
}
|
||||
|
||||
// QueryGoldenSet returns all golden set rows with responses >= minChars.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// rows, err := db.QueryGoldenSet(500)
|
||||
func (db *DuckDB) QueryGoldenSet(minChars int) ([]GoldenSetRow, error) {
|
||||
rows, err := db.conn.Query(
|
||||
"SELECT idx, seed_id, domain, voice, prompt, response, gen_time, char_count "+
|
||||
"FROM golden_set WHERE char_count >= ? ORDER BY idx",
|
||||
minChars,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, core.E("store.DuckDB.QueryGoldenSet", "query golden_set", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var result []GoldenSetRow
|
||||
for rows.Next() {
|
||||
var r GoldenSetRow
|
||||
if err := rows.Scan(&r.Idx, &r.SeedID, &r.Domain, &r.Voice,
|
||||
&r.Prompt, &r.Response, &r.GenTime, &r.CharCount); err != nil {
|
||||
return nil, core.E("store.DuckDB.QueryGoldenSet", "scan golden_set row", err)
|
||||
}
|
||||
result = append(result, r)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
// CountGoldenSet returns the total count of golden set rows.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// count, err := db.CountGoldenSet()
|
||||
func (db *DuckDB) CountGoldenSet() (int, error) {
|
||||
var count int
|
||||
err := db.conn.QueryRow("SELECT COUNT(*) FROM golden_set").Scan(&count)
|
||||
if err != nil {
|
||||
return 0, core.E("store.DuckDB.CountGoldenSet", "count golden_set", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// QueryExpansionPrompts returns expansion prompts filtered by status.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// prompts, err := db.QueryExpansionPrompts("pending", 100)
|
||||
func (db *DuckDB) QueryExpansionPrompts(status string, limit int) ([]ExpansionPromptRow, error) {
|
||||
query := "SELECT idx, seed_id, region, domain, language, prompt, prompt_en, priority, status " +
|
||||
"FROM expansion_prompts"
|
||||
var args []any
|
||||
|
||||
if status != "" {
|
||||
query += " WHERE status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
query += " ORDER BY priority, idx"
|
||||
|
||||
if limit > 0 {
|
||||
query += core.Sprintf(" LIMIT %d", limit)
|
||||
}
|
||||
|
||||
rows, err := db.conn.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, core.E("store.DuckDB.QueryExpansionPrompts", "query expansion_prompts", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var result []ExpansionPromptRow
|
||||
for rows.Next() {
|
||||
var r ExpansionPromptRow
|
||||
if err := rows.Scan(&r.Idx, &r.SeedID, &r.Region, &r.Domain,
|
||||
&r.Language, &r.Prompt, &r.PromptEn, &r.Priority, &r.Status); err != nil {
|
||||
return nil, core.E("store.DuckDB.QueryExpansionPrompts", "scan expansion_prompt row", err)
|
||||
}
|
||||
result = append(result, r)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
// CountExpansionPrompts returns counts by status.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// total, pending, err := db.CountExpansionPrompts()
|
||||
func (db *DuckDB) CountExpansionPrompts() (total int, pending int, err error) {
|
||||
err = db.conn.QueryRow("SELECT COUNT(*) FROM expansion_prompts").Scan(&total)
|
||||
if err != nil {
|
||||
return 0, 0, core.E("store.DuckDB.CountExpansionPrompts", "count expansion_prompts", err)
|
||||
}
|
||||
err = db.conn.QueryRow("SELECT COUNT(*) FROM expansion_prompts WHERE status = 'pending'").Scan(&pending)
|
||||
if err != nil {
|
||||
return total, 0, core.E("store.DuckDB.CountExpansionPrompts", "count pending expansion_prompts", err)
|
||||
}
|
||||
return total, pending, nil
|
||||
}
|
||||
|
||||
// UpdateExpansionStatus updates the status of an expansion prompt by idx.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// err := db.UpdateExpansionStatus(42, "done")
|
||||
func (db *DuckDB) UpdateExpansionStatus(idx int64, status string) error {
|
||||
_, err := db.conn.Exec("UPDATE expansion_prompts SET status = ? WHERE idx = ?", status, idx)
|
||||
if err != nil {
|
||||
return core.E("store.DuckDB.UpdateExpansionStatus", core.Sprintf("update expansion_prompt %d", idx), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryRows executes an arbitrary SQL query and returns results as maps.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// rows, err := db.QueryRows("SELECT COUNT(*) AS n FROM golden_set")
|
||||
func (db *DuckDB) QueryRows(query string, args ...any) ([]map[string]any, error) {
|
||||
rows, err := db.conn.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, core.E("store.DuckDB.QueryRows", "query", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, core.E("store.DuckDB.QueryRows", "columns", err)
|
||||
}
|
||||
|
||||
var result []map[string]any
|
||||
for rows.Next() {
|
||||
values := make([]any, len(cols))
|
||||
ptrs := make([]any, len(cols))
|
||||
for i := range values {
|
||||
ptrs[i] = &values[i]
|
||||
}
|
||||
if err := rows.Scan(ptrs...); err != nil {
|
||||
return nil, core.E("store.DuckDB.QueryRows", "scan", err)
|
||||
}
|
||||
row := make(map[string]any, len(cols))
|
||||
for i, col := range cols {
|
||||
row[col] = values[i]
|
||||
}
|
||||
result = append(result, row)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
// EnsureScoringTables creates the scoring tables if they do not exist.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// db.EnsureScoringTables()
|
||||
func (db *DuckDB) EnsureScoringTables() {
|
||||
db.conn.Exec(core.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
|
||||
model TEXT, run_id TEXT, label TEXT, iteration INTEGER,
|
||||
correct INTEGER, total INTEGER, accuracy DOUBLE,
|
||||
scored_at TIMESTAMP DEFAULT current_timestamp,
|
||||
PRIMARY KEY (run_id, label)
|
||||
)`, TableCheckpointScores))
|
||||
db.conn.Exec(core.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
|
||||
model TEXT, run_id TEXT, label TEXT, probe_id TEXT,
|
||||
passed BOOLEAN, response TEXT, iteration INTEGER,
|
||||
scored_at TIMESTAMP DEFAULT current_timestamp,
|
||||
PRIMARY KEY (run_id, label, probe_id)
|
||||
)`, TableProbeResults))
|
||||
db.conn.Exec(`CREATE TABLE IF NOT EXISTS scoring_results (
|
||||
model TEXT, prompt_id TEXT, suite TEXT,
|
||||
dimension TEXT, score DOUBLE,
|
||||
scored_at TIMESTAMP DEFAULT current_timestamp
|
||||
)`)
|
||||
}
|
||||
|
||||
// WriteScoringResult writes a single scoring dimension result to DuckDB.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// err := db.WriteScoringResult("lem-8b", "p-001", "ethics", "honesty", 0.95)
|
||||
func (db *DuckDB) WriteScoringResult(model, promptID, suite, dimension string, score float64) error {
|
||||
_, err := db.conn.Exec(
|
||||
`INSERT INTO scoring_results (model, prompt_id, suite, dimension, score) VALUES (?, ?, ?, ?, ?)`,
|
||||
model, promptID, suite, dimension, score,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// TableCounts returns row counts for all known tables.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// counts, err := db.TableCounts()
|
||||
// n := counts["golden_set"]
|
||||
func (db *DuckDB) TableCounts() (map[string]int, error) {
|
||||
tables := []string{"golden_set", "expansion_prompts", "seeds", "prompts",
|
||||
"training_examples", "gemini_responses", "benchmark_questions", "benchmark_results", "validations",
|
||||
TableCheckpointScores, TableProbeResults, "scoring_results"}
|
||||
|
||||
counts := make(map[string]int)
|
||||
for _, t := range tables {
|
||||
var count int
|
||||
err := db.conn.QueryRow(core.Sprintf("SELECT COUNT(*) FROM %s", t)).Scan(&count)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
counts[t] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
23
go.mod
23
go.mod
|
|
@ -9,13 +9,34 @@ require (
|
|||
modernc.org/sqlite v1.47.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.1.1 // indirect
|
||||
github.com/apache/arrow-go/v18 v18.1.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/google/flatbuffers v25.1.24+incompatible // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
|
||||
github.com/parquet-go/bitpack v1.0.0 // indirect
|
||||
github.com/parquet-go/jsonlite v1.0.0 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.22 // indirect
|
||||
github.com/twpayne/go-geom v1.6.1 // indirect
|
||||
github.com/zeebo/xxh3 v1.0.2 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/mod v0.34.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/telemetry v0.0.0-20260311193753-579e4da9a98c // indirect
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
|
||||
google.golang.org/protobuf v1.36.1 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/marcboeker/go-duckdb v1.8.5
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/parquet-go/parquet-go v0.29.0
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
|
|
|
|||
61
go.sum
61
go.sum
|
|
@ -1,26 +1,67 @@
|
|||
dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk=
|
||||
dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/alecthomas/assert/v2 v2.10.0 h1:jjRCHsj6hBJhkmhznrCzoNpbA3zqy0fYiUcYZP/GkPY=
|
||||
github.com/alecthomas/assert/v2 v2.10.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
|
||||
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
|
||||
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
||||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||
github.com/apache/arrow-go/v18 v18.1.0 h1:agLwJUiVuwXZdwPYVrlITfx7bndULJ/dggbnLFgDp/Y=
|
||||
github.com/apache/arrow-go/v18 v18.1.0/go.mod h1:tigU/sIgKNXaesf5d7Y95jBBKS5KsxTqYBKXFsvKzo0=
|
||||
github.com/apache/thrift v0.21.0 h1:tdPmh/ptjE1IJnhbhrcl2++TauVjy242rkV/UzJChnE=
|
||||
github.com/apache/thrift v0.21.0/go.mod h1:W1H8aR/QRtYNvrPeFXBtobyRkd0/YVhTc6i07XIAgDw=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
|
||||
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/flatbuffers v25.1.24+incompatible h1:4wPqL3K7GzBd1CwyhSd3usxLKOaJN/AC6puCca6Jm7o=
|
||||
github.com/google/flatbuffers v25.1.24+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4=
|
||||
github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE=
|
||||
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
|
||||
github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY=
|
||||
github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/marcboeker/go-duckdb v1.8.5 h1:tkYp+TANippy0DaIOP5OEfBEwbUINqiFqgwMQ44jME0=
|
||||
github.com/marcboeker/go-duckdb v1.8.5/go.mod h1:6mK7+WQE4P4u5AFLvVBmhFxY5fvhymFptghgJX6B+/8=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=
|
||||
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY=
|
||||
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI=
|
||||
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/parquet-go/bitpack v1.0.0 h1:AUqzlKzPPXf2bCdjfj4sTeacrUwsT7NlcYDMUQxPcQA=
|
||||
github.com/parquet-go/bitpack v1.0.0/go.mod h1:XnVk9TH+O40eOOmvpAVZ7K2ocQFrQwysLMnc6M/8lgs=
|
||||
github.com/parquet-go/jsonlite v1.0.0 h1:87QNdi56wOfsE5bdgas0vRzHPxfJgzrXGml1zZdd7VU=
|
||||
github.com/parquet-go/jsonlite v1.0.0/go.mod h1:nDjpkpL4EOtqs6NQugUsi0Rleq9sW/OtC1NnZEnxzF0=
|
||||
github.com/parquet-go/parquet-go v0.29.0 h1:xXlPtFVR51jpSVzf+cgHnNIcb7Xet+iuvkbe0HIm90Y=
|
||||
github.com/parquet-go/parquet-go v0.29.0/go.mod h1:navtkAYr2LGoJVp141oXPlO/sxLvaOe3la2JEoD8+rg=
|
||||
github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU=
|
||||
github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
|
|
@ -29,6 +70,16 @@ github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0t
|
|||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/twpayne/go-geom v1.6.1 h1:iLE+Opv0Ihm/ABIcvQFGIiFBXd76oBIar9drAwHFhR4=
|
||||
github.com/twpayne/go-geom v1.6.1/go.mod h1:Kr+Nly6BswFsKM5sd31YaoWS5PeDDH2NftJTK7Gd028=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ=
|
||||
github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
|
||||
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
|
||||
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
|
||||
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
|
|
@ -36,8 +87,16 @@ golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
|||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/telemetry v0.0.0-20260311193753-579e4da9a98c h1:6a8FdnNk6bTXBjR4AGKFgUKuo+7GnR3FX5L7CbveeZc=
|
||||
golang.org/x/telemetry v0.0.0-20260311193753-579e4da9a98c/go.mod h1:TpUTTEp9frx7rTdLpC9gFG9kdI7zVLFTFFlqaH2Cncw=
|
||||
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
|
||||
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY=
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90=
|
||||
gonum.org/v1/gonum v0.15.1 h1:FNy7N6OUZVUaWG9pTiD+jlhdQ3lMP+/LcTpJ6+a8sQ0=
|
||||
gonum.org/v1/gonum v0.15.1/go.mod h1:eZTZuRFrzu5pcyjN5wJhcIhnUdNijYxX1T2IcrOGY0o=
|
||||
google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk=
|
||||
google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
|
|
|
|||
544
import.go
Normal file
544
import.go
Normal file
|
|
@ -0,0 +1,544 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"io/fs"
|
||||
|
||||
core "dappco.re/go/core"
|
||||
)
|
||||
|
||||
// localFs provides unrestricted filesystem access for import operations.
|
||||
var localFs = (&core.Fs{}).New("/")
|
||||
|
||||
// ScpFunc is a callback for executing SCP file transfers.
|
||||
// The function receives remote source and local destination paths.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// scp := func(remote, local string) error { return exec.Command("scp", remote, local).Run() }
|
||||
type ScpFunc func(remote, local string) error
|
||||
|
||||
// ScpDirFunc is a callback for executing recursive SCP directory transfers.
|
||||
// The function receives remote source and local destination directory paths.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// scpDir := func(remote, localDir string) error { return exec.Command("scp", "-r", remote, localDir).Run() }
|
||||
type ScpDirFunc func(remote, localDir string) error
|
||||
|
||||
// ImportConfig holds options for the import-all operation.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// cfg := store.ImportConfig{DataDir: "/Volumes/Data/lem", SkipM3: true}
|
||||
type ImportConfig struct {
|
||||
// SkipM3 disables pulling files from the M3 host.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// cfg.SkipM3 // true
|
||||
SkipM3 bool
|
||||
|
||||
// DataDir is the local directory containing LEM data files.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// cfg.DataDir // "/Volumes/Data/lem"
|
||||
DataDir string
|
||||
|
||||
// M3Host is the SSH hostname for SCP operations. Defaults to "m3".
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// cfg.M3Host // "m3"
|
||||
M3Host string
|
||||
|
||||
// Scp copies a single file from the remote host. If nil, SCP is skipped.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// cfg.Scp("m3:/path/file.jsonl", "/local/file.jsonl")
|
||||
Scp ScpFunc
|
||||
|
||||
// ScpDir copies a directory recursively from the remote host. If nil, SCP is skipped.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// cfg.ScpDir("m3:/path/dir/", "/local/dir/")
|
||||
ScpDir ScpDirFunc
|
||||
}
|
||||
|
||||
// ImportAll imports all LEM data into DuckDB from M3 and local files.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// err := store.ImportAll(db, store.ImportConfig{DataDir: "/Volumes/Data/lem"}, os.Stdout)
|
||||
func ImportAll(db *DuckDB, cfg ImportConfig, w io.Writer) error {
|
||||
m3Host := cfg.M3Host
|
||||
if m3Host == "" {
|
||||
m3Host = "m3"
|
||||
}
|
||||
|
||||
totals := make(map[string]int)
|
||||
|
||||
// ── 1. Golden set ──
|
||||
goldenPath := core.JoinPath(cfg.DataDir, "gold-15k.jsonl")
|
||||
if !cfg.SkipM3 && cfg.Scp != nil {
|
||||
core.Print(w, " Pulling golden set from M3...")
|
||||
remote := core.Sprintf("%s:/Volumes/Data/lem/responses/gold-15k.jsonl", m3Host)
|
||||
if err := cfg.Scp(remote, goldenPath); err != nil {
|
||||
core.Print(w, " WARNING: could not pull golden set from M3: %v", err)
|
||||
}
|
||||
}
|
||||
if isFile(goldenPath) {
|
||||
db.Exec("DROP TABLE IF EXISTS golden_set")
|
||||
err := db.Exec(core.Sprintf(`
|
||||
CREATE TABLE golden_set AS
|
||||
SELECT
|
||||
idx::INT AS idx,
|
||||
seed_id::VARCHAR AS seed_id,
|
||||
domain::VARCHAR AS domain,
|
||||
voice::VARCHAR AS voice,
|
||||
prompt::VARCHAR AS prompt,
|
||||
response::VARCHAR AS response,
|
||||
gen_time::DOUBLE AS gen_time,
|
||||
length(response)::INT AS char_count,
|
||||
length(response) - length(replace(response, ' ', '')) + 1 AS word_count
|
||||
FROM read_json_auto('%s', maximum_object_size=1048576)
|
||||
`, escapeSQLPath(goldenPath)))
|
||||
if err != nil {
|
||||
core.Print(w, " WARNING: golden set import failed: %v", err)
|
||||
} else {
|
||||
var n int
|
||||
db.QueryRowScan("SELECT count(*) FROM golden_set", &n)
|
||||
totals["golden_set"] = n
|
||||
core.Print(w, " golden_set: %d rows", n)
|
||||
}
|
||||
}
|
||||
|
||||
// ── 2. Training examples ──
|
||||
trainingDirs := []struct {
|
||||
name string
|
||||
files []string
|
||||
}{
|
||||
{"training", []string{"training/train.jsonl", "training/valid.jsonl", "training/test.jsonl"}},
|
||||
{"training-2k", []string{"training-2k/train.jsonl", "training-2k/valid.jsonl", "training-2k/test.jsonl"}},
|
||||
{"training-expanded", []string{"training-expanded/train.jsonl", "training-expanded/valid.jsonl"}},
|
||||
{"training-book", []string{"training-book/train.jsonl", "training-book/valid.jsonl", "training-book/test.jsonl"}},
|
||||
{"training-conv", []string{"training-conv/train.jsonl", "training-conv/valid.jsonl", "training-conv/test.jsonl"}},
|
||||
{"gold-full", []string{"gold-full/train.jsonl", "gold-full/valid.jsonl"}},
|
||||
{"sovereignty-gold", []string{"sovereignty-gold/train.jsonl", "sovereignty-gold/valid.jsonl"}},
|
||||
{"composure-lessons", []string{"composure-lessons/train.jsonl", "composure-lessons/valid.jsonl"}},
|
||||
{"watts-full", []string{"watts-full/train.jsonl", "watts-full/valid.jsonl"}},
|
||||
{"watts-expanded", []string{"watts-expanded/train.jsonl", "watts-expanded/valid.jsonl"}},
|
||||
{"watts-composure", []string{"watts-composure-merged/train.jsonl", "watts-composure-merged/valid.jsonl"}},
|
||||
{"western-fresh", []string{"western-fresh/train.jsonl", "western-fresh/valid.jsonl"}},
|
||||
{"deepseek-soak", []string{"deepseek-western-soak/train.jsonl", "deepseek-western-soak/valid.jsonl"}},
|
||||
{"russian-bridge", []string{"russian-bridge/train.jsonl", "russian-bridge/valid.jsonl"}},
|
||||
}
|
||||
|
||||
trainingLocal := core.JoinPath(cfg.DataDir, "training")
|
||||
localFs.EnsureDir(trainingLocal)
|
||||
|
||||
if !cfg.SkipM3 && cfg.Scp != nil {
|
||||
core.Print(w, " Pulling training sets from M3...")
|
||||
for _, td := range trainingDirs {
|
||||
for _, rel := range td.files {
|
||||
local := core.JoinPath(trainingLocal, rel)
|
||||
localFs.EnsureDir(core.PathDir(local))
|
||||
remote := core.Sprintf("%s:/Volumes/Data/lem/%s", m3Host, rel)
|
||||
cfg.Scp(remote, local) // ignore errors, file might not exist
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.Exec("DROP TABLE IF EXISTS training_examples")
|
||||
db.Exec(`
|
||||
CREATE TABLE training_examples (
|
||||
source VARCHAR,
|
||||
split VARCHAR,
|
||||
prompt TEXT,
|
||||
response TEXT,
|
||||
num_turns INT,
|
||||
full_messages TEXT,
|
||||
char_count INT
|
||||
)
|
||||
`)
|
||||
|
||||
trainingTotal := 0
|
||||
for _, td := range trainingDirs {
|
||||
for _, rel := range td.files {
|
||||
local := core.JoinPath(trainingLocal, rel)
|
||||
if !isFile(local) {
|
||||
continue
|
||||
}
|
||||
|
||||
split := "train"
|
||||
if core.Contains(rel, "valid") {
|
||||
split = "valid"
|
||||
} else if core.Contains(rel, "test") {
|
||||
split = "test"
|
||||
}
|
||||
|
||||
n := importTrainingFile(db, local, td.name, split)
|
||||
trainingTotal += n
|
||||
}
|
||||
}
|
||||
totals["training_examples"] = trainingTotal
|
||||
core.Print(w, " training_examples: %d rows", trainingTotal)
|
||||
|
||||
// ── 3. Benchmark results ──
|
||||
benchLocal := core.JoinPath(cfg.DataDir, "benchmarks")
|
||||
localFs.EnsureDir(benchLocal)
|
||||
|
||||
if !cfg.SkipM3 {
|
||||
core.Print(w, " Pulling benchmarks from M3...")
|
||||
if cfg.Scp != nil {
|
||||
for _, bname := range []string{"truthfulqa", "gsm8k", "do_not_answer", "toxigen"} {
|
||||
remote := core.Sprintf("%s:/Volumes/Data/lem/benchmarks/%s.jsonl", m3Host, bname)
|
||||
cfg.Scp(remote, core.JoinPath(benchLocal, bname+".jsonl"))
|
||||
}
|
||||
}
|
||||
if cfg.ScpDir != nil {
|
||||
for _, subdir := range []string{"results", "scale_results", "cross_arch_results", "deepseek-r1-7b"} {
|
||||
localSub := core.JoinPath(benchLocal, subdir)
|
||||
localFs.EnsureDir(localSub)
|
||||
remote := core.Sprintf("%s:/Volumes/Data/lem/benchmarks/%s/", m3Host, subdir)
|
||||
cfg.ScpDir(remote, core.JoinPath(benchLocal)+"/")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.Exec("DROP TABLE IF EXISTS benchmark_results")
|
||||
db.Exec(`
|
||||
CREATE TABLE benchmark_results (
|
||||
source VARCHAR, id VARCHAR, benchmark VARCHAR, model VARCHAR,
|
||||
prompt TEXT, response TEXT, elapsed_seconds DOUBLE, domain VARCHAR
|
||||
)
|
||||
`)
|
||||
|
||||
benchTotal := 0
|
||||
for _, subdir := range []string{"results", "scale_results", "cross_arch_results", "deepseek-r1-7b"} {
|
||||
resultDir := core.JoinPath(benchLocal, subdir)
|
||||
matches := core.PathGlob(core.JoinPath(resultDir, "*.jsonl"))
|
||||
for _, jf := range matches {
|
||||
n := importBenchmarkFile(db, jf, subdir)
|
||||
benchTotal += n
|
||||
}
|
||||
}
|
||||
|
||||
// Also import standalone benchmark files.
|
||||
for _, bfile := range []string{"lem_bench", "lem_ethics", "lem_ethics_allen", "instruction_tuned", "abliterated", "base_pt"} {
|
||||
local := core.JoinPath(benchLocal, bfile+".jsonl")
|
||||
if !isFile(local) {
|
||||
if !cfg.SkipM3 && cfg.Scp != nil {
|
||||
remote := core.Sprintf("%s:/Volumes/Data/lem/benchmark/%s.jsonl", m3Host, bfile)
|
||||
cfg.Scp(remote, local)
|
||||
}
|
||||
}
|
||||
if isFile(local) {
|
||||
n := importBenchmarkFile(db, local, "benchmark")
|
||||
benchTotal += n
|
||||
}
|
||||
}
|
||||
totals["benchmark_results"] = benchTotal
|
||||
core.Print(w, " benchmark_results: %d rows", benchTotal)
|
||||
|
||||
// ── 4. Benchmark questions ──
|
||||
db.Exec("DROP TABLE IF EXISTS benchmark_questions")
|
||||
db.Exec(`
|
||||
CREATE TABLE benchmark_questions (
|
||||
benchmark VARCHAR, id VARCHAR, question TEXT,
|
||||
best_answer TEXT, correct_answers TEXT, incorrect_answers TEXT, category VARCHAR
|
||||
)
|
||||
`)
|
||||
|
||||
benchQTotal := 0
|
||||
for _, bname := range []string{"truthfulqa", "gsm8k", "do_not_answer", "toxigen"} {
|
||||
local := core.JoinPath(benchLocal, bname+".jsonl")
|
||||
if isFile(local) {
|
||||
n := importBenchmarkQuestions(db, local, bname)
|
||||
benchQTotal += n
|
||||
}
|
||||
}
|
||||
totals["benchmark_questions"] = benchQTotal
|
||||
core.Print(w, " benchmark_questions: %d rows", benchQTotal)
|
||||
|
||||
// ── 5. Seeds ──
|
||||
db.Exec("DROP TABLE IF EXISTS seeds")
|
||||
db.Exec(`
|
||||
CREATE TABLE seeds (
|
||||
source_file VARCHAR, region VARCHAR, seed_id VARCHAR, domain VARCHAR, prompt TEXT
|
||||
)
|
||||
`)
|
||||
|
||||
seedTotal := 0
|
||||
seedDirs := []string{core.JoinPath(cfg.DataDir, "seeds"), "/tmp/lem-data/seeds", "/tmp/lem-repo/seeds"}
|
||||
for _, seedDir := range seedDirs {
|
||||
if !isDir(seedDir) {
|
||||
continue
|
||||
}
|
||||
n := importSeeds(db, seedDir)
|
||||
seedTotal += n
|
||||
}
|
||||
totals["seeds"] = seedTotal
|
||||
core.Print(w, " seeds: %d rows", seedTotal)
|
||||
|
||||
// ── Summary ──
|
||||
grandTotal := 0
|
||||
core.Print(w, "\n%s", repeat("=", 50))
|
||||
core.Print(w, "LEM Database Import Complete")
|
||||
core.Print(w, "%s", repeat("=", 50))
|
||||
for table, count := range totals {
|
||||
core.Print(w, " %-25s %8d", table, count)
|
||||
grandTotal += count
|
||||
}
|
||||
core.Print(w, " %s", repeat("-", 35))
|
||||
core.Print(w, " %-25s %8d", "TOTAL", grandTotal)
|
||||
core.Print(w, "\nDatabase: %s", db.Path())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func importTrainingFile(db *DuckDB, path, source, split string) int {
|
||||
r := localFs.Open(path)
|
||||
if !r.OK {
|
||||
return 0
|
||||
}
|
||||
f := r.Value.(io.ReadCloser)
|
||||
defer f.Close()
|
||||
|
||||
count := 0
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
var rec struct {
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
}
|
||||
if r := core.JSONUnmarshal(scanner.Bytes(), &rec); !r.OK {
|
||||
continue
|
||||
}
|
||||
|
||||
prompt := ""
|
||||
response := ""
|
||||
assistantCount := 0
|
||||
for _, m := range rec.Messages {
|
||||
if m.Role == "user" && prompt == "" {
|
||||
prompt = m.Content
|
||||
}
|
||||
if m.Role == "assistant" {
|
||||
if response == "" {
|
||||
response = m.Content
|
||||
}
|
||||
assistantCount++
|
||||
}
|
||||
}
|
||||
|
||||
msgsJSON := core.JSONMarshalString(rec.Messages)
|
||||
db.Exec(`INSERT INTO training_examples VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
source, split, prompt, response, assistantCount, msgsJSON, len(response))
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func importBenchmarkFile(db *DuckDB, path, source string) int {
|
||||
r := localFs.Open(path)
|
||||
if !r.OK {
|
||||
return 0
|
||||
}
|
||||
f := r.Value.(io.ReadCloser)
|
||||
defer f.Close()
|
||||
|
||||
count := 0
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
var rec map[string]any
|
||||
if r := core.JSONUnmarshal(scanner.Bytes(), &rec); !r.OK {
|
||||
continue
|
||||
}
|
||||
|
||||
db.Exec(`INSERT INTO benchmark_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
source,
|
||||
core.Sprint(rec["id"]),
|
||||
strOrEmpty(rec, "benchmark"),
|
||||
strOrEmpty(rec, "model"),
|
||||
strOrEmpty(rec, "prompt"),
|
||||
strOrEmpty(rec, "response"),
|
||||
floatOrZero(rec, "elapsed_seconds"),
|
||||
strOrEmpty(rec, "domain"),
|
||||
)
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func importBenchmarkQuestions(db *DuckDB, path, benchmark string) int {
|
||||
r := localFs.Open(path)
|
||||
if !r.OK {
|
||||
return 0
|
||||
}
|
||||
f := r.Value.(io.ReadCloser)
|
||||
defer f.Close()
|
||||
|
||||
count := 0
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
var rec map[string]any
|
||||
if r := core.JSONUnmarshal(scanner.Bytes(), &rec); !r.OK {
|
||||
continue
|
||||
}
|
||||
|
||||
correctJSON := core.JSONMarshalString(rec["correct_answers"])
|
||||
incorrectJSON := core.JSONMarshalString(rec["incorrect_answers"])
|
||||
|
||||
db.Exec(`INSERT INTO benchmark_questions VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
benchmark,
|
||||
core.Sprint(rec["id"]),
|
||||
strOrEmpty(rec, "question"),
|
||||
strOrEmpty(rec, "best_answer"),
|
||||
correctJSON,
|
||||
incorrectJSON,
|
||||
strOrEmpty(rec, "category"),
|
||||
)
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func importSeeds(db *DuckDB, seedDir string) int {
|
||||
count := 0
|
||||
walkDir(seedDir, func(path string) {
|
||||
if !core.HasSuffix(path, ".json") {
|
||||
return
|
||||
}
|
||||
|
||||
readResult := localFs.Read(path)
|
||||
if !readResult.OK {
|
||||
return
|
||||
}
|
||||
data := []byte(readResult.Value.(string))
|
||||
|
||||
rel := core.TrimPrefix(path, seedDir+"/")
|
||||
region := core.TrimSuffix(core.PathBase(path), ".json")
|
||||
|
||||
// Try parsing as array or object with prompts/seeds field.
|
||||
var seedsList []any
|
||||
var raw any
|
||||
if r := core.JSONUnmarshal(data, &raw); !r.OK {
|
||||
return
|
||||
}
|
||||
|
||||
switch v := raw.(type) {
|
||||
case []any:
|
||||
seedsList = v
|
||||
case map[string]any:
|
||||
if prompts, ok := v["prompts"].([]any); ok {
|
||||
seedsList = prompts
|
||||
} else if seeds, ok := v["seeds"].([]any); ok {
|
||||
seedsList = seeds
|
||||
}
|
||||
}
|
||||
|
||||
for _, s := range seedsList {
|
||||
switch seed := s.(type) {
|
||||
case map[string]any:
|
||||
prompt := strOrEmpty(seed, "prompt")
|
||||
if prompt == "" {
|
||||
prompt = strOrEmpty(seed, "text")
|
||||
}
|
||||
if prompt == "" {
|
||||
prompt = strOrEmpty(seed, "question")
|
||||
}
|
||||
db.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
|
||||
rel, region,
|
||||
strOrEmpty(seed, "seed_id"),
|
||||
strOrEmpty(seed, "domain"),
|
||||
prompt,
|
||||
)
|
||||
count++
|
||||
case string:
|
||||
db.Exec(`INSERT INTO seeds VALUES (?, ?, ?, ?, ?)`,
|
||||
rel, region, "", "", seed)
|
||||
count++
|
||||
}
|
||||
}
|
||||
})
|
||||
return count
|
||||
}
|
||||
|
||||
// walkDir recursively visits all regular files under root, calling fn for each.
|
||||
func walkDir(root string, fn func(path string)) {
|
||||
r := localFs.List(root)
|
||||
if !r.OK {
|
||||
return
|
||||
}
|
||||
entries, ok := r.Value.([]fs.DirEntry)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, entry := range entries {
|
||||
full := core.JoinPath(root, entry.Name())
|
||||
if entry.IsDir() {
|
||||
walkDir(full, fn)
|
||||
} else {
|
||||
fn(full)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// strOrEmpty extracts a string value from a map, returning an empty string if
|
||||
// the key is absent.
|
||||
func strOrEmpty(m map[string]any, key string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
return core.Sprint(v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// floatOrZero extracts a float64 value from a map, returning zero if the key
|
||||
// is absent or not a number.
|
||||
func floatOrZero(m map[string]any, key string) float64 {
|
||||
if v, ok := m[key]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
return f
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// repeat returns a string consisting of count copies of s.
|
||||
func repeat(s string, count int) string {
|
||||
if count <= 0 {
|
||||
return ""
|
||||
}
|
||||
b := core.NewBuilder()
|
||||
for range count {
|
||||
b.WriteString(s)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// escapeSQLPath escapes single quotes in a file path for use in DuckDB SQL
|
||||
// string literals.
|
||||
func escapeSQLPath(p string) string {
|
||||
return core.Replace(p, "'", "''")
|
||||
}
|
||||
|
||||
// isFile returns true if the path exists and is a regular file.
|
||||
func isFile(path string) bool {
|
||||
return localFs.IsFile(path)
|
||||
}
|
||||
|
||||
// isDir returns true if the path exists and is a directory.
|
||||
func isDir(path string) bool {
|
||||
return localFs.IsDir(path)
|
||||
}
|
||||
166
inventory.go
Normal file
166
inventory.go
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
core "dappco.re/go/core"
|
||||
)
|
||||
|
||||
// TargetTotal is the golden set target size used for progress reporting.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// pct := float64(count) / float64(store.TargetTotal) * 100
|
||||
const TargetTotal = 15000
|
||||
|
||||
// duckDBTableOrder defines the canonical display order for DuckDB inventory
|
||||
// tables.
|
||||
var duckDBTableOrder = []string{
|
||||
"golden_set", "expansion_prompts", "seeds", "prompts",
|
||||
"training_examples", "gemini_responses", "benchmark_questions",
|
||||
"benchmark_results", "validations", TableCheckpointScores,
|
||||
TableProbeResults, "scoring_results",
|
||||
}
|
||||
|
||||
// duckDBTableDetail holds extra context for a single table beyond its row count.
|
||||
type duckDBTableDetail struct {
|
||||
notes []string
|
||||
}
|
||||
|
||||
// PrintDuckDBInventory queries all known DuckDB tables and prints a formatted
|
||||
// inventory with row counts, detail breakdowns, and a grand total.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// err := store.PrintDuckDBInventory(db, os.Stdout)
|
||||
func PrintDuckDBInventory(db *DuckDB, w io.Writer) error {
|
||||
counts, err := db.TableCounts()
|
||||
if err != nil {
|
||||
return core.E("store.PrintDuckDBInventory", "table counts", err)
|
||||
}
|
||||
|
||||
details := gatherDuckDBDetails(db, counts)
|
||||
|
||||
core.Print(w, "DuckDB Inventory")
|
||||
core.Print(w, "%s", repeat("-", 52))
|
||||
|
||||
grand := 0
|
||||
for _, table := range duckDBTableOrder {
|
||||
count, ok := counts[table]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
grand += count
|
||||
line := core.Sprintf(" %-24s %8d rows", table, count)
|
||||
|
||||
if d, has := details[table]; has && len(d.notes) > 0 {
|
||||
line += core.Sprintf(" (%s)", core.Join(", ", d.notes...))
|
||||
}
|
||||
core.Print(w, "%s", line)
|
||||
}
|
||||
|
||||
core.Print(w, "%s", repeat("-", 52))
|
||||
core.Print(w, " %-24s %8d rows", "TOTAL", grand)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// gatherDuckDBDetails runs per-table detail queries and returns annotations
|
||||
// keyed by table name. Errors on individual queries are silently ignored so
|
||||
// the inventory always prints.
|
||||
func gatherDuckDBDetails(db *DuckDB, counts map[string]int) map[string]*duckDBTableDetail {
|
||||
details := make(map[string]*duckDBTableDetail)
|
||||
|
||||
// golden_set: progress towards target
|
||||
if count, ok := counts["golden_set"]; ok {
|
||||
pct := float64(count) / float64(TargetTotal) * 100
|
||||
details["golden_set"] = &duckDBTableDetail{
|
||||
notes: []string{core.Sprintf("%.1f%% of %d target", pct, TargetTotal)},
|
||||
}
|
||||
}
|
||||
|
||||
// training_examples: distinct sources
|
||||
if _, ok := counts["training_examples"]; ok {
|
||||
rows, err := db.QueryRows("SELECT COUNT(DISTINCT source) AS n FROM training_examples")
|
||||
if err == nil && len(rows) > 0 {
|
||||
n := duckDBToInt(rows[0]["n"])
|
||||
details["training_examples"] = &duckDBTableDetail{
|
||||
notes: []string{core.Sprintf("%d sources", n)},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// prompts: distinct domains and voices
|
||||
if _, ok := counts["prompts"]; ok {
|
||||
d := &duckDBTableDetail{}
|
||||
rows, err := db.QueryRows("SELECT COUNT(DISTINCT domain) AS n FROM prompts")
|
||||
if err == nil && len(rows) > 0 {
|
||||
d.notes = append(d.notes, core.Sprintf("%d domains", duckDBToInt(rows[0]["n"])))
|
||||
}
|
||||
rows, err = db.QueryRows("SELECT COUNT(DISTINCT voice) AS n FROM prompts")
|
||||
if err == nil && len(rows) > 0 {
|
||||
d.notes = append(d.notes, core.Sprintf("%d voices", duckDBToInt(rows[0]["n"])))
|
||||
}
|
||||
if len(d.notes) > 0 {
|
||||
details["prompts"] = d
|
||||
}
|
||||
}
|
||||
|
||||
// gemini_responses: group by source_model
|
||||
if _, ok := counts["gemini_responses"]; ok {
|
||||
rows, err := db.QueryRows(
|
||||
"SELECT source_model, COUNT(*) AS n FROM gemini_responses GROUP BY source_model ORDER BY n DESC",
|
||||
)
|
||||
if err == nil && len(rows) > 0 {
|
||||
var parts []string
|
||||
for _, row := range rows {
|
||||
model := duckDBStrVal(row, "source_model")
|
||||
n := duckDBToInt(row["n"])
|
||||
if model != "" {
|
||||
parts = append(parts, core.Sprintf("%s:%d", model, n))
|
||||
}
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
details["gemini_responses"] = &duckDBTableDetail{notes: parts}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// benchmark_results: distinct source categories
|
||||
if _, ok := counts["benchmark_results"]; ok {
|
||||
rows, err := db.QueryRows("SELECT COUNT(DISTINCT source) AS n FROM benchmark_results")
|
||||
if err == nil && len(rows) > 0 {
|
||||
n := duckDBToInt(rows[0]["n"])
|
||||
details["benchmark_results"] = &duckDBTableDetail{
|
||||
notes: []string{core.Sprintf("%d categories", n)},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return details
|
||||
}
|
||||
|
||||
// duckDBToInt converts a DuckDB value to int. DuckDB returns integers as int64
|
||||
// (not float64 like InfluxDB), so we handle both types.
|
||||
func duckDBToInt(v any) int {
|
||||
switch n := v.(type) {
|
||||
case int64:
|
||||
return int(n)
|
||||
case int32:
|
||||
return int(n)
|
||||
case float64:
|
||||
return int(n)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// duckDBStrVal extracts a string value from a row map.
|
||||
func duckDBStrVal(row map[string]any, key string) string {
|
||||
if v, ok := row[key]; ok {
|
||||
return core.Sprint(v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
195
parquet.go
Normal file
195
parquet.go
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
|
||||
core "dappco.re/go/core"
|
||||
"github.com/parquet-go/parquet-go"
|
||||
)
|
||||
|
||||
// ChatMessage represents a single message in a chat conversation, used for
|
||||
// reading JSONL training data during Parquet export and data import.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// msg := store.ChatMessage{Role: "user", Content: "What is sovereignty?"}
|
||||
type ChatMessage struct {
|
||||
// Role is the message author role (e.g. "user", "assistant", "system").
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// msg.Role // "user"
|
||||
Role string `json:"role"`
|
||||
|
||||
// Content is the message text.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// msg.Content // "What is sovereignty?"
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ParquetRow is the schema for exported Parquet files.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// row := store.ParquetRow{Prompt: "What is sovereignty?", Response: "...", System: "You are LEM."}
|
||||
type ParquetRow struct {
|
||||
// Prompt is the user prompt text.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// row.Prompt // "What is sovereignty?"
|
||||
Prompt string `parquet:"prompt"`
|
||||
|
||||
// Response is the assistant response text.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// row.Response // "Sovereignty is..."
|
||||
Response string `parquet:"response"`
|
||||
|
||||
// System is the system prompt text.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// row.System // "You are LEM."
|
||||
System string `parquet:"system"`
|
||||
|
||||
// Messages is the JSON-encoded full conversation messages.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// row.Messages // `[{"role":"user","content":"..."}]`
|
||||
Messages string `parquet:"messages"`
|
||||
}
|
||||
|
||||
// ExportParquet reads JSONL training splits (train.jsonl, valid.jsonl, test.jsonl)
|
||||
// from trainingDir and writes Parquet files with snappy compression to outputDir.
|
||||
// Returns total rows exported.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// total, err := store.ExportParquet("/Volumes/Data/lem/training", "/Volumes/Data/lem/parquet")
|
||||
func ExportParquet(trainingDir, outputDir string) (int, error) {
|
||||
if outputDir == "" {
|
||||
outputDir = core.JoinPath(trainingDir, "parquet")
|
||||
}
|
||||
if r := localFs.EnsureDir(outputDir); !r.OK {
|
||||
return 0, core.E("store.ExportParquet", "create output directory", r.Value.(error))
|
||||
}
|
||||
|
||||
total := 0
|
||||
for _, split := range []string{"train", "valid", "test"} {
|
||||
jsonlPath := core.JoinPath(trainingDir, split+".jsonl")
|
||||
if !localFs.IsFile(jsonlPath) {
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := ExportSplitParquet(jsonlPath, outputDir, split)
|
||||
if err != nil {
|
||||
return total, core.E("store.ExportParquet", core.Sprintf("export %s", split), err)
|
||||
}
|
||||
total += n
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// ExportSplitParquet reads a chat JSONL file and writes a Parquet file for the
|
||||
// given split name. Returns the number of rows written.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// n, err := store.ExportSplitParquet("/data/train.jsonl", "/data/parquet", "train")
|
||||
func ExportSplitParquet(jsonlPath, outputDir, split string) (int, error) {
|
||||
openResult := localFs.Open(jsonlPath)
|
||||
if !openResult.OK {
|
||||
return 0, core.E("store.ExportSplitParquet", core.Sprintf("open %s", jsonlPath), openResult.Value.(error))
|
||||
}
|
||||
f := openResult.Value.(io.ReadCloser)
|
||||
defer f.Close()
|
||||
|
||||
var rows []ParquetRow
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
text := core.Trim(scanner.Text())
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
}
|
||||
if r := core.JSONUnmarshal([]byte(text), &data); !r.OK {
|
||||
continue
|
||||
}
|
||||
|
||||
var prompt, response, system string
|
||||
for _, m := range data.Messages {
|
||||
switch m.Role {
|
||||
case "user":
|
||||
if prompt == "" {
|
||||
prompt = m.Content
|
||||
}
|
||||
case "assistant":
|
||||
if response == "" {
|
||||
response = m.Content
|
||||
}
|
||||
case "system":
|
||||
if system == "" {
|
||||
system = m.Content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msgsJSON := core.JSONMarshalString(data.Messages)
|
||||
rows = append(rows, ParquetRow{
|
||||
Prompt: prompt,
|
||||
Response: response,
|
||||
System: system,
|
||||
Messages: msgsJSON,
|
||||
})
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return 0, core.E("store.ExportSplitParquet", core.Sprintf("scan %s", jsonlPath), err)
|
||||
}
|
||||
|
||||
if len(rows) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
outPath := core.JoinPath(outputDir, split+".parquet")
|
||||
|
||||
createResult := localFs.Create(outPath)
|
||||
if !createResult.OK {
|
||||
return 0, core.E("store.ExportSplitParquet", core.Sprintf("create %s", outPath), createResult.Value.(error))
|
||||
}
|
||||
out := createResult.Value.(io.WriteCloser)
|
||||
|
||||
writer := parquet.NewGenericWriter[ParquetRow](out,
|
||||
parquet.Compression(&parquet.Snappy),
|
||||
)
|
||||
|
||||
if _, err := writer.Write(rows); err != nil {
|
||||
out.Close()
|
||||
return 0, core.E("store.ExportSplitParquet", "write parquet rows", err)
|
||||
}
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
out.Close()
|
||||
return 0, core.E("store.ExportSplitParquet", "close parquet writer", err)
|
||||
}
|
||||
|
||||
if err := out.Close(); err != nil {
|
||||
return 0, core.E("store.ExportSplitParquet", "close file", err)
|
||||
}
|
||||
|
||||
return len(rows), nil
|
||||
}
|
||||
196
publish.go
Normal file
196
publish.go
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
core "dappco.re/go/core"
|
||||
)
|
||||
|
||||
// PublishConfig holds options for the publish operation.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// cfg := store.PublishConfig{InputDir: "/data/parquet", Repo: "snider/lem-training", Public: true}
|
||||
type PublishConfig struct {
|
||||
// InputDir is the directory containing Parquet files to upload.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// cfg.InputDir // "/data/parquet"
|
||||
InputDir string
|
||||
|
||||
// Repo is the HuggingFace dataset repository (e.g. "user/dataset").
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// cfg.Repo // "snider/lem-training"
|
||||
Repo string
|
||||
|
||||
// Public sets the dataset visibility to public when true.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// cfg.Public // true
|
||||
Public bool
|
||||
|
||||
// Token is the HuggingFace API token. Falls back to HF_TOKEN env or ~/.huggingface/token.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// cfg.Token // "hf_..."
|
||||
Token string
|
||||
|
||||
// DryRun lists files that would be uploaded without actually uploading.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// cfg.DryRun // true
|
||||
DryRun bool
|
||||
}
|
||||
|
||||
// uploadEntry pairs a local file path with its remote destination.
|
||||
type uploadEntry struct {
|
||||
local string
|
||||
remote string
|
||||
}
|
||||
|
||||
// Publish uploads Parquet files to HuggingFace Hub.
|
||||
//
|
||||
// It looks for train.parquet, valid.parquet, and test.parquet in InputDir,
|
||||
// plus an optional dataset_card.md in the parent directory (uploaded as README.md).
|
||||
// The token is resolved from PublishConfig.Token, the HF_TOKEN environment variable,
|
||||
// or ~/.huggingface/token, in that order.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// err := store.Publish(store.PublishConfig{InputDir: "/data/parquet", Repo: "snider/lem-training"}, os.Stdout)
|
||||
func Publish(cfg PublishConfig, w io.Writer) error {
|
||||
if cfg.InputDir == "" {
|
||||
return core.E("store.Publish", "input directory is required", nil)
|
||||
}
|
||||
|
||||
token := resolveHFToken(cfg.Token)
|
||||
if token == "" && !cfg.DryRun {
|
||||
return core.E("store.Publish", "HuggingFace token required (--token, HF_TOKEN env, or ~/.huggingface/token)", nil)
|
||||
}
|
||||
|
||||
files, err := collectUploadFiles(cfg.InputDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(files) == 0 {
|
||||
return core.E("store.Publish", core.Sprintf("no Parquet files found in %s", cfg.InputDir), nil)
|
||||
}
|
||||
|
||||
if cfg.DryRun {
|
||||
core.Print(w, "Dry run: would publish to %s", cfg.Repo)
|
||||
if cfg.Public {
|
||||
core.Print(w, " Visibility: public")
|
||||
} else {
|
||||
core.Print(w, " Visibility: private")
|
||||
}
|
||||
for _, f := range files {
|
||||
statResult := localFs.Stat(f.local)
|
||||
if !statResult.OK {
|
||||
return core.E("store.Publish", core.Sprintf("stat %s", f.local), statResult.Value.(error))
|
||||
}
|
||||
info := statResult.Value.(fs.FileInfo)
|
||||
sizeMB := float64(info.Size()) / 1024 / 1024
|
||||
core.Print(w, " %s -> %s (%.1f MB)", core.PathBase(f.local), f.remote, sizeMB)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
core.Print(w, "Publishing to https://huggingface.co/datasets/%s", cfg.Repo)
|
||||
|
||||
for _, f := range files {
|
||||
if err := uploadFileToHF(token, cfg.Repo, f.local, f.remote); err != nil {
|
||||
return core.E("store.Publish", core.Sprintf("upload %s", core.PathBase(f.local)), err)
|
||||
}
|
||||
core.Print(w, " Uploaded %s -> %s", core.PathBase(f.local), f.remote)
|
||||
}
|
||||
|
||||
core.Print(w, "\nPublished to https://huggingface.co/datasets/%s", cfg.Repo)
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveHFToken returns a HuggingFace API token from the given value,
|
||||
// HF_TOKEN env var, or ~/.huggingface/token file.
|
||||
func resolveHFToken(explicit string) string {
|
||||
if explicit != "" {
|
||||
return explicit
|
||||
}
|
||||
if env := core.Env("HF_TOKEN"); env != "" {
|
||||
return env
|
||||
}
|
||||
home := core.Env("DIR_HOME")
|
||||
if home == "" {
|
||||
return ""
|
||||
}
|
||||
r := localFs.Read(core.JoinPath(home, ".huggingface", "token"))
|
||||
if !r.OK {
|
||||
return ""
|
||||
}
|
||||
return core.Trim(r.Value.(string))
|
||||
}
|
||||
|
||||
// collectUploadFiles finds Parquet split files and an optional dataset card.
|
||||
func collectUploadFiles(inputDir string) ([]uploadEntry, error) {
|
||||
splits := []string{"train", "valid", "test"}
|
||||
var files []uploadEntry
|
||||
|
||||
for _, split := range splits {
|
||||
path := core.JoinPath(inputDir, split+".parquet")
|
||||
if !isFile(path) {
|
||||
continue
|
||||
}
|
||||
files = append(files, uploadEntry{path, core.Sprintf("data/%s.parquet", split)})
|
||||
}
|
||||
|
||||
// Check for dataset card in parent directory.
|
||||
cardPath := core.JoinPath(inputDir, "..", "dataset_card.md")
|
||||
if isFile(cardPath) {
|
||||
files = append(files, uploadEntry{cardPath, "README.md"})
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// uploadFileToHF uploads a single file to a HuggingFace dataset repo via the
|
||||
// Hub API.
|
||||
func uploadFileToHF(token, repoID, localPath, remotePath string) error {
|
||||
readResult := localFs.Read(localPath)
|
||||
if !readResult.OK {
|
||||
return core.E("store.uploadFileToHF", core.Sprintf("read %s", localPath), readResult.Value.(error))
|
||||
}
|
||||
raw := []byte(readResult.Value.(string))
|
||||
|
||||
url := core.Sprintf("https://huggingface.co/api/datasets/%s/upload/main/%s", repoID, remotePath)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPut, url, bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return core.E("store.uploadFileToHF", "create request", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return core.E("store.uploadFileToHF", "upload request", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return core.E("store.uploadFileToHF", core.Sprintf("upload failed: HTTP %d: %s", resp.StatusCode, string(body)), nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue