cli/pkg/ml/db.go

242 lines
6.9 KiB
Go
Raw Normal View History

package ml
import (
"database/sql"
"fmt"
_ "github.com/marcboeker/go-duckdb"
)
// DB wraps a DuckDB connection.
type DB struct {
conn *sql.DB
path string
}
// OpenDB opens a DuckDB database file in read-only mode to avoid locking
// issues with the Python pipeline.
func OpenDB(path string) (*DB, error) {
conn, err := sql.Open("duckdb", path+"?access_mode=READ_ONLY")
if err != nil {
return nil, fmt.Errorf("open duckdb %s: %w", path, err)
}
if err := conn.Ping(); err != nil {
conn.Close()
return nil, fmt.Errorf("ping duckdb %s: %w", path, err)
}
return &DB{conn: conn, path: path}, nil
}
// OpenDBReadWrite opens a DuckDB database in read-write mode.
func OpenDBReadWrite(path string) (*DB, error) {
conn, err := sql.Open("duckdb", path)
if err != nil {
return nil, fmt.Errorf("open duckdb %s: %w", path, err)
}
if err := conn.Ping(); err != nil {
conn.Close()
return nil, fmt.Errorf("ping duckdb %s: %w", path, err)
}
return &DB{conn: conn, path: path}, nil
}
// Close closes the database connection.
func (db *DB) Close() error {
return db.conn.Close()
}
// GoldenSetRow represents one row from the golden_set table.
type GoldenSetRow struct {
Idx int
SeedID string
Domain string
Voice string
Prompt string
Response string
GenTime float64
CharCount int
}
// ExpansionPromptRow represents one row from the expansion_prompts table.
type ExpansionPromptRow struct {
Idx int64
SeedID string
Region string
Domain string
Language string
Prompt string
PromptEn string
Priority int
Status string
}
// QueryGoldenSet returns all golden set rows with responses >= minChars.
func (db *DB) QueryGoldenSet(minChars int) ([]GoldenSetRow, error) {
rows, err := db.conn.Query(
"SELECT idx, seed_id, domain, voice, prompt, response, gen_time, char_count "+
"FROM golden_set WHERE char_count >= ? ORDER BY idx",
minChars,
)
if err != nil {
return nil, fmt.Errorf("query golden_set: %w", err)
}
defer rows.Close()
var result []GoldenSetRow
for rows.Next() {
var r GoldenSetRow
if err := rows.Scan(&r.Idx, &r.SeedID, &r.Domain, &r.Voice,
&r.Prompt, &r.Response, &r.GenTime, &r.CharCount); err != nil {
return nil, fmt.Errorf("scan golden_set row: %w", err)
}
result = append(result, r)
}
return result, rows.Err()
}
// CountGoldenSet returns the total count of golden set rows.
func (db *DB) CountGoldenSet() (int, error) {
var count int
err := db.conn.QueryRow("SELECT COUNT(*) FROM golden_set").Scan(&count)
if err != nil {
return 0, fmt.Errorf("count golden_set: %w", err)
}
return count, nil
}
// QueryExpansionPrompts returns expansion prompts filtered by status.
func (db *DB) QueryExpansionPrompts(status string, limit int) ([]ExpansionPromptRow, error) {
query := "SELECT idx, seed_id, region, domain, language, prompt, prompt_en, priority, status " +
"FROM expansion_prompts"
var args []interface{}
if status != "" {
query += " WHERE status = ?"
args = append(args, status)
}
query += " ORDER BY priority, idx"
if limit > 0 {
query += fmt.Sprintf(" LIMIT %d", limit)
}
rows, err := db.conn.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("query expansion_prompts: %w", err)
}
defer rows.Close()
var result []ExpansionPromptRow
for rows.Next() {
var r ExpansionPromptRow
if err := rows.Scan(&r.Idx, &r.SeedID, &r.Region, &r.Domain,
&r.Language, &r.Prompt, &r.PromptEn, &r.Priority, &r.Status); err != nil {
return nil, fmt.Errorf("scan expansion_prompt row: %w", err)
}
result = append(result, r)
}
return result, rows.Err()
}
// CountExpansionPrompts returns counts by status.
func (db *DB) CountExpansionPrompts() (total int, pending int, err error) {
err = db.conn.QueryRow("SELECT COUNT(*) FROM expansion_prompts").Scan(&total)
if err != nil {
return 0, 0, fmt.Errorf("count expansion_prompts: %w", err)
}
err = db.conn.QueryRow("SELECT COUNT(*) FROM expansion_prompts WHERE status = 'pending'").Scan(&pending)
if err != nil {
return total, 0, fmt.Errorf("count pending expansion_prompts: %w", err)
}
return total, pending, nil
}
// UpdateExpansionStatus updates the status of an expansion prompt by idx.
func (db *DB) UpdateExpansionStatus(idx int64, status string) error {
_, err := db.conn.Exec("UPDATE expansion_prompts SET status = ? WHERE idx = ?", status, idx)
if err != nil {
return fmt.Errorf("update expansion_prompt %d: %w", idx, err)
}
return nil
}
// QueryRows executes an arbitrary SQL query and returns results as maps.
func (db *DB) QueryRows(query string, args ...interface{}) ([]map[string]interface{}, error) {
rows, err := db.conn.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("query: %w", err)
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("columns: %w", err)
}
var result []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(cols))
ptrs := make([]interface{}, len(cols))
for i := range values {
ptrs[i] = &values[i]
}
if err := rows.Scan(ptrs...); err != nil {
return nil, fmt.Errorf("scan: %w", err)
}
row := make(map[string]interface{}, len(cols))
for i, col := range cols {
row[col] = values[i]
}
result = append(result, row)
}
return result, rows.Err()
}
// EnsureScoringTables creates the scoring tables if they don't exist.
func (db *DB) EnsureScoringTables() {
db.conn.Exec(`CREATE TABLE IF NOT EXISTS checkpoint_scores (
model TEXT, run_id TEXT, label TEXT, iteration INTEGER,
correct INTEGER, total INTEGER, accuracy DOUBLE,
scored_at TIMESTAMP DEFAULT current_timestamp,
PRIMARY KEY (run_id, label)
)`)
db.conn.Exec(`CREATE TABLE IF NOT EXISTS probe_results (
model TEXT, run_id TEXT, label TEXT, probe_id TEXT,
passed BOOLEAN, response TEXT, iteration INTEGER,
scored_at TIMESTAMP DEFAULT current_timestamp,
PRIMARY KEY (run_id, label, probe_id)
)`)
db.conn.Exec(`CREATE TABLE IF NOT EXISTS scoring_results (
model TEXT, prompt_id TEXT, suite TEXT,
dimension TEXT, score DOUBLE,
scored_at TIMESTAMP DEFAULT current_timestamp
)`)
}
// WriteScoringResult writes a single scoring dimension result to DuckDB.
func (db *DB) WriteScoringResult(model, promptID, suite, dimension string, score float64) error {
_, err := db.conn.Exec(
`INSERT INTO scoring_results (model, prompt_id, suite, dimension, score) VALUES (?, ?, ?, ?, ?)`,
model, promptID, suite, dimension, score,
)
return err
}
// TableCounts returns row counts for all known tables.
func (db *DB) TableCounts() (map[string]int, error) {
tables := []string{"golden_set", "expansion_prompts", "seeds", "prompts",
"training_examples", "gemini_responses", "benchmark_questions", "benchmark_results", "validations",
"checkpoint_scores", "probe_results", "scoring_results"}
counts := make(map[string]int)
for _, t := range tables {
var count int
err := db.conn.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", t)).Scan(&count)
if err != nil {
continue
}
counts[t] = count
}
return counts, nil
}