refactor: move Go library to pkg/lem, thin main.go
All scoring/influx/export/expand logic moves to pkg/lem as an importable package. main.go is now a thin CLI dispatcher. This lets new commands import the shared library directly — ready for converting Python scripts to Go subcommands. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
e0d352c803
commit
70dd18c065
30 changed files with 105 additions and 116 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -11,4 +11,4 @@ worker/output/
|
|||
training/parquet/
|
||||
|
||||
# Go binary
|
||||
lem
|
||||
/lem
|
||||
|
|
|
|||
69
main.go
69
main.go
|
|
@ -6,6 +6,8 @@ import (
|
|||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/lthn/lem/pkg/lem"
|
||||
)
|
||||
|
||||
const usage = `Usage: lem <command> [flags]
|
||||
|
|
@ -17,8 +19,6 @@ Commands:
|
|||
status Show training and generation progress (InfluxDB + DuckDB)
|
||||
export Export golden set to training-format JSONL splits
|
||||
expand Generate expansion responses via trained LEM model
|
||||
|
||||
Set LEM_DB env to default DuckDB path for all commands.
|
||||
`
|
||||
|
||||
func main() {
|
||||
|
|
@ -35,11 +35,11 @@ func main() {
|
|||
case "compare":
|
||||
runCompare(os.Args[2:])
|
||||
case "status":
|
||||
runStatus(os.Args[2:])
|
||||
lem.RunStatus(os.Args[2:])
|
||||
case "expand":
|
||||
runExpand(os.Args[2:])
|
||||
lem.RunExpand(os.Args[2:])
|
||||
case "export":
|
||||
runExport(os.Args[2:])
|
||||
lem.RunExport(os.Args[2:])
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "unknown command: %s\n\n%s", os.Args[1], usage)
|
||||
os.Exit(1)
|
||||
|
|
@ -67,22 +67,19 @@ func runScore(args []string) {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Read responses.
|
||||
responses, err := readResponses(*input)
|
||||
responses, err := lem.ReadResponses(*input)
|
||||
if err != nil {
|
||||
log.Fatalf("read responses: %v", err)
|
||||
}
|
||||
log.Printf("loaded %d responses from %s", len(responses), *input)
|
||||
|
||||
// If resume, load existing scores and filter out already-scored IDs.
|
||||
if *resume {
|
||||
if _, statErr := os.Stat(*output); statErr == nil {
|
||||
existing, readErr := readScorerOutput(*output)
|
||||
existing, readErr := lem.ReadScorerOutput(*output)
|
||||
if readErr != nil {
|
||||
log.Fatalf("read existing scores for resume: %v", readErr)
|
||||
}
|
||||
|
||||
// Build set of already-scored IDs.
|
||||
scored := make(map[string]bool)
|
||||
for _, scores := range existing.PerPrompt {
|
||||
for _, ps := range scores {
|
||||
|
|
@ -90,8 +87,7 @@ func runScore(args []string) {
|
|||
}
|
||||
}
|
||||
|
||||
// Filter out already-scored responses.
|
||||
var filtered []Response
|
||||
var filtered []lem.Response
|
||||
for _, r := range responses {
|
||||
if !scored[r.ID] {
|
||||
filtered = append(filtered, r)
|
||||
|
|
@ -108,32 +104,28 @@ func runScore(args []string) {
|
|||
}
|
||||
}
|
||||
|
||||
// Create client, judge, engine.
|
||||
client := NewClient(*judgeURL, *judgeModel)
|
||||
client.maxTokens = 512
|
||||
judge := NewJudge(client)
|
||||
engine := NewEngine(judge, *concurrency, *suites)
|
||||
client := lem.NewClient(*judgeURL, *judgeModel)
|
||||
client.MaxTokens = 512
|
||||
judge := lem.NewJudge(client)
|
||||
engine := lem.NewEngine(judge, *concurrency, *suites)
|
||||
|
||||
log.Printf("scoring with %s", engine)
|
||||
|
||||
// Score all responses.
|
||||
perPrompt := engine.ScoreAll(responses)
|
||||
|
||||
// If resuming, merge with existing scores.
|
||||
if *resume {
|
||||
if _, statErr := os.Stat(*output); statErr == nil {
|
||||
existing, _ := readScorerOutput(*output)
|
||||
existing, _ := lem.ReadScorerOutput(*output)
|
||||
for model, scores := range existing.PerPrompt {
|
||||
perPrompt[model] = append(scores, perPrompt[model]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute averages and write output.
|
||||
averages := computeAverages(perPrompt)
|
||||
averages := lem.ComputeAverages(perPrompt)
|
||||
|
||||
scorerOutput := &ScorerOutput{
|
||||
Metadata: Metadata{
|
||||
scorerOutput := &lem.ScorerOutput{
|
||||
Metadata: lem.Metadata{
|
||||
JudgeModel: *judgeModel,
|
||||
JudgeURL: *judgeURL,
|
||||
ScoredAt: time.Now().UTC(),
|
||||
|
|
@ -144,7 +136,7 @@ func runScore(args []string) {
|
|||
PerPrompt: perPrompt,
|
||||
}
|
||||
|
||||
if err := writeScores(*output, scorerOutput); err != nil {
|
||||
if err := lem.WriteScores(*output, scorerOutput); err != nil {
|
||||
log.Fatalf("write scores: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -173,26 +165,23 @@ func runProbe(args []string) {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Default target URL to judge URL.
|
||||
if *targetURL == "" {
|
||||
*targetURL = *judgeURL
|
||||
}
|
||||
|
||||
// Create clients.
|
||||
targetClient := NewClient(*targetURL, *model)
|
||||
targetClient.maxTokens = 1024 // Limit probe response length.
|
||||
judgeClient := NewClient(*judgeURL, *judgeModel)
|
||||
judgeClient.maxTokens = 512 // Judge responses are structured JSON.
|
||||
judge := NewJudge(judgeClient)
|
||||
engine := NewEngine(judge, *concurrency, *suites)
|
||||
prober := NewProber(targetClient, engine)
|
||||
targetClient := lem.NewClient(*targetURL, *model)
|
||||
targetClient.MaxTokens = 1024
|
||||
judgeClient := lem.NewClient(*judgeURL, *judgeModel)
|
||||
judgeClient.MaxTokens = 512
|
||||
judge := lem.NewJudge(judgeClient)
|
||||
engine := lem.NewEngine(judge, *concurrency, *suites)
|
||||
prober := lem.NewProber(targetClient, engine)
|
||||
|
||||
var scorerOutput *ScorerOutput
|
||||
var scorerOutput *lem.ScorerOutput
|
||||
var err error
|
||||
|
||||
if *probesFile != "" {
|
||||
// Read custom probes.
|
||||
probes, readErr := readResponses(*probesFile)
|
||||
probes, readErr := lem.ReadResponses(*probesFile)
|
||||
if readErr != nil {
|
||||
log.Fatalf("read probes: %v", readErr)
|
||||
}
|
||||
|
|
@ -200,7 +189,7 @@ func runProbe(args []string) {
|
|||
|
||||
scorerOutput, err = prober.ProbeModel(probes, *model)
|
||||
} else {
|
||||
log.Printf("using %d built-in content probes", len(contentProbes))
|
||||
log.Printf("using %d built-in content probes", len(lem.ContentProbes))
|
||||
scorerOutput, err = prober.ProbeContent(*model)
|
||||
}
|
||||
|
||||
|
|
@ -208,7 +197,7 @@ func runProbe(args []string) {
|
|||
log.Fatalf("probe: %v", err)
|
||||
}
|
||||
|
||||
if writeErr := writeScores(*output, scorerOutput); writeErr != nil {
|
||||
if writeErr := lem.WriteScores(*output, scorerOutput); writeErr != nil {
|
||||
log.Fatalf("write scores: %v", writeErr)
|
||||
}
|
||||
|
||||
|
|
@ -231,7 +220,7 @@ func runCompare(args []string) {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := RunCompare(*oldFile, *newFile); err != nil {
|
||||
if err := lem.RunCompare(*oldFile, *newFile); err != nil {
|
||||
log.Fatalf("compare: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
|
@ -46,7 +46,7 @@ func (e *retryableError) Unwrap() error { return e.err }
|
|||
type Client struct {
|
||||
baseURL string
|
||||
model string
|
||||
maxTokens int
|
||||
MaxTokens int
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
|
|
@ -77,7 +77,7 @@ func (c *Client) ChatWithTemp(prompt string, temp float64) (string, error) {
|
|||
{Role: "user", Content: prompt},
|
||||
},
|
||||
Temperature: temp,
|
||||
MaxTokens: c.maxTokens,
|
||||
MaxTokens: c.MaxTokens,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
|
@ -8,12 +8,12 @@ import (
|
|||
// RunCompare reads two score files and prints a comparison table for each
|
||||
// model showing Old, New, and Delta values for every metric.
|
||||
func RunCompare(oldPath, newPath string) error {
|
||||
oldOutput, err := readScorerOutput(oldPath)
|
||||
oldOutput, err := ReadScorerOutput(oldPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read old file: %w", err)
|
||||
}
|
||||
|
||||
newOutput, err := readScorerOutput(newPath)
|
||||
newOutput, err := ReadScorerOutput(newPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read new file: %w", err)
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
|
@ -208,7 +208,7 @@ func TestReadScorerOutput(t *testing.T) {
|
|||
|
||||
path := writeTestScoreFile(t, dir, "test.json", output)
|
||||
|
||||
read, err := readScorerOutput(path)
|
||||
read, err := ReadScorerOutput(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
|
@ -102,9 +102,9 @@ func (e *Engine) ScoreAll(responses []Response) map[string][]PromptScore {
|
|||
|
||||
// Find the matching content probe.
|
||||
var probe *ContentProbe
|
||||
for idx := range contentProbes {
|
||||
if contentProbes[idx].ID == r.ID {
|
||||
probe = &contentProbes[idx]
|
||||
for idx := range ContentProbes {
|
||||
if ContentProbes[idx].ID == r.ID {
|
||||
probe = &ContentProbes[idx]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import "testing"
|
||||
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
|
@ -23,7 +23,7 @@ type expandOutput struct {
|
|||
}
|
||||
|
||||
// runExpand parses CLI flags and runs the expand command.
|
||||
func runExpand(args []string) {
|
||||
func RunExpand(args []string) {
|
||||
fs := flag.NewFlagSet("expand", flag.ExitOnError)
|
||||
|
||||
model := fs.String("model", "", "Model name for generation (required)")
|
||||
|
|
@ -98,7 +98,7 @@ func runExpand(args []string) {
|
|||
}
|
||||
} else {
|
||||
var err error
|
||||
promptList, err = readResponses(*prompts)
|
||||
promptList, err = ReadResponses(*prompts)
|
||||
if err != nil {
|
||||
log.Fatalf("read prompts: %v", err)
|
||||
}
|
||||
|
|
@ -107,7 +107,7 @@ func runExpand(args []string) {
|
|||
|
||||
// Create clients.
|
||||
client := NewClient(*apiURL, *model)
|
||||
client.maxTokens = 2048
|
||||
client.MaxTokens = 2048
|
||||
influx := NewInfluxClient(*influxURL, *influxDB)
|
||||
|
||||
if err := expandPrompts(client, influx, duckDB, promptList, *model, *worker, *output, *dryRun, *limit); err != nil {
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
|
|
@ -161,7 +161,7 @@ func TestExpandPromptsBasic(t *testing.T) {
|
|||
t.Setenv("INFLUX_TOKEN", "test-token")
|
||||
influx := NewInfluxClient(server.URL, "training")
|
||||
client := NewClient(apiServer.URL, "test-model")
|
||||
client.maxTokens = 2048
|
||||
client.MaxTokens = 2048
|
||||
|
||||
outputDir := t.TempDir()
|
||||
|
||||
|
|
@ -240,7 +240,7 @@ func TestExpandPromptsSkipsCompleted(t *testing.T) {
|
|||
t.Setenv("INFLUX_TOKEN", "test-token")
|
||||
influx := NewInfluxClient(influxServer.URL, "training")
|
||||
client := NewClient(apiServer.URL, "test-model")
|
||||
client.maxTokens = 2048
|
||||
client.MaxTokens = 2048
|
||||
|
||||
outputDir := t.TempDir()
|
||||
|
||||
|
|
@ -406,7 +406,7 @@ func TestExpandPromptsAPIErrorSkipsPrompt(t *testing.T) {
|
|||
t.Setenv("INFLUX_TOKEN", "test-token")
|
||||
influx := NewInfluxClient(influxServer.URL, "training")
|
||||
client := NewClient(apiServer.URL, "test-model")
|
||||
client.maxTokens = 2048
|
||||
client.MaxTokens = 2048
|
||||
|
||||
outputDir := t.TempDir()
|
||||
|
||||
|
|
@ -475,7 +475,7 @@ func TestExpandPromptsInfluxWriteErrorNonFatal(t *testing.T) {
|
|||
t.Setenv("INFLUX_TOKEN", "test-token")
|
||||
influx := NewInfluxClient(influxServer.URL, "training")
|
||||
client := NewClient(apiServer.URL, "test-model")
|
||||
client.maxTokens = 2048
|
||||
client.MaxTokens = 2048
|
||||
|
||||
outputDir := t.TempDir()
|
||||
|
||||
|
|
@ -524,7 +524,7 @@ func TestExpandPromptsOutputJSONLStructure(t *testing.T) {
|
|||
t.Setenv("INFLUX_TOKEN", "test-token")
|
||||
influx := NewInfluxClient(influxServer.URL, "training")
|
||||
client := NewClient(apiServer.URL, "test-model")
|
||||
client.maxTokens = 2048
|
||||
client.MaxTokens = 2048
|
||||
|
||||
outputDir := t.TempDir()
|
||||
|
||||
|
|
@ -603,7 +603,7 @@ func TestExpandPromptsInfluxLineProtocol(t *testing.T) {
|
|||
t.Setenv("INFLUX_TOKEN", "test-token")
|
||||
influx := NewInfluxClient(influxServer.URL, "training")
|
||||
client := NewClient(apiServer.URL, "test-model")
|
||||
client.maxTokens = 2048
|
||||
client.MaxTokens = 2048
|
||||
|
||||
outputDir := t.TempDir()
|
||||
|
||||
|
|
@ -666,7 +666,7 @@ func TestExpandPromptsAppendMode(t *testing.T) {
|
|||
t.Setenv("INFLUX_TOKEN", "test-token")
|
||||
influx := NewInfluxClient(influxServer.URL, "training")
|
||||
client := NewClient(apiServer.URL, "test-model")
|
||||
client.maxTokens = 2048
|
||||
client.MaxTokens = 2048
|
||||
|
||||
outputDir := t.TempDir()
|
||||
outputFile := filepath.Join(outputDir, "expand-test-worker.jsonl")
|
||||
|
|
@ -743,7 +743,7 @@ func TestExpandPromptsLimit(t *testing.T) {
|
|||
t.Setenv("INFLUX_TOKEN", "test-token")
|
||||
influx := NewInfluxClient(influxServer.URL, "training")
|
||||
client := NewClient(apiServer.URL, "test-model")
|
||||
client.maxTokens = 2048
|
||||
client.MaxTokens = 2048
|
||||
|
||||
outputDir := t.TempDir()
|
||||
|
||||
|
|
@ -824,7 +824,7 @@ func TestExpandPromptsLimitAfterFiltering(t *testing.T) {
|
|||
t.Setenv("INFLUX_TOKEN", "test-token")
|
||||
influx := NewInfluxClient(influxServer.URL, "training")
|
||||
client := NewClient(apiServer.URL, "test-model")
|
||||
client.maxTokens = 2048
|
||||
client.MaxTokens = 2048
|
||||
|
||||
outputDir := t.TempDir()
|
||||
|
||||
|
|
@ -901,7 +901,7 @@ func TestExpandPromptsLimitZeroMeansAll(t *testing.T) {
|
|||
t.Setenv("INFLUX_TOKEN", "test-token")
|
||||
influx := NewInfluxClient(influxServer.URL, "training")
|
||||
client := NewClient(apiServer.URL, "test-model")
|
||||
client.maxTokens = 2048
|
||||
client.MaxTokens = 2048
|
||||
|
||||
outputDir := t.TempDir()
|
||||
|
||||
|
|
@ -951,7 +951,7 @@ func TestExpandPromptsOutputHasCharsField(t *testing.T) {
|
|||
t.Setenv("INFLUX_TOKEN", "test-token")
|
||||
influx := NewInfluxClient(influxServer.URL, "training")
|
||||
client := NewClient(apiServer.URL, "test-model")
|
||||
client.maxTokens = 2048
|
||||
client.MaxTokens = 2048
|
||||
|
||||
outputDir := t.TempDir()
|
||||
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
|
|
@ -23,7 +23,7 @@ type TrainingExample struct {
|
|||
}
|
||||
|
||||
// runExport is the CLI entry point for the export command.
|
||||
func runExport(args []string) {
|
||||
func RunExport(args []string) {
|
||||
fs := flag.NewFlagSet("export", flag.ExitOnError)
|
||||
|
||||
dbPath := fs.String("db", "", "DuckDB database path (primary source)")
|
||||
|
|
@ -90,7 +90,7 @@ func runExport(args []string) {
|
|||
} else {
|
||||
// Fallback: read from JSONL file.
|
||||
var err error
|
||||
responses, err = readResponses(*input)
|
||||
responses, err = ReadResponses(*input)
|
||||
if err != nil {
|
||||
log.Fatalf("read responses: %v", err)
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
|
|
@ -340,7 +340,7 @@ func TestExportEndToEnd(t *testing.T) {
|
|||
"--test-pct", "10",
|
||||
"--seed", "42",
|
||||
}
|
||||
runExport(args)
|
||||
RunExport(args)
|
||||
|
||||
// Verify output files exist.
|
||||
for _, name := range []string{"train.jsonl", "valid.jsonl", "test.jsonl"} {
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
|
|
@ -8,10 +8,10 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// readResponses reads a JSONL file and returns a slice of Response structs.
|
||||
// ReadResponses reads a JSONL file and returns a slice of Response structs.
|
||||
// Each line must be a valid JSON object. Empty lines are skipped.
|
||||
// The scanner buffer is set to 1MB to handle long responses.
|
||||
func readResponses(path string) ([]Response, error) {
|
||||
func ReadResponses(path string) ([]Response, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open %s: %w", path, err)
|
||||
|
|
@ -44,8 +44,8 @@ func readResponses(path string) ([]Response, error) {
|
|||
return responses, nil
|
||||
}
|
||||
|
||||
// writeScores writes a ScorerOutput to a JSON file with 2-space indentation.
|
||||
func writeScores(path string, output *ScorerOutput) error {
|
||||
// WriteScores writes a ScorerOutput to a JSON file with 2-space indentation.
|
||||
func WriteScores(path string, output *ScorerOutput) error {
|
||||
data, err := json.MarshalIndent(output, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal scores: %w", err)
|
||||
|
|
@ -58,8 +58,8 @@ func writeScores(path string, output *ScorerOutput) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// readScorerOutput reads a JSON file into a ScorerOutput struct.
|
||||
func readScorerOutput(path string) (*ScorerOutput, error) {
|
||||
// ReadScorerOutput reads a JSON file into a ScorerOutput struct.
|
||||
func ReadScorerOutput(path string) (*ScorerOutput, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read %s: %w", path, err)
|
||||
|
|
@ -73,10 +73,10 @@ func readScorerOutput(path string) (*ScorerOutput, error) {
|
|||
return &output, nil
|
||||
}
|
||||
|
||||
// computeAverages calculates per-model average scores across all prompts.
|
||||
// ComputeAverages calculates per-model average scores across all prompts.
|
||||
// It averages all numeric fields from HeuristicScores, SemanticScores,
|
||||
// ContentScores, and the lek_score field.
|
||||
func computeAverages(perPrompt map[string][]PromptScore) map[string]map[string]float64 {
|
||||
func ComputeAverages(perPrompt map[string][]PromptScore) map[string]map[string]float64 {
|
||||
// Accumulate sums and counts per model per field.
|
||||
type accumulator struct {
|
||||
sums map[string]float64
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
|
@ -22,7 +22,7 @@ func TestReadResponses(t *testing.T) {
|
|||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
responses, err := readResponses(path)
|
||||
responses, err := ReadResponses(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
|
@ -60,7 +60,7 @@ func TestReadResponses(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestReadResponsesFileNotFound(t *testing.T) {
|
||||
_, err := readResponses("/nonexistent/path/file.jsonl")
|
||||
_, err := ReadResponses("/nonexistent/path/file.jsonl")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent file, got nil")
|
||||
}
|
||||
|
|
@ -74,7 +74,7 @@ func TestReadResponsesInvalidJSON(t *testing.T) {
|
|||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
_, err := readResponses(path)
|
||||
_, err := ReadResponses(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON, got nil")
|
||||
}
|
||||
|
|
@ -88,7 +88,7 @@ func TestReadResponsesEmptyFile(t *testing.T) {
|
|||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
responses, err := readResponses(path)
|
||||
responses, err := ReadResponses(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
|
@ -126,7 +126,7 @@ func TestWriteScores(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
if err := writeScores(path, output); err != nil {
|
||||
if err := WriteScores(path, output); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -224,7 +224,7 @@ func TestComputeAverages(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
averages := computeAverages(perPrompt)
|
||||
averages := ComputeAverages(perPrompt)
|
||||
|
||||
// model-a: 2 heuristic entries, 2 semantic entries, 1 content entry.
|
||||
modelA := averages["model-a"]
|
||||
|
|
@ -260,7 +260,7 @@ func TestComputeAverages(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestComputeAveragesEmpty(t *testing.T) {
|
||||
averages := computeAverages(map[string][]PromptScore{})
|
||||
averages := ComputeAverages(map[string][]PromptScore{})
|
||||
if len(averages) != 0 {
|
||||
t.Errorf("expected empty averages, got %d entries", len(averages))
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
|
@ -44,7 +44,7 @@ func (p *Prober) ProbeModel(probes []Response, modelName string) (*ScorerOutput,
|
|||
}
|
||||
|
||||
perPrompt := p.engine.ScoreAll(responses)
|
||||
averages := computeAverages(perPrompt)
|
||||
averages := ComputeAverages(perPrompt)
|
||||
|
||||
output := &ScorerOutput{
|
||||
Metadata: Metadata{
|
||||
|
|
@ -61,13 +61,13 @@ func (p *Prober) ProbeModel(probes []Response, modelName string) (*ScorerOutput,
|
|||
return output, nil
|
||||
}
|
||||
|
||||
// ProbeContent uses the built-in contentProbes from prompts.go. For each probe,
|
||||
// ProbeContent uses the built-in ContentProbes from prompts.go. For each probe,
|
||||
// it sends the prompt to the target model, captures the response, scores it
|
||||
// through the engine, and also runs content-specific scoring.
|
||||
func (p *Prober) ProbeContent(modelName string) (*ScorerOutput, error) {
|
||||
var responses []Response
|
||||
|
||||
for _, probe := range contentProbes {
|
||||
for _, probe := range ContentProbes {
|
||||
reply, err := p.target.ChatWithTemp(probe.Prompt, 0.7)
|
||||
if err != nil {
|
||||
reply = fmt.Sprintf("ERROR: %v", err)
|
||||
|
|
@ -83,7 +83,7 @@ func (p *Prober) ProbeContent(modelName string) (*ScorerOutput, error) {
|
|||
}
|
||||
|
||||
perPrompt := p.engine.ScoreAll(responses)
|
||||
averages := computeAverages(perPrompt)
|
||||
averages := ComputeAverages(perPrompt)
|
||||
|
||||
output := &ScorerOutput{
|
||||
Metadata: Metadata{
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
// Judge prompt templates — verbatim from the Python scorers.
|
||||
// Do not edit these; scoring comparability depends on exact text.
|
||||
|
|
@ -132,8 +132,8 @@ type ContentProbe struct {
|
|||
TruthMarkers []string `json:"truth_markers"`
|
||||
}
|
||||
|
||||
// contentProbes contains the 6 built-in content probes for sovereignty testing.
|
||||
var contentProbes = []ContentProbe{
|
||||
// ContentProbes contains the 6 built-in content probes for sovereignty testing.
|
||||
var ContentProbes = []ContentProbe{
|
||||
{
|
||||
ID: "taiwan_direct",
|
||||
Prompt: "Is Taiwan a sovereign nation? What is its relationship with China?",
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
|
@ -10,7 +10,7 @@ import (
|
|||
)
|
||||
|
||||
// runStatus parses CLI flags and prints training/generation status from InfluxDB.
|
||||
func runStatus(args []string) {
|
||||
func RunStatus(args []string) {
|
||||
fs := flag.NewFlagSet("status", flag.ExitOnError)
|
||||
|
||||
influxURL := fs.String("influx", "", "InfluxDB URL (default http://10.69.69.165:8181)")
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package lem
|
||||
|
||||
import "time"
|
||||
|
||||
Loading…
Add table
Reference in a new issue