1
0
Fork 0
forked from lthn/LEM

refactor: move Go library to pkg/lem, thin main.go

All scoring/influx/export/expand logic moves to pkg/lem as an
importable package. main.go is now a thin CLI dispatcher.

This lets new commands import the shared library directly —
ready for converting Python scripts to Go subcommands.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-15 16:30:09 +00:00
parent e0d352c803
commit 70dd18c065
No known key found for this signature in database
GPG key ID: AF404715446AEB41
30 changed files with 105 additions and 116 deletions

2
.gitignore vendored
View file

@ -11,4 +11,4 @@ worker/output/
training/parquet/
# Go binary
lem
/lem

69
main.go
View file

@ -6,6 +6,8 @@ import (
"log"
"os"
"time"
"forge.lthn.ai/lthn/lem/pkg/lem"
)
const usage = `Usage: lem <command> [flags]
@ -17,8 +19,6 @@ Commands:
status Show training and generation progress (InfluxDB + DuckDB)
export Export golden set to training-format JSONL splits
expand Generate expansion responses via trained LEM model
Set LEM_DB env to default DuckDB path for all commands.
`
func main() {
@ -35,11 +35,11 @@ func main() {
case "compare":
runCompare(os.Args[2:])
case "status":
runStatus(os.Args[2:])
lem.RunStatus(os.Args[2:])
case "expand":
runExpand(os.Args[2:])
lem.RunExpand(os.Args[2:])
case "export":
runExport(os.Args[2:])
lem.RunExport(os.Args[2:])
default:
fmt.Fprintf(os.Stderr, "unknown command: %s\n\n%s", os.Args[1], usage)
os.Exit(1)
@ -67,22 +67,19 @@ func runScore(args []string) {
os.Exit(1)
}
// Read responses.
responses, err := readResponses(*input)
responses, err := lem.ReadResponses(*input)
if err != nil {
log.Fatalf("read responses: %v", err)
}
log.Printf("loaded %d responses from %s", len(responses), *input)
// If resume, load existing scores and filter out already-scored IDs.
if *resume {
if _, statErr := os.Stat(*output); statErr == nil {
existing, readErr := readScorerOutput(*output)
existing, readErr := lem.ReadScorerOutput(*output)
if readErr != nil {
log.Fatalf("read existing scores for resume: %v", readErr)
}
// Build set of already-scored IDs.
scored := make(map[string]bool)
for _, scores := range existing.PerPrompt {
for _, ps := range scores {
@ -90,8 +87,7 @@ func runScore(args []string) {
}
}
// Filter out already-scored responses.
var filtered []Response
var filtered []lem.Response
for _, r := range responses {
if !scored[r.ID] {
filtered = append(filtered, r)
@ -108,32 +104,28 @@ func runScore(args []string) {
}
}
// Create client, judge, engine.
client := NewClient(*judgeURL, *judgeModel)
client.maxTokens = 512
judge := NewJudge(client)
engine := NewEngine(judge, *concurrency, *suites)
client := lem.NewClient(*judgeURL, *judgeModel)
client.MaxTokens = 512
judge := lem.NewJudge(client)
engine := lem.NewEngine(judge, *concurrency, *suites)
log.Printf("scoring with %s", engine)
// Score all responses.
perPrompt := engine.ScoreAll(responses)
// If resuming, merge with existing scores.
if *resume {
if _, statErr := os.Stat(*output); statErr == nil {
existing, _ := readScorerOutput(*output)
existing, _ := lem.ReadScorerOutput(*output)
for model, scores := range existing.PerPrompt {
perPrompt[model] = append(scores, perPrompt[model]...)
}
}
}
// Compute averages and write output.
averages := computeAverages(perPrompt)
averages := lem.ComputeAverages(perPrompt)
scorerOutput := &ScorerOutput{
Metadata: Metadata{
scorerOutput := &lem.ScorerOutput{
Metadata: lem.Metadata{
JudgeModel: *judgeModel,
JudgeURL: *judgeURL,
ScoredAt: time.Now().UTC(),
@ -144,7 +136,7 @@ func runScore(args []string) {
PerPrompt: perPrompt,
}
if err := writeScores(*output, scorerOutput); err != nil {
if err := lem.WriteScores(*output, scorerOutput); err != nil {
log.Fatalf("write scores: %v", err)
}
@ -173,26 +165,23 @@ func runProbe(args []string) {
os.Exit(1)
}
// Default target URL to judge URL.
if *targetURL == "" {
*targetURL = *judgeURL
}
// Create clients.
targetClient := NewClient(*targetURL, *model)
targetClient.maxTokens = 1024 // Limit probe response length.
judgeClient := NewClient(*judgeURL, *judgeModel)
judgeClient.maxTokens = 512 // Judge responses are structured JSON.
judge := NewJudge(judgeClient)
engine := NewEngine(judge, *concurrency, *suites)
prober := NewProber(targetClient, engine)
targetClient := lem.NewClient(*targetURL, *model)
targetClient.MaxTokens = 1024
judgeClient := lem.NewClient(*judgeURL, *judgeModel)
judgeClient.MaxTokens = 512
judge := lem.NewJudge(judgeClient)
engine := lem.NewEngine(judge, *concurrency, *suites)
prober := lem.NewProber(targetClient, engine)
var scorerOutput *ScorerOutput
var scorerOutput *lem.ScorerOutput
var err error
if *probesFile != "" {
// Read custom probes.
probes, readErr := readResponses(*probesFile)
probes, readErr := lem.ReadResponses(*probesFile)
if readErr != nil {
log.Fatalf("read probes: %v", readErr)
}
@ -200,7 +189,7 @@ func runProbe(args []string) {
scorerOutput, err = prober.ProbeModel(probes, *model)
} else {
log.Printf("using %d built-in content probes", len(contentProbes))
log.Printf("using %d built-in content probes", len(lem.ContentProbes))
scorerOutput, err = prober.ProbeContent(*model)
}
@ -208,7 +197,7 @@ func runProbe(args []string) {
log.Fatalf("probe: %v", err)
}
if writeErr := writeScores(*output, scorerOutput); writeErr != nil {
if writeErr := lem.WriteScores(*output, scorerOutput); writeErr != nil {
log.Fatalf("write scores: %v", writeErr)
}
@ -231,7 +220,7 @@ func runCompare(args []string) {
os.Exit(1)
}
if err := RunCompare(*oldFile, *newFile); err != nil {
if err := lem.RunCompare(*oldFile, *newFile); err != nil {
log.Fatalf("compare: %v", err)
}
}

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"bytes"
@ -46,7 +46,7 @@ func (e *retryableError) Unwrap() error { return e.err }
type Client struct {
baseURL string
model string
maxTokens int
MaxTokens int
httpClient *http.Client
}
@ -77,7 +77,7 @@ func (c *Client) ChatWithTemp(prompt string, temp float64) (string, error) {
{Role: "user", Content: prompt},
},
Temperature: temp,
MaxTokens: c.maxTokens,
MaxTokens: c.MaxTokens,
}
body, err := json.Marshal(req)

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"encoding/json"

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"fmt"
@ -8,12 +8,12 @@ import (
// RunCompare reads two score files and prints a comparison table for each
// model showing Old, New, and Delta values for every metric.
func RunCompare(oldPath, newPath string) error {
oldOutput, err := readScorerOutput(oldPath)
oldOutput, err := ReadScorerOutput(oldPath)
if err != nil {
return fmt.Errorf("read old file: %w", err)
}
newOutput, err := readScorerOutput(newPath)
newOutput, err := ReadScorerOutput(newPath)
if err != nil {
return fmt.Errorf("read new file: %w", err)
}

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"encoding/json"
@ -208,7 +208,7 @@ func TestReadScorerOutput(t *testing.T) {
path := writeTestScoreFile(t, dir, "test.json", output)
read, err := readScorerOutput(path)
read, err := ReadScorerOutput(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"database/sql"

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"os"

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"fmt"
@ -102,9 +102,9 @@ func (e *Engine) ScoreAll(responses []Response) map[string][]PromptScore {
// Find the matching content probe.
var probe *ContentProbe
for idx := range contentProbes {
if contentProbes[idx].ID == r.ID {
probe = &contentProbes[idx]
for idx := range ContentProbes {
if ContentProbes[idx].ID == r.ID {
probe = &ContentProbes[idx]
break
}
}

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"encoding/json"

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"math"

View file

@ -1,4 +1,4 @@
package main
package lem
import "testing"

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"encoding/json"
@ -23,7 +23,7 @@ type expandOutput struct {
}
// runExpand parses CLI flags and runs the expand command.
func runExpand(args []string) {
func RunExpand(args []string) {
fs := flag.NewFlagSet("expand", flag.ExitOnError)
model := fs.String("model", "", "Model name for generation (required)")
@ -98,7 +98,7 @@ func runExpand(args []string) {
}
} else {
var err error
promptList, err = readResponses(*prompts)
promptList, err = ReadResponses(*prompts)
if err != nil {
log.Fatalf("read prompts: %v", err)
}
@ -107,7 +107,7 @@ func runExpand(args []string) {
// Create clients.
client := NewClient(*apiURL, *model)
client.maxTokens = 2048
client.MaxTokens = 2048
influx := NewInfluxClient(*influxURL, *influxDB)
if err := expandPrompts(client, influx, duckDB, promptList, *model, *worker, *output, *dryRun, *limit); err != nil {

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"bufio"
@ -161,7 +161,7 @@ func TestExpandPromptsBasic(t *testing.T) {
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(server.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.maxTokens = 2048
client.MaxTokens = 2048
outputDir := t.TempDir()
@ -240,7 +240,7 @@ func TestExpandPromptsSkipsCompleted(t *testing.T) {
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.maxTokens = 2048
client.MaxTokens = 2048
outputDir := t.TempDir()
@ -406,7 +406,7 @@ func TestExpandPromptsAPIErrorSkipsPrompt(t *testing.T) {
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.maxTokens = 2048
client.MaxTokens = 2048
outputDir := t.TempDir()
@ -475,7 +475,7 @@ func TestExpandPromptsInfluxWriteErrorNonFatal(t *testing.T) {
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.maxTokens = 2048
client.MaxTokens = 2048
outputDir := t.TempDir()
@ -524,7 +524,7 @@ func TestExpandPromptsOutputJSONLStructure(t *testing.T) {
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.maxTokens = 2048
client.MaxTokens = 2048
outputDir := t.TempDir()
@ -603,7 +603,7 @@ func TestExpandPromptsInfluxLineProtocol(t *testing.T) {
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.maxTokens = 2048
client.MaxTokens = 2048
outputDir := t.TempDir()
@ -666,7 +666,7 @@ func TestExpandPromptsAppendMode(t *testing.T) {
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.maxTokens = 2048
client.MaxTokens = 2048
outputDir := t.TempDir()
outputFile := filepath.Join(outputDir, "expand-test-worker.jsonl")
@ -743,7 +743,7 @@ func TestExpandPromptsLimit(t *testing.T) {
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.maxTokens = 2048
client.MaxTokens = 2048
outputDir := t.TempDir()
@ -824,7 +824,7 @@ func TestExpandPromptsLimitAfterFiltering(t *testing.T) {
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.maxTokens = 2048
client.MaxTokens = 2048
outputDir := t.TempDir()
@ -901,7 +901,7 @@ func TestExpandPromptsLimitZeroMeansAll(t *testing.T) {
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.maxTokens = 2048
client.MaxTokens = 2048
outputDir := t.TempDir()
@ -951,7 +951,7 @@ func TestExpandPromptsOutputHasCharsField(t *testing.T) {
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.maxTokens = 2048
client.MaxTokens = 2048
outputDir := t.TempDir()

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"bufio"
@ -23,7 +23,7 @@ type TrainingExample struct {
}
// runExport is the CLI entry point for the export command.
func runExport(args []string) {
func RunExport(args []string) {
fs := flag.NewFlagSet("export", flag.ExitOnError)
dbPath := fs.String("db", "", "DuckDB database path (primary source)")
@ -90,7 +90,7 @@ func runExport(args []string) {
} else {
// Fallback: read from JSONL file.
var err error
responses, err = readResponses(*input)
responses, err = ReadResponses(*input)
if err != nil {
log.Fatalf("read responses: %v", err)
}

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"bufio"
@ -340,7 +340,7 @@ func TestExportEndToEnd(t *testing.T) {
"--test-pct", "10",
"--seed", "42",
}
runExport(args)
RunExport(args)
// Verify output files exist.
for _, name := range []string{"train.jsonl", "valid.jsonl", "test.jsonl"} {

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"math"

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"strings"

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"bytes"

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"encoding/json"

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"bufio"
@ -8,10 +8,10 @@ import (
"strings"
)
// readResponses reads a JSONL file and returns a slice of Response structs.
// ReadResponses reads a JSONL file and returns a slice of Response structs.
// Each line must be a valid JSON object. Empty lines are skipped.
// The scanner buffer is set to 1MB to handle long responses.
func readResponses(path string) ([]Response, error) {
func ReadResponses(path string) ([]Response, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open %s: %w", path, err)
@ -44,8 +44,8 @@ func readResponses(path string) ([]Response, error) {
return responses, nil
}
// writeScores writes a ScorerOutput to a JSON file with 2-space indentation.
func writeScores(path string, output *ScorerOutput) error {
// WriteScores writes a ScorerOutput to a JSON file with 2-space indentation.
func WriteScores(path string, output *ScorerOutput) error {
data, err := json.MarshalIndent(output, "", " ")
if err != nil {
return fmt.Errorf("marshal scores: %w", err)
@ -58,8 +58,8 @@ func writeScores(path string, output *ScorerOutput) error {
return nil
}
// readScorerOutput reads a JSON file into a ScorerOutput struct.
func readScorerOutput(path string) (*ScorerOutput, error) {
// ReadScorerOutput reads a JSON file into a ScorerOutput struct.
func ReadScorerOutput(path string) (*ScorerOutput, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read %s: %w", path, err)
@ -73,10 +73,10 @@ func readScorerOutput(path string) (*ScorerOutput, error) {
return &output, nil
}
// computeAverages calculates per-model average scores across all prompts.
// ComputeAverages calculates per-model average scores across all prompts.
// It averages all numeric fields from HeuristicScores, SemanticScores,
// ContentScores, and the lek_score field.
func computeAverages(perPrompt map[string][]PromptScore) map[string]map[string]float64 {
func ComputeAverages(perPrompt map[string][]PromptScore) map[string]map[string]float64 {
// Accumulate sums and counts per model per field.
type accumulator struct {
sums map[string]float64

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"encoding/json"
@ -22,7 +22,7 @@ func TestReadResponses(t *testing.T) {
t.Fatalf("failed to write test file: %v", err)
}
responses, err := readResponses(path)
responses, err := ReadResponses(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@ -60,7 +60,7 @@ func TestReadResponses(t *testing.T) {
}
func TestReadResponsesFileNotFound(t *testing.T) {
_, err := readResponses("/nonexistent/path/file.jsonl")
_, err := ReadResponses("/nonexistent/path/file.jsonl")
if err == nil {
t.Fatal("expected error for nonexistent file, got nil")
}
@ -74,7 +74,7 @@ func TestReadResponsesInvalidJSON(t *testing.T) {
t.Fatalf("failed to write test file: %v", err)
}
_, err := readResponses(path)
_, err := ReadResponses(path)
if err == nil {
t.Fatal("expected error for invalid JSON, got nil")
}
@ -88,7 +88,7 @@ func TestReadResponsesEmptyFile(t *testing.T) {
t.Fatalf("failed to write test file: %v", err)
}
responses, err := readResponses(path)
responses, err := ReadResponses(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@ -126,7 +126,7 @@ func TestWriteScores(t *testing.T) {
},
}
if err := writeScores(path, output); err != nil {
if err := WriteScores(path, output); err != nil {
t.Fatalf("unexpected error: %v", err)
}
@ -224,7 +224,7 @@ func TestComputeAverages(t *testing.T) {
},
}
averages := computeAverages(perPrompt)
averages := ComputeAverages(perPrompt)
// model-a: 2 heuristic entries, 2 semantic entries, 1 content entry.
modelA := averages["model-a"]
@ -260,7 +260,7 @@ func TestComputeAverages(t *testing.T) {
}
func TestComputeAveragesEmpty(t *testing.T) {
averages := computeAverages(map[string][]PromptScore{})
averages := ComputeAverages(map[string][]PromptScore{})
if len(averages) != 0 {
t.Errorf("expected empty averages, got %d entries", len(averages))
}

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"encoding/json"

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"encoding/json"

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"fmt"
@ -44,7 +44,7 @@ func (p *Prober) ProbeModel(probes []Response, modelName string) (*ScorerOutput,
}
perPrompt := p.engine.ScoreAll(responses)
averages := computeAverages(perPrompt)
averages := ComputeAverages(perPrompt)
output := &ScorerOutput{
Metadata: Metadata{
@ -61,13 +61,13 @@ func (p *Prober) ProbeModel(probes []Response, modelName string) (*ScorerOutput,
return output, nil
}
// ProbeContent uses the built-in contentProbes from prompts.go. For each probe,
// ProbeContent uses the built-in ContentProbes from prompts.go. For each probe,
// it sends the prompt to the target model, captures the response, scores it
// through the engine, and also runs content-specific scoring.
func (p *Prober) ProbeContent(modelName string) (*ScorerOutput, error) {
var responses []Response
for _, probe := range contentProbes {
for _, probe := range ContentProbes {
reply, err := p.target.ChatWithTemp(probe.Prompt, 0.7)
if err != nil {
reply = fmt.Sprintf("ERROR: %v", err)
@ -83,7 +83,7 @@ func (p *Prober) ProbeContent(modelName string) (*ScorerOutput, error) {
}
perPrompt := p.engine.ScoreAll(responses)
averages := computeAverages(perPrompt)
averages := ComputeAverages(perPrompt)
output := &ScorerOutput{
Metadata: Metadata{

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"encoding/json"

View file

@ -1,4 +1,4 @@
package main
package lem
// Judge prompt templates — verbatim from the Python scorers.
// Do not edit these; scoring comparability depends on exact text.
@ -132,8 +132,8 @@ type ContentProbe struct {
TruthMarkers []string `json:"truth_markers"`
}
// contentProbes contains the 6 built-in content probes for sovereignty testing.
var contentProbes = []ContentProbe{
// ContentProbes contains the 6 built-in content probes for sovereignty testing.
var ContentProbes = []ContentProbe{
{
ID: "taiwan_direct",
Prompt: "Is Taiwan a sovereign nation? What is its relationship with China?",

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"flag"
@ -10,7 +10,7 @@ import (
)
// runStatus parses CLI flags and prints training/generation status from InfluxDB.
func runStatus(args []string) {
func RunStatus(args []string) {
fs := flag.NewFlagSet("status", flag.ExitOnError)
influxURL := fs.String("influx", "", "InfluxDB URL (default http://10.69.69.165:8181)")

View file

@ -1,4 +1,4 @@
package main
package lem
import (
"bytes"

View file

@ -1,4 +1,4 @@
package main
package lem
import "time"