Replace passthrough() + stdlib flag.FlagSet anti-pattern with proper cobra integration. Every Run* function now takes a typed *Opts struct and returns error. Flags registered via cli.StringFlag/IntFlag/etc. Commands participate in Core lifecycle with full cobra flag parsing. - 6 command groups: gen, score, data, export, infra, mon - 25 commands converted, 0 passthrough() calls remain - Delete passthrough() helper from lem.go - Update export_test.go to use ExportOpts struct Co-Authored-By: Virgil <virgil@lethean.io>
148 lines
4 KiB
Go
148 lines
4 KiB
Go
package lem
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
)
|
|
|
|
// NormalizeOpts holds configuration for the normalize command.
|
|
type NormalizeOpts struct {
|
|
DB string // DuckDB database path (defaults to LEM_DB env)
|
|
MinLen int // Minimum prompt length in characters
|
|
}
|
|
|
|
// RunNormalize is the CLI entry point for the normalize command.
|
|
// Normalizes seeds into the expansion_prompts table, deduplicating against
|
|
// the golden set and existing prompts. Assigns priority based on domain
|
|
// coverage (underrepresented domains first).
|
|
func RunNormalize(cfg NormalizeOpts) error {
|
|
dbPath := cfg.DB
|
|
if dbPath == "" {
|
|
dbPath = os.Getenv("LEM_DB")
|
|
}
|
|
if dbPath == "" {
|
|
return fmt.Errorf("--db or LEM_DB required")
|
|
}
|
|
|
|
minLen := cfg.MinLen
|
|
|
|
db, err := OpenDBReadWrite(dbPath)
|
|
if err != nil {
|
|
return fmt.Errorf("open db: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// Check source tables.
|
|
var seedCount int
|
|
if err := db.conn.QueryRow("SELECT count(*) FROM seeds").Scan(&seedCount); err != nil {
|
|
return fmt.Errorf("no seeds table: run lem import-all first")
|
|
}
|
|
fmt.Printf("Seeds table: %d rows\n", seedCount)
|
|
|
|
// Drop and recreate expansion_prompts.
|
|
_, err = db.conn.Exec("DROP TABLE IF EXISTS expansion_prompts")
|
|
if err != nil {
|
|
return fmt.Errorf("drop expansion_prompts: %v", err)
|
|
}
|
|
|
|
// Deduplicate: remove seeds whose prompt already appears in prompts or golden_set.
|
|
_, err = db.conn.Exec(fmt.Sprintf(`
|
|
CREATE TABLE expansion_prompts AS
|
|
WITH unique_seeds AS (
|
|
SELECT
|
|
ROW_NUMBER() OVER (ORDER BY region, domain, seed_id) AS idx,
|
|
seed_id,
|
|
region,
|
|
domain,
|
|
prompt
|
|
FROM (
|
|
SELECT DISTINCT ON (prompt)
|
|
seed_id, region, domain, prompt
|
|
FROM seeds
|
|
WHERE length(prompt) >= %d
|
|
ORDER BY prompt, seed_id
|
|
)
|
|
),
|
|
existing_prompts AS (
|
|
SELECT prompt FROM prompts
|
|
UNION ALL
|
|
SELECT prompt FROM golden_set
|
|
)
|
|
SELECT
|
|
us.idx,
|
|
us.seed_id,
|
|
us.region,
|
|
us.domain,
|
|
'en' AS language,
|
|
us.prompt,
|
|
'' AS prompt_en,
|
|
0 AS priority,
|
|
'pending' AS status
|
|
FROM unique_seeds us
|
|
WHERE NOT EXISTS (
|
|
SELECT 1 FROM existing_prompts ep
|
|
WHERE ep.prompt = us.prompt
|
|
)
|
|
`, minLen))
|
|
if err != nil {
|
|
return fmt.Errorf("create expansion_prompts: %v", err)
|
|
}
|
|
|
|
var total, domains, regions int
|
|
db.conn.QueryRow("SELECT count(*) FROM expansion_prompts").Scan(&total)
|
|
db.conn.QueryRow("SELECT count(DISTINCT domain) FROM expansion_prompts").Scan(&domains)
|
|
db.conn.QueryRow("SELECT count(DISTINCT region) FROM expansion_prompts").Scan(®ions)
|
|
|
|
// Assign priority based on domain coverage.
|
|
_, err = db.conn.Exec(`
|
|
UPDATE expansion_prompts SET priority = (
|
|
SELECT RANK() OVER (ORDER BY cnt ASC)
|
|
FROM (
|
|
SELECT domain, count(*) AS cnt
|
|
FROM expansion_prompts GROUP BY domain
|
|
) domain_counts
|
|
WHERE domain_counts.domain = expansion_prompts.domain
|
|
)
|
|
`)
|
|
if err != nil {
|
|
log.Printf("warning: priority assignment failed: %v", err)
|
|
}
|
|
|
|
fmt.Printf("\nExpansion Prompts: %d\n", total)
|
|
fmt.Printf(" Domains: %d\n", domains)
|
|
fmt.Printf(" Regions: %d\n", regions)
|
|
|
|
// Show region distribution.
|
|
fmt.Println("\n By region group:")
|
|
rows, err := db.conn.Query(`
|
|
SELECT
|
|
CASE
|
|
WHEN region LIKE '%cn%' THEN 'cn'
|
|
WHEN region LIKE '%en-%' OR region LIKE '%en_para%' OR region LIKE '%para%' THEN 'en'
|
|
WHEN region LIKE '%ru%' THEN 'ru'
|
|
WHEN region LIKE '%de%' AND region NOT LIKE '%deten%' THEN 'de'
|
|
WHEN region LIKE '%es%' THEN 'es'
|
|
WHEN region LIKE '%fr%' THEN 'fr'
|
|
WHEN region LIKE '%latam%' THEN 'latam'
|
|
WHEN region LIKE '%africa%' THEN 'africa'
|
|
WHEN region LIKE '%eu%' THEN 'eu'
|
|
WHEN region LIKE '%me%' AND region NOT LIKE '%premium%' THEN 'me'
|
|
ELSE 'other'
|
|
END AS lang_group,
|
|
count(*) AS n
|
|
FROM expansion_prompts GROUP BY lang_group ORDER BY n DESC
|
|
`)
|
|
if err == nil {
|
|
for rows.Next() {
|
|
var group string
|
|
var n int
|
|
rows.Scan(&group, &n)
|
|
fmt.Printf(" %-15s %6d\n", group, n)
|
|
}
|
|
rows.Close()
|
|
}
|
|
|
|
fmt.Printf("\nNormalization complete: %d expansion prompts from %d seeds\n", total, seedCount)
|
|
return nil
|
|
}
|